Commit 53fa872c authored by wenjh's avatar wenjh
Browse files

Merge branch 'nv_release_v2.8' into release_v2.8

parents 27ddce40 40c69e75
...@@ -81,6 +81,7 @@ void compute_ref(const ProcessingMethod processing_method, ...@@ -81,6 +81,7 @@ void compute_ref(const ProcessingMethod processing_method,
// Cache computations // Cache computations
for (size_t i = i_min; i < i_max; ++i) { for (size_t i = i_min; i < i_max; ++i) {
for (size_t j = j_min; j < j_max; ++j) { for (size_t j = j_min; j < j_max; ++j) {
const size_t idx = i * cols + j; const size_t idx = i * cols + j;
const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min); const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min);
...@@ -310,12 +311,13 @@ void performTest_x1(const ProcessingMethod processing_method, ...@@ -310,12 +311,13 @@ void performTest_x1(const ProcessingMethod processing_method,
const double rel_tolerable_mismatches_limit = 0.0; const double rel_tolerable_mismatches_limit = 0.0;
size_t mismatches_scales = 0; size_t mismatches_scales = 0;
compare_e8m0_scaling_factors("scales", gpu_scales_ptr, ref_output_scales.get(),
unpadded_blocks_Y, unpadded_blocks_X, scales_stride, compare_scaling_factors("scales", gpu_scales_ptr, ref_output_scales.get(),
mismatches_scales, unpadded_blocks_Y, unpadded_blocks_X, scales_stride,
scale_diff_abs_tolerance, mismatches_scales,
abs_tolerable_mismatches_limit, scale_diff_abs_tolerance,
rel_tolerable_mismatches_limit); abs_tolerable_mismatches_limit,
rel_tolerable_mismatches_limit);
const size_t mismatches_elts = 32 * mismatches_scales; const size_t mismatches_elts = 32 * mismatches_scales;
auto [atol, rtol] = getTolerances(otype); auto [atol, rtol] = getTolerances(otype);
...@@ -481,22 +483,22 @@ void performTest_x2(const ProcessingMethod processing_method, ...@@ -481,22 +483,22 @@ void performTest_x2(const ProcessingMethod processing_method,
const double rel_tolerable_mismatches_limit = 0.0; const double rel_tolerable_mismatches_limit = 0.0;
size_t mismatches_scales_rowwise = 0; size_t mismatches_scales_rowwise = 0;
compare_e8m0_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr<fp8e8m0>(), compare_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr<fp8e8m0>(),
ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise, ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise,
unpadded_blocks_X_rowwise, scales_stride_rowwise, unpadded_blocks_X_rowwise, scales_stride_rowwise,
mismatches_scales_rowwise, mismatches_scales_rowwise,
scale_diff_abs_tolerance, scale_diff_abs_tolerance,
abs_tolerable_mismatches_limit, abs_tolerable_mismatches_limit,
rel_tolerable_mismatches_limit); rel_tolerable_mismatches_limit);
size_t mismatches_scales_colwise = 0; size_t mismatches_scales_colwise = 0;
compare_e8m0_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr<fp8e8m0>(), compare_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr<fp8e8m0>(),
ref_scales_colwise.get(), unpadded_blocks_Y_colwise, ref_scales_colwise.get(), unpadded_blocks_Y_colwise,
unpadded_blocks_X_colwise, scales_stride_colwise, unpadded_blocks_X_colwise, scales_stride_colwise,
mismatches_scales_colwise, mismatches_scales_colwise,
scale_diff_abs_tolerance, scale_diff_abs_tolerance,
abs_tolerable_mismatches_limit, abs_tolerable_mismatches_limit,
rel_tolerable_mismatches_limit); rel_tolerable_mismatches_limit);
const size_t mismatches_elts_rowwise = 32 * mismatches_scales_rowwise; const size_t mismatches_elts_rowwise = 32 * mismatches_scales_rowwise;
const size_t mismatches_elts_colwise = 32 * mismatches_scales_colwise; const size_t mismatches_elts_colwise = 32 * mismatches_scales_colwise;
......
...@@ -267,19 +267,20 @@ void performTest_x1(const size_t rows, ...@@ -267,19 +267,20 @@ void performTest_x1(const size_t rows,
? output.rowwise_cpu_scale_inv_ptr<fp8e8m0>() ? output.rowwise_cpu_scale_inv_ptr<fp8e8m0>()
: output.columnwise_cpu_scale_inv_ptr<fp8e8m0>(); : output.columnwise_cpu_scale_inv_ptr<fp8e8m0>();
if (rowwise) { if (rowwise) {
compare_e8m0_scaling_factors("rowwise scales", gpu_scales_ptr, ref_output_scales.get(), compare_scaling_factors("rowwise scales", gpu_scales_ptr, ref_output_scales.get(),
unpadded_blocks_Y, unpadded_blocks_X, scales_stride, unpadded_blocks_Y, unpadded_blocks_X, scales_stride,
mismatches_scales, mismatches_scales,
scale_diff_abs_tolerance, scale_diff_abs_tolerance,
abs_tolerable_mismatches_limit, abs_tolerable_mismatches_limit,
rel_tolerable_mismatches_limit); rel_tolerable_mismatches_limit);
} else { } else {
compare_e8m0_scaling_factors("colwise scales", gpu_scales_ptr, ref_output_scales.get(), compare_scaling_factors("colwise scales", gpu_scales_ptr, ref_output_scales.get(),
unpadded_blocks_Y, unpadded_blocks_X, scales_stride, unpadded_blocks_Y, unpadded_blocks_X, scales_stride,
mismatches_scales, mismatches_scales,
scale_diff_abs_tolerance, scale_diff_abs_tolerance,
abs_tolerable_mismatches_limit, abs_tolerable_mismatches_limit,
rel_tolerable_mismatches_limit); rel_tolerable_mismatches_limit);
} }
const size_t mismatches_elts = 32 * mismatches_scales; const size_t mismatches_elts = 32 * mismatches_scales;
...@@ -378,21 +379,22 @@ void performTest_x2(const size_t rows, ...@@ -378,21 +379,22 @@ void performTest_x2(const size_t rows,
const double rel_tolerable_mismatches_limit = 1.0e-4; const double rel_tolerable_mismatches_limit = 1.0e-4;
size_t mismatches_scales_rowwise = 0; size_t mismatches_scales_rowwise = 0;
compare_e8m0_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr<fp8e8m0>(), compare_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr<fp8e8m0>(),
ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise, ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise,
unpadded_blocks_X_rowwise, scales_stride_rowwise, unpadded_blocks_X_rowwise, scales_stride_rowwise,
mismatches_scales_rowwise, mismatches_scales_rowwise,
scale_diff_abs_tolerance, scale_diff_abs_tolerance,
abs_tolerable_mismatches_limit, abs_tolerable_mismatches_limit,
rel_tolerable_mismatches_limit); rel_tolerable_mismatches_limit);
size_t mismatches_scales_colwise = 0; size_t mismatches_scales_colwise = 0;
compare_e8m0_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr<fp8e8m0>(), compare_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr<fp8e8m0>(),
ref_scales_colwise.get(), unpadded_blocks_Y_colwise, ref_scales_colwise.get(), unpadded_blocks_Y_colwise,
unpadded_blocks_X_colwise, scales_stride_colwise, unpadded_blocks_X_colwise, scales_stride_colwise,
mismatches_scales_colwise, mismatches_scales_colwise,
scale_diff_abs_tolerance, scale_diff_abs_tolerance,
abs_tolerable_mismatches_limit, abs_tolerable_mismatches_limit,
rel_tolerable_mismatches_limit); rel_tolerable_mismatches_limit);
const size_t mismatches_elts_rowwise = 32 * mismatches_scales_rowwise; const size_t mismatches_elts_rowwise = 32 * mismatches_scales_rowwise;
const size_t mismatches_elts_colwise = 32 * mismatches_scales_colwise; const size_t mismatches_elts_colwise = 32 * mismatches_scales_colwise;
......
This diff is collapsed.
...@@ -111,6 +111,10 @@ size_t DIVUP(const size_t &x, const size_t &y){ ...@@ -111,6 +111,10 @@ size_t DIVUP(const size_t &x, const size_t &y){
return (((x) + ((y)-1)) / (y)); return (((x) + ((y)-1)) / (y));
} }
size_t DIVUP_TO_MULTIPLE(const size_t &x, const size_t &y){
return DIVUP(x, y) * y;
}
struct scale_inv_meta { struct scale_inv_meta {
std::vector<size_t> shape; std::vector<size_t> shape;
DType type; DType type;
...@@ -147,21 +151,71 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape, ...@@ -147,21 +151,71 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
scale_inv_meta ret_rowwise, ret_colwise; scale_inv_meta ret_rowwise, ret_colwise;
auto block_alignment = std::vector<size_t>{128ul, 4ul}; const size_t block_size_X_rowwise = 32;
{ size_t scale_dim_Y_rowwise = DIVUP_TO_MULTIPLE(first_dim, scale_tensor_alignment_Y_rowwise);
auto alignment = block_alignment[0]; size_t scale_dim_X_rowwise = DIVUP_TO_MULTIPLE(DIVUP(last_dim, block_size_X_rowwise), scale_tensor_alignment_X_rowwise);
auto scale_dim_0 = DIVUP(DIVUP(first_dim, static_cast<size_t>(1)), alignment) * alignment; ret_rowwise.shape = {scale_dim_Y_rowwise, scale_dim_X_rowwise};
alignment = block_alignment[1];
auto scale_dim_1 = DIVUP(DIVUP(last_dim, static_cast<size_t>(32)), alignment) * alignment; const size_t block_size_Y_colwise = 32;
ret_rowwise.shape = {scale_dim_0, scale_dim_1}; size_t scale_dim_Y_colwise = DIVUP_TO_MULTIPLE(DIVUP(first_dim, block_size_Y_colwise), scale_tensor_alignment_Y_colwise);
size_t scale_dim_X_colwise = DIVUP_TO_MULTIPLE(last_dim, scale_tensor_alignment_X_colwise);
ret_colwise.shape = {scale_dim_Y_colwise, scale_dim_X_colwise};
ret_rowwise.type = DType::kFloat8E8M0;
ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0);
ret_colwise.type = DType::kFloat8E8M0;
ret_colwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0);
return {ret_rowwise, ret_colwise};
}
if (scaling_mode == NVTE_NVFP4_1D_SCALING) {
std::vector<size_t> shape_vec;
for (size_t i = 0; i < shape.ndim; ++i) {
shape_vec.push_back(shape.data[i]);
} }
{ size_t first_dim = first_dimension(shape_vec);
auto alignment = block_alignment[1]; size_t last_dim = last_dimension(shape_vec);
auto scale_dim_0 = DIVUP(DIVUP(first_dim, static_cast<size_t>(32)), alignment) * alignment;
alignment = block_alignment[0]; NVTE_CHECK(last_dim % 32 == 0);
auto scale_dim_1 = DIVUP(DIVUP(last_dim, static_cast<size_t>(1)), alignment) * alignment; NVTE_CHECK(first_dim % 32 == 0);
ret_colwise.shape = {scale_dim_0, scale_dim_1};
scale_inv_meta ret_rowwise, ret_colwise;
size_t scale_dim_Y = DIVUP_TO_MULTIPLE(first_dim, scale_tensor_alignment_Y_rowwise);
size_t scale_dim_X = DIVUP_TO_MULTIPLE(DIVUP(last_dim, 16lu), scale_tensor_alignment_X_rowwise);
ret_rowwise.shape = {scale_dim_Y, scale_dim_X};
size_t scale_dim_Y_t = DIVUP_TO_MULTIPLE(last_dim, scale_tensor_alignment_Y_rowwise);
size_t scale_dim_X_t = DIVUP_TO_MULTIPLE(DIVUP(first_dim, 16lu), scale_tensor_alignment_X_rowwise);
ret_colwise.shape = {scale_dim_Y_t, scale_dim_X_t};
ret_rowwise.type = DType::kFloat8E4M3;
ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat8E4M3);
ret_colwise.type = DType::kFloat8E4M3;
ret_colwise.type_size_bits = typeToNumBits(DType::kFloat8E4M3);
return {ret_rowwise, ret_colwise};
}
if (scaling_mode == NVTE_MXFP8_1D_SCALING) {
std::vector<size_t> shape_vec;
for (size_t i = 0; i < shape.ndim; ++i) {
shape_vec.push_back(shape.data[i]);
} }
size_t first_dim = first_dimension(shape_vec);
size_t last_dim = last_dimension(shape_vec);
scale_inv_meta ret_rowwise, ret_colwise;
const size_t block_size_X_rowwise = 32;
size_t scale_dim_Y_rowwise = DIVUP_TO_MULTIPLE(first_dim, scale_tensor_alignment_Y_rowwise);
size_t scale_dim_X_rowwise = DIVUP_TO_MULTIPLE(DIVUP(last_dim, block_size_X_rowwise), scale_tensor_alignment_X_rowwise);
ret_rowwise.shape = {scale_dim_Y_rowwise, scale_dim_X_rowwise};
const size_t block_size_Y_colwise = 32;
size_t scale_dim_Y_colwise = DIVUP_TO_MULTIPLE(DIVUP(first_dim, block_size_Y_colwise), scale_tensor_alignment_Y_colwise);
size_t scale_dim_X_colwise = DIVUP_TO_MULTIPLE(last_dim, scale_tensor_alignment_X_colwise);
ret_colwise.shape = {scale_dim_Y_colwise, scale_dim_X_colwise};
ret_rowwise.type = DType::kFloat8E8M0; ret_rowwise.type = DType::kFloat8E8M0;
ret_colwise.type = DType::kFloat8E8M0; ret_colwise.type = DType::kFloat8E8M0;
ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0); ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0);
...@@ -180,13 +234,13 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape, ...@@ -180,13 +234,13 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
scale_inv_meta ret_rowwise, ret_colwise; scale_inv_meta ret_rowwise, ret_colwise;
{ {
auto scale_dim_0 = DIVUP(first_dim, static_cast<size_t>(blockwise_fp8_block_len())); size_t scale_dim_0 = DIVUP(first_dim, static_cast<size_t>(blockwise_fp8_block_len()));
auto scale_dim_1 = DIVUP(DIVUP(last_dim, static_cast<size_t>(blockwise_fp8_block_len())), 4) * 4; size_t scale_dim_1 = DIVUP(DIVUP(last_dim, static_cast<size_t>(blockwise_fp8_block_len())), 4) * 4;
ret_rowwise.shape = {scale_dim_0, scale_dim_1}; ret_rowwise.shape = {scale_dim_0, scale_dim_1};
} }
{ {
auto scale_dim_0 = DIVUP(last_dim, static_cast<size_t>(blockwise_fp8_block_len())); size_t scale_dim_0 = DIVUP(last_dim, static_cast<size_t>(blockwise_fp8_block_len()));
auto scale_dim_1 = DIVUP(DIVUP(first_dim, static_cast<size_t>(blockwise_fp8_block_len())), 4) * 4; size_t scale_dim_1 = DIVUP(DIVUP(first_dim, static_cast<size_t>(blockwise_fp8_block_len())), 4) * 4;
ret_colwise.shape = {scale_dim_0, scale_dim_1}; ret_colwise.shape = {scale_dim_0, scale_dim_1};
} }
ret_rowwise.type = DType::kFloat32; ret_rowwise.type = DType::kFloat32;
...@@ -206,13 +260,13 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape, ...@@ -206,13 +260,13 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
scale_inv_meta ret_rowwise, ret_colwise; scale_inv_meta ret_rowwise, ret_colwise;
{ {
auto scale_dim_0 = DIVUP(last_dim, static_cast<size_t>(blockwise_fp8_block_len())); size_t scale_dim_0 = DIVUP(last_dim, static_cast<size_t>(blockwise_fp8_block_len()));
auto scale_dim_1 = DIVUP(first_dim, 4) * 4; size_t scale_dim_1 = DIVUP(first_dim, 4) * 4;
ret_rowwise.shape = {scale_dim_0, scale_dim_1}; ret_rowwise.shape = {scale_dim_0, scale_dim_1};
} }
{ {
auto scale_dim_0 = DIVUP(first_dim, static_cast<size_t>(blockwise_fp8_block_len())); size_t scale_dim_0 = DIVUP(first_dim, static_cast<size_t>(blockwise_fp8_block_len()));
auto scale_dim_1 = DIVUP(last_dim, 4) * 4; size_t scale_dim_1 = DIVUP(last_dim, 4) * 4;
ret_colwise.shape = {scale_dim_0, scale_dim_1}; ret_colwise.shape = {scale_dim_0, scale_dim_1};
} }
ret_rowwise.type = DType::kFloat32; ret_rowwise.type = DType::kFloat32;
...@@ -254,14 +308,15 @@ Tensor::Tensor(const std::string& name, ...@@ -254,14 +308,15 @@ Tensor::Tensor(const std::string& name,
NVTEShape columnwise_shape = {}; NVTEShape columnwise_shape = {};
std::vector<size_t> columnwise_shape_vec; std::vector<size_t> columnwise_shape_vec;
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING || scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D) { if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING
|| scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D) {
// Transpose when tensor scaling // Transpose when tensor scaling
columnwise_shape_vec.emplace_back(shape.data[shape.ndim - 1]); columnwise_shape_vec.emplace_back(shape.data[shape.ndim - 1]);
for (size_t i = 0; i < shape.ndim - 1; ++i) { for (size_t i = 0; i < shape.ndim - 1; ++i) {
columnwise_shape_vec.emplace_back(shape.data[i]); columnwise_shape_vec.emplace_back(shape.data[i]);
} }
} else { } else {
// Same shape for MX // Same shape for MX and NVFP4
for (size_t i = 0; i < shape.ndim; ++i) { for (size_t i = 0; i < shape.ndim; ++i) {
columnwise_shape_vec.emplace_back(shape.data[i]); columnwise_shape_vec.emplace_back(shape.data[i]);
} }
...@@ -287,10 +342,13 @@ Tensor::Tensor(const std::string& name, ...@@ -287,10 +342,13 @@ Tensor::Tensor(const std::string& name,
std::fill_n(cpu_data_columnwise_.get(), total_size, 0); std::fill_n(cpu_data_columnwise_.get(), total_size, 0);
} }
} }
tensor_.set_rowwise_data(dptr_rowwise, type, shape);
tensor_.set_columnwise_data(dptr_columnwise, type, columnwise_shape);
if (isFp8Type(type)) { const DType rowwise_type = (scaling_mode == NVTE_NVFP4_1D_SCALING) ? DType::kFloat4E2M1 : type;
const DType colwise_type = (scaling_mode == NVTE_NVFP4_1D_SCALING) ? DType::kFloat4E2M1 : type;
tensor_.set_rowwise_data(dptr_rowwise, rowwise_type, shape);
tensor_.set_columnwise_data(dptr_columnwise, colwise_type, columnwise_shape);
if (isFp8Type(type) || isFp4Type(type)) {
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
cudaMalloc((void**)&amax, sizeof(float)); // NOLINT(*) cudaMalloc((void**)&amax, sizeof(float)); // NOLINT(*)
cudaMemset(amax, 0, sizeof(float)); cudaMemset(amax, 0, sizeof(float));
...@@ -309,13 +367,19 @@ Tensor::Tensor(const std::string& name, ...@@ -309,13 +367,19 @@ Tensor::Tensor(const std::string& name,
} }
if (columnwise) { if (columnwise) {
tensor_.set_columnwise_scale_inv(rowwise_scale_inv, DType::kFloat32, tensor_.set_columnwise_scale_inv(rowwise_scale_inv, DType::kFloat32,
std::vector<size_t>{1}); std::vector<size_t>{1});
columnwise_scale_inv_cpu_data_ = std::make_unique<unsigned char[]>(sizeof(float)); columnwise_scale_inv_cpu_data_ = std::make_unique<unsigned char[]>(sizeof(float));
std::fill_n(columnwise_scale_inv_cpu_data_.get(), sizeof(float), 0); std::fill_n(columnwise_scale_inv_cpu_data_.get(), sizeof(float), 0);
} }
} else { } else {
auto [rowwise_scale_meta, colwise_scale_meta] = if (scaling_mode == NVTE_NVFP4_1D_SCALING) {
get_scales(normalized_shape, tensor_.scaling_mode()); // Used for NVFP4 second stage scaling
cudaMalloc((void**)&scale, sizeof(float)); // NOLINT(*)
cudaMemset(scale, 0, sizeof(float));
scale_cpu_data_ = std::make_shared<float>(0);
tensor_.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
}
auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(normalized_shape, tensor_.scaling_mode());
auto rowwise_scale_size = rowwise_scale_meta.bytes(); auto rowwise_scale_size = rowwise_scale_meta.bytes();
auto columnwise_scale_size = colwise_scale_meta.bytes(); auto columnwise_scale_size = colwise_scale_meta.bytes();
auto scale_shape = rowwise_scale_meta.shape; auto scale_shape = rowwise_scale_meta.shape;
...@@ -350,13 +414,16 @@ void Tensor::to_cpu() const { ...@@ -350,13 +414,16 @@ void Tensor::to_cpu() const {
cudaMemcpyDeviceToHost); cudaMemcpyDeviceToHost);
} }
if (columnwise_) { if (columnwise_) {
const DType colwise_type = tensor_.dtype();
const size_t colwise_size = bytes(s, colwise_type);
cudaMemcpy(cpu_data_columnwise_.get(), cudaMemcpy(cpu_data_columnwise_.get(),
tensor_.get_columnwise_data().data_ptr, tensor_.get_columnwise_data().data_ptr,
size, colwise_size,
cudaMemcpyDeviceToHost); cudaMemcpyDeviceToHost);
} }
if (isFp8Type(dtype())) { if (isFp8Type(dtype()) || isFp4Type(dtype())) {
if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) { if ((tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING)) {
if (tensor_.amax() != nullptr){ if (tensor_.amax() != nullptr){
cudaMemcpy(amax_cpu_data_.get(), cudaMemcpy(amax_cpu_data_.get(),
tensor_.amax(), tensor_.amax(),
...@@ -368,8 +435,7 @@ void Tensor::to_cpu() const { ...@@ -368,8 +435,7 @@ void Tensor::to_cpu() const {
sizeof(float), sizeof(float),
cudaMemcpyDeviceToHost); cudaMemcpyDeviceToHost);
} }
auto [rowwise_scale_meta, colwise_scale_meta] = auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(s, tensor_.scaling_mode());
get_scales(s, tensor_.scaling_mode());
if (rowwise_) { if (rowwise_) {
auto scale_size = rowwise_scale_meta.bytes(); auto scale_size = rowwise_scale_meta.bytes();
cudaMemcpy(rowwise_scale_inv_cpu_data_.get(), cudaMemcpy(rowwise_scale_inv_cpu_data_.get(),
...@@ -398,15 +464,15 @@ void Tensor::from_cpu() const { ...@@ -398,15 +464,15 @@ void Tensor::from_cpu() const {
cudaMemcpy(tensor_.get_columnwise_data().data_ptr, cpu_data_columnwise_.get(), size, cudaMemcpy(tensor_.get_columnwise_data().data_ptr, cpu_data_columnwise_.get(), size,
cudaMemcpyHostToDevice); cudaMemcpyHostToDevice);
} }
if (isFp8Type(dtype())) { if (isFp8Type(dtype()) || isFp4Type(dtype())) {
if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) { if ((tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING)
|| (tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING)) {
if (tensor_.amax() != nullptr){ if (tensor_.amax() != nullptr){
cudaMemcpy(tensor_.amax(), amax_cpu_data_.get(), sizeof(float), cudaMemcpyHostToDevice); cudaMemcpy(tensor_.amax(), amax_cpu_data_.get(), sizeof(float), cudaMemcpyHostToDevice);
} }
cudaMemcpy(tensor_.scale(), scale_cpu_data_.get(), sizeof(float), cudaMemcpyHostToDevice); cudaMemcpy(tensor_.scale(), scale_cpu_data_.get(), sizeof(float), cudaMemcpyHostToDevice);
} }
auto [rowwise_scale_meta, colwise_scale_meta] = auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(s, tensor_.scaling_mode());
get_scales(s, tensor_.scaling_mode());
if (rowwise_) { if (rowwise_) {
auto scale_size = rowwise_scale_meta.bytes(); auto scale_size = rowwise_scale_meta.bytes();
cudaMemcpy(tensor_.get_rowwise_scale_inv().data_ptr, cudaMemcpy(tensor_.get_rowwise_scale_inv().data_ptr,
...@@ -423,7 +489,7 @@ void Tensor::from_cpu() const { ...@@ -423,7 +489,7 @@ void Tensor::from_cpu() const {
} }
void Tensor::set_scale(float scale) { void Tensor::set_scale(float scale) {
if (isFp8Type(dtype())) { if (isFp8Type(dtype()) || isFp4Type(dtype())) {
NVTE_CHECK(scale_cpu_data_); NVTE_CHECK(scale_cpu_data_);
if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) { if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) {
*scale_cpu_data_ = scale; *scale_cpu_data_ = scale;
...@@ -433,7 +499,7 @@ void Tensor::set_scale(float scale) { ...@@ -433,7 +499,7 @@ void Tensor::set_scale(float scale) {
} }
void Tensor::set_scale_inv(float scale_inv) { void Tensor::set_scale_inv(float scale_inv) {
if (isFp8Type(dtype())) { if (isFp8Type(dtype()) || isFp4Type(dtype())) {
if (rowwise_) { if (rowwise_) {
NVTE_CHECK(rowwise_scale_inv_cpu_data_); NVTE_CHECK(rowwise_scale_inv_cpu_data_);
} }
...@@ -441,8 +507,7 @@ void Tensor::set_scale_inv(float scale_inv) { ...@@ -441,8 +507,7 @@ void Tensor::set_scale_inv(float scale_inv) {
NVTE_CHECK(columnwise_scale_inv_cpu_data_); NVTE_CHECK(columnwise_scale_inv_cpu_data_);
} }
auto [rowwise_scale_meta, colwise_scale_meta] = auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(tensor_.shape(), tensor_.scaling_mode());
get_scales(tensor_.shape(), tensor_.scaling_mode());
if (rowwise_) { if (rowwise_) {
auto num_scales = product(rowwise_scale_meta.shape); auto num_scales = product(rowwise_scale_meta.shape);
if (num_scales == 1) { if (num_scales == 1) {
...@@ -472,7 +537,8 @@ void Tensor::set_scale_inv(float scale_inv) { ...@@ -472,7 +537,8 @@ void Tensor::set_scale_inv(float scale_inv) {
} }
void Tensor::shareFP8Meta(const Tensor &other) { void Tensor::shareFP8Meta(const Tensor &other) {
if (isFp8Type(dtype()) && isFp8Type(other.dtype())) { if ((isFp8Type(dtype()) && isFp8Type(other.dtype()))
|| isFp4Type(dtype()) && isFp4Type(other.dtype())) {
auto new_tensor = TensorWrapper(other.tensor_.scaling_mode()); auto new_tensor = TensorWrapper(other.tensor_.scaling_mode());
auto my_rowwise_data = tensor_.get_rowwise_data(); auto my_rowwise_data = tensor_.get_rowwise_data();
new_tensor.set_rowwise_data(my_rowwise_data.data_ptr, static_cast<DType>(my_rowwise_data.dtype), new_tensor.set_rowwise_data(my_rowwise_data.data_ptr, static_cast<DType>(my_rowwise_data.dtype),
...@@ -724,12 +790,30 @@ void compareResults(const std::string &name, const uint8_t *test, const uint8_t ...@@ -724,12 +790,30 @@ void compareResults(const std::string &name, const uint8_t *test, const uint8_t
} }
} }
void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref, template <typename T>
const size_t row_blocks, const size_t col_blocks, const size_t stride, struct CastToType;
size_t& mismatches_num, const size_t atol,
const double abs_tolerable_mismatches_limit, template <>
const double rel_tolerable_mismatches_limit) struct CastToType<uint8_t> {
using type = int;
};
template <>
struct CastToType<fp8e4m3> {
using type = float;
};
template <typename T>
void compare_scaling_factors(const std::string &name, const T *test, const T *ref,
const size_t row_blocks, const size_t col_blocks, const size_t stride,
size_t& mismatches_num, const size_t atol,
const double abs_tolerable_mismatches_limit,
const double rel_tolerable_mismatches_limit)
{ {
using UpcastType = typename CastToType<T>::type;
auto [atol_fp8e4m3, rtol_fp8e4m3] = getTolerances(DType::kFloat8E4M3);
const size_t N = row_blocks * col_blocks; const size_t N = row_blocks * col_blocks;
const size_t tolerable_mismatches_limit = std::min(abs_tolerable_mismatches_limit, const size_t tolerable_mismatches_limit = std::min(abs_tolerable_mismatches_limit,
std::floor(N * rel_tolerable_mismatches_limit)); std::floor(N * rel_tolerable_mismatches_limit));
...@@ -739,11 +823,31 @@ void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, ...@@ -739,11 +823,31 @@ void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test,
for (int i = 0; i < row_blocks; ++i) { for (int i = 0; i < row_blocks; ++i) {
for (int j = 0; j < col_blocks; ++j) { for (int j = 0; j < col_blocks; ++j) {
const int idx = i * stride + j; const int idx = i * stride + j;
const int test_val = static_cast<int>(test[idx]); float t, r;
const int ref_val = static_cast<int>(ref[idx]);
const int abs_delta = std::abs(test_val - ref_val); bool assertion = false;
if (abs_delta > atol) { if (std::is_same<T, uint8_t>::value) {
t = static_cast<float>(test[idx]);
r = static_cast<float>(ref[idx]);
assertion = std::abs(t - r) > atol;
} else {
t = static_cast<float>(*reinterpret_cast<const fp8e4m3*>(&test[idx]));
r = static_cast<float>(*reinterpret_cast<const fp8e4m3*>(&ref[idx]));
const bool mismatch = (fabs(t - r) > atol_fp8e4m3)
&& (r == 0 || fabs((t - r) / r) > rtol_fp8e4m3);
if (mismatch) {
/* Check if it is just a failure of round to nearest choosing different
side of the real value */
const double mean = (t + r) / 2;
const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6);
const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6);
const double cast_mean_p = static_cast<double>(static_cast<T>(mean_p));
const double cast_mean_m = static_cast<double>(static_cast<T>(mean_m));
assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r));
}
}
if (assertion) {
mismatches_num++; mismatches_num++;
mismatch_indices.push_back(idx); mismatch_indices.push_back(idx);
} }
...@@ -751,8 +855,8 @@ void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, ...@@ -751,8 +855,8 @@ void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test,
std::cout << "Error in " << name << std::endl; std::cout << "Error in " << name << std::endl;
for (const int index : mismatch_indices) { for (const int index : mismatch_indices) {
std::cout << "Mismatch at (" << index << "):" std::cout << "Mismatch at (" << index << "):"
<< static_cast<int>(test[index]) << " vs " << static_cast<UpcastType>(test[index]) << " vs "
<< static_cast<int>(ref[index]) << std::endl; << static_cast<UpcastType>(ref[index]) << std::endl;
} }
GTEST_FAIL() << mismatches_num << " mismatche(s) which is more than tolerable mismatch limit of " GTEST_FAIL() << mismatches_num << " mismatche(s) which is more than tolerable mismatch limit of "
<< tolerable_mismatches_limit << "."; << tolerable_mismatches_limit << ".";
...@@ -761,6 +865,22 @@ void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, ...@@ -761,6 +865,22 @@ void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test,
} }
} }
// Instantiate templates
template
void compare_scaling_factors<uint8_t>(const std::string &name, const uint8_t *test, const uint8_t *ref,
const size_t row_blocks, const size_t col_blocks, const size_t stride,
size_t& mismatches_num, const size_t atol,
const double abs_tolerable_mismatches_limit,
const double rel_tolerable_mismatches_limit);
template
void compare_scaling_factors<fp8e4m3>(const std::string &name, const fp8e4m3 *test, const fp8e4m3 *ref,
const size_t row_blocks, const size_t col_blocks, const size_t stride,
size_t& mismatches_num, const size_t atol,
const double abs_tolerable_mismatches_limit,
const double rel_tolerable_mismatches_limit);
std::pair<double, double> getTolerances(const DType type) { std::pair<double, double> getTolerances(const DType type) {
switch(type) { switch(type) {
case DType::kFloat32: case DType::kFloat32:
...@@ -920,6 +1040,10 @@ bool isFp8Type(DType type) { ...@@ -920,6 +1040,10 @@ bool isFp8Type(DType type) {
return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2 || type == DType::kFloat8E8M0; return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2 || type == DType::kFloat8E8M0;
} }
bool isFp4Type(DType type) {
return type == DType::kFloat4E2M1;
}
int32_t getDeviceComputeCapability() { int32_t getDeviceComputeCapability() {
cudaDeviceProp deviceProp; cudaDeviceProp deviceProp;
cudaGetDeviceProperties(&deviceProp, 0); cudaGetDeviceProperties(&deviceProp, 0);
...@@ -941,7 +1065,8 @@ std::array<size_t, 4> get_scale_tensor_dims(const size_t rows, ...@@ -941,7 +1065,8 @@ std::array<size_t, 4> get_scale_tensor_dims(const size_t rows,
const size_t cols, const size_t cols,
const size_t block_size_rows, const size_t block_size_rows,
const size_t block_size_cols) { const size_t block_size_cols) {
const bool is_rowwise = (block_size_rows == 1) && (block_size_cols == 32); const bool is_rowwise = (block_size_rows == 1)
&& ((block_size_cols == 32) || (block_size_cols == 16));
const size_t alignment_Y = is_rowwise const size_t alignment_Y = is_rowwise
? scale_tensor_alignment_Y_rowwise ? scale_tensor_alignment_Y_rowwise
......
...@@ -79,6 +79,8 @@ using fp8e8m0 = uint8_t; ...@@ -79,6 +79,8 @@ using fp8e8m0 = uint8_t;
using int8 = int8_t; using int8 = int8_t;
#if FP4_TYPE_SUPPORTED #if FP4_TYPE_SUPPORTED
using fp4e2m1 = __nv_fp4_e2m1; using fp4e2m1 = __nv_fp4_e2m1;
using fp4e2m1x2 = __nv_fp4x2_e2m1;
using fp4e2m1x4 = __nv_fp4x4_e2m1;
#endif #endif
template <typename T> template <typename T>
...@@ -240,7 +242,9 @@ class Tensor { ...@@ -240,7 +242,9 @@ class Tensor {
float scale() const { float scale() const {
if(scale_cpu_data_) { if(scale_cpu_data_) {
NVTE_CHECK(tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING, "Invalid scaling_mode!"); NVTE_CHECK((tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING)
|| (tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING),
"Invalid scaling_mode!");
to_cpu(); to_cpu();
return *scale_cpu_data_; return *scale_cpu_data_;
} else { } else {
...@@ -254,6 +258,8 @@ class Tensor { ...@@ -254,6 +258,8 @@ class Tensor {
NVTE_CHECK(TypeInfo<T>::dtype == DType::kFloat32, "Invalid type!"); NVTE_CHECK(TypeInfo<T>::dtype == DType::kFloat32, "Invalid type!");
} else if (tensor_.scaling_mode() == NVTE_BLOCK_SCALING_1D || tensor_.scaling_mode() == NVTE_BLOCK_SCALING_2D) { } else if (tensor_.scaling_mode() == NVTE_BLOCK_SCALING_1D || tensor_.scaling_mode() == NVTE_BLOCK_SCALING_2D) {
NVTE_CHECK(TypeInfo<T>::dtype == DType::kFloat32, "Invalid type!"); NVTE_CHECK(TypeInfo<T>::dtype == DType::kFloat32, "Invalid type!");
} else if (tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING) {
NVTE_CHECK(TypeInfo<T>::dtype == DType::kFloat8E4M3, "Invalid type!");
} else { } else {
NVTE_CHECK(TypeInfo<T>::dtype == DType::kByte, "Invalid type!"); NVTE_CHECK(TypeInfo<T>::dtype == DType::kByte, "Invalid type!");
} }
...@@ -267,6 +273,8 @@ class Tensor { ...@@ -267,6 +273,8 @@ class Tensor {
NVTE_CHECK(TypeInfo<T>::dtype == DType::kFloat32, "Invalid type!"); NVTE_CHECK(TypeInfo<T>::dtype == DType::kFloat32, "Invalid type!");
} else if (tensor_.scaling_mode() == NVTE_BLOCK_SCALING_1D || tensor_.scaling_mode() == NVTE_BLOCK_SCALING_2D) { } else if (tensor_.scaling_mode() == NVTE_BLOCK_SCALING_1D || tensor_.scaling_mode() == NVTE_BLOCK_SCALING_2D) {
NVTE_CHECK(TypeInfo<T>::dtype == DType::kFloat32, "Invalid type!"); NVTE_CHECK(TypeInfo<T>::dtype == DType::kFloat32, "Invalid type!");
} else if (tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING) {
NVTE_CHECK(TypeInfo<T>::dtype == DType::kFloat8E4M3, "Invalid type!");
} else { } else {
NVTE_CHECK(TypeInfo<T>::dtype == DType::kByte, "Invalid type!"); NVTE_CHECK(TypeInfo<T>::dtype == DType::kByte, "Invalid type!");
} }
...@@ -321,10 +329,10 @@ constexpr uint32_t FP32_EXPONENT_BIAS = 127; ...@@ -321,10 +329,10 @@ constexpr uint32_t FP32_EXPONENT_BIAS = 127;
constexpr uint32_t FP32_MANTISSA_BITS = 23; constexpr uint32_t FP32_MANTISSA_BITS = 23;
// [128,4] rowwise and [4,128] colwise alignment requirement // [128,4] rowwise and [4,128] colwise alignment requirement
constexpr size_t scale_tensor_alignment_X_rowwise = 4;
constexpr size_t scale_tensor_alignment_Y_rowwise = 128; constexpr size_t scale_tensor_alignment_Y_rowwise = 128;
constexpr size_t scale_tensor_alignment_X_colwise = 128; constexpr size_t scale_tensor_alignment_X_rowwise = 4;
constexpr size_t scale_tensor_alignment_Y_colwise = 4; constexpr size_t scale_tensor_alignment_Y_colwise = 4;
constexpr size_t scale_tensor_alignment_X_colwise = 128;
inline size_t divide_round_up(const size_t N, const size_t M) { inline size_t divide_round_up(const size_t N, const size_t M) {
return (N - 1 + M) / M; return (N - 1 + M) / M;
...@@ -473,12 +481,14 @@ void compareResults(const std::string &name, const float test, const float ref, ...@@ -473,12 +481,14 @@ void compareResults(const std::string &name, const float test, const float ref,
double atol = 1e-5, double rtol = 1e-8); double atol = 1e-5, double rtol = 1e-8);
void compareResults(const std::string &name, const uint8_t *test, const uint8_t *ref, void compareResults(const std::string &name, const uint8_t *test, const uint8_t *ref,
size_t N, float mismatch_rate_tol = 0.); size_t N, float mismatch_rate_tol = 0.);
void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref, template <typename T>
const size_t row_blocks, const size_t col_blocks, const size_t stride, void compare_scaling_factors(const std::string &name, const T *test, const T *ref,
size_t& mismatches_num, const size_t row_blocks, const size_t col_blocks, const size_t stride,
const size_t scale_diff_abs_tolerance = 0, size_t& mismatches_num,
const double abs_tolerable_mismatches_limit = 0, const size_t scale_diff_abs_tolerance = 0,
const double rel_tolerable_mismatches_limit = 0); const double abs_tolerable_mismatches_limit = 0,
const double rel_tolerable_mismatches_limit = 0);
std::array<size_t, 4> get_scale_tensor_dims(const size_t rows, const size_t cols, std::array<size_t, 4> get_scale_tensor_dims(const size_t rows, const size_t cols,
const size_t block_size_rows, const size_t block_size_cols); const size_t block_size_rows, const size_t block_size_cols);
...@@ -501,6 +511,7 @@ const std::string& caseName(InputsFillCase type); ...@@ -501,6 +511,7 @@ const std::string& caseName(InputsFillCase type);
extern std::vector<DType> all_fp_types; extern std::vector<DType> all_fp_types;
bool isFp8Type(DType type); bool isFp8Type(DType type);
bool isFp4Type(DType type);
int32_t getDeviceComputeCapability(); int32_t getDeviceComputeCapability();
constexpr int32_t hopperComputeCapability = 90; constexpr int32_t hopperComputeCapability = 90;
...@@ -578,7 +589,7 @@ constexpr int32_t blackwellComputeCapability = 100; ...@@ -578,7 +589,7 @@ constexpr int32_t blackwellComputeCapability = 100;
SWITCH_FP4_TYPE_HANDLE(type, __VA_ARGS__) \ SWITCH_FP4_TYPE_HANDLE(type, __VA_ARGS__) \
default: \ default: \
printf("dtype: %d\n", static_cast<int>(dtype)); \ printf("dtype: %d\n", static_cast<int>(dtype)); \
NVTE_ERROR("Invalid type MARKED TEST."); \ NVTE_ERROR("Invalid type."); \
} }
#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(dtype, type, ...) \ #define TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(dtype, type, ...) \
...@@ -597,7 +608,7 @@ constexpr int32_t blackwellComputeCapability = 100; ...@@ -597,7 +608,7 @@ constexpr int32_t blackwellComputeCapability = 100;
} \ } \
break; \ break; \
default: \ default: \
NVTE_ERROR("Invalid type MARKED TEST 2."); \ NVTE_ERROR("Invalid type."); \
} }
#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP4_ONLY(dtype, type, ...) \ #define TRANSFORMER_ENGINE_TYPE_SWITCH_FP4_ONLY(dtype, type, ...) \
...@@ -605,7 +616,7 @@ constexpr int32_t blackwellComputeCapability = 100; ...@@ -605,7 +616,7 @@ constexpr int32_t blackwellComputeCapability = 100;
using namespace transformer_engine; \ using namespace transformer_engine; \
SWITCH_FP4_HANDLE(type, __VA_ARGS__) \ SWITCH_FP4_HANDLE(type, __VA_ARGS__) \
default: \ default: \
NVTE_ERROR("Invalid type MARKED TEST 3."); \ NVTE_ERROR("Invalid type."); \
} }
#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(dtype, type, ...) \ #define TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(dtype, type, ...) \
...@@ -630,5 +641,5 @@ constexpr int32_t blackwellComputeCapability = 100; ...@@ -630,5 +641,5 @@ constexpr int32_t blackwellComputeCapability = 100;
} \ } \
break; \ break; \
default: \ default: \
NVTE_ERROR("Invalid type MARKED TEST 4."); \ NVTE_ERROR("Invalid type."); \
} }
...@@ -780,9 +780,15 @@ class TestFusedQuantize: ...@@ -780,9 +780,15 @@ class TestFusedQuantize:
assert_allclose(te_output.data, jax_output.data) assert_allclose(te_output.data, jax_output.data)
if is_dbias: if is_dbias:
# TE kernels cast the intermediate results to the input dtype which reduces precision compared to the JAX implementation, for dbias this typically only affects bfloat16.
precise_comparison = not ( precise_comparison = not (
in_dtype == jnp.bfloat16 and scaling_mode.is_1d_block_scaling() # TE kernels cast the intermediate results to the input dtype which reduces precision compared to the JAX implementation, for dbias this typically only affects bfloat16.
(in_dtype == jnp.bfloat16 and scaling_mode.is_1d_block_scaling())
# Due to the amax dependency, current scaling is unfused. In TE we store the activation results in bf16 which reduces precision compared to JAX implementation which will implicitly promote to float32 for the intermediate results when JIT'd. This only produces a tolerance issue when using squared_relu currently.
or (
activation_type == ("squared_relu",)
and in_dtype == jnp.bfloat16
and scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING
)
) )
assert_allclose( assert_allclose(
te_dbias, jax_dbias, dtype=in_dtype if precise_comparison else out_dtype te_dbias, jax_dbias, dtype=in_dtype if precise_comparison else out_dtype
......
...@@ -76,8 +76,6 @@ class TestDistributedLayernorm: ...@@ -76,8 +76,6 @@ class TestDistributedLayernorm:
all_reduce_loss_bytes + weight_count * shape[-1] * jax_dtype.itemsize all_reduce_loss_bytes + weight_count * shape[-1] * jax_dtype.itemsize
) )
other_bytes = 0 other_bytes = 0
if fp8_recipe == recipe.Float8CurrentScaling():
allreduce_total_bytes += jax_dtype.itemsize # 1 * dtype for the amax reduction
return generate_collectives_count( return generate_collectives_count(
allreduce=allreduce_total_bytes * int(is_dp_enabled), allgather=0, other=other_bytes allreduce=allreduce_total_bytes * int(is_dp_enabled), allgather=0, other=other_bytes
) )
......
This diff is collapsed.
...@@ -6,6 +6,7 @@ import os ...@@ -6,6 +6,7 @@ import os
import subprocess import subprocess
import sys import sys
import pathlib import pathlib
import logging
import pytest import pytest
import torch import torch
...@@ -13,20 +14,28 @@ from transformer_engine.pytorch.utils import ( ...@@ -13,20 +14,28 @@ from transformer_engine.pytorch.utils import (
get_device_compute_capability, get_device_compute_capability,
get_cudnn_version, get_cudnn_version,
) )
from transformer_engine.common.recipe import (
DelayedScaling,
Float8CurrentScaling,
)
from transformer_engine.pytorch.attention.dot_product_attention.utils import FlashAttentionUtils from transformer_engine.pytorch.attention.dot_product_attention.utils import FlashAttentionUtils
_current_file = pathlib.Path(__file__).resolve() _current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent.parent)) sys.path.append(str(_current_file.parent.parent))
from utils import ModelConfig, get_available_attention_backends from utils import ModelConfig, get_available_attention_backends
pytest_logging_level = logging.getLevelName(logging.root.level)
# Initialize RNG state # Initialize RNG state
seed = 1234 seed = 1234
torch.manual_seed(seed) torch.manual_seed(seed)
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
test_essential = True
from torch.utils.cpp_extension import IS_HIP_EXTENSION from torch.utils.cpp_extension import IS_HIP_EXTENSION
model_configs_flash_attn = { model_configs_flash_attn = {
# test: b, h, hg, d, sq, skv, p, mask, bias # test: ModelConfig(b, sq, hq, dqk)
"cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal"), # MHA "cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal"), # MHA
"cp_1_1": ModelConfig(2, 4096, 12, 128), # MHA "cp_1_1": ModelConfig(2, 4096, 12, 128), # MHA
"cp_1_2": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 0)), # MHA "cp_1_2": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 0)), # MHA
...@@ -61,18 +70,31 @@ def get_bash_arguments(num_gpus_per_node, **kwargs): ...@@ -61,18 +70,31 @@ def get_bash_arguments(num_gpus_per_node, **kwargs):
return args return args
dtypes = ["bf16", "fp16"]
qkv_formats = ["bshd", "sbhd", "thd"]
cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"]
if test_essential:
configs = ["cp_1_0", "cp_2_1", "cp_3_2", "cp_3_3"]
model_configs_flash_attn = {k: model_configs_flash_attn[k] for k in configs}
dtypes = ["bf16"]
qkv_formats = ["sbhd", "thd"]
@pytest.mark.skipif(not FlashAttentionUtils.v2_plus, reason="Flash-attn 2.0+ is required.") @pytest.mark.skipif(not FlashAttentionUtils.v2_plus, reason="Flash-attn 2.0+ is required.")
@pytest.mark.skipif(not IS_HIP_EXTENSION and get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.") @pytest.mark.skipif(not IS_HIP_EXTENSION and get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.")
@pytest.mark.parametrize("dtype", ["bf16", "fp16"]) @pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("model", model_configs_flash_attn.keys()) @pytest.mark.parametrize("model", model_configs_flash_attn.keys())
@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"]) @pytest.mark.parametrize("qkv_format", qkv_formats)
@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a", "a2a+p2p"]) @pytest.mark.parametrize("cp_comm_type", cp_comm_types)
def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
num_gpus = 4 if cp_comm_type == "a2a+p2p" else 2 num_gpus = 4 if cp_comm_type == "a2a+p2p" else 2
if num_gpus > torch.cuda.device_count(): if num_gpus > torch.cuda.device_count():
pytest.skip(f"Test requires {num_gpus} GPUs, but found {torch.cuda.device_count()}") pytest.skip(f"Test requires {num_gpus} GPUs, but found {torch.cuda.device_count()}")
config = model_configs_flash_attn[model] config = model_configs_flash_attn[model]
config.context_parallel = True
config.cp_comm_type = cp_comm_type
if "p2p" in cp_comm_type and config.window_size != (-1, 0) and config.window_size != (-1, -1): if "p2p" in cp_comm_type and config.window_size != (-1, 0) and config.window_size != (-1, -1):
pytest.skip("CP implementation with KV P2P does not support sliding window yet!") pytest.skip("CP implementation with KV P2P does not support sliding window yet!")
if cp_comm_type == "all_gather" and qkv_format == "thd": if cp_comm_type == "all_gather" and qkv_format == "thd":
...@@ -90,6 +112,15 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): ...@@ -90,6 +112,15 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
) )
if "p2p" not in cp_comm_type and config.head_dim_qk != config.head_dim_v: if "p2p" not in cp_comm_type and config.head_dim_qk != config.head_dim_v:
pytest.skip("MLA CP currently only support KV P2P!") pytest.skip("MLA CP currently only support KV P2P!")
dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16}
available_backends, *_ = get_available_attention_backends(
config,
qkv_dtype=dtypes[dtype],
qkv_layout="_".join([qkv_format] * 3),
)
flash_attn_supported, *_ = available_backends
if not flash_attn_supported:
pytest.skip("No attention backend available.")
subprocess.run( subprocess.run(
get_bash_arguments( get_bash_arguments(
...@@ -99,13 +130,14 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): ...@@ -99,13 +130,14 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
qkv_format=qkv_format, qkv_format=qkv_format,
kernel_backend="FlashAttention", kernel_backend="FlashAttention",
cp_comm_type=cp_comm_type, cp_comm_type=cp_comm_type,
log_level=pytest_logging_level,
), ),
check=True, check=True,
) )
model_configs_fused_attn = { model_configs_fused_attn = {
# test: b, h, hg, d, sq, skv, p, mask, bias # test: ModelConfig(b, sq, hq, dqk)
"cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal"), # MHA "cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal"), # MHA
"cp_1_1": ModelConfig(2, 4096, 12, 128), # MHA "cp_1_1": ModelConfig(2, 4096, 12, 128), # MHA
"cp_1_2": ModelConfig( "cp_1_2": ModelConfig(
...@@ -136,17 +168,42 @@ model_configs_fused_attn = { ...@@ -136,17 +168,42 @@ model_configs_fused_attn = {
2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias", head_dim_v=64 2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias", head_dim_v=64
), # MLA ), # MLA
"cp_3_3": ModelConfig(2, 4096, 12, 128, attn_bias_type="post_scale_bias", head_dim_v=64), # MLA "cp_3_3": ModelConfig(2, 4096, 12, 128, attn_bias_type="post_scale_bias", head_dim_v=64), # MLA
"cp_4_0": ModelConfig(
2, 4096, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="vanilla"
), # GQA
"cp_4_1": ModelConfig(
2, 4096, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="off-by-one"
), # GQA
"cp_4_2": ModelConfig(
2, 4096, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="learnable"
), # GQA
} }
dtypes = ["bf16", "fp16", "fp8"]
qkv_formats = ["bshd", "sbhd", "thd"]
cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"]
if test_essential:
configs = ["cp_1_0", "cp_2_0", "cp_2_2", "cp_3_2", "cp_4_2"]
model_configs_fused_attn = {k: model_configs_fused_attn[k] for k in configs}
dtypes = ["bf16", "fp8"]
qkv_formats = ["sbhd", "thd"]
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 7), reason="cuDNN 8.9.7+ is required.") @pytest.mark.skipif(get_cudnn_version() < (8, 9, 7), reason="cuDNN 8.9.7+ is required.")
@pytest.mark.skipif(IS_HIP_EXTENSION or get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.") @pytest.mark.skipif(IS_HIP_EXTENSION or get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.")
@pytest.mark.parametrize("dtype", ["bf16", "fp16", "fp8"]) @pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("model", model_configs_fused_attn.keys()) @pytest.mark.parametrize("model", model_configs_fused_attn.keys())
@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"]) @pytest.mark.parametrize("qkv_format", qkv_formats)
@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a", "a2a+p2p"]) @pytest.mark.parametrize("cp_comm_type", cp_comm_types)
@pytest.mark.parametrize("fp8_mha", [False, True]) @pytest.mark.parametrize("fp8_bwd", [True, False])
def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha): @pytest.mark.parametrize("fp8_mha", [True, False])
@pytest.mark.parametrize("fp8_dpa", [True, False])
@pytest.mark.parametrize("scaling_mode", [None, "delayed", "current"])
@pytest.mark.parametrize("f16_O", [True, False])
def test_cp_with_fused_attention(
dtype, model, qkv_format, cp_comm_type, fp8_bwd, fp8_mha, fp8_dpa, scaling_mode, f16_O
):
num_gpus = 4 if cp_comm_type == "a2a+p2p" else 2 num_gpus = 4 if cp_comm_type == "a2a+p2p" else 2
if num_gpus > torch.cuda.device_count(): if num_gpus > torch.cuda.device_count():
pytest.skip(f"Test requires {num_gpus} GPUs, but found {torch.cuda.device_count()}") pytest.skip(f"Test requires {num_gpus} GPUs, but found {torch.cuda.device_count()}")
...@@ -157,8 +214,15 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha ...@@ -157,8 +214,15 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha
pytest.skip("CP implementation with KV all-gather is only supported with cuDNN >= 9.3.0!") pytest.skip("CP implementation with KV all-gather is only supported with cuDNN >= 9.3.0!")
if dtype == "fp8" and get_device_compute_capability() < (9, 0): if dtype == "fp8" and get_device_compute_capability() < (9, 0):
pytest.skip("FP8 attention is only supported on sm90+!") pytest.skip("FP8 attention is only supported on sm90+!")
if dtype == "fp8" and not fp8_dpa and fp8_mha:
pytest.skip("Duplicate tests to fp8_dpa=True and fp8_mha=True!")
if dtype != "fp8" and fp8_bwd:
pytest.skip("Only fp8 works with fp8_bwd=True!")
config = model_configs_fused_attn[model] config = model_configs_fused_attn[model]
config.context_parallel = True
config.cp_comm_type = cp_comm_type
if qkv_format == "thd" and config.attn_bias_type == "post_scale_bias": if qkv_format == "thd" and config.attn_bias_type == "post_scale_bias":
pytest.skip("THD format does not support post_scale_bias yet!") pytest.skip("THD format does not support post_scale_bias yet!")
if qkv_format == "thd" and cp_comm_type == "all_gather": if qkv_format == "thd" and cp_comm_type == "all_gather":
...@@ -186,19 +250,57 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha ...@@ -186,19 +250,57 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha
f"CP implementation with QKVO A2A requires num_heads ({config.num_heads}) and" f"CP implementation with QKVO A2A requires num_heads ({config.num_heads}) and"
f" num_gqa_groups ({config.num_gqa_groups}) to be divisible by cp_size (2)!" f" num_gqa_groups ({config.num_gqa_groups}) to be divisible by cp_size (2)!"
) )
if dtype != "fp8" and fp8_mha: if dtype != "fp8" and (fp8_mha or fp8_dpa):
pytest.skip("Only fp8 works with fp8_mha=True!") pytest.skip("Only fp8 works with fp8_dpa=True or fp8_mha=True!")
if dtype == "fp8" and not (fp8_mha or fp8_dpa):
pytest.skip("fp8 only works with fp8_dpa=True or fp8_mha=True!")
if dtype != "fp8" and scaling_mode is not None:
pytest.skip("Only fp8 works with scaling_mode != None!")
if dtype == "fp8" and scaling_mode is None:
pytest.skip("fp8 only works with scaling_mode != None!")
if (
dtype == "fp8"
and scaling_mode == "current"
and cp_comm_type not in ["p2p", "a2a+p2p", "a2a"]
):
pytest.skip("fp8 only works with P2P, A2A and A2A+P2P for scaling_mode = current!")
if f16_O and (dtype != "fp8" or scaling_mode != "current"):
pytest.skip("f16_O only needs to be tested for dtype = fp8 and scaling_mode = current!")
if "p2p" not in cp_comm_type and config.head_dim_qk != config.head_dim_v: if "p2p" not in cp_comm_type and config.head_dim_qk != config.head_dim_v:
pytest.skip("MLA CP currently only support KV P2P!") pytest.skip("MLA CP currently only support KV P2P!")
if dtype == "fp8" and config.head_dim_qk != config.head_dim_v: if dtype == "fp8" and config.head_dim_qk != config.head_dim_v:
pytest.skip("MLA CP currently does not support FP8 attention!") pytest.skip("MLA CP currently does not support FP8 attention!")
if dtype == "fp8" and config.softmax_type != "vanilla":
pytest.skip("CP implementation does not support non-vanilla softmax types in FP8!")
if config.softmax_type != "vanilla" and cp_comm_type != "a2a":
pytest.skip(
"CP implementation only supports cp_comm_type=a2a for non-vanilla softmax types!"
)
if config.softmax_type != "vanilla" and qkv_format == "thd":
pytest.skip(
"CP implementation does not support qkv_format=thd for non-vanilla softmax types!"
)
dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16} dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16}
fp8_meta = {}
fp8_meta["recipe"] = None
fp8_meta["local_recipes"] = []
fp8 = dtype == "fp8" and (fp8_dpa or fp8_mha)
if fp8 and scaling_mode == "delayed":
fp8_meta["recipe"] = DelayedScaling(fp8_dpa=True)
fp8_meta["local_recipes"] = [DelayedScaling(fp8_dpa=True)]
if fp8 and scaling_mode == "current":
fp8_meta["recipe"] = DelayedScaling(fp8_dpa=True)
fp8_meta["local_recipes"] = [
Float8CurrentScaling(fp8_dpa=True),
DelayedScaling(fp8_dpa=True),
]
available_backends, _, fused_attn_backends = get_available_attention_backends( available_backends, _, fused_attn_backends = get_available_attention_backends(
config, config,
qkv_dtype=dtypes[dtype], qkv_dtype=dtypes[dtype] if dtype != "fp8" else torch.float8_e4m3fn,
qkv_layout="_".join([qkv_format] * 3), qkv_layout="_".join([qkv_format] * 3),
window_size=config.window_size, fp8=fp8,
context_parallel=True, fp8_meta=fp8_meta,
) )
_, fused_attn_supported, _ = available_backends _, fused_attn_supported, _ = available_backends
if not fused_attn_supported: if not fused_attn_supported:
...@@ -212,7 +314,12 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha ...@@ -212,7 +314,12 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha
qkv_format=qkv_format, qkv_format=qkv_format,
kernel_backend="FusedAttention", kernel_backend="FusedAttention",
cp_comm_type=cp_comm_type, cp_comm_type=cp_comm_type,
fp8_bwd=fp8_bwd,
fp8_dpa=fp8_dpa,
fp8_mha=fp8_mha, fp8_mha=fp8_mha,
scaling_mode=scaling_mode,
f16_O=f16_O,
log_level=pytest_logging_level,
), ),
check=True, check=True,
) )
...@@ -469,7 +469,6 @@ def test_kv_cache(dtype, model, qkv_format, is_paged, backend, module, is_cuda_g ...@@ -469,7 +469,6 @@ def test_kv_cache(dtype, model, qkv_format, is_paged, backend, module, is_cuda_g
config, config,
qkv_dtype=dtype, qkv_dtype=dtype,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
window_size=config.window_size,
pad_between_seqs=False, pad_between_seqs=False,
is_training=False, is_training=False,
fp8=is_fp8, fp8=is_fp8,
......
...@@ -9,6 +9,7 @@ import datetime ...@@ -9,6 +9,7 @@ import datetime
import os import os
import sys import sys
from functools import wraps from functools import wraps
import math
import torch import torch
from torch import nn from torch import nn
...@@ -20,10 +21,15 @@ from transformer_engine.common.recipe import ( ...@@ -20,10 +21,15 @@ from transformer_engine.common.recipe import (
DelayedScaling, DelayedScaling,
Float8CurrentScaling, Float8CurrentScaling,
Float8BlockScaling, Float8BlockScaling,
NVFP4BlockScaling,
Format, Format,
Recipe, Recipe,
QParams,
) )
from transformer_engine.pytorch.tensor.float8_tensor import Float8CurrentScalingQuantizer from transformer_engine.pytorch.tensor.float8_tensor import Float8CurrentScalingQuantizer
from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer
from transformer_engine.pytorch.constants import NVFP4_BLOCK_SCALING_SIZE
from transformer_engine.pytorch.distributed import gather_along_first_dim
from run_layer_with_overlap import _compare_tensors from run_layer_with_overlap import _compare_tensors
from torch.utils.cpp_extension import IS_HIP_EXTENSION from torch.utils.cpp_extension import IS_HIP_EXTENSION
...@@ -48,6 +54,14 @@ if os.environ.get("NVTE_TEST_NVINSPECT_ENABLED", False): ...@@ -48,6 +54,14 @@ if os.environ.get("NVTE_TEST_NVINSPECT_ENABLED", False):
) )
def nvfp4_vanilla():
nvfp4_recipe = NVFP4BlockScaling()
nvfp4_recipe.fp4_quant_fwd_inp = QParams()
nvfp4_recipe.fp4_quant_fwd_weight = QParams()
nvfp4_recipe.fp4_quant_bwd_grad = QParams()
return nvfp4_recipe
# Quantization recipe setup # Quantization recipe setup
def quantization_recipe() -> Recipe: def quantization_recipe() -> Recipe:
if QUANTIZATION == "fp8": if QUANTIZATION == "fp8":
...@@ -60,6 +74,8 @@ def quantization_recipe() -> Recipe: ...@@ -60,6 +74,8 @@ def quantization_recipe() -> Recipe:
return Float8CurrentScaling() return Float8CurrentScaling()
if QUANTIZATION == "fp8_block_scaling": if QUANTIZATION == "fp8_block_scaling":
return Float8BlockScaling() return Float8BlockScaling()
if QUANTIZATION == "nvfp4":
return nvfp4_vanilla()
return te.fp8.get_default_fp8_recipe() return te.fp8.get_default_fp8_recipe()
...@@ -97,10 +113,14 @@ def main(argv=None, namespace=None): ...@@ -97,10 +113,14 @@ def main(argv=None, namespace=None):
# Quantization scheme # Quantization scheme
QUANTIZATION = args.quantization QUANTIZATION = args.quantization
global SEQ_LEN, BATCH_SIZE, HIDDEN_SIZE global SEQ_LEN, BATCH_SIZE, HIDDEN_SIZE
if QUANTIZATION in ("fp8", "mxfp8"): if QUANTIZATION in ("fp8", "mxfp8", "nvfp4"):
SEQ_LEN = 32 SEQ_LEN = 32
BATCH_SIZE = 32 BATCH_SIZE = 32
HIDDEN_SIZE = 128 HIDDEN_SIZE = 128
# For fp8 block scaling, block size is 128,
# and to make low precision TP work, input tensor
# must be 128x128 divisible to be eligible for
# low precision All-Gather when needed
elif QUANTIZATION == "fp8_block_scaling": elif QUANTIZATION == "fp8_block_scaling":
SEQ_LEN = 128 SEQ_LEN = 128
BATCH_SIZE = 128 BATCH_SIZE = 128
...@@ -108,6 +128,7 @@ def main(argv=None, namespace=None): ...@@ -108,6 +128,7 @@ def main(argv=None, namespace=None):
test_dict = [ test_dict = [
test_quantizer, test_quantizer,
test_quantized_all_gather,
test_linear, test_linear,
test_layernorm, test_layernorm,
test_layernorm_linear, test_layernorm_linear,
...@@ -177,6 +198,9 @@ def _get_tolerances(dtype): ...@@ -177,6 +198,9 @@ def _get_tolerances(dtype):
# row parallel & sequence parallel, because we do the all_gather in backward pass # row parallel & sequence parallel, because we do the all_gather in backward pass
if QUANTIZATION == "fp8_cs": if QUANTIZATION == "fp8_cs":
return {"rtol": 0.4, "atol": 0.25} return {"rtol": 0.4, "atol": 0.25}
elif QUANTIZATION == "nvfp4":
# TODO(zhongboz): investigate why the tolerance is so large
return {"rtol": 0.125, "atol": 0.12}
elif QUANTIZATION is not None: elif QUANTIZATION is not None:
return {"rtol": 0.125, "atol": 0.0625} return {"rtol": 0.125, "atol": 0.0625}
...@@ -327,24 +351,36 @@ def _alloc_main_grad(model_single_node, model_distributed): ...@@ -327,24 +351,36 @@ def _alloc_main_grad(model_single_node, model_distributed):
############################################### ###############################################
# Quantizer # # Quantizer #
############################################### ###############################################
def _construct_quantizer(quantizer_class, fp8_dtype, device, tp_group, tp_size): def _construct_quantizer(quantizer_class, low_precision_dtype, device, tp_group, tp_size):
""" """
quantizer is the reference quantizer on a single GPU. quantizer is the reference quantizer on a single GPU.
quantizer_dist is the distributed quantizer to be tested on multiple GPUs. quantizer_dist is the distributed quantizer to be tested on multiple GPUs.
""" """
if quantizer_class == Float8CurrentScalingQuantizer: if quantizer_class == Float8CurrentScalingQuantizer:
quantizer_dist = quantizer_class( quantizer_dist = quantizer_class(
fp8_dtype=fp8_dtype, fp8_dtype=low_precision_dtype,
device=device, device=device,
with_amax_reduction=True, with_amax_reduction=True,
amax_reduction_group=tp_group, amax_reduction_group=tp_group,
) )
quantizer = quantizer_class( quantizer = quantizer_class(
fp8_dtype=fp8_dtype, fp8_dtype=low_precision_dtype,
device=device, device=device,
with_amax_reduction=False, with_amax_reduction=False,
) )
return quantizer, quantizer_dist return quantizer, quantizer_dist
elif quantizer_class == NVFP4Quantizer:
quantizer_dist = quantizer_class(
fp4_dtype=low_precision_dtype,
with_amax_reduction=True,
amax_reduction_group=tp_group,
)
quantizer = quantizer_class(
fp4_dtype=low_precision_dtype,
with_amax_reduction=False,
amax_reduction_group=None,
)
return quantizer, quantizer_dist
else: else:
raise ValueError(f"Unsupported quantizer class: {quantizer_class}") raise ValueError(f"Unsupported quantizer class: {quantizer_class}")
...@@ -415,6 +451,194 @@ def test_quantizer(): ...@@ -415,6 +451,194 @@ def test_quantizer():
_test_quantizer(input_dtype, fp8_dtype) _test_quantizer(input_dtype, fp8_dtype)
############################################
# Quantized All-Gather #
############################################
def _ref_zero_padding_scale_inv(scale_inv, unpadded_shape):
"""
Zero padding the scale_inv.
scale_inv shape is the padded shape, but not zero padded
unpadded_shape is the original shape before padding
"""
dim0, dim1 = scale_inv.shape
unpadded_dim0, unpadded_dim1 = unpadded_shape
pad_dim0 = (128 - unpadded_dim0 % 128) % 128
pad_dim1 = (4 - unpadded_dim1 % 4) % 4
new_dim0 = unpadded_dim0 + pad_dim0
new_dim1 = unpadded_dim1 + pad_dim1
assert dim0 == new_dim0
assert dim1 == new_dim1
# return input if no padding is needed
if pad_dim0 == 0 and pad_dim1 == 0:
return scale_inv
# unpad first to remove random bits from torch empty
scale_inv = scale_inv[:unpadded_dim0, :unpadded_dim1].contiguous()
# using torch padding
new_scale_inv = torch.nn.functional.pad(
scale_inv, (0, pad_dim1, 0, pad_dim0), mode="constant", value=0
)
assert new_scale_inv.shape == (new_dim0, new_dim1)
return new_scale_inv
def _get_unpadded_scale_inv_shape(input_shape, quantizer_cls, columnwise):
"""
Calculate the unpadded shape of the scale_inv tensor.
"""
M, K = 1, 1
M = math.prod(input_shape[:-1])
K = input_shape[-1]
if quantizer_cls == NVFP4Quantizer:
if columnwise:
outer = K
inner = math.ceil(M / NVFP4_BLOCK_SCALING_SIZE)
return (outer, inner)
else:
outer = M
inner = math.ceil(K / NVFP4_BLOCK_SCALING_SIZE)
return (outer, inner)
else:
raise ValueError(f"Unsupported quantizer class: {quantizer_cls}")
@run_distributed_test()
def _test_quantized_all_gather(input_dtype, low_precision_dtype, quantizer_cls):
"""Test the quantizer under distributed settings.
Args:
input_dtype (torch.dtype): The data type of the input.
low_precision_dtype (tex.DType): The data type of the low precision, can be fp4 or fp8.
"""
M, N = WORLD_SIZE * BATCH_SIZE, HIDDEN_SIZE // 2
# high precision input
x_hp_cpu = torch.randn((M, N), device="cpu").to(input_dtype)
# set one element of the input to a very large value, which doesn't live in rank 0 after the split
# to test the amax reduction on purpose
# x_hp_cpu[M - 1, N - 1] = 1e4
# get the unpadded shapes
unpadded_rowwise_scale_inv_shape = _get_unpadded_scale_inv_shape((M, N), quantizer_cls, False)
unpadded_columnwise_scale_inv_shape = _get_unpadded_scale_inv_shape((M, N), quantizer_cls, True)
# rank 0 takes the full copy and quantize with GPU 0 for verification
if WORLD_RANK == 0:
x_hp_rank0 = x_hp_cpu.clone().detach().requires_grad_(True).to("cuda")
x_hp_local_rank = _shard_tensor(x_hp_cpu, WORLD_SIZE, 0)[WORLD_RANK]
# Create quantizers
quantizer, quantizer_dist = _construct_quantizer(
quantizer_cls, low_precision_dtype, x_hp_local_rank.device, NCCL_WORLD, WORLD_SIZE
)
# quantize the entire input
if WORLD_RANK == 0:
x_low_precision_single = quantizer(x_hp_rank0)
# run all-gather with a quantizer as input for quantized all-gather
x_low_precision_total, _ = gather_along_first_dim(
x_hp_local_rank, NCCL_WORLD, async_op=False, quantizer=quantizer_dist
)
# check the outputs
if WORLD_RANK == 0:
# assert all data and scale_inv are the same
torch.testing.assert_close(
x_low_precision_single._rowwise_data,
x_low_precision_total._rowwise_data,
rtol=0.0,
atol=0.0,
)
# check the rowwise scale without any padding
unpad_dim0, unpad_dim1 = unpadded_rowwise_scale_inv_shape
unpadded_rowwise_scale_inv_ref = x_low_precision_single._rowwise_scale_inv[
:unpad_dim0, :unpad_dim1
]
unpadded_rowwise_scale_inv = x_low_precision_total._rowwise_scale_inv[
:unpad_dim0, :unpad_dim1
]
torch.testing.assert_close(
unpadded_rowwise_scale_inv_ref,
unpadded_rowwise_scale_inv,
rtol=0.0,
atol=0.0,
)
torch.testing.assert_close(
_ref_zero_padding_scale_inv(
x_low_precision_single._rowwise_scale_inv, unpadded_rowwise_scale_inv_shape
),
_ref_zero_padding_scale_inv(
x_low_precision_total._rowwise_scale_inv, unpadded_rowwise_scale_inv_shape
),
rtol=0.0,
atol=0.0,
)
torch.testing.assert_close(
x_low_precision_single._columnwise_data,
x_low_precision_total._columnwise_data,
rtol=0.0,
atol=0.0,
)
unpad_dim0, unpad_dim1 = unpadded_columnwise_scale_inv_shape
unpadded_columnwise_scale_inv_ref = x_low_precision_single._columnwise_scale_inv[
:unpad_dim0, :unpad_dim1
]
unpadded_columnwise_scale_inv = x_low_precision_total._columnwise_scale_inv[
:unpad_dim0, :unpad_dim1
]
torch.testing.assert_close(
unpadded_columnwise_scale_inv_ref,
unpadded_columnwise_scale_inv,
rtol=0.0,
atol=0.0,
)
torch.testing.assert_close(
_ref_zero_padding_scale_inv(
x_low_precision_single._columnwise_scale_inv, unpadded_columnwise_scale_inv_shape
),
_ref_zero_padding_scale_inv(
x_low_precision_total._columnwise_scale_inv, unpadded_columnwise_scale_inv_shape
),
rtol=0.0,
atol=0.0,
)
def test_quantized_all_gather():
"""
Run quantized all-gather tests with various configurations.
"""
# skip this test for other quantization schemes
is_nvfp4 = QUANTIZATION == "nvfp4"
# add other recipes for testing if needed
if not is_nvfp4:
return
input_dtypes = [torch.bfloat16]
fp4_dtype = [tex.DType.kFloat4E2M1]
fp8_dtype = [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]
quantizer_cls_nvfp4 = [NVFP4Quantizer]
# add FP8 quantizers if needed
quantizer_cls_fp8 = []
low_precisio_dtypes = fp4_dtype if is_nvfp4 else fp8_dtype
quantizer_cls_list = quantizer_cls_nvfp4 if is_nvfp4 else quantizer_cls_fp8
for quantizer_cls in quantizer_cls_list:
for input_dtype in input_dtypes:
for low_precision_dtype in low_precisio_dtypes:
_test_quantized_all_gather(input_dtype, low_precision_dtype, quantizer_cls)
############################################ ############################################
# Linear # # Linear #
############################################ ############################################
...@@ -515,10 +739,11 @@ def test_linear(): ...@@ -515,10 +739,11 @@ def test_linear():
{"init_method": _constant}, {"init_method": _constant},
{"fuse_wgrad_accumulation": True}, {"fuse_wgrad_accumulation": True},
{"return_bias": True}, {"return_bias": True},
{"params_dtype": torch.float16}, {"params_dtype": torch.float16 if QUANTIZATION != "nvfp4" else torch.bfloat16},
{"delay_wgrad_compute": True}, {"delay_wgrad_compute": True},
{"save_original_input": True}, {"save_original_input": True},
] ]
for kwargs in kwargs_list: for kwargs in kwargs_list:
if kwargs.get("save_original_input", False) and QUANTIZATION == "fp8": if kwargs.get("save_original_input", False) and QUANTIZATION == "fp8":
continue continue
...@@ -694,11 +919,12 @@ def test_layernorm_linear(): ...@@ -694,11 +919,12 @@ def test_layernorm_linear():
{"init_method": _constant}, {"init_method": _constant},
{"fuse_wgrad_accumulation": True}, {"fuse_wgrad_accumulation": True},
{"return_bias": True}, {"return_bias": True},
{"params_dtype": torch.float16}, {"params_dtype": torch.float16 if QUANTIZATION != "nvfp4" else torch.bfloat16},
{"zero_centered_gamma": False}, {"zero_centered_gamma": False},
{"return_layernorm_output": True}, {"return_layernorm_output": True},
{"delay_wgrad_compute": True}, {"delay_wgrad_compute": True},
] ]
for kwargs in kwargs_list: for kwargs in kwargs_list:
for parallel_mode in ["column"]: for parallel_mode in ["column"]:
for sequence_parallel in [False, True]: for sequence_parallel in [False, True]:
...@@ -800,7 +1026,7 @@ def test_layernorm_mlp(): ...@@ -800,7 +1026,7 @@ def test_layernorm_mlp():
{"normalization": "RMSNorm"}, {"normalization": "RMSNorm"},
{"zero_centered_gamma": True}, {"zero_centered_gamma": True},
{"bias": False}, {"bias": False},
{"params_dtype": torch.float16}, {"params_dtype": torch.float16 if QUANTIZATION != "nvfp4" else torch.bfloat16},
{"activation": "relu"}, {"activation": "relu"},
{"fuse_wgrad_accumulation": True}, {"fuse_wgrad_accumulation": True},
{"return_bias": True}, {"return_bias": True},
...@@ -898,7 +1124,7 @@ def test_transformer_layer(): ...@@ -898,7 +1124,7 @@ def test_transformer_layer():
{"fuse_qkv_params": True, "fuse_wgrad_accumulation": True}, {"fuse_qkv_params": True, "fuse_wgrad_accumulation": True},
{"qkv_weight_interleaved": False}, {"qkv_weight_interleaved": False},
{"bias": False}, {"bias": False},
{"params_dtype": torch.float16}, {"params_dtype": torch.float16 if QUANTIZATION != "nvfp4" else torch.bfloat16},
{"fuse_qkv_params": True}, {"fuse_qkv_params": True},
{"activation": "relu"}, {"activation": "relu"},
] ]
......
This diff is collapsed.
...@@ -31,6 +31,7 @@ mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available( ...@@ -31,6 +31,7 @@ mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available(
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( fp8_block_scaling_available, reason_for_no_fp8_block_scaling = (
FP8GlobalStateManager.is_fp8_block_scaling_available() FP8GlobalStateManager.is_fp8_block_scaling_available()
) )
nvfp4_available, reason_for_no_nvfp4 = FP8GlobalStateManager.is_nvfp4_available()
TEST_ROOT = Path(__file__).parent.resolve() TEST_ROOT = Path(__file__).parent.resolve()
NUM_PROCS: int = min(4, torch.cuda.device_count()) NUM_PROCS: int = min(4, torch.cuda.device_count())
...@@ -51,7 +52,9 @@ def _run_test(quantization): ...@@ -51,7 +52,9 @@ def _run_test(quantization):
all_boolean = [True, False] all_boolean = [True, False]
@pytest.mark.parametrize("quantization", [None, "fp8", "mxfp8", "fp8_cs", "fp8_block_scaling"]) @pytest.mark.parametrize(
"quantization", [None, "fp8", "mxfp8", "fp8_cs", "fp8_block_scaling", "nvfp4"]
)
def test_distributed(quantization): def test_distributed(quantization):
if quantization == "fp8" and not fp8_available: if quantization == "fp8" and not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
...@@ -61,4 +64,6 @@ def test_distributed(quantization): ...@@ -61,4 +64,6 @@ def test_distributed(quantization):
pytest.skip(reason_for_no_mxfp8) pytest.skip(reason_for_no_mxfp8)
if quantization == "fp8_block_scaling" and not fp8_block_scaling_available: if quantization == "fp8_block_scaling" and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling) pytest.skip(reason_for_no_fp8_block_scaling)
if quantization == "nvfp4" and not nvfp4_available:
pytest.skip(reason_for_no_nvfp4)
_run_test(quantization) _run_test(quantization)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import os
import subprocess
from pathlib import Path
import pytest
import torch
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
"""
Distributed numerics tests
This numerical test aims for zero tolerance test for absolute confidence in numerics.
In the case of NVFP4, with the experimental NVFP4 quantization, we matched bitwise
result with the native silicon. For distrbuted test cases, we can do the same by thing
by comparing BF16 AG results with the low precision AG results at layer level.
"""
if torch.cuda.device_count() < 2:
pytest.skip("Distributed training needs at least 2 GPUs.")
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = (
FP8GlobalStateManager.is_fp8_block_scaling_available()
)
nvfp4_available, reason_for_no_nvfp4 = FP8GlobalStateManager.is_nvfp4_available()
TEST_ROOT = Path(__file__).parent.resolve()
NUM_PROCS: int = min(4, torch.cuda.device_count())
LAUNCH_CMD = ["torchrun", f"--nproc_per_node={NUM_PROCS}"]
def _run_test(quantization, batch_size, hidden_size, out_size):
test_path = TEST_ROOT / "run_numerics_exact.py"
test_cmd = LAUNCH_CMD + [str(test_path)]
test_cmd += ["--quantization", quantization]
test_cmd += ["--batch-size", str(batch_size)]
test_cmd += ["--hidden-size", str(hidden_size)]
test_cmd += ["--out-size", str(out_size)]
result = subprocess.run(test_cmd, env=os.environ, check=False)
assert result.returncode == 0
all_boolean = [True, False]
@pytest.mark.parametrize("quantization", ["nvfp4"])
@pytest.mark.parametrize(
"batch_size, hidden_size, out_size",
[
(64, 128, 128),
(128, 128, 128),
(128, 256, 256),
(512, 1024, 768),
(512, 256, 1024),
(2048, 2048, 2048),
],
)
def test_distributed(quantization, batch_size, hidden_size, out_size):
if quantization == "nvfp4" and not nvfp4_available:
pytest.skip(reason_for_no_nvfp4)
_run_test(quantization, batch_size, hidden_size, out_size)
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment