"vscode:/vscode.git/clone" did not exist on "6214dd6ce95cc2d00daa14b0db6ca0661cd83853"
Unverified Commit 99df8810 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

Add logic for block-scaled tensors with GEMM swizzled scales (#2486)



* Add general C API for setting tensor params
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Implement general accessors for NVTETensor
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Refactor tex swizzling to skip if scales are already swizzled
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add checks for non-swizzled scales in MXFP8 and NVFP4 kernels
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Support pre-swizzled scales in MXFP8Tensor
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add tex function to swizzle MXFP8 scales
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix bug in inplace swizzle function
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Tweak comments to use "compact/swizzled format"
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* MXFP8 quantize kernel with pre-swizzled scales
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Expose pre-swizzled scales in modules
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix bug in multi-swizzle
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Support MXFP8 gated activations with swizzled scales
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Add PyTorch infrastructure for pre-swizzled NVFP4 tensors
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Deprecate DSv3-specific quantization logic in C API
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Remove support for DSv3 compact data from quantizer
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Remove DSv3 compact data format from core lib
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix bug in FP8 all-gather
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix linter warnings
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Update JAX to use new swizzled scale API
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Review suggestion from @greptile-apps
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Review suggestions from @greptile-apps
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update C++ swizzle test with swizzled scales API
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Return default tensor params when querying params for invalid NVTETensor
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug DSv3 FP8 test failures
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug Userbuffers test failures
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Make sure gated activations populate FP8 transpose if needed
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Review suggestions from @greptile-apps
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Disable pre-swizzling with debug quantizer
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Review suggestion from @greptile-apps
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix merge conflicts and review suggestions

Update copyright years. Tweak comments. Fix various complaints from @greptile-apps.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Use explicitly sized types in config accessors

Miscellaneous review suggestions from @ptrendx.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Make util header for function that compute swizzled scale index
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Apply suggestions from @greptile-apps
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* Update expected error message in FP8 block-scaling test
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Review suggestion from @yaox12
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
parent a652730f
......@@ -739,6 +739,7 @@ void hadamard_transform_cast_fusion_columnwise(const Tensor &input_, Tensor &out
NVTE_CHECK(input_.dtype() == transformer_engine::DType::kBFloat16,
"Input tensor must be BF16 tensor, but dtype is ", to_string(input_.dtype()), ".");
NVTE_CHECK(input_.dim() >= 2, "Input must be a 2D tensor.");
NVTE_CHECK(!output_.with_gemm_swizzled_scales, "Output must have scales in compact format.");
const SimpleTensor &input = input_.data;
SimpleTensor &global_amax = output_.amax;
SimpleTensor &output_t = output_.data;
......
......@@ -59,8 +59,8 @@ NVTEMatmulConfig nvte_create_matmul_config();
*
* \param[in] config Matrix multiplication configuration.
* \param[in] attr Option type.
* \param[out] buf Memory address to write option value. Ignored if
* NULL.
* \param[out] buf Memory address to write option value to.
* Ignored if NULL.
* \param[in] size_in_bytes Size of buf.
* \param[out] size_written Number of bytes that have been written to
* buf. If buf is NULL, then the number of
......@@ -71,9 +71,9 @@ void nvte_get_matmul_config_attribute(NVTEMatmulConfig config, NVTEMatmulConfigA
/*! \brief Set an option in matrix multiplication configuration.
*
* \param[in] config Matrix multiplication configuration.
* \param[in/out] config Matrix multiplication configuration.
* \param[in] attr Option type.
* \param[out] buf Memory address to read option value.
* \param[in] buf Memory address to read option value from.
* \param[in] size_in_bytes Size of buf.
*/
void nvte_set_matmul_config_attribute(NVTEMatmulConfig config, NVTEMatmulConfigAttribute attr,
......@@ -296,14 +296,15 @@ class MatmulConfigWrapper {
/*! \brief Set whether to compute GELU in GEMM epilogue. */
void set_with_gelu_epilogue(bool with_gelu_epilogue) {
nvte_set_matmul_config_attribute(config_, kNVTEMatmulConfigWithGELUEpilogue,
&with_gelu_epilogue, sizeof(bool));
const auto val = static_cast<uint8_t>(with_gelu_epilogue);
nvte_set_matmul_config_attribute(config_, kNVTEMatmulConfigWithGELUEpilogue, &val, sizeof(val));
}
/*! \brief Set whether to compute GELU backward in GEMM epilogue. */
void set_with_dgelu_epilogue(bool with_dgelu_epilogue) {
nvte_set_matmul_config_attribute(config_, kNVTEMatmulConfigWithDGELUEpilogue,
&with_dgelu_epilogue, sizeof(bool));
const auto val = static_cast<uint8_t>(with_dgelu_epilogue);
nvte_set_matmul_config_attribute(config_, kNVTEMatmulConfigWithDGELUEpilogue, &val,
sizeof(val));
}
/*! \brief Set auxilliary tensor for GEMM epilogue. */
......@@ -314,13 +315,15 @@ class MatmulConfigWrapper {
/*! \brief Set whether to use split accumulator for FP8 GEMM. */
void set_use_split_accumulator(bool use_split_accumulator) {
nvte_set_matmul_config_attribute(config_, kNVTEMatmulConfigUseSplitAccumulator,
&use_split_accumulator, sizeof(bool));
const auto val = static_cast<uint8_t>(use_split_accumulator);
nvte_set_matmul_config_attribute(config_, kNVTEMatmulConfigUseSplitAccumulator, &val,
sizeof(val));
}
/*! \brief Set number of streaming multiprocessors to use in GEMM kernel. */
void set_sm_count(int sm_count) {
nvte_set_matmul_config_attribute(config_, kNVTEMatmulConfigSMCount, &sm_count, sizeof(int));
const auto val = static_cast<int32_t>(sm_count);
nvte_set_matmul_config_attribute(config_, kNVTEMatmulConfigSMCount, &val, sizeof(val));
}
private:
......
......@@ -4,8 +4,8 @@
* See LICENSE for license information.
************************************************************************/
/*! \file cast.h
* \brief Functions to cast to/from FP8.
/*! \file swizzle.h
* \brief Functions to convert scaling factors into format expected by GEMM.
*/
#ifndef TRANSFORMER_ENGINE_SWIZZLE_H_
......@@ -47,7 +47,7 @@ void nvte_multi_tensor_swizzle_scaling_factors(const NVTETensor* inputs, NVTETen
/*! \brief Swizzling FP8 block scaling scaling factors into mxfp8 interleaved layout for GEMM
*
* \param[in] input Input FP8 block scaling tensor with GEMM_READY scale_inv.
* \param[in] input Input FP8 block-scaled tensor.
* \param[in,out] output Output mxfp8 tensor which hosts swizzled scale_inv.
* \param[in] stream CUDA stream used for the operation.
*
......@@ -57,7 +57,6 @@ void nvte_multi_tensor_swizzle_scaling_factors(const NVTETensor* inputs, NVTETen
* Requirements:
* - input is an FP8 block scaling tensor
* - input has rowwise usage
* - input.scale_inv is in GEMM_READY format
* - output is an MXFP8 tensor
* - output has rowwise usage
* - output.scale_inv has appropriate shape
......
......@@ -13,6 +13,7 @@
#include <cuda_runtime_api.h>
#include <stddef.h>
#include <stdint.h>
#ifdef __cplusplus
extern "C" {
......@@ -70,6 +71,7 @@ enum NVTETensorParam {
kNVTERowwiseScaleInv = 4, /*!< Scale inverse tensor for decoding Rowwise Data */
kNVTEColumnwiseScaleInv = 5, /*!< Scale inverse tensor for decoding Columnwise Data */
kNVTEColumnwiseAmax = 6, /*!< Columnwise Amax tensor */
kNVTEWithGEMMSwizzledScales = 7, /*!< Whether scaling factors are in format expected by GEMM */
kNVTENumTensorParams
};
......@@ -266,6 +268,8 @@ NVTEShape nvte_tensor_scale_inv_shape(const NVTETensor tensor);
void nvte_zero_tensor(const NVTETensor tensor, cudaStream_t stream);
/*! \brief Set a parameter of the tensor.
*
* \warning Deprecated in favor of nvte_set_tensor_param_v2.
*
* \param[in/out] tensor Tensor.
* \param[in] param_name The parameter to be set.
......@@ -275,12 +279,38 @@ void nvte_set_tensor_param(NVTETensor *tensor, NVTETensorParam param_name,
const NVTEBasicTensor *param);
/*! \brief Get a value of the parameter of the tensor.
*
* \warning Deprecated in favor of nvte_set_tensor_param_v2.
*
* \param[in] tensor Tensor.
* \param[in] param_name The parameter to be set.
*/
NVTEBasicTensor nvte_get_tensor_param(const NVTETensor tensor, NVTETensorParam param_name);
/*! \brief Set a tensor parameter.
*
* \param[in/out] tensor Tensor.
* \param[in] param Tensor parameter type.
* \param[in] buf Memory address to read parameter value.
* \param[in] size_in_bytes Size of buf.
*/
void nvte_set_tensor_param_v2(NVTETensor tensor, NVTETensorParam param, const void *buf,
size_t size_in_bytes);
/*! \brief Query a tensor parameter.
*
* \param[in] tensor Tensor.
* \param[in] param Tensor parameter type.
* \param[out] buf Memory address to write parameter value.
* Ignored if NULL.
* \param[in] size_in_bytes Size of buf.
* \param[out] size_written Number of bytes that have been written to
* buf. If buf is NULL, then the number of
* bytes that would have been written.
*/
void nvte_get_tensor_param_v2(const NVTETensor tensor, NVTETensorParam param, void *buf,
size_t size_in_bytes, size_t *size_written);
/*! \brief Get the granularity of scaling of this tensor.
*
* \param[in] tensor Tensor.
......@@ -326,12 +356,7 @@ enum NVTEQuantizationConfigAttribute {
conditional early even when captured in a static CUDA graph.
*/
kNVTEQuantizationConfigNoopTensor = 2,
/*! Data format for an FP8 block-scaled tensor
*
* This is not the right design since the tensor format is a
* property of the tensor, not the quantization. This enum will
* likely be refactored away in the future.
*/
/*! \warning Deprecated */
kNVTEQuantizationConfigFloat8BlockScaleTensorFormat = 3,
/*! RNG state (NVTETensor with 2 elements - seed and offset */
kNVTEQuantizationConfigRNGState = 4,
......@@ -357,8 +382,8 @@ NVTEQuantizationConfig nvte_create_quantization_config();
*
* \param[in] config Quantization config.
* \param[in] attr Option type.
* \param[out] buf Memory address to write option value. Ignored if
* NULL.
* \param[out] buf Memory address to write option value.
* Ignored if NULL.
* \param[in] size_in_bytes Size of buf.
* \param[out] size_written Number of bytes that have been written to
* buf. If buf is NULL, then the number of
......@@ -370,9 +395,9 @@ void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config,
/*! \brief Set an option in quantization config.
*
* \param[in] config Quantization config.
* \param[in/out] config Quantization config.
* \param[in] attr Option type.
* \param[out] buf Memory address to read option value.
* \param[in] buf Memory address to read option value.
* \param[in] size_in_bytes Size of buf.
*/
void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config,
......@@ -589,20 +614,20 @@ class TensorWrapper {
const NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING) {
tensor_ = nvte_create_tensor(scaling_mode);
NVTEBasicTensor data = {dptr, static_cast<NVTEDType>(dtype), shape};
nvte_set_tensor_param(&tensor_, kNVTERowwiseData, &data);
nvte_set_tensor_param_v2(tensor_, kNVTERowwiseData, &data, sizeof(data));
NVTEBasicTensor amax = {amax_dptr, kNVTEFloat32,
amax_dptr != nullptr ? defaultShape : emptyShape};
nvte_set_tensor_param(&tensor_, kNVTEAmax, &amax);
nvte_set_tensor_param_v2(tensor_, kNVTEAmax, &amax, sizeof(amax));
NVTEBasicTensor scale = {scale_dptr, kNVTEFloat32,
scale_dptr != nullptr ? defaultShape : emptyShape};
nvte_set_tensor_param(&tensor_, kNVTEScale, &scale);
nvte_set_tensor_param_v2(tensor_, kNVTEScale, &scale, sizeof(scale));
if (scale_inv_dptr == nullptr && scale_inv_shape.ndim == defaultShape.ndim &&
scale_inv_shape.ndim == 1 && scale_inv_shape.data[0] == defaultShape.data[0]) {
// Scale-inv pointer has not been provided and shape matches default
scale_inv_shape = emptyShape;
}
NVTEBasicTensor scale_inv = {scale_inv_dptr, kNVTEFloat32, scale_inv_shape};
nvte_set_tensor_param(&tensor_, kNVTERowwiseScaleInv, &scale_inv);
nvte_set_tensor_param_v2(tensor_, kNVTERowwiseScaleInv, &scale_inv, sizeof(scale_inv));
}
/*! \brief Constructs new TensorWrapper.
......@@ -673,7 +698,7 @@ class TensorWrapper {
const ShapeType &shape) noexcept {
NVTEShape nvte_shape = this->convertShape(shape);
NVTEBasicTensor data = {dptr, static_cast<NVTEDType>(type), nvte_shape};
nvte_set_tensor_param(&tensor_, param, &data);
nvte_set_tensor_param_v2(tensor_, param, &data, sizeof(data));
return *this;
}
......@@ -712,10 +737,17 @@ class TensorWrapper {
return set_parameter(kNVTEColumnwiseAmax, dptr, type, shape);
}
void set_with_gemm_swizzled_scales(bool with_gemm_swizzled_scales) {
const auto val = static_cast<uint8_t>(with_gemm_swizzled_scales);
nvte_set_tensor_param_v2(tensor_, kNVTEWithGEMMSwizzledScales, &val, sizeof(val));
}
// Parameter getters
NVTEBasicTensor get_parameter(const NVTETensorParam param) const noexcept {
return nvte_get_tensor_param(tensor_, param);
NVTEBasicTensor ret;
nvte_get_tensor_param_v2(tensor_, param, &ret, sizeof(ret), nullptr);
return ret;
}
NVTEBasicTensor get_rowwise_data() const noexcept { return get_parameter(kNVTERowwiseData); }
......@@ -740,6 +772,12 @@ class TensorWrapper {
return get_parameter(kNVTEColumnwiseAmax);
}
bool get_with_gemm_swizzled_scales() const {
uint8_t val = 0;
nvte_get_tensor_param_v2(tensor_, kNVTEWithGEMMSwizzledScales, &val, sizeof(val), nullptr);
return static_cast<bool>(val);
}
/*! \brief Get an underlying NVTETensor.
*
* \return NVTETensor held by this TensorWrapper.
......@@ -919,15 +957,8 @@ class TensorWrapper {
NVTETensor tensor_ = nullptr;
};
/*! \enum Float8BlockScaleTensorFormat
* \brief Data format for an FP8 block-scaled tensor
*/
enum class Float8BlockScaleTensorFormat {
/*! FP8 data is transposed if needed and scales are swizzled */
GEMM_READY = 0,
/*! FP8 data is untransposed and scales are not swizzled or padded */
COMPACT = 1
};
/*! \warning Deprecated */
enum class Float8BlockScaleTensorFormat { GEMM_READY = 0, COMPACT = 1, INVALID };
/*! \struct QuantizationConfigWrapper
* \brief C++ wrapper for NVTEQuantizationConfigWrapper.
......@@ -968,8 +999,9 @@ class QuantizationConfigWrapper {
/*! \brief Set whether to force power of 2 scales */
void set_force_pow_2_scales(bool force_pow_2_scales) {
nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigForcePow2Scales,
&force_pow_2_scales, sizeof(bool));
const auto val = static_cast<uint8_t>(force_pow_2_scales);
nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigForcePow2Scales, &val,
sizeof(val));
}
/*! \brief Set small value to add to amax */
......@@ -984,12 +1016,8 @@ class QuantizationConfigWrapper {
sizeof(NVTETensor));
}
/*! \brief Set FP8 block-scaled tensor format */
void set_float8_block_scale_tensor_format(Float8BlockScaleTensorFormat format) {
nvte_set_quantization_config_attribute(config_,
kNVTEQuantizationConfigFloat8BlockScaleTensorFormat,
&format, sizeof(Float8BlockScaleTensorFormat));
}
/*! \warning Deprecated */
void set_float8_block_scale_tensor_format(Float8BlockScaleTensorFormat format) {}
/*! \brief Set stochastic rounding state */
void set_rng_state(NVTETensor rng_state) {
......@@ -999,20 +1027,23 @@ class QuantizationConfigWrapper {
/*! \brief Set whether to use 2D block scaling for NVFP4 */
void set_nvfp4_2d_quantization(bool nvfp4_2d_quantization) {
const auto val = static_cast<uint8_t>(nvfp4_2d_quantization);
nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigNVFP42DQuantization,
&nvfp4_2d_quantization, sizeof(bool));
&val, sizeof(val));
}
/*! \brief Set whether to use stochastic rounding */
void set_stochastic_rounding(bool stochastic_rounding) {
nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigStochasticRounding,
&stochastic_rounding, sizeof(bool));
const auto val = static_cast<uint8_t>(stochastic_rounding);
nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigStochasticRounding, &val,
sizeof(val));
}
/*! \brief Set whether to enable fast math operations */
void set_use_fast_math(bool use_fast_math) {
nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigUseFastMath,
&use_fast_math, sizeof(bool));
const auto val = static_cast<uint8_t>(use_fast_math);
nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigUseFastMath, &val,
sizeof(val));
}
private:
......
......@@ -27,10 +27,15 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
const float epsilon, Tensor* z, Tensor* mu, Tensor* rsigma, Tensor* workspace,
const int multiprocessorCount, const bool zero_centered_gamma,
cudaStream_t stream) {
// Check for unsupported configurations
if (is_fp8_dtype(z->data.dtype) && !is_delayed_tensor_scaling(z->scaling_mode) &&
!is_mxfp8_scaling(z->scaling_mode)) {
NVTE_ERROR("Not implemented scaling mode: " + to_string(z->scaling_mode) + ".");
}
if (is_mxfp8_scaling(z->scaling_mode)) {
NVTE_CHECK(!z->with_gemm_swizzled_scales,
"MXFP8 output must have scales in compact format, not swizzled for GEMM.");
}
NVTE_CHECK(x.data.shape.size() == 2, "x must be 2D tensor.");
NVTE_CHECK(gamma.data.shape == beta.data.shape, "Gamma and Beta must have the same shape.");
......
......@@ -23,10 +23,15 @@ using namespace normalization;
void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tensor *z,
Tensor *rsigma, Tensor *workspace, const int multiprocessorCount,
const bool zero_centered_gamma, cudaStream_t stream) {
// Check for unsupported configurations
if (is_fp8_dtype(z->data.dtype) && !is_delayed_tensor_scaling(z->scaling_mode) &&
!is_mxfp8_scaling(z->scaling_mode)) {
NVTE_ERROR("Not implemented scaling mode: " + to_string(z->scaling_mode) + ".");
}
if (is_mxfp8_scaling(z->scaling_mode)) {
NVTE_CHECK(!z->with_gemm_swizzled_scales,
"MXFP8 output must have scales in compact format, not swizzled for GEMM.");
}
NVTE_CHECK(x.data.shape.size() == 2, "x must be 2D tensor.");
......
......@@ -110,7 +110,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
OType *output_rowwise_minus_offset = output_rowwise - start_offset;
OType *output_colwise_minus_offset = output_colwise - start_offset;
int warp_idx = threadIdx.x / 32;
int lane_idx = threadIdx.x % 32;
// int lane_idx = threadIdx.x % 32;
int c = blockIdx.x * kColsPerTile + threadIdx.x;
int r = blockIdx.y * kRowsPerTile;
......
......@@ -340,6 +340,10 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
// Check tensors
CheckInputTensor(*input, "scaling_factor_input");
CheckInputTensor(*output, "scaling_factor_output");
NVTE_CHECK(!input->with_gemm_swizzled_scales,
"Expected input tensor with scales in compact format.");
NVTE_CHECK(output->with_gemm_swizzled_scales,
"Expected output tensor with scales in GEMM swizzled format.");
switch (scaling_mode) {
case NVTE_MXFP8_1D_SCALING:
NVTE_CHECK(is_fp8_dtype(input->dtype()), "Input tensor has invalid dtype (expected FP8, got ",
......@@ -656,6 +660,11 @@ void multi_tensor_swizzle_scaling_factors(const std::vector<Tensor*>& input,
NVTE_CHECK(
(is_fp8 && is_mxfp8_scaling(scaling_mode)) || (is_fp4 && is_nvfp4_scaling(scaling_mode)),
"Not implemented scaling mode " + to_string(scaling_mode) + ".");
NVTE_CHECK(!input[i]->with_gemm_swizzled_scales,
"Expected input tensors with scales in compact format.");
NVTE_CHECK(output[i]->with_gemm_swizzled_scales,
"Expected output tensors with scales in GEMM swizzled format.");
// We don't allow empty tensors. They should be filtered out before calling this function.
NVTE_CHECK(input[i]->numel() != 0, "Tensor input[", i, "] is empty.");
CheckInputTensor(*input[i], "scaling_factor_input[" + std::to_string(i) + "]");
......
......@@ -98,7 +98,8 @@ void __global__ __launch_bounds__(WARPS_X_PER_TB* WARPS_Y_PER_TB* WARP_SIZE)
// calculate this warp's input base pointer
constexpr uint32_t in_x_stride = WARP_SIZE * sizeof(uint4);
const void* const warp_src = in + in_tile_y * in_y_stride + in_tile_x * in_x_stride;
const void* const warp_src =
(reinterpret_cast<const uint8_t*>(in) + in_tile_y * in_y_stride + in_tile_x * in_x_stride);
// load scaling factors for this lane's initial four 1x128 tiles
uint4 sf;
......@@ -129,7 +130,8 @@ void __global__ __launch_bounds__(WARPS_X_PER_TB* WARPS_Y_PER_TB* WARP_SIZE)
// store them cooperatively for 512 1x32 tiles in a 128x128 tile
constexpr uint32_t out_x_stride = 512;
void* const warp_dst = out + out_tile_y * out_y_stride + out_tile_x * out_x_stride;
void* const warp_dst =
(reinterpret_cast<uint8_t*>(out) + out_tile_y * out_y_stride + out_tile_x * out_x_stride);
reinterpret_cast<uint4*>(warp_dst)[lane] = sf;
}
......@@ -193,7 +195,8 @@ void __global__ __launch_bounds__(WARPS_X_PER_TB* WARPS_Y_PER_TB* WARP_SIZE)
// calculate this warp's input base pointer
constexpr uint32_t in_x_stride = sizeof(float);
const void* const warp_src = in + in_tile_y * in_y_stride + in_tile_x * in_x_stride;
const void* const warp_src =
(reinterpret_cast<const uint8_t*>(in) + in_tile_y * in_y_stride + in_tile_x * in_x_stride);
// load scaling factor for this warp's 128x128 tile
uint32_t sf = *reinterpret_cast<const uint32_t*>(warp_src);
......@@ -208,7 +211,8 @@ void __global__ __launch_bounds__(WARPS_X_PER_TB* WARPS_Y_PER_TB* WARP_SIZE)
// store it cooperatively for 512 1x32 tiles in a 128x128 tile
constexpr uint32_t out_x_stride = 512;
void* const warp_dst = out + out_tile_y * out_y_stride + out_tile_x * out_x_stride;
void* const warp_dst =
(reinterpret_cast<uint8_t*>(out) + out_tile_y * out_y_stride + out_tile_x * out_x_stride);
reinterpret_cast<uint4*>(warp_dst)[lane] = sf4;
}
......@@ -261,6 +265,9 @@ void swizzle_block_scaling_to_mxfp8_scaling_factors(const Tensor* input, Tensor*
NVTE_CHECK(output->scale_inv.dtype == DType::kFloat8E8M0,
"Output must have E8M0 scaling factors");
NVTE_CHECK(output->with_gemm_swizzled_scales,
"Expected output tensor with scales in GEMM swizzled format.");
NVTE_CHECK(input->data.dptr != nullptr, "Input must have rowwise data");
NVTE_CHECK(output->data.dptr == input->data.dptr, "Output must share data with input");
NVTE_CHECK(input->scale_inv.dptr != nullptr, "Input must have rowwise scaling factors");
......
......@@ -6,12 +6,16 @@
#include <transformer_engine/transformer_engine.h>
#include <algorithm>
#include <atomic>
#include <climits>
#include <cstring>
#include <iostream>
#include <mutex>
#include <optional>
#include <string>
#include <utility>
#include <vector>
#include "common.h"
#include "common/util/cuda_runtime.h"
......@@ -778,7 +782,8 @@ void nvte_set_tensor_param(NVTETensor *tensor, NVTETensorParam param_name,
t->columnwise_amax = *param;
break;
default:
NVTE_ERROR("Unknown tensor parameter!");
NVTE_ERROR("Unsupported tensor parameter (", static_cast<int>(param_name),
"). Consider using nvte_set_tensor_param_v2 instead.");
}
}
......@@ -803,7 +808,148 @@ NVTEBasicTensor nvte_get_tensor_param(const NVTETensor tensor, NVTETensorParam p
case kNVTEColumnwiseAmax:
return t.columnwise_amax;
default:
NVTE_ERROR("Unknown tensor parameter!");
NVTE_ERROR("Unsupported tensor parameter (", static_cast<int>(param_name),
"). Consider using nvte_set_tensor_param_v2 instead.");
}
}
void nvte_set_tensor_param_v2(NVTETensor tensor, NVTETensorParam param, const void *buf,
size_t size_in_bytes) {
// Check attribute and buffer
NVTE_CHECK(param < kNVTENumTensorParams, "Invalid NVTETensorParam (got ", static_cast<int>(param),
")");
NVTE_CHECK(tensor != nullptr, "Tensor pointer can't be NULL.");
auto &t = *transformer_engine::convertNVTETensorCheck(tensor);
const auto &attr_size = transformer_engine::Tensor::attr_sizes[param];
NVTE_CHECK(size_in_bytes >= attr_size,
"Buffer is too small for tensor parameter "
"(parameter ",
static_cast<int>(param), " needs ", attr_size, " bytes, but buffer has ",
size_in_bytes, " bytes)");
NVTE_CHECK(buf != nullptr, "Invalid buffer (got NULL)");
// Read from buffer
switch (param) {
case kNVTERowwiseData: {
const NVTEBasicTensor *basic_tensor = reinterpret_cast<const NVTEBasicTensor *>(buf);
t.data = *basic_tensor;
break;
}
case kNVTEColumnwiseData: {
const NVTEBasicTensor *basic_tensor = reinterpret_cast<const NVTEBasicTensor *>(buf);
t.columnwise_data = *basic_tensor;
break;
}
case kNVTEScale: {
const NVTEBasicTensor *basic_tensor = reinterpret_cast<const NVTEBasicTensor *>(buf);
t.scale = *basic_tensor;
break;
}
case kNVTEAmax: {
const NVTEBasicTensor *basic_tensor = reinterpret_cast<const NVTEBasicTensor *>(buf);
t.amax = *basic_tensor;
break;
}
case kNVTERowwiseScaleInv: {
const NVTEBasicTensor *basic_tensor = reinterpret_cast<const NVTEBasicTensor *>(buf);
t.scale_inv = *basic_tensor;
break;
}
case kNVTEColumnwiseScaleInv: {
const NVTEBasicTensor *basic_tensor = reinterpret_cast<const NVTEBasicTensor *>(buf);
t.columnwise_scale_inv = *basic_tensor;
break;
}
case kNVTEColumnwiseAmax: {
const NVTEBasicTensor *basic_tensor = reinterpret_cast<const NVTEBasicTensor *>(buf);
t.columnwise_amax = *basic_tensor;
break;
}
case kNVTEWithGEMMSwizzledScales:
t.with_gemm_swizzled_scales = static_cast<bool>(*reinterpret_cast<const uint8_t *>(buf));
break;
default:
NVTE_ERROR("Unsupported tensor parameter (", static_cast<int>(param), ")");
}
}
void nvte_get_tensor_param_v2(const NVTETensor tensor, NVTETensorParam param, void *buf,
size_t size_in_bytes, size_t *size_written) {
using namespace transformer_engine;
// Check param
NVTE_CHECK(param < kNVTENumTensorParams, "Invalid NVTETensorParam (got ", static_cast<int>(param),
")");
// Write attribute size if provided
const auto &attr_size = Tensor::attr_sizes[param];
if (size_written != nullptr) {
*size_written = attr_size;
}
// Return immediately if buffer is not provided
if (buf == nullptr) {
return;
}
// Check buffer size
NVTE_CHECK(size_in_bytes >= attr_size,
"Buffer is too small for tensor parameter "
"(parameter ",
static_cast<int>(param), " needs ", attr_size, " bytes, but buffer has ",
size_in_bytes, " bytes)");
// Get C++ tensor
const Tensor *t = convertNVTETensor(tensor);
std::optional<Tensor> dummy;
if (t == nullptr) {
// Make dummy tensor if provided tensor is invalid
dummy.emplace();
t = &(*dummy);
}
// Write to buffer
switch (param) {
case kNVTERowwiseData: {
NVTEBasicTensor *basic_tensor = reinterpret_cast<NVTEBasicTensor *>(buf);
*basic_tensor = static_cast<NVTEBasicTensor>(t->data);
break;
}
case kNVTEColumnwiseData: {
NVTEBasicTensor *basic_tensor = reinterpret_cast<NVTEBasicTensor *>(buf);
*basic_tensor = static_cast<NVTEBasicTensor>(t->columnwise_data);
break;
}
case kNVTEScale: {
NVTEBasicTensor *basic_tensor = reinterpret_cast<NVTEBasicTensor *>(buf);
*basic_tensor = static_cast<NVTEBasicTensor>(t->scale);
break;
}
case kNVTEAmax: {
NVTEBasicTensor *basic_tensor = reinterpret_cast<NVTEBasicTensor *>(buf);
*basic_tensor = static_cast<NVTEBasicTensor>(t->amax);
break;
}
case kNVTERowwiseScaleInv: {
NVTEBasicTensor *basic_tensor = reinterpret_cast<NVTEBasicTensor *>(buf);
*basic_tensor = static_cast<NVTEBasicTensor>(t->scale_inv);
break;
}
case kNVTEColumnwiseScaleInv: {
NVTEBasicTensor *basic_tensor = reinterpret_cast<NVTEBasicTensor *>(buf);
*basic_tensor = static_cast<NVTEBasicTensor>(t->columnwise_scale_inv);
break;
}
case kNVTEColumnwiseAmax: {
NVTEBasicTensor *basic_tensor = reinterpret_cast<NVTEBasicTensor *>(buf);
*basic_tensor = static_cast<NVTEBasicTensor>(t->columnwise_amax);
break;
}
case kNVTEWithGEMMSwizzledScales:
*reinterpret_cast<uint8_t *>(buf) = static_cast<uint8_t>(t->with_gemm_swizzled_scales);
break;
default:
NVTE_ERROR("Unsupported tensor parameter (", static_cast<int>(param), ")");
}
}
......@@ -854,10 +1000,12 @@ NVTEQuantizationConfig nvte_create_quantization_config() {
void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config,
NVTEQuantizationConfigAttribute attr, void *buf,
size_t size_in_bytes, size_t *size_written) {
using namespace transformer_engine;
// Write attribute size
NVTE_CHECK(attr < kNVTEQuantizationConfigNumAttributes,
"Invalid NVTEQuantizationConfigAttribute (got ", static_cast<int>(attr), ")");
const auto &attr_size = transformer_engine::QuantizationConfig::attr_sizes[attr];
const auto &attr_size = QuantizationConfig::attr_sizes[attr];
if (size_written != nullptr) {
*size_written = attr_size;
}
......@@ -874,12 +1022,18 @@ void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config,
static_cast<int>(attr), " needs ", attr_size, " bytes, but buffer has ", size_in_bytes,
" bytes)");
// bool size is implementation-dependent, so we explicitly specify
// uint8_t in the user-facing API.
auto bool_to_uint8 = [](bool in, void *out) {
*reinterpret_cast<uint8_t *>(out) = static_cast<uint8_t>(in);
};
// Write to buffer
NVTE_CHECK(config != nullptr, "Invalid NVTEQuantizationConfig (got NULL)");
const auto &config_ = *reinterpret_cast<const transformer_engine::QuantizationConfig *>(config);
const auto &config_ = *reinterpret_cast<const QuantizationConfig *>(config);
switch (attr) {
case kNVTEQuantizationConfigForcePow2Scales:
std::memcpy(buf, &config_.force_pow_2_scales, attr_size);
bool_to_uint8(config_.force_pow_2_scales, buf);
break;
case kNVTEQuantizationConfigAmaxEpsilon:
std::memcpy(buf, &config_.amax_epsilon, attr_size);
......@@ -887,20 +1041,23 @@ void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config,
case kNVTEQuantizationConfigNoopTensor:
std::memcpy(buf, &config_.noop_tensor, attr_size);
break;
case kNVTEQuantizationConfigFloat8BlockScaleTensorFormat:
std::memcpy(buf, &config_.float8_block_scale_tensor_format, attr_size);
case kNVTEQuantizationConfigFloat8BlockScaleTensorFormat: {
// Deprecated
const auto invalid = Float8BlockScaleTensorFormat::INVALID;
std::memcpy(buf, &invalid, attr_size);
break;
}
case kNVTEQuantizationConfigRNGState:
std::memcpy(buf, &config_.rng_state, attr_size);
break;
case kNVTEQuantizationConfigNVFP42DQuantization:
std::memcpy(buf, &config_.nvfp4_2d_quantization, attr_size);
bool_to_uint8(config_.nvfp4_2d_quantization, buf);
break;
case kNVTEQuantizationConfigStochasticRounding:
std::memcpy(buf, &config_.stochastic_rounding, attr_size);
bool_to_uint8(config_.stochastic_rounding, buf);
break;
case kNVTEQuantizationConfigUseFastMath:
std::memcpy(buf, &config_.use_fast_math, attr_size);
bool_to_uint8(config_.use_fast_math, buf);
break;
default:
NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast<int>(attr), ")");
......@@ -910,10 +1067,12 @@ void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config,
void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config,
NVTEQuantizationConfigAttribute attr, const void *buf,
size_t size_in_bytes) {
using namespace transformer_engine;
// Check attribute and buffer
NVTE_CHECK(attr < kNVTEQuantizationConfigNumAttributes,
"Invalid NVTEQuantizationConfigAttribute (got ", static_cast<int>(attr), ")");
const auto &attr_size = transformer_engine::QuantizationConfig::attr_sizes[attr];
const auto &attr_size = QuantizationConfig::attr_sizes[attr];
NVTE_CHECK(size_in_bytes >= attr_size,
"Buffer is too small for quantization config attribute "
"(attribute ",
......@@ -921,12 +1080,18 @@ void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config,
" bytes)");
NVTE_CHECK(buf != nullptr, "Invalid buffer (got NULL)");
// bool size is implementation-dependent, so we explicitly specify
// uint8_t in the user-facing API.
auto uint8_to_bool = [](const void *in, bool &out) {
out = static_cast<bool>(*reinterpret_cast<const uint8_t *>(in));
};
// Read from buffer
NVTE_CHECK(config != nullptr, "Invalid NVTEQuantizationConfig (got NULL)");
auto &config_ = *reinterpret_cast<transformer_engine::QuantizationConfig *>(config);
auto &config_ = *reinterpret_cast<QuantizationConfig *>(config);
switch (attr) {
case kNVTEQuantizationConfigForcePow2Scales:
std::memcpy(&config_.force_pow_2_scales, buf, attr_size);
uint8_to_bool(buf, config_.force_pow_2_scales);
break;
case kNVTEQuantizationConfigAmaxEpsilon:
std::memcpy(&config_.amax_epsilon, buf, attr_size);
......@@ -935,19 +1100,19 @@ void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config,
std::memcpy(&config_.noop_tensor, buf, attr_size);
break;
case kNVTEQuantizationConfigFloat8BlockScaleTensorFormat:
std::memcpy(&config_.float8_block_scale_tensor_format, buf, attr_size);
// Deprecated
break;
case kNVTEQuantizationConfigRNGState:
std::memcpy(&config_.rng_state, buf, attr_size);
break;
case kNVTEQuantizationConfigNVFP42DQuantization:
std::memcpy(&config_.nvfp4_2d_quantization, buf, attr_size);
uint8_to_bool(buf, config_.nvfp4_2d_quantization);
break;
case kNVTEQuantizationConfigStochasticRounding:
std::memcpy(&config_.stochastic_rounding, buf, attr_size);
uint8_to_bool(buf, config_.stochastic_rounding);
break;
case kNVTEQuantizationConfigUseFastMath:
std::memcpy(&config_.use_fast_math, buf, attr_size);
uint8_to_bool(buf, config_.use_fast_math);
break;
default:
NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast<int>(attr), ")");
......
......@@ -36,7 +36,7 @@ enum class FP8BlockwiseRowwiseOption {
NONE,
// Rowwise data, scales in GEMM format
ROWWISE_GEMM_READY,
// Rowwise data, scales in compact format, needs extra processing (padding, transposing) before GEMM
// Deprecated
ROWWISE_COMPACT
};
......@@ -50,8 +50,7 @@ enum class FP8BlockwiseColumnwiseOption {
// On Hopper sm90, GEMM_READY means that columnwise quantization also fuses transpose op
// On higher sm versions with TN,NT,NN fp8 gemm, GEMM_READY doesn't fuse transpose
COLUMNWISE_GEMM_READY,
// Columnwise data in original shape
// Scales in compact format, needs extra processing (padding, transposing) before GEMM
// Deprecated
COLUMNWISE_COMPACT
};
......
......@@ -492,7 +492,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
}
NVTE_CHECK(input.shape == output.shape, "Input and output must have the same shape.");
const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u;
const size_t row_length = input.shape.size() > 0 ? input.shape.back() : 1;
size_t num_rows = 1;
for (size_t i = 0; (i < input.shape.size() - 1) && (input.shape.size() > 0); ++i) {
num_rows *= input.shape.at(i);
......@@ -511,12 +511,14 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
const float* noop_ptr = reinterpret_cast<const float*>(noop_tensor.dptr);
if (return_transpose) {
NVTE_CHECK(output_t.shape.size() == input.shape.size(),
"output_t must have same number of dimensions as input.");
NVTE_CHECK(output_t.shape.size() == input.shape.size(), "input (shape=", input.shape,
") and output_t (shape=", output_t.shape, ") have incompatible dims.");
if (output_t.shape.size() > 0) {
NVTE_CHECK(output_t.shape[0] == row_length, "Wrong dimension 0 of output_t.");
NVTE_CHECK(output_t.shape.front() == input.shape.back(), "input (shape=", input.shape,
") and output_t (shape=", output_t.shape, ") have incompatible dims.");
for (size_t i = 1; i < output_t.shape.size(); ++i) {
NVTE_CHECK(output_t.shape.at(i) == input.shape.at(i - 1), "Wrong dimension in output_t");
NVTE_CHECK(output_t.shape[i] == input.shape[i - 1], "input (shape=", input.shape,
") and output_t (shape=", output_t.shape, ") have incompatible dims.");
}
}
NVTE_CHECK(output.dtype == output_t.dtype, "output and output_t need to have the same type.");
......
......@@ -14,8 +14,10 @@
#include "../util/rtc.h"
#include "../util/string.h"
#include "../utils.cuh"
#include "./transpose.h"
namespace transformer_engine {
namespace detail {
namespace {
......@@ -203,7 +205,8 @@ void transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cudaStr
NVTE_CHECK(input.data.dptr != nullptr, "Input is not allocated.");
NVTE_CHECK(output.data.dptr != nullptr, "Output is not allocated.");
NVTE_CHECK(input.data.dtype == output.data.dtype, "Input and output type must match.");
NVTE_CHECK(input.data.dtype == output.data.dtype, "Input (dtype=", to_string(input.data.dtype),
") and output (dtype=", to_string(output.data.dtype), ") do not match.");
if (noop.data.dptr != nullptr) {
NVTE_CHECK(noop.numel() == 1, "Expected 1 element, ", "but found ", noop.numel(), ".");
......@@ -283,19 +286,20 @@ void transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cudaStr
}); // NOLINT(*)
}
} // namespace detail
} // namespace transformer_engine
void nvte_transpose(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_transpose);
using namespace transformer_engine;
auto noop = Tensor();
transpose(*convertNVTETensorCheck(input), noop, convertNVTETensor(output), stream);
detail::transpose(*convertNVTETensorCheck(input), noop, convertNVTETensor(output), stream);
}
void nvte_transpose_with_noop(const NVTETensor input, const NVTETensor noop, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_transpose_with_noop);
using namespace transformer_engine;
transpose(*convertNVTETensorCheck(input), *convertNVTETensorCheck(noop),
detail::transpose(*convertNVTETensorCheck(input), *convertNVTETensorCheck(noop),
convertNVTETensor(output), stream);
}
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_COMMON_TRANSPOSE_TRANSPOSE_H_
#define TRANSFORMER_ENGINE_COMMON_TRANSPOSE_TRANSPOSE_H_
#include "../common.h"
namespace transformer_engine {
namespace detail {
void transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cudaStream_t stream);
} // namespace detail
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_COMMON_TRANSPOSE_TRANSPOSE_H_
......@@ -840,6 +840,7 @@ __device__ __forceinline__ int32_t elect_one_sync(uint32_t mask = 0xFFFFFFFFu) {
return pred;
#else
NVTE_DEVICE_ERROR("elect_one_sync is only supported on SM 10.0+.");
return 0;
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
......@@ -891,6 +892,7 @@ __device__ __forceinline__ bf16 get_amax(bf16 a, bf16 b) {
return r;
#else
NVTE_DEVICE_ERROR("get_amax is only supported on SM 10.0+.");
return 0.f;
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
......@@ -903,6 +905,7 @@ __device__ __forceinline__ fp16 get_amax(fp16 a, fp16 b) {
return r;
#else
NVTE_DEVICE_ERROR("get_amax is only supported on SM 10.0+.");
return 0.f;
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
......
......@@ -83,7 +83,8 @@
pybind11::enum_<transformer_engine::Float8BlockScaleTensorFormat>( \
m, "Float8BlockScaleTensorFormat", pybind11::module_local()) \
.value("GEMM_READY", transformer_engine::Float8BlockScaleTensorFormat::GEMM_READY) \
.value("COMPACT", transformer_engine::Float8BlockScaleTensorFormat::COMPACT); \
.value("COMPACT", transformer_engine::Float8BlockScaleTensorFormat::COMPACT) \
.value("INVALID", transformer_engine::Float8BlockScaleTensorFormat::INVALID); \
pybind11::enum_<transformer_engine::CommOverlapType>(m, "CommOverlapType", \
pybind11::module_local()) \
.value("RS", transformer_engine::CommOverlapType::RS) \
......
......@@ -62,12 +62,17 @@ class DebugQuantizer(Quantizer):
self.tp_group = tp_group # used in inspect_tensor calls
self.iteration = TEDebugState.get_iteration()
# Configure parent quantizer
if parent_quantizer is not None:
# .internal = True is slightly faster, but results
# in errors when caching the weights.
# Setting .internal = False is safer.
if parent_quantizer is not None:
parent_quantizer.internal = False
# .optimize_for_gemm = True is not supported because debug
# quantizers perform non-GEMM operations.
parent_quantizer.optimize_for_gemm = False
self.rowwise_gemm_name, self.columnwise_gemm_name = _tensor_to_gemm_names_map[tensor_name]
# next iteration when this quantizer will call any API
......
......@@ -65,23 +65,33 @@ std::tuple<TensorWrapper, std::vector<size_t>> xla_buffer_to_nvte_gemm_operand(
NVTE_CHECK(typeToSize(scale_dtype) == 1,
"Inverse scale factors need to have an 8-bit data type.");
}
if (!is_nvfp4) {
if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) {
// Assume MXFP8 scales are already swizzled
if (rowwise) {
input.set_rowwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape);
} else {
input.set_columnwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape);
}
} else { // Swizzle for NVFP4
input.set_with_gemm_swizzled_scales(true);
} else if (is_nvfp4) { // Swizzle for NVFP4
NVTE_CHECK(rowwise, "NVFP4 GEMM expects rowwise for both LHS and RHS");
input.set_rowwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape);
// Create tensor to hold swizzled scale factor
TensorWrapper output(get_nvte_scaling_mode(scaling_mode));
output.set_rowwise_data(buffer.untyped_data(), input_dtype, input_shape);
output.set_rowwise_scale_inv(swizzle_scale_ptr, scale_dtype, scale_shape);
output.set_with_gemm_swizzled_scales(true);
// Launch swizzle kernel
nvte_swizzle_scaling_factors(input.data(), output.data(), stream);
// Set swizzled scales into the input tensor
input.set_rowwise_scale_inv(swizzle_scale_ptr, scale_dtype, scale_shape);
input.set_with_gemm_swizzled_scales(true);
} else { // Tensor scaling
if (rowwise) {
input.set_rowwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape);
} else {
input.set_columnwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape);
}
}
}
......@@ -669,6 +679,7 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type
lhs_swizzle_i.set_rowwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_shape_i);
lhs_i.set_rowwise_scale_inv(swizzled_lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_shape_i);
}
lhs_i.set_with_gemm_swizzled_scales(true);
if (rhs_use_colwise) {
rhs_swizzle_i.set_columnwise_data(rhs_vptr, rhs_dtype, rhs_shape_i);
rhs_swizzle_i.set_columnwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_shape_i);
......@@ -678,6 +689,7 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type
rhs_swizzle_i.set_rowwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_shape_i);
rhs_i.set_rowwise_scale_inv(swizzled_rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_shape_i);
}
rhs_i.set_with_gemm_swizzled_scales(true);
if (!is_empty_gemm) {
lhs_swizzle_wrapper_list.push_back(std::move(lhs_swizzle_i));
......
......@@ -164,17 +164,9 @@ def general_gemm(
bias_dtype = TE_DType[torch.bfloat16 if bias is None else bias.dtype]
if isinstance(A, Float8BlockwiseQTensorStorage) or isinstance(B, Float8BlockwiseQTensorStorage):
# There is not use_split_accumulator == False
# implementation for Float8BlockwiseQTensorStorage GEMM
# FP8 block-scaling requires split accumulator
use_split_accumulator = True
# Check that data format is supported
if (
A._data_format != tex.Float8BlockScaleTensorFormat.GEMM_READY
or B._data_format != tex.Float8BlockScaleTensorFormat.GEMM_READY
):
raise RuntimeError("GEMM with Float8BlockwiseQTensor requires GEMM_READY format")
args = (
A,
transa, # transa
......
......@@ -301,11 +301,13 @@ std::vector<size_t> convertShape(const NVTEShape& shape) {
return std::vector<size_t>(shape.data, shape.data + shape.ndim);
}
size_t roundup(const size_t value, const size_t multiple) {
size_t roundup(size_t value, size_t multiple) {
assert(multiple > 0);
return ((value + multiple - 1) / multiple) * multiple;
}
size_t ceildiv(size_t numer, size_t denom) { return (numer + denom - 1) / denom; }
void philox_unpack(at::PhiloxCudaState arg, int64_t* rng_state_ptr) {
NVTE_SCOPED_GIL_RELEASE({
nvte_extract_seed_and_offset(rng_state_ptr, arg.captured_, arg.seed_.ptr, arg.seed_.val,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment