Unverified Commit cb504cda authored by Oleg Goncharov's avatar Oleg Goncharov Committed by GitHub
Browse files

[Common] Improved performance of mxfp8 cast kernels (#1628)



* Fixed conflicts
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Minor code refactoring to avoid unnecessary checks
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fixed typo
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Fixed dBias accumulation error due to initialization. Minor code refactoring
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Test case to reproduce the init error
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fixed rowwise dbias error
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Changed ptx API
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Added a struct for two packed FP8 values
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Rolled back to scalar code for columnwise scaling due to its better performance
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Minor corrections
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Rebased on main
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fixes per code review
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Removed constexpr in C++ test suite to build faster
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Computed activations are now numerically truncated to InputType before scaling. Improved test suite.
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Minor refactoring
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Minor refactoring
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Modified mismatches checks of MXFP8 to address FP8 numerics
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Implemented Jeremy's fixes to JAX test suite with an intermediate downcast
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Reduced the dims of the test tensors to improve CI runtime
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Fixed memory alignment issue. Compute dbias without downcast.
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fixed misaligned memory issue also in gated kernels. Reduced size of MXFP8 gated tests
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 315b47db
......@@ -36,95 +36,34 @@ enum ActivationType {
SReLU
};
template <typename InputType, typename OutputType, float (*OP)(const float)>
void scale_block(const ProcessingMethod processing_method,
template <typename InputType, typename OutputType>
void compute_ref(const ProcessingMethod processing_method,
float (*OP)(const float),
const bool rowwise,
const bool colwise,
const InputType* input,
const InputType* grad,
OutputType* output_c,
float* dbias,
fp8e8m0* output_scales,
const size_t scale_idx,
const size_t i_min,
const size_t i_max,
const size_t j_min,
const size_t j_max,
const size_t cols) {
float amax = 0.0f;
// Find the absolute maximum value in the block
for (size_t i = i_min; i < i_max; ++i) {
for (size_t j = j_min; j < j_max; ++j) {
const size_t idx = i * cols + j;
float elt = static_cast<float>(input[idx]);
if (processing_method == ProcessingMethod::CAST_DBIAS) {
// grad is the input
elt = static_cast<float>(grad[idx]);
}
if (processing_method != ProcessingMethod::CAST_ONLY
&& processing_method != ProcessingMethod::CAST_DBIAS) {
elt = OP(elt);
}
if (processing_method == ProcessingMethod::CAST_DACT ||
processing_method == ProcessingMethod::CAST_DBIAS_DACT) {
elt *= static_cast<float>(grad[idx]);
}
dbias[j] += elt;
if (isinf(elt) || isnan(elt)) {
continue;
}
amax = std::max(amax, std::abs(elt));
}
}
const fp8e8m0 biased_exponent = float_to_e8m0(amax * Quantized_Limits<OutputType>::max_reciprocal());
const float scale_reciprocal = exp2f_rcp(biased_exponent);
output_scales[scale_idx] = biased_exponent;
// Quantize elements in the block
for (size_t i = i_min; i < i_max; ++i) {
for (size_t j = j_min; j < j_max; ++j) {
const size_t idx = i * cols + j;
float elt = static_cast<float>(input[idx]);
if (processing_method == ProcessingMethod::CAST_DBIAS) {
// grad is the input
elt = static_cast<float>(grad[idx]);
}
if (processing_method != ProcessingMethod::CAST_ONLY
&& processing_method != ProcessingMethod::CAST_DBIAS) {
elt = OP(elt);
}
if (processing_method == ProcessingMethod::CAST_DACT ||
processing_method == ProcessingMethod::CAST_DBIAS_DACT) {
elt *= static_cast<float>(grad[idx]);
}
output_c[idx] = static_cast<OutputType>(elt * scale_reciprocal);
}
}
}
template <typename InputType, typename OutputType, float (*OP)(const float)>
void compute_ref_x1(const ProcessingMethod processing_method,
const InputType* input,
const InputType* grad,
OutputType* output_c,
fp8e8m0* output_scales,
InputType* output_dbias,
const size_t rows,
const size_t cols,
const size_t block_size_Y,
const size_t block_size_X,
const size_t scales_stride)
OutputType* output_rowwise,
OutputType* output_colwise,
fp8e8m0* output_scales_rowwise,
fp8e8m0* output_scales_colwise,
InputType* output_dbias,
const size_t rows,
const size_t cols,
const size_t scales_stride_rowwise,
const size_t scales_stride_colwise)
{
const size_t tile_size_Y = std::max(32lu, block_size_Y);
const size_t tile_size_X = std::max(64lu, block_size_X);
const size_t tile_size_Y = 32;
const size_t tile_size_X = 32;
const size_t tiles_num_Y = (rows + tile_size_Y - 1) / tile_size_Y;
const size_t tiles_num_X = (cols + tile_size_X - 1) / tile_size_X;
const size_t blocks_per_tile_Y = tile_size_Y / block_size_Y;
const size_t blocks_per_tile_X = tile_size_X / block_size_X;
std::vector<float> output_dbias_fp32(cols, 0);
#pragma omp parallel proc_bind(spread)
{
// Buffers to cache intermediate computations
std::vector<float> cache_buffer(tile_size_Y * tile_size_X);
std::vector<float> thread_dbias(cols, 0);
#pragma omp for schedule(static)
for (size_t t = 0; t < tiles_num_Y * tiles_num_X; ++t) {
......@@ -133,24 +72,83 @@ void compute_ref_x1(const ProcessingMethod processing_method,
const size_t tile_offset_Y = tile_Y * tile_size_Y;
const size_t tile_offset_X = tile_X * tile_size_X;
for (size_t ii = 0; ii < blocks_per_tile_Y; ++ii) {
const size_t block_idx_Y = tile_Y * blocks_per_tile_Y + ii;
const size_t block_offset_Y = ii * block_size_Y;
const size_t i_min = tile_offset_Y + block_offset_Y;
if (i_min >= rows) continue;
const size_t i_max = std::min(i_min + block_size_Y, rows);
for (size_t jj = 0; jj < blocks_per_tile_X; ++jj) {
const size_t block_idx_X = tile_X * blocks_per_tile_X + jj;
const size_t block_offset_X = jj * block_size_X;
const size_t j_min = tile_offset_X + block_offset_X;
if (j_min >= cols) continue;
const size_t j_max = std::min(j_min + block_size_X, cols);
const size_t scale_idx = block_idx_Y * scales_stride + block_idx_X;
scale_block<InputType, OutputType, OP>(
processing_method, input, grad, output_c, thread_dbias.data(),
output_scales, scale_idx, i_min, i_max, j_min, j_max, cols);
const size_t i_min = tile_offset_Y;
const size_t i_max = std::min(i_min + tile_size_Y, rows);
const size_t j_min = tile_offset_X;
const size_t j_max = std::min(j_min + tile_size_X, cols);
// Cache computations
for (size_t i = i_min; i < i_max; ++i) {
for (size_t j = j_min; j < j_max; ++j) {
const int idx = i * cols + j;
const int cache_idx = (i - i_min) * tile_size_X + (j - j_min);
float elt = static_cast<float>(input[idx]);
if (processing_method == ProcessingMethod::CAST_DBIAS) {
// grad is the input
elt = static_cast<float>(grad[idx]);
}
if (processing_method != ProcessingMethod::CAST_ONLY
&& processing_method != ProcessingMethod::CAST_DBIAS) {
elt = OP(elt);
}
if (processing_method == ProcessingMethod::CAST_DACT ||
processing_method == ProcessingMethod::CAST_DBIAS_DACT) {
elt *= static_cast<float>(grad[idx]);
}
thread_dbias[j] += elt;
// Numerical truncation: after downcast to InputType (BF16/FP16), upcast it back to FP32
elt = static_cast<float>(static_cast<InputType>(elt));
cache_buffer[cache_idx] = elt;
if (isinf(elt) || isnan(elt)) {
continue;
}
}
}
if (rowwise) {
for (size_t i = i_min; i < i_max; ++i) {
float block_amax = 0.0f;
for (size_t j = j_min; j < j_max; ++j) {
const int cache_idx = (i - i_min) * tile_size_X + (j - j_min);
block_amax = std::max(block_amax, std::abs(cache_buffer[cache_idx]));
}
const fp8e8m0 biased_exponent = float_to_e8m0(block_amax * Quantized_Limits<OutputType>::max_reciprocal());
const int scale_idx = i * scales_stride_rowwise + tile_X;
output_scales_rowwise[scale_idx] = biased_exponent;
const float scale_reciprocal = exp2f_rcp(biased_exponent);
for (size_t j = j_min; j < j_max; ++j) {
const int idx = i * cols + j;
const int cache_idx = (i - i_min) * tile_size_X + (j - j_min);
output_rowwise[idx] = static_cast<OutputType>(cache_buffer[cache_idx] * scale_reciprocal);
}
}
}
if (colwise) {
for (size_t j = j_min; j < j_max; ++j) {
float block_amax = 0.0f;
for (size_t i = i_min; i < i_max; ++i) {
const int cache_idx = (i - i_min) * tile_size_X + (j - j_min);
block_amax = std::max(block_amax, std::abs(cache_buffer[cache_idx]));
}
const fp8e8m0 biased_exponent = float_to_e8m0(block_amax * Quantized_Limits<OutputType>::max_reciprocal());
const int scale_idx = tile_Y * scales_stride_colwise + j;
output_scales_colwise[scale_idx] = biased_exponent;
const float scale_reciprocal = exp2f_rcp(biased_exponent);
for (size_t i = i_min; i < i_max; ++i) {
const int idx = i * cols + j;
const int cache_idx = (i - i_min) * tile_size_X + (j - j_min);
output_colwise[idx] = static_cast<OutputType>(cache_buffer[cache_idx] * scale_reciprocal);
}
}
}
}
......@@ -166,29 +164,6 @@ void compute_ref_x1(const ProcessingMethod processing_method,
}
}
template <typename InputType, typename OutputType, float (*OP)(const float)>
void compute_ref_x2(const ProcessingMethod processing_method,
const InputType* input,
const InputType* grad,
OutputType* output_rowwise,
OutputType* output_colwise,
fp8e8m0* scales_rowwise,
fp8e8m0* scales_colwise,
InputType* output_dbias,
const size_t rows,
const size_t cols,
const size_t block_size_Y,
const size_t block_size_X,
const size_t scales_stride_rowwise,
const size_t scales_stride_colwise) {
compute_ref_x1<InputType, OutputType, OP>(
processing_method, input, grad, output_rowwise, scales_rowwise, output_dbias,
rows, cols, 1, block_size_X, scales_stride_rowwise);
compute_ref_x1<InputType, OutputType, OP>(
processing_method, input, grad, output_colwise, scales_colwise, output_dbias,
rows, cols, block_size_Y, 1, scales_stride_colwise);
}
/**
* Scaling along single dimension (either rows or columns)
* Produces one set of output data and the corresponding data of the fused operation (dbias):
......@@ -197,8 +172,9 @@ void compute_ref_x2(const ProcessingMethod processing_method,
* 2) Scaled columns + column-wise scaling factors
*/
template <typename InputType, typename OutputType, float (*OP)(const float)>
template <typename InputType, typename OutputType>
void performTest_x1(const ProcessingMethod processing_method,
float (*OP)(const float),
const std::vector<size_t>& shape,
const bool rowwise,
const bool colwise,
......@@ -261,28 +237,46 @@ void performTest_x1(const ProcessingMethod processing_method,
break;
}
case ProcessingMethod::CAST_DBIAS_DACT: {
nvte_quantize_dbias_dgelu(grad.data(),
input.data(),
output_c.data(),
output_dbias.data(),
workspace.data(),
0);
auto nvte_quantize_dbias_dact = &nvte_quantize_dbias_dgelu;
if (OP == &dsilu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_dsilu; }
else if (OP == &drelu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_drelu; }
else if (OP == &dqgelu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_dqgelu; }
else if (OP == &dsrelu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_dsrelu; }
nvte_quantize_dbias_dact(grad.data(),
input.data(),
output_c.data(),
output_dbias.data(),
workspace.data(),
0);
workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype());
nvte_quantize_dbias_dgelu(grad.data(),
input.data(),
output_c.data(),
output_dbias.data(),
workspace.data(),
0);
nvte_quantize_dbias_dact(grad.data(),
input.data(),
output_c.data(),
output_dbias.data(),
workspace.data(),
0);
break;
}
case ProcessingMethod::CAST_DACT: {
nvte_dgelu(grad.data(), input.data(), output_c.data(), 0);
auto nvte_dact = &nvte_dgelu;
if (OP == &dsilu) { nvte_dact = &nvte_dsilu; }
else if (OP == &drelu) { nvte_dact = &nvte_drelu; }
else if (OP == &dqgelu) { nvte_dact = &nvte_dqgelu; }
else if (OP == &dsrelu) { nvte_dact = &nvte_dsrelu; }
nvte_dact(grad.data(), input.data(), output_c.data(), 0);
break;
}
case ProcessingMethod::CAST_ACT: {
nvte_gelu(input.data(), output_c.data(), 0);
auto nvte_act = &nvte_gelu;
if (OP == &silu) { nvte_act = &nvte_silu; }
else if (OP == &relu) { nvte_act = &nvte_relu; }
else if (OP == &qgelu) { nvte_act = &nvte_qgelu; }
else if (OP == &srelu) { nvte_act = &nvte_srelu; }
nvte_act(input.data(), output_c.data(), 0);
break;
}
}
......@@ -291,29 +285,45 @@ void performTest_x1(const ProcessingMethod processing_method,
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
compute_ref_x1<InputType, OutputType, OP>(processing_method,
input.rowwise_cpu_dptr<InputType>(),
grad.rowwise_cpu_dptr<InputType>(),
ref_output_c.get(),
ref_output_scales.get(),
ref_output_dbias.get(),
rows,
cols,
block_size_rows,
block_size_cols,
scales_stride);
auto [atol, rtol] = getTolerances(otype);
compareResults("output_c", output_c, ref_output_c.get(), rowwise, atol, rtol);
compute_ref<InputType, OutputType>(processing_method,
OP,
rowwise,
colwise,
input.rowwise_cpu_dptr<InputType>(),
grad.rowwise_cpu_dptr<InputType>(),
ref_output_c.get(),
ref_output_c.get(),
ref_output_scales.get(),
ref_output_scales.get(),
ref_output_dbias.get(),
rows,
cols,
scales_stride,
scales_stride);
const uint8_t * const gpu_scales_ptr = rowwise
? output_c.rowwise_cpu_scale_inv_ptr<fp8e8m0>()
: output_c.columnwise_cpu_scale_inv_ptr<fp8e8m0>();
const size_t scale_diff_abs_tolerance = 0;
const double abs_tolerable_mismatches_limit = 0.0;
const double rel_tolerable_mismatches_limit = 0.0;
size_t mismatches_scales = 0;
compare_e8m0_scaling_factors("scales", gpu_scales_ptr, ref_output_scales.get(),
unpadded_blocks_Y, unpadded_blocks_X, scales_stride);
unpadded_blocks_Y, unpadded_blocks_X, scales_stride,
mismatches_scales,
scale_diff_abs_tolerance,
abs_tolerable_mismatches_limit,
rel_tolerable_mismatches_limit);
const size_t mismatches_elts = 32 * mismatches_scales;
auto [atol, rtol] = getTolerances(otype);
compareResults("output_c", output_c, ref_output_c.get(), rowwise, atol, rtol, true, mismatches_elts);
if (processing_method == ProcessingMethod::CAST_DBIAS || processing_method == ProcessingMethod::CAST_DBIAS_DACT) {
if (processing_method == ProcessingMethod::CAST_DBIAS
|| processing_method == ProcessingMethod::CAST_DBIAS_DACT)
{
auto [atol_dbias, rtol_dbias] = getTolerances(itype);
if (itype == DType::kFloat32) {
atol_dbias = 1e-4;
......@@ -332,8 +342,9 @@ void performTest_x1(const ProcessingMethod processing_method,
* AND
* 2) Scaled columns + column-wise scaling factors
*/
template <typename InputType, typename OutputType, float (*OP)(const float)>
template <typename InputType, typename OutputType>
void performTest_x2(const ProcessingMethod processing_method,
float (*OP)(const float),
const std::vector<size_t>& shape,
const size_t block_size_rows,
const size_t block_size_cols,
......@@ -401,28 +412,46 @@ void performTest_x2(const ProcessingMethod processing_method,
break;
}
case ProcessingMethod::CAST_DBIAS_DACT: {
nvte_quantize_dbias_dgelu(grad.data(),
input.data(),
output.data(),
output_dbias.data(),
workspace.data(),
0);
auto nvte_quantize_dbias_dact = &nvte_quantize_dbias_dgelu;
if (OP == &dsilu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_dsilu; }
else if (OP == &drelu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_drelu; }
else if (OP == &dqgelu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_dqgelu; }
else if (OP == &dsrelu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_dsrelu; }
nvte_quantize_dbias_dact(grad.data(),
input.data(),
output.data(),
output_dbias.data(),
workspace.data(),
0);
workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype());
nvte_quantize_dbias_dgelu(grad.data(),
input.data(),
output.data(),
output_dbias.data(),
workspace.data(),
0);
nvte_quantize_dbias_dact(grad.data(),
input.data(),
output.data(),
output_dbias.data(),
workspace.data(),
0);
break;
}
case ProcessingMethod::CAST_DACT: {
nvte_dgelu(grad.data(), input.data(), output.data(), 0);
auto nvte_dact = &nvte_dgelu;
if (OP == &dsilu) { nvte_dact = &nvte_dsilu; }
else if (OP == &drelu) { nvte_dact = &nvte_drelu; }
else if (OP == &dqgelu) { nvte_dact = &nvte_dqgelu; }
else if (OP == &dsrelu) { nvte_dact = &nvte_dsrelu; }
nvte_dact(grad.data(), input.data(), output.data(), 0);
break;
}
case ProcessingMethod::CAST_ACT: {
nvte_gelu(input.data(), output.data(), 0);
auto nvte_act = &nvte_gelu;
if (OP == &silu) { nvte_act = &nvte_silu; }
else if (OP == &relu) { nvte_act = &nvte_relu; }
else if (OP == &qgelu) { nvte_act = &nvte_qgelu; }
else if (OP == &srelu) { nvte_act = &nvte_srelu; }
nvte_act(input.data(), output.data(), 0);
break;
}
}
......@@ -431,32 +460,54 @@ void performTest_x2(const ProcessingMethod processing_method,
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
compute_ref_x2<InputType, OutputType, OP>(processing_method,
input.rowwise_cpu_dptr<InputType>(),
grad.rowwise_cpu_dptr<InputType>(),
ref_output_c_rowwise.get(),
ref_output_c_colwise.get(),
ref_scales_rowwise.get(),
ref_scales_colwise.get(),
ref_output_dbias.get(),
rows,
cols,
block_size_rows,
block_size_cols,
scales_stride_rowwise,
scales_stride_colwise);
auto [atol, rtol] = getTolerances(otype);
compareResults("output_c_rowwise", output, ref_output_c_rowwise.get(), true, atol, rtol);
compareResults("output_c_colwise", output, ref_output_c_colwise.get(), false, atol, rtol);
compute_ref<InputType, OutputType>(processing_method,
OP,
true,
true,
input.rowwise_cpu_dptr<InputType>(),
grad.rowwise_cpu_dptr<InputType>(),
ref_output_c_rowwise.get(),
ref_output_c_colwise.get(),
ref_scales_rowwise.get(),
ref_scales_colwise.get(),
ref_output_dbias.get(),
rows,
cols,
scales_stride_rowwise,
scales_stride_colwise);
const size_t scale_diff_abs_tolerance = 0;
const double abs_tolerable_mismatches_limit = 0.0;
const double rel_tolerable_mismatches_limit = 0.0;
size_t mismatches_scales_rowwise = 0;
compare_e8m0_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr<fp8e8m0>(),
ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise,
unpadded_blocks_X_rowwise, scales_stride_rowwise);
unpadded_blocks_X_rowwise, scales_stride_rowwise,
mismatches_scales_rowwise,
scale_diff_abs_tolerance,
abs_tolerable_mismatches_limit,
rel_tolerable_mismatches_limit);
size_t mismatches_scales_colwise = 0;
compare_e8m0_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr<fp8e8m0>(),
ref_scales_colwise.get(), unpadded_blocks_Y_colwise,
unpadded_blocks_X_colwise, scales_stride_colwise);
unpadded_blocks_X_colwise, scales_stride_colwise,
mismatches_scales_colwise,
scale_diff_abs_tolerance,
abs_tolerable_mismatches_limit,
rel_tolerable_mismatches_limit);
const size_t mismatches_elts_rowwise = 32 * mismatches_scales_rowwise;
const size_t mismatches_elts_colwise = 32 * mismatches_scales_colwise;
auto [atol, rtol] = getTolerances(otype);
compareResults("output_c_rowwise", output, ref_output_c_rowwise.get(), true, atol, rtol, true, mismatches_elts_rowwise);
compareResults("output_c_colwise", output, ref_output_c_colwise.get(), false, atol, rtol, true, mismatches_elts_colwise);
if (processing_method == ProcessingMethod::CAST_DBIAS || processing_method == ProcessingMethod::CAST_DBIAS_DACT) {
if (processing_method == ProcessingMethod::CAST_DBIAS
|| processing_method == ProcessingMethod::CAST_DBIAS_DACT)
{
auto [atol_dbias, rtol_dbias] = getTolerances(itype);
if (itype == DType::kFloat32) {
atol_dbias = 1e-4;
......@@ -475,11 +526,10 @@ std::vector<std::vector<size_t>> matrix_sizes = {
{128, 128},
{256, 256},
{993, 512},
{256, 65536},
{2048, 6144},
{16384, 128},
{32768, 160},
{4096, 1632},
{511, 6144},
{8192, 128},
{2048, 160},
{577, 1632},
{1024},
{8, 32, 1024},
{16, 8, 4, 512},
......@@ -528,26 +578,6 @@ class FusedCastMXFP8TestSuite : public ::testing::TestWithParam
transformer_engine::DType,
InputsFillCase>> {};
#define DACT_FUNC_SWITCH(OP_FUNC_TYPE, OP, ...) \
switch (OP_FUNC_TYPE) { \
case ActivationType::Identity: { constexpr auto OP = &identity; { __VA_ARGS__ } } break; \
case ActivationType::GeLU: { constexpr auto OP = &dgelu; { __VA_ARGS__ } } break; \
case ActivationType::SiLU: { constexpr auto OP = &dsilu; { __VA_ARGS__ } } break; \
case ActivationType::ReLU: { constexpr auto OP = &drelu; { __VA_ARGS__ } } break; \
case ActivationType::QGeLU: { constexpr auto OP = &dqgelu; { __VA_ARGS__ } } break; \
case ActivationType::SReLU: { constexpr auto OP = &dsrelu; { __VA_ARGS__ } } break; \
}
#define ACT_FUNC_SWITCH(OP_FUNC_TYPE, OP, ...) \
switch (OP_FUNC_TYPE) { \
case ActivationType::Identity: { constexpr auto OP = &identity; { __VA_ARGS__ } } break; \
case ActivationType::GeLU: { constexpr auto OP = &gelu; { __VA_ARGS__ } } break; \
case ActivationType::SiLU: { constexpr auto OP = &silu; { __VA_ARGS__ } } break; \
case ActivationType::ReLU: { constexpr auto OP = &relu; { __VA_ARGS__ } } break; \
case ActivationType::QGeLU: { constexpr auto OP = &qgelu; { __VA_ARGS__ } } break; \
case ActivationType::SReLU: { constexpr auto OP = &srelu; { __VA_ARGS__ } } break; \
}
TEST_P(FusedCastMXFP8TestSuite, TestFusedCastMXFP8) {
// Skip tests for pre-Blackwell architectures
if (getDeviceComputeCapability() < blackwellComputeCapability) {
......@@ -581,35 +611,48 @@ TEST_P(FusedCastMXFP8TestSuite, TestFusedCastMXFP8) {
const bool colwise = block_size.first != 1;
if (processing_method == ProcessingMethod::CAST_ACT) {
// Forward activations
ACT_FUNC_SWITCH(Act_type, OP,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType,
if (block_size.first == 1 || block_size.second == 1) {
performTest_x1<InputType, OutputType, OP>(
processing_method, matrix_size,
rowwise, colwise, fill_case);
} else {
performTest_x2<InputType, OutputType, OP>(
processing_method, matrix_size,
block_size.first, block_size.second, fill_case);
}
);
auto OP = &identity;
switch (Act_type) {
case ActivationType::GeLU: OP = &gelu; break;
case ActivationType::SiLU: OP = &silu; break;
case ActivationType::ReLU: OP = &relu; break;
case ActivationType::QGeLU: OP = &qgelu; break;
case ActivationType::SReLU: OP = &srelu; break;
}
TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType,
if (block_size.first == 1 || block_size.second == 1) {
performTest_x1<InputType, OutputType>(
processing_method, OP, matrix_size,
rowwise, colwise, fill_case);
} else {
performTest_x2<InputType, OutputType>(
processing_method, OP, matrix_size,
block_size.first, block_size.second, fill_case);
}
);
);
} else {
DACT_FUNC_SWITCH(Act_type, OP,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType,
if (block_size.first == 1 || block_size.second == 1) {
performTest_x1<InputType, OutputType, OP>(
processing_method, matrix_size,
rowwise, colwise, fill_case);
} else {
performTest_x2<InputType, OutputType, OP>(
processing_method, matrix_size,
block_size.first, block_size.second, fill_case);
}
);
auto OP = &identity;
switch (Act_type) {
case ActivationType::GeLU: OP = &dgelu; break;
case ActivationType::SiLU: OP = &dsilu; break;
case ActivationType::ReLU: OP = &drelu; break;
case ActivationType::QGeLU: OP = &dqgelu; break;
case ActivationType::SReLU: OP = &dsrelu; break;
}
TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType,
if (block_size.first == 1 || block_size.second == 1) {
performTest_x1<InputType, OutputType>(
processing_method, OP, matrix_size,
rowwise, colwise, fill_case);
} else {
performTest_x2<InputType, OutputType>(
processing_method, OP, matrix_size,
block_size.first, block_size.second, fill_case);
}
);
);
}
......
......@@ -18,107 +18,32 @@ using namespace test;
namespace {
template <bool IS_DGATED, typename IType, typename OType>
void scale_block(const IType* grad,
template <typename IType, typename OType>
void compute_ref(const IType* grad,
const IType* input,
OType* output,
fp8e8m0* output_scales,
const size_t scale_idx,
const size_t scale_idx_gate,
float& thread_amax,
const size_t i_min,
const size_t i_max,
const size_t j_min,
const size_t j_max,
const size_t cols) {
float block_amax = 0.0f;
float block_amax_gate = 0.0f;
const size_t stride = cols * 2;
// Find the absolute maximum value in the block
for (size_t i = i_min; i < i_max; ++i) {
for (size_t j = j_min; j < j_max; ++j) {
float silu_elt = static_cast<float>(input[i * stride + j]);
float gate_elt = static_cast<float>(input[i * stride + cols + j]);
float gated_amax_act = 0;
float gated_amax_gate = 0;
if constexpr (IS_DGATED) {
const float grad_elt = static_cast<float>(grad[i * cols + j]);
const float after_dsilu = dsilu(silu_elt) * grad_elt * gate_elt;
const float after_dgate = silu(silu_elt) * grad_elt;
gated_amax_act = abs(after_dsilu);
gated_amax_gate = abs(after_dgate);
} else {
const float after_silu = silu(silu_elt) * gate_elt;
gated_amax_act = abs(after_silu);
}
if (gated_amax_act > block_amax) { block_amax = gated_amax_act; }
if (gated_amax_gate > block_amax_gate) { block_amax_gate = gated_amax_gate; }
}
}
const fp8e8m0 biased_exponent = float_to_e8m0(block_amax *
Quantized_Limits<OType>::max_reciprocal());
const float scale_reciprocal = exp2f_rcp(biased_exponent);
output_scales[scale_idx] = biased_exponent;
float scale_reciprocal_gate = 1;
if constexpr (IS_DGATED) {
const fp8e8m0 biased_exponent = float_to_e8m0(block_amax_gate *
Quantized_Limits<OType>::max_reciprocal());
scale_reciprocal_gate = exp2f_rcp(biased_exponent);
output_scales[scale_idx_gate] = biased_exponent;
}
// Quantize elements in the block
for (size_t i = i_min; i < i_max; ++i) {
for (size_t j = j_min; j < j_max; ++j) {
float silu_elt = static_cast<float>(input[i * stride + j]);
float gate_elt = static_cast<float>(input[i * stride + cols + j]);
if constexpr (IS_DGATED) {
const float grad_elt = static_cast<float>(grad[i * cols + j]);
const float after_dsilu = dsilu(silu_elt) * grad_elt * gate_elt;
const float after_dgate = silu(silu_elt) * grad_elt;
output[i * stride + j] = static_cast<OType>(after_dsilu * scale_reciprocal);
output[i * stride + cols + j] = static_cast<OType>(after_dgate *
scale_reciprocal_gate);
} else {
const float after_silu = silu(silu_elt) * gate_elt;
output[i * cols + j] = static_cast<OType>(after_silu * scale_reciprocal);
}
}
}
thread_amax = std::max(thread_amax, block_amax);
thread_amax = std::max(thread_amax, block_amax_gate);
}
template <bool IS_DGATED, typename IType, typename OType>
void compute_ref_x1(const IType* grad,
const IType* input,
OType* output,
fp8e8m0* output_scales,
float& ref_amax,
const size_t rows,
const size_t cols,
const size_t block_size_Y,
const size_t block_size_X,
const size_t scales_stride) {
const size_t tile_size_Y = std::max(32lu, block_size_Y);
const size_t tile_size_X = std::max(64lu, block_size_X);
OType* output_rowwise,
OType* output_colwise,
fp8e8m0* output_scales_rowwise,
fp8e8m0* output_scales_colwise,
float& ref_amax,
const bool IS_DGATED,
const size_t rows,
const size_t cols,
const size_t scales_stride_rowwise,
const size_t scales_stride_colwise,
const bool is_rowwise,
const bool is_colwise) {
constexpr size_t tile_size_Y = 32;
constexpr size_t tile_size_X = 32;
const size_t tiles_num_Y = (rows + tile_size_Y - 1) / tile_size_Y;
const size_t tiles_num_X = (cols + tile_size_X - 1) / tile_size_X;
const size_t blocks_per_tile_Y = tile_size_Y / block_size_Y;
const size_t blocks_per_tile_X = tile_size_X / block_size_X;
float amax = 0;
#pragma omp parallel reduction(max: amax) proc_bind(spread)
{
float thread_amax = 0;
// Buffers to cache intermediate computations
std::vector<float> cache_buffer_act(tile_size_Y * tile_size_X);
std::vector<float> cache_buffer_gate(tile_size_Y * tile_size_X);
float thread_amax = 0.0f;
#pragma omp for schedule(static)
for (size_t t = 0; t < tiles_num_Y * tiles_num_X; ++t) {
const size_t tile_Y = t / tiles_num_X;
......@@ -126,26 +51,124 @@ void compute_ref_x1(const IType* grad,
const size_t tile_offset_Y = tile_Y * tile_size_Y;
const size_t tile_offset_X = tile_X * tile_size_X;
for (size_t ii = 0; ii < blocks_per_tile_Y; ++ii) {
const size_t block_idx_Y = tile_Y * blocks_per_tile_Y + ii;
const size_t block_offset_Y = ii * block_size_Y;
const size_t i_min = tile_offset_Y + block_offset_Y;
if (i_min >= rows) continue;
const size_t i_max = std::min(i_min + block_size_Y, rows);
for (size_t jj = 0; jj < blocks_per_tile_X; ++jj) {
const size_t block_idx_X = tile_X * blocks_per_tile_X + jj;
const size_t block_offset_X = jj * block_size_X;
const size_t j_min = tile_offset_X + block_offset_X;
if (j_min >= cols) continue;
const size_t j_max = std::min(j_min + block_size_X, cols);
const size_t mx_scale_idx = block_idx_Y * scales_stride + block_idx_X;
const size_t mx_scale_idx_gate = block_idx_Y * scales_stride + block_idx_X +
cols / block_size_X;
scale_block<IS_DGATED, IType, OType>(
grad, input, output, output_scales, mx_scale_idx, mx_scale_idx_gate,
thread_amax, i_min, i_max, j_min, j_max, cols);
const size_t stride = cols * 2;
const size_t i_min = tile_offset_Y;
const size_t i_max = std::min(rows, tile_offset_Y + tile_size_Y);
const size_t j_min = tile_offset_X;
const size_t j_max = std::min(cols, tile_offset_X + tile_size_X);
// Compute and cache activations for the entire tile
for (size_t i = i_min; i < i_max; ++i) {
for (size_t j = j_min; j < j_max; ++j) {
float silu_elt = static_cast<float>(input[i * stride + j]);
float gate_elt = static_cast<float>(input[i * stride + cols + j]);
const int cached_idx = (i - i_min) * tile_size_X + (j - j_min);
if (IS_DGATED) {
const float x = silu_elt;
const float s = sigmoid(x);
const float act_x = x * s;
const float dact_x = x * s * (1 - s) + s;
const float grad_elt = static_cast<float>(grad[i * cols + j]);
float after_dsilu = dact_x * grad_elt * gate_elt;
float after_dgate = act_x * grad_elt;
// Numerical truncation: after downcast to IType (BF16/FP16), upcast it back to FP32
after_dsilu = static_cast<float>(static_cast<IType>(after_dsilu));
after_dgate = static_cast<float>(static_cast<IType>(after_dgate));
cache_buffer_act[cached_idx] = after_dsilu;
cache_buffer_gate[cached_idx] = after_dgate;
thread_amax = std::max(thread_amax, std::abs(after_dsilu));
thread_amax = std::max(thread_amax, std::abs(after_dgate));
} else {
float after_silu = silu(silu_elt) * gate_elt;
// Numerical truncation: after downcast to IType (BF16/FP16), upcast it back to FP32
after_silu = static_cast<float>(static_cast<IType>(after_silu));
cache_buffer_act[cached_idx] = after_silu;
thread_amax = std::max(thread_amax, std::abs(after_silu));
}
}
}
if (is_rowwise) {
for (size_t i = i_min; i < i_max; ++i) {
float block_amax_act = 0.0f;
float block_amax_gate = 0.0f;
for (size_t j = j_min; j < j_max; ++j) {
const int cached_idx = (i - i_min) * tile_size_X + (j - j_min);
block_amax_act = std::max(block_amax_act, std::abs(cache_buffer_act[cached_idx]));
if (IS_DGATED) {
block_amax_gate = std::max(block_amax_gate, std::abs(cache_buffer_gate[cached_idx]));
}
}
const fp8e8m0 biased_exponent_act = float_to_e8m0(block_amax_act * Quantized_Limits<OType>::max_reciprocal());
const float scale_reciprocal_act = exp2f_rcp(biased_exponent_act);
const int scale_idx_act = i * scales_stride_rowwise + tile_X;
output_scales_rowwise[scale_idx_act] = biased_exponent_act;
float scale_reciprocal_gate;
if (IS_DGATED) {
const fp8e8m0 biased_exponent_gate = float_to_e8m0(block_amax_gate * Quantized_Limits<OType>::max_reciprocal());
scale_reciprocal_gate = exp2f_rcp(biased_exponent_gate);
const int scale_idx_gate = scale_idx_act + (cols + 32 - 1) / 32;
output_scales_rowwise[scale_idx_gate] = biased_exponent_gate;
}
for (size_t j = j_min; j < j_max; ++j) {
const int cached_idx = (i - i_min) * tile_size_X + (j - j_min);
const float after_act = cache_buffer_act[cached_idx] * scale_reciprocal_act;
if (IS_DGATED) {
const float after_gate = cache_buffer_gate[cached_idx] * scale_reciprocal_gate;
output_rowwise[i * stride + j] = static_cast<OType>(after_act);
output_rowwise[i * stride + cols + j] = static_cast<OType>(after_gate);
} else {
output_rowwise[i * cols + j] = static_cast<OType>(after_act);
}
}
}
}
if (is_colwise) {
for (size_t j = j_min; j < j_max; ++j) {
float block_amax_act = 0.0f;
float block_amax_gate = 0.0f;
for (size_t i = i_min; i < i_max; ++i) {
const int cached_idx = (i - i_min) * tile_size_X + (j - j_min);
block_amax_act = std::max(block_amax_act, std::abs(cache_buffer_act[cached_idx]));
if (IS_DGATED) {
block_amax_gate = std::max(block_amax_gate, std::abs(cache_buffer_gate[cached_idx]));
}
}
const fp8e8m0 biased_exponent_act = float_to_e8m0(block_amax_act * Quantized_Limits<OType>::max_reciprocal());
const float scale_reciprocal_act = exp2f_rcp(biased_exponent_act);
const int scale_idx_act = tile_Y * scales_stride_colwise + j;
output_scales_colwise[scale_idx_act] = biased_exponent_act;
float scale_reciprocal_gate;
if (IS_DGATED) {
const fp8e8m0 biased_exponent_gate = float_to_e8m0(block_amax_gate * Quantized_Limits<OType>::max_reciprocal());
const int scale_idx_gate = scale_idx_act + cols;
scale_reciprocal_gate = exp2f_rcp(biased_exponent_gate);
output_scales_colwise[scale_idx_gate] = biased_exponent_gate;
}
for (size_t i = i_min; i < i_max; ++i) {
const int cached_idx = (i - i_min) * tile_size_X + (j - j_min);
const float after_act = cache_buffer_act[cached_idx] * scale_reciprocal_act;
if (IS_DGATED) {
const float after_gate = cache_buffer_gate[cached_idx] * scale_reciprocal_gate;
output_colwise[i * stride + j] = static_cast<OType>(after_act);
output_colwise[i * stride + cols + j] = static_cast<OType>(after_gate);
} else {
output_colwise[i * cols + j] = static_cast<OType>(after_act);
}
}
}
}
}
......@@ -156,26 +179,6 @@ void compute_ref_x1(const IType* grad,
ref_amax = amax;
}
template <bool IS_DGATED, typename IType, typename OType>
void compute_ref_x2(const IType* grad,
const IType* input,
OType* output_rowwise,
OType* output_colwise,
fp8e8m0* scales_rowwise,
fp8e8m0* scales_colwise,
float& ref_amax,
const size_t rows,
const size_t cols,
const size_t block_size_Y,
const size_t block_size_X,
const size_t scales_stride_rowwise,
const size_t scales_stride_colwise) {
compute_ref_x1<IS_DGATED, IType, OType>(
grad, input, output_rowwise, scales_rowwise, ref_amax, rows, cols, 1, block_size_X, scales_stride_rowwise);
compute_ref_x1<IS_DGATED, IType, OType>(
grad, input, output_colwise, scales_colwise, ref_amax, rows, cols, block_size_Y, 1, scales_stride_colwise);
}
/**
* Scaling along single dimension (either rows or columns)
* Produces one set of output data and the corresponding data of the fused operation (dbias):
......@@ -183,12 +186,13 @@ void compute_ref_x2(const IType* grad,
* OR
* 2) Scaled columns + column-wise scaling factors
*/
template <bool IS_DGATED, typename IType, typename OType>
template <typename IType, typename OType>
void performTest_x1(const size_t rows,
const size_t cols,
const size_t block_size_rows,
const size_t block_size_cols,
InputsFillCase fill_case) {
InputsFillCase fill_case,
const bool IS_DGATED) {
using namespace test;
using EncodingType = fp32;
DType itype = TypeInfo<IType>::dtype;
......@@ -198,12 +202,6 @@ void performTest_x1(const size_t rows,
const bool colwise = (block_size_rows == 32) && (block_size_cols == 1);
NVTE_CHECK(rowwise || colwise);
// std::cout << "unpadded_blocks_Y: " << unpadded_blocks_Y << std::endl;
// std::cout << "unpadded_blocks_X: " << unpadded_blocks_X << std::endl;
// std::cout << "blocks_Y: " << blocks_Y << std::endl;
// std::cout << "blocks_X: " << blocks_X << std::endl;
// std::cout << "scales_stride: " << scales_stride << std::endl;
Tensor grad("grad", std::vector<size_t>{ rows, cols }, itype);
Tensor input("input", std::vector<size_t>{ rows, cols * 2 }, itype);
......@@ -229,12 +227,12 @@ void performTest_x1(const size_t rows,
}
// fillCase<EncodingType>(&grad, fill_case);
if constexpr (IS_DGATED) {
if (IS_DGATED) {
fillUniform(&grad);
}
fillUniform(&input);
if constexpr (IS_DGATED) {
if (IS_DGATED) {
nvte_dswiglu(grad.data(), input.data(), output.data(), 0);
} else {
nvte_swiglu(input.data(), output.data(), 0);
......@@ -245,30 +243,48 @@ void performTest_x1(const size_t rows,
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
float ref_amax = 0;
compute_ref_x1<IS_DGATED, IType, OType>(grad.rowwise_cpu_dptr<IType>(),
input.rowwise_cpu_dptr<IType>(),
ref_output.get(),
ref_output_scales.get(),
ref_amax,
rows,
cols,
block_size_rows,
block_size_cols,
scales_stride);
auto [atol, rtol] = getTolerances(otype);
compareResults("output", output, ref_output.get(), rowwise, atol, rtol);
compute_ref<IType, OType>(grad.rowwise_cpu_dptr<IType>(),
input.rowwise_cpu_dptr<IType>(),
ref_output.get(),
ref_output.get(),
ref_output_scales.get(),
ref_output_scales.get(),
ref_amax,
IS_DGATED,
rows,
cols,
scales_stride,
scales_stride,
rowwise,
colwise);
size_t mismatches_scales = 0;
const size_t scale_diff_abs_tolerance = 0;
const double abs_tolerable_mismatches_limit = 1.0;
const double rel_tolerable_mismatches_limit = 1.0e-4;
const uint8_t * const gpu_scales_ptr = rowwise
? output.rowwise_cpu_scale_inv_ptr<fp8e8m0>()
: output.columnwise_cpu_scale_inv_ptr<fp8e8m0>();
if (rowwise) {
compare_e8m0_scaling_factors("rowwise scales", gpu_scales_ptr, ref_output_scales.get(),
unpadded_blocks_Y, unpadded_blocks_X, scales_stride);
unpadded_blocks_Y, unpadded_blocks_X, scales_stride,
mismatches_scales,
scale_diff_abs_tolerance,
abs_tolerable_mismatches_limit,
rel_tolerable_mismatches_limit);
} else {
compare_e8m0_scaling_factors("colwise scales", gpu_scales_ptr, ref_output_scales.get(),
unpadded_blocks_Y, unpadded_blocks_X, scales_stride);
unpadded_blocks_Y, unpadded_blocks_X, scales_stride,
mismatches_scales,
scale_diff_abs_tolerance,
abs_tolerable_mismatches_limit,
rel_tolerable_mismatches_limit);
}
const size_t mismatches_elts = 32 * mismatches_scales;
auto [atol, rtol] = getTolerances(otype);
compareResults("output", output, ref_output.get(), rowwise, atol, rtol, true, mismatches_elts);
}
/**
......@@ -278,12 +294,13 @@ void performTest_x1(const size_t rows,
* AND
* 2) Scaled columns + column-wise scaling factors
*/
template <bool IS_DGATED, typename IType, typename OType>
template <typename IType, typename OType>
void performTest_x2(const size_t rows,
const size_t cols,
const size_t block_size_rows,
const size_t block_size_cols,
InputsFillCase fill_case) {
InputsFillCase fill_case,
const bool IS_DGATED) {
using namespace test;
using EncodingType = fp32;
DType itype = TypeInfo<IType>::dtype;
......@@ -325,12 +342,12 @@ void performTest_x2(const size_t rows,
}
// fillCase<EncodingType>(&grad, fill_case);
if constexpr (IS_DGATED) {
if (IS_DGATED) {
fillUniform(&grad);
}
fillUniform(&input);
if constexpr (IS_DGATED) {
if (IS_DGATED) {
nvte_dswiglu(grad.data(), input.data(), output.data(), 0);
} else {
nvte_swiglu(input.data(), output.data(), 0);
......@@ -341,30 +358,49 @@ void performTest_x2(const size_t rows,
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
float ref_amax = 0;
compute_ref_x2<IS_DGATED, IType, OType>(grad.rowwise_cpu_dptr<IType>(),
input.rowwise_cpu_dptr<IType>(),
ref_output_rowwise.get(),
ref_output_colwise.get(),
ref_scales_rowwise.get(),
ref_scales_colwise.get(),
ref_amax,
rows,
cols,
block_size_rows,
block_size_cols,
scales_stride_rowwise,
scales_stride_colwise);
auto [atol, rtol] = getTolerances(otype);
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("output_c_rowwise", output, ref_output_rowwise.get(), true, atol, rtol);
compareResults("output_c_colwise", output, ref_output_colwise.get(), false, atol, rtol);
compute_ref<IType, OType>(grad.rowwise_cpu_dptr<IType>(),
input.rowwise_cpu_dptr<IType>(),
ref_output_rowwise.get(),
ref_output_colwise.get(),
ref_scales_rowwise.get(),
ref_scales_colwise.get(),
ref_amax,
IS_DGATED,
rows,
cols,
scales_stride_rowwise,
scales_stride_colwise,
true,
true);
const size_t scale_diff_abs_tolerance = 0;
const double abs_tolerable_mismatches_limit = 1.0;
const double rel_tolerable_mismatches_limit = 1.0e-4;
size_t mismatches_scales_rowwise = 0;
compare_e8m0_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr<fp8e8m0>(),
ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise,
unpadded_blocks_X_rowwise, scales_stride_rowwise);
unpadded_blocks_X_rowwise, scales_stride_rowwise,
mismatches_scales_rowwise,
scale_diff_abs_tolerance,
abs_tolerable_mismatches_limit,
rel_tolerable_mismatches_limit);
size_t mismatches_scales_colwise = 0;
compare_e8m0_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr<fp8e8m0>(),
ref_scales_colwise.get(), unpadded_blocks_Y_colwise,
unpadded_blocks_X_colwise, scales_stride_colwise);
unpadded_blocks_X_colwise, scales_stride_colwise,
mismatches_scales_colwise,
scale_diff_abs_tolerance,
abs_tolerable_mismatches_limit,
rel_tolerable_mismatches_limit);
const size_t mismatches_elts_rowwise = 32 * mismatches_scales_rowwise;
const size_t mismatches_elts_colwise = 32 * mismatches_scales_colwise;
auto [atol, rtol] = getTolerances(otype);
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("output_c_rowwise", output, ref_output_rowwise.get(), true, atol, rtol, true, mismatches_elts_rowwise);
compareResults("output_c_colwise", output, ref_output_colwise.get(), false, atol, rtol, true, mismatches_elts_colwise);
}
std::vector<std::pair<size_t, size_t>> matrix_sizes = {
......@@ -375,8 +411,8 @@ std::vector<std::pair<size_t, size_t>> matrix_sizes = {
{256, 256},
{993, 512},
{768, 1024},
{65504, 128},
{16384, 1632},
{8192, 128},
{577, 1632},
};
std::vector<std::pair<size_t, size_t>> block_sizes = {
......@@ -393,9 +429,9 @@ std::vector<InputsFillCase> input_scenarios = {
// InputsFillCase::maxNorm_to_inf
};
std::vector<bool> is_dgated_op = {
true,
false
std::vector<bool> is_bwd_op = {
false,
true
};
} // namespace
......@@ -427,21 +463,11 @@ TEST_P(CastMXFP8_GatedActTestSuite, TestCastMXFP8Swiglu) {
TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OType,
if (block_size.first == 1 || block_size.second == 1) {
if (IS_DGATED) {
performTest_x1<true, IType, OType>(matrix_size.first, matrix_size.second,
block_size.first, block_size.second, fill_case);
} else {
performTest_x1<false, IType, OType>(matrix_size.first, matrix_size.second,
block_size.first, block_size.second, fill_case);
}
performTest_x1<IType, OType>(matrix_size.first, matrix_size.second,
block_size.first, block_size.second, fill_case, IS_DGATED);
} else {
if (IS_DGATED) {
performTest_x2<true, IType, OType>(matrix_size.first, matrix_size.second,
block_size.first, block_size.second, fill_case);
} else {
performTest_x2<false, IType, OType>(matrix_size.first, matrix_size.second,
block_size.first, block_size.second, fill_case);
}
performTest_x2<IType, OType>(matrix_size.first, matrix_size.second,
block_size.first, block_size.second, fill_case, IS_DGATED);
}
);
);
......@@ -456,7 +482,7 @@ INSTANTIATE_TEST_SUITE_P(
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2),
::testing::ValuesIn(input_scenarios),
::testing::ValuesIn(is_dgated_op)),
::testing::ValuesIn(is_bwd_op)),
[](const testing::TestParamInfo<CastMXFP8_GatedActTestSuite::ParamType>& info) {
std::string name = std::to_string(std::get<0>(info.param).first) + "X" +
std::to_string(std::get<0>(info.param).second) + "X" +
......@@ -465,6 +491,6 @@ INSTANTIATE_TEST_SUITE_P(
test::typeName(std::get<2>(info.param)) + "X" +
test::typeName(std::get<3>(info.param)) + "X" +
test::caseName(std::get<4>(info.param)) + "X" +
(std::get<5>(info.param) ? "DGATED" : "GATED");
(std::get<5>(info.param) ? "BWD" : "FWD");
return name;
});
......@@ -523,10 +523,13 @@ std::vector<size_t> unravel(const size_t i, const NVTEShape &shape) {
void compareResults_sequential(const std::string &name, const Tensor &test,
const void *ref, const bool rowwise,
double atol, double rtol, bool if_on_gpus) {
double atol, double rtol, bool if_on_gpus,
const size_t tolerable_mismatches_limit) {
if (if_on_gpus) test.to_cpu();
const auto& shape = rowwise ? test.rowwise_shape() : test.columnwise_shape();
const size_t N = product(shape);
size_t mismatches_num = 0;
int first_mismatch_idx = -1;
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(test.dtype(), T,
const T *test_data = rowwise ? test.rowwise_cpu_dptr<T>() : test.columnwise_cpu_dptr<T>();
const T *ref_data = reinterpret_cast<const T*>(ref);
......@@ -547,80 +550,102 @@ void compareResults_sequential(const std::string &name, const Tensor &test,
assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r));
}
std::string direction = rowwise ? "rowwise" : "columnwise";
ASSERT_FALSE(assertion) << "Error in tensor " << name << " in "
<< direction << " direction." << std::endl
<< "Mismatch at place " << to_string(unravel(i, shape))
<< " (" << std::to_string(i) << "): " << t << " vs " << r;
if (assertion) {
mismatches_num++;
if (first_mismatch_idx == -1) {
first_mismatch_idx = i;
}
}
if (mismatches_num > tolerable_mismatches_limit) {
const double first_mismatch_t = static_cast<double>(test_data[first_mismatch_idx]);
const double first_mismatch_r = static_cast<double>(ref_data[first_mismatch_idx]);
GTEST_FAIL() << mismatches_num << " mismatche(s) which is more than tolerable mismatch limit of "
<< tolerable_mismatches_limit << "." << std::endl
<< "Error in tensor " << name << " in "
<< direction << " direction." << std::endl
<< "First mismatch at place " << to_string(unravel(first_mismatch_idx, shape))
<< " (" << std::to_string(first_mismatch_idx) << "): "
<< first_mismatch_t << " vs " << first_mismatch_r;
}
}
);
}
template <typename T>
static size_t getFirstMismatchIdx(const DType data_type, const T* test_data, const T* ref_data,
const size_t N, const double atol, const double rtol) {
const size_t N, const double atol, const double rtol,
size_t& mismatches) {
int first_mismatch_idx = N;
bool is_mismatch_found = false;
#pragma omp parallel for schedule(static) firstprivate(is_mismatch_found) \
reduction(min: first_mismatch_idx) proc_bind(spread)
for (size_t i = 0; i < N; ++i) {
if (is_mismatch_found) { // early escape of the omp thread
continue;
}
double t = static_cast<double>(test_data[i]);
double r = static_cast<double>(ref_data[i]);
#pragma omp parallel reduction(min: first_mismatch_idx) reduction(+: mismatches) proc_bind(spread)
{
size_t thread_mismatches = 0;
#pragma omp for schedule(static)
for (size_t i = 0; i < N; ++i) {
double t = static_cast<double>(test_data[i]);
double r = static_cast<double>(ref_data[i]);
bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol);
/* For Float32 the floating point comparison is enough to error out */
bool assertion = mismatch && (data_type == DType::kFloat32);
if (mismatch && !assertion) {
/* Check if it is just a failure of round to nearest choosing different
side of the real value */
const double mean = (t + r) / 2;
const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6);
const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6);
const double cast_mean_p = static_cast<double>(static_cast<T>(mean_p));
const double cast_mean_m = static_cast<double>(static_cast<T>(mean_m));
assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r));
}
if (assertion && i < first_mismatch_idx) {
first_mismatch_idx = i;
is_mismatch_found = true;
bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol);
/* For Float32 the floating point comparison is enough to error out */
bool assertion = mismatch && (data_type == DType::kFloat32);
if (mismatch && !assertion) {
/* Check if it is just a failure of round to nearest choosing different
side of the real value */
const double mean = (t + r) / 2;
const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6);
const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6);
const double cast_mean_p = static_cast<double>(static_cast<T>(mean_p));
const double cast_mean_m = static_cast<double>(static_cast<T>(mean_m));
assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r));
}
if (assertion) {
if (i < first_mismatch_idx) {
first_mismatch_idx = i;
}
thread_mismatches++;
}
}
mismatches += thread_mismatches;
}
return first_mismatch_idx;
}
void compareResults_parallel(const std::string &name, const Tensor &test, const void *ref,
const bool rowwise, double atol, double rtol, bool if_on_gpus) {
const bool rowwise, double atol, double rtol, bool if_on_gpus,
const size_t tolerable_mismatches_limit) {
if (if_on_gpus) test.to_cpu();
const auto& shape = rowwise ? test.rowwise_shape() : test.columnwise_shape();
const size_t N = product(shape);
size_t mismatches = 0;
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(test.dtype(), T,
const T *test_data = rowwise ? test.rowwise_cpu_dptr<T>() : test.columnwise_cpu_dptr<T>();
const T *ref_data = reinterpret_cast<const T*>(ref);
const size_t i = getFirstMismatchIdx<T>(test.dtype(), test_data, ref_data, N, atol, rtol);
if (i != N) {
const size_t i = getFirstMismatchIdx<T>(test.dtype(), test_data, ref_data, N, atol, rtol, mismatches);
if ((i != N) && (mismatches > tolerable_mismatches_limit)) {
const double t = static_cast<double>(test_data[i]);
const double r = static_cast<double>(ref_data[i]);
std::string direction = rowwise ? "rowwise" : "columnwise";
ASSERT_FALSE(true) << "Error in tensor " << name << " in "
<< direction << " direction." << std::endl
<< "Mismatch at place " << to_string(unravel(i, shape))
<< " (" << std::to_string(i) << "): " << t << " vs " << r;
GTEST_FAIL() << mismatches << " mismatche(s) which is more than tolerable mismatch limit of "
<< tolerable_mismatches_limit << "." << std::endl
<< "Error in tensor " << name << " in "
<< direction << " direction." << std::endl
<< "Mismatch at place " << to_string(unravel(i, shape))
<< " (" << std::to_string(i) << "): " << t << " vs " << r;
}
);
}
void compareResults(const std::string &name, const Tensor &test, const void *ref,
const bool rowwise, double atol, double rtol, bool if_on_gpus) {
const bool rowwise, double atol, double rtol, bool if_on_gpus,
const size_t tolerable_mismatches_limit) {
constexpr bool sequential = false;
if constexpr (sequential) {
compareResults_sequential(name, test, ref, rowwise, atol, rtol, if_on_gpus);
compareResults_sequential(name, test, ref, rowwise, atol, rtol, if_on_gpus, tolerable_mismatches_limit);
} else {
compareResults_parallel(name, test, ref, rowwise, atol, rtol, if_on_gpus);
compareResults_parallel(name, test, ref, rowwise, atol, rtol, if_on_gpus, tolerable_mismatches_limit);
}
}
......@@ -657,25 +682,39 @@ void compareResults(const std::string &name, const uint8_t *test, const uint8_t
}
void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref,
const size_t row_blocks, const size_t col_blocks, const size_t stride)
const size_t row_blocks, const size_t col_blocks, const size_t stride,
size_t& mismatches_num, const size_t atol,
const double abs_tolerable_mismatches_limit,
const double rel_tolerable_mismatches_limit)
{
const size_t N = row_blocks * col_blocks;
const size_t tolerable_mismatches_limit = std::min(abs_tolerable_mismatches_limit,
std::floor(N * rel_tolerable_mismatches_limit));
mismatches_num = 0;
std::vector<int> mismatch_indices;
for (int i = 0; i < row_blocks; ++i) {
for (int j = 0; j < col_blocks; ++j) {
const int idx = i * stride + j;
ASSERT_FALSE(test[idx] != ref[idx]) << "Error in " << name << std::endl
<< "Mismatch: " << static_cast<int>(test[idx]) << " vs "
<< static_cast<int>(ref[idx]) << " at index " << idx;
}
}
}
const int test_val = static_cast<int>(test[idx]);
const int ref_val = static_cast<int>(ref[idx]);
const int abs_delta = std::abs(test_val - ref_val);
void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref,
const size_t N)
{
for (int i = 0; i < N; i++) {
ASSERT_FALSE(test[i] != ref[i]) << "Error in " << name << std::endl
<< "Mismatch: " << static_cast<int>(test[i]) << " vs "
<< static_cast<int>(ref[i]) << " at index " << i;
if (abs_delta > atol) {
mismatches_num++;
mismatch_indices.push_back(idx);
}
if (mismatches_num > tolerable_mismatches_limit) {
std::cout << "Error in " << name << std::endl;
for (const int index : mismatch_indices) {
std::cout << "Mismatch at (" << index << "):"
<< static_cast<int>(test[index]) << " vs "
<< static_cast<int>(ref[index]) << std::endl;
}
GTEST_FAIL() << mismatches_num << " mismatche(s) which is more than tolerable mismatch limit of "
<< tolerable_mismatches_limit << ".";
}
}
}
}
......
......@@ -413,7 +413,12 @@ inline fp8e8m0 float_to_e8m0(float val) {
}
inline float exp2f_rcp(fp8e8m0 biased_exp) {
return (biased_exp == 0) ? 1 : exp2f(FP32_EXPONENT_BIAS - static_cast<float>(biased_exp));
if (biased_exp == 0) {
return 1.0f;
}
int32_t int_val = (254 - biased_exp) << FP32_MANTISSA_BITS; // 127 - (biased_exp - 127)
float fp32_val = *reinterpret_cast<float*>(&int_val);
return fp32_val;
}
inline float identity(const float x) { return x; }
......@@ -445,15 +450,18 @@ size_t last_dimension(const std::vector<size_t> &shape);
bool areShapesEqual(const NVTEShape &s1, const NVTEShape &s2);
void compareResults(const std::string &name, const Tensor &test, const void *ref,
bool rowwise, double atol = 1e-5, double rtol = 1e-8, bool if_on_gpus = true);
bool rowwise, double atol = 1e-5, double rtol = 1e-8, bool if_on_gpus = true,
const size_t tolerable_mismatches_limit = 0);
void compareResults(const std::string &name, const float test, const float ref,
double atol = 1e-5, double rtol = 1e-8);
void compareResults(const std::string &name, const uint8_t *test, const uint8_t *ref,
size_t N, float mismatch_rate_tol = 0.);
void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref,
const size_t row_blocks, const size_t col_blocks, const size_t stride);
void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref,
const size_t N);
const size_t row_blocks, const size_t col_blocks, const size_t stride,
size_t& mismatches_num,
const size_t scale_diff_abs_tolerance = 0,
const double abs_tolerable_mismatches_limit = 0,
const double rel_tolerable_mismatches_limit = 0);
std::array<size_t, 4> get_scale_tensor_dims(const size_t rows, const size_t cols,
const size_t block_size_rows, const size_t block_size_cols);
......
......@@ -78,8 +78,14 @@ def is_shape_supported_by_mxfp8(input_shape):
return False
def assert_bitwise_scaled_tensors(a: ScaledTensor, b: ScaledTensor):
def assert_bitwise_scaled_tensors(
a: ScaledTensor, b: ScaledTensor, precise_comparison: bool = True
):
if isinstance(a, ScaledTensor1x) and isinstance(b, ScaledTensor1x):
if not precise_comparison:
assert_allclose(a.dequantize(), b.dequantize(), dtype=a.data.dtype)
return
assert a.scaling_mode == b.scaling_mode
assert a.scale_inv.dtype == b.scale_inv.dtype
if a.scaling_mode.is_tensor_scaling():
......@@ -94,8 +100,12 @@ def assert_bitwise_scaled_tensors(a: ScaledTensor, b: ScaledTensor):
assert_allclose(a.data, b.data)
elif isinstance(a, ScaledTensor2x) and isinstance(b, ScaledTensor2x):
assert_bitwise_scaled_tensors(a.rowwise_tensor, b.rowwise_tensor)
assert_bitwise_scaled_tensors(a.colwise_tensor, b.colwise_tensor)
assert_bitwise_scaled_tensors(
a.rowwise_tensor, b.rowwise_tensor, precise_comparison=precise_comparison
)
assert_bitwise_scaled_tensors(
a.colwise_tensor, b.colwise_tensor, precise_comparison=precise_comparison
)
else:
pytest.fail("Unsupported input types")
......@@ -481,24 +491,7 @@ class TestNorm:
# if the input dtype is not float32
precise_comparison = False
if precise_comparison:
assert_bitwise_scaled_tensors(output, ref_out)
else:
if isinstance(ref_out, ScaledTensor1x):
assert_allclose(output.dequantize(), ref_out.dequantize(), dtype=out_dtype)
elif isinstance(ref_out, ScaledTensor2x):
assert_allclose(
output.rowwise_tensor.dequantize(),
ref_out.rowwise_tensor.dequantize(),
dtype=out_dtype,
)
assert_allclose(
output.colwise_tensor.dequantize(),
ref_out.colwise_tensor.dequantize(),
dtype=out_dtype,
)
else:
pytest.fail("Unsupported output type")
assert_bitwise_scaled_tensors(output, ref_out, precise_comparison=precise_comparison)
assert_allclose(rsigma, ref_rsigma, dtype=inp_dtype)
if norm_type == "layernorm":
......@@ -768,12 +761,24 @@ class TestFusedQuantize:
)(dz, x)
if is_casted_output:
assert_bitwise_scaled_tensors(te_output, jax_output)
# TE kernels cast the intermediate results to the input dtype which reduces precision compared to the JAX implementation
precise_comparison = not (
in_dtype != jnp.float32 and scaling_mode.is_1d_block_scaling()
)
assert_bitwise_scaled_tensors(
te_output, jax_output, precise_comparison=precise_comparison
)
else:
assert_allclose(te_output, jax_output)
if is_dbias:
assert_allclose(te_dbias, jax_dbias)
# TE kernels cast the intermediate results to the input dtype which reduces precision compared to the JAX implementation, for dbias this typically only affects bfloat16.
precise_comparison = not (
in_dtype == jnp.bfloat16 and scaling_mode.is_1d_block_scaling()
)
assert_allclose(
te_dbias, jax_dbias, dtype=in_dtype if precise_comparison else out_dtype
)
@pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
@pytest_parametrize_wrapper("input_shape", ALL_ACTIVATION_SHAPES)
......
......@@ -192,6 +192,7 @@ if (NVTE_BUILD_ACTIVATION_WITH_FAST_MATH)
set_source_files_properties(activation/gelu.cu
activation/relu.cu
activation/swiglu.cu
util/cast.cu
PROPERTIES
COMPILE_OPTIONS "--use_fast_math")
endif()
......
......@@ -162,10 +162,10 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor,
void *dataPtr = reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(tensor.dptr) +
(offset_elems * type_num_bits) / 8);
NVTE_CHECK(is_aligned_ptr(dataPtr, TMA_gmem_alignment),
NVTE_CHECK(is_aligned_ptr(dataPtr, TMA_GMEM_ALIGNMENT),
"Tensor data pointer must be 16B aligned");
const int TMA_needed_size = (TMA_gmem_alignment * 8) / type_num_bits;
const int TMA_needed_size = (TMA_GMEM_ALIGNMENT * 8) / type_num_bits;
NVTE_CHECK(globalX % TMA_needed_size == 0, "Shape not supported. For ", type_num_bits,
"-bit data type, expected multiple of ", TMA_needed_size, ", got ", globalX);
......
......@@ -668,7 +668,8 @@ constexpr size_t scale_tensor_alignment_X_colwise = 128;
constexpr size_t scale_tensor_alignment_Y_colwise = 4;
// Alignment requirements for the Tensor Memory Accelerator (TMA)
constexpr int TMA_gmem_alignment = 16; // global memory address alignment
constexpr size_t TMA_GMEM_ALIGNMENT = 16; // global memory address alignment
constexpr size_t TMA_SHMEM_ALIGNMENT = 128; // shared memory address alignment
inline bool is_aligned_ptr(const void *ptr, size_t alignment) {
return reinterpret_cast<uintptr_t>(ptr) % alignment == 0;
......
......@@ -27,14 +27,8 @@
namespace transformer_engine {
template <typename T1, typename T2>
__device__ __host__ __forceinline__ uint64_t DIVUP_TO_MULTIPLE(T1 N, T2 M) {
return DIVUP(static_cast<uint64_t>(N), static_cast<uint64_t>(M)) * M;
}
namespace gated_kernels {
constexpr size_t ALIGNMENT_SIZE = 128;
constexpr size_t CHUNK_DIM_Y = 128;
constexpr size_t CHUNK_DIM_X = 128;
constexpr size_t THREADS_PER_CHUNK = 512;
......@@ -76,18 +70,19 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
float amax = 0;
const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1;
extern __shared__ char dshmem_unaligned[];
const uint64_t dshmem_unaligned_as_uint = reinterpret_cast<uint64_t>(dshmem_unaligned);
const uint64_t dshmem_aligned_as_uint =
DIVUP(dshmem_unaligned_as_uint, static_cast<uint64_t>(ALIGNMENT_SIZE)) * ALIGNMENT_SIZE;
char *dshmem = reinterpret_cast<char *>(dshmem_aligned_as_uint);
extern __shared__ char dynamic_shmem[];
uintptr_t base_shmem_ptr = reinterpret_cast<uintptr_t>(dynamic_shmem);
// Manually align dynamic SHMEM per TMA requirements using padding
// __align__(128) Does not guarantee the pointer to be aligned!
uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) &
~(static_cast<uintptr_t>(TMA_SHMEM_ALIGNMENT - 1));
constexpr size_t buff_elems = SHMEM_DIM_Y * SHMEM_DIM_X;
constexpr size_t buff_elems_total = BUFFERS_NUM * buff_elems;
constexpr size_t buff_size_aligned_in =
DIVUP(buff_elems_total * sizeof(IType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE;
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_aligned_out =
DIVUP(buff_elems_total * sizeof(OType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE;
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT);
constexpr size_t grad_mem = IS_DGATED ? buff_size_aligned_in : 0;
......@@ -96,8 +91,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
constexpr size_t in_mem = in_act_mem + in_gate_mem;
constexpr size_t out_act_mem = buff_size_aligned_out;
// const size_t in_transaction_size = grad_mem + in_mem;
constexpr size_t in_transaction_size = buff_elems * sizeof(IType);
// The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned
......@@ -269,9 +262,34 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
namespace mxfp8_kernel {
constexpr size_t CHUNK_DIM_Y = 64;
constexpr size_t CHUNK_DIM_X = 64;
constexpr size_t THREADS_PER_CHUNK_COLWISE = 128;
constexpr size_t THREADS_PER_CHUNK_NON_COLWISE = CHUNK_DIM_X;
constexpr size_t SCALE_DIM_Y = 32;
constexpr size_t SCALE_DIM_X = 32;
constexpr size_t BUFFS_NUM = 2;
constexpr size_t BUFF_DIM_Y = 32;
constexpr size_t BUFF_DIM_X = CHUNK_DIM_X;
constexpr size_t BUFF_DIM = BUFF_DIM_Y * BUFF_DIM_X;
static_assert(BUFF_DIM_Y == 32);
constexpr size_t PACK_SIZE = 4;
constexpr size_t WAVES = SCALE_DIM_X / PACK_SIZE;
// Number of 1-byte elements that span 32 banks (4-byte each) of shared memory
constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4) / 1; // 128
// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory
constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM_X; // 4 = 128 / 32
template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &), typename IType, typename OType,
size_t SCALE_DIM_Y, size_t SCALE_DIM_X>
bool ROWWISE_SCALING, bool COLWISE_SCALING, size_t THREADS_PER_CHUNK>
__global__ void __launch_bounds__(THREADS_PER_CHUNK)
cast_mxfp8_gated_kernel(const __grid_constant__ CUtensorMap tensor_map_grad,
const __grid_constant__ CUtensorMap tensor_map_input_act,
......@@ -284,43 +302,73 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const size_t rows, const size_t cols, const size_t scale_stride_rowwise,
const size_t scale_stride_colwise) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
constexpr bool USE_ROWWISE_SCALING = SCALE_DIM_X > 1;
constexpr bool USE_COLWISE_SCALING = SCALE_DIM_Y > 1;
using IType2 = typename ptx::FPx2<IType>;
using OType2 = typename ptx::FPx2<OType>;
constexpr size_t SCALES_ROWWISE_PER_CHUNK_Y = CHUNK_DIM_Y; // 128
constexpr size_t SCALES_ROWWISE_PER_CHUNK_X = CHUNK_DIM_X / SCALE_DIM_X; // 4 = 128 / 32
constexpr size_t STAGES = CHUNK_DIM_Y / BUFF_DIM_Y;
static_assert(STAGES >= 1);
constexpr size_t SCALES_COLWISE_PER_CHUNK_Y = CHUNK_DIM_Y / SCALE_DIM_Y; // 4 = 128 / 32
constexpr size_t SCALES_COLWISE_PER_CHUNK_X = CHUNK_DIM_X; // 128
constexpr bool IS_CACHED_ACT_OP = ROWWISE_SCALING && COLWISE_SCALING;
constexpr bool ONLY_COLWISE_SCALING = COLWISE_SCALING && (!ROWWISE_SCALING);
const int scales_rowwise_chunk_offset_Y = blockIdx.y * SCALES_ROWWISE_PER_CHUNK_Y;
const int scales_rowwise_chunk_offset_X = blockIdx.x * SCALES_ROWWISE_PER_CHUNK_X;
const int scales_colwise_chunk_offset_Y = blockIdx.y * SCALES_COLWISE_PER_CHUNK_Y;
const int scales_colwise_chunk_offset_X = blockIdx.x * SCALES_COLWISE_PER_CHUNK_X;
// # of rows covered by one wave. Equal to the # of columnwise threads in Y dimension.
constexpr int COLWISE_WAVEFRONT_SIZE = DIVUP(THREADS_PER_CHUNK, CHUNK_DIM_X);
const int chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y;
const int chunk_offset_X = blockIdx.x * CHUNK_DIM_X;
const int block_offset_Y = blockIdx.y * CHUNK_DIM_Y;
const int block_offset_X = blockIdx.x * CHUNK_DIM_X;
const int scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y;
const int scales_block_offset_X_rowwise = blockIdx.x * CHUNK_DIM_X / SCALE_DIM_X;
const int scales_block_offset_Y_colwise = blockIdx.y * CHUNK_DIM_Y / SCALE_DIM_Y;
const int scales_block_offset_X_colwise = blockIdx.x * CHUNK_DIM_X;
const int tid_Y = threadIdx.x / THREADS_PER_CHUNK_X;
const int tid_X = threadIdx.x % THREADS_PER_CHUNK_X;
constexpr size_t THREADS_X_ROWWISE = CHUNK_DIM_X / SCALE_DIM_X;
const int thread_offset_Y = tid_Y;
const int thread_offset_X = tid_X;
const int tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE;
const int tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE;
const int tid_Y_colwise = threadIdx.x / CHUNK_DIM_X;
const int tid_X_colwise = threadIdx.x % CHUNK_DIM_X;
const int thread_offset_Y_rowwise = tid_Y_rowwise;
const int thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM_X;
const int thread_offset_Y_colwise = tid_Y_colwise;
const int thread_offset_X_colwise = tid_X_colwise;
const int row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise;
const int col_base_rowwise = block_offset_X + thread_offset_X_rowwise;
const int row_base_colwise = block_offset_Y + thread_offset_Y_colwise;
const int col_base_colwise = block_offset_X + thread_offset_X_colwise;
const bool col_out_of_bounds_rowwise = (col_base_rowwise >= cols);
const bool col_out_of_bounds_colwise = (col_base_colwise >= cols);
const int scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise;
const int scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise;
const int scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise;
const int scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise;
const int gate_scale_idx_offset_rowwise = (cols + SCALE_DIM_X - 1) / SCALE_DIM_X;
const int gate_scale_idx_offset_colwise = cols;
// helps resolving bank conflicts in shmem
const int thread_lane = threadIdx.x % THREADS_PER_WARP;
const int bank_group = thread_lane / THREADS_PER_BANK;
const bool col_out_of_bounds = (chunk_offset_X + thread_offset_X >= cols);
constexpr int SUBAMAX_BUFF_DIM_Y = ONLY_COLWISE_SCALING ? COLWISE_WAVEFRONT_SIZE - 1 : 1;
__shared__ float subamax_colwise_buff[SUBAMAX_BUFF_DIM_Y][CHUNK_DIM_X];
extern __shared__ char dshmem_unaligned[];
const uint64_t dshmem_unaligned_as_uint = reinterpret_cast<uint64_t>(dshmem_unaligned);
const uint64_t dshmem_aligned_as_uint =
DIVUP(dshmem_unaligned_as_uint, static_cast<uint64_t>(ALIGNMENT_SIZE)) * ALIGNMENT_SIZE;
char *dshmem = reinterpret_cast<char *>(dshmem_aligned_as_uint);
extern __shared__ char dynamic_shmem[];
uintptr_t base_shmem_ptr = reinterpret_cast<uintptr_t>(dynamic_shmem);
// Manually align dynamic SHMEM per TMA requirements using padding
// __align__(128) Does not guarantee the pointer to be aligned!
uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) &
~(static_cast<uintptr_t>(TMA_SHMEM_ALIGNMENT - 1));
const size_t buff_elems = SHMEM_DIM_Y * SHMEM_DIM_X;
const size_t buff_elems_total = BUFFERS_NUM * buff_elems;
const size_t buff_size_aligned_in =
DIVUP(buff_elems_total * sizeof(IType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE;
const size_t buff_size_aligned_out =
DIVUP(buff_elems_total * sizeof(OType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE;
constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X;
constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems;
constexpr size_t buff_size_aligned_in =
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_aligned_out =
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT);
const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0);
......@@ -329,12 +377,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const size_t in_mem = in_act_mem + in_gate_mem;
const size_t out_act_mem = buff_size_aligned_out;
const size_t out_gate_mem = buff_size_aligned_out;
const size_t out_gate_mem = (IS_DGATED ? buff_size_aligned_out : 0);
const size_t out_mem = out_act_mem + out_gate_mem;
// const size_t in_transaction_size = grad_mem + in_mem;
const size_t in_transaction_size = (IS_DGATED ? 3 : 2) * buff_elems * sizeof(IType);
// The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned
IType *in_grad_sh = reinterpret_cast<IType *>(dshmem);
IType *in_act_sh = reinterpret_cast<IType *>(dshmem + grad_mem);
......@@ -346,374 +391,493 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
OType *out_act_colwise_sh = out_act_rowwise_sh;
OType *out_gate_colwise_sh = out_gate_rowwise_sh;
if constexpr (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) {
if constexpr (ROWWISE_SCALING && COLWISE_SCALING) {
out_act_colwise_sh = reinterpret_cast<OType *>(dshmem + grad_mem + in_mem + out_mem);
out_gate_colwise_sh =
reinterpret_cast<OType *>(dshmem + grad_mem + in_mem + out_mem + out_act_mem);
}
const uint64_t *TMAP_grad_in = reinterpret_cast<const uint64_t *>(&tensor_map_grad);
const uint64_t *TMAP_in_act = reinterpret_cast<const uint64_t *>(&tensor_map_input_act);
const uint64_t *TMAP_in_gate = reinterpret_cast<const uint64_t *>(&tensor_map_input_gate);
const uint64_t *TMAP_output_act_rowwise =
reinterpret_cast<const uint64_t *>(&tensor_map_output_act_rowwise);
const uint64_t *TMAP_output_gate_rowwise =
reinterpret_cast<const uint64_t *>(&tensor_map_output_gate_rowwise);
const uint64_t *TMAP_output_act_colwise =
reinterpret_cast<const uint64_t *>(&tensor_map_output_act_colwise);
const uint64_t *TMAP_output_gate_colwise =
reinterpret_cast<const uint64_t *>(&tensor_map_output_gate_colwise);
IType *cached_act_sh = in_act_sh; // in_act_sh is used as a cache buffer for activations
IType *cached_gate_sh = in_gate_sh; // in_gate_sh is used as a cache buffer for gated values
__shared__ float stage_amax_sh[THREADS_PER_CHUNK_Y][CHUNK_DIM_X];
constexpr int shmem_buff_size = buff_size_aligned_in / BUFFS_NUM;
const bool is_master_thread = (threadIdx.x == 0);
// Initialize shared memory barrier with the number of threads participating in the barrier.
#pragma nv_diag_suppress static_var_with_dynamic_init
__shared__ alignas(8) uint64_t mbar[ITERATIONS];
const bool is_master_thread = (threadIdx.x == 0);
__shared__ alignas(8) uint64_t mbar[STAGES];
if (is_master_thread) {
// Initialize barrier. All `blockDim.x * blockDim.y` threads in block participate.
#pragma unroll
for (int it = 0; it < ITERATIONS; ++it) {
ptx::mbarrier_init(&mbar[it], THREADS_PER_CHUNK);
}
ptx::fence_proxy_async_shared_cta();
}
// Syncthreads so initialized barrier is visible to all threads.
__syncthreads();
initialize_barriers<STAGES, THREADS_PER_CHUNK>(mbar, is_master_thread);
int parity = 0;
// Prefetch data of the first stage
if (is_master_thread) {
// Initiate bulk tensor copy
// Grad
if constexpr (IS_DGATED) {
ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast<uint64_t *>(&in_grad_sh[0]),
TMAP_grad_in, chunk_offset_X, chunk_offset_Y,
&mbar[0]);
}
// Act
ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast<uint64_t *>(&in_act_sh[0]),
TMAP_in_act, chunk_offset_X, chunk_offset_Y,
&mbar[0]);
// Gate
ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast<uint64_t *>(&in_gate_sh[0]),
TMAP_in_gate, chunk_offset_X, chunk_offset_Y,
&mbar[0]);
// Arrive on the barrier and tell how many bytes are expected to come in.
ptx::mbarrier_arrive_expect_tx(&mbar[0], in_transaction_size);
if constexpr (IS_DGATED) {
copy_2d_to_sharedx3(&in_grad_sh[0], &tensor_map_grad, block_offset_X, block_offset_Y,
&in_act_sh[0], &tensor_map_input_act, block_offset_X, block_offset_Y,
&in_gate_sh[0], &tensor_map_input_gate, block_offset_X, block_offset_Y,
shmem_buff_size, &mbar[0], is_master_thread);
} else {
// Other threads just arrive
ptx::mbarrier_arrive(&mbar[0]);
copy_2d_to_sharedx2(&in_act_sh[0], &tensor_map_input_act, block_offset_X, block_offset_Y,
&in_gate_sh[0], &tensor_map_input_gate, block_offset_X, block_offset_Y,
shmem_buff_size, &mbar[0], is_master_thread);
}
#pragma unroll
for (int it = 0; it < ITERATIONS; ++it) {
const int buff = it % BUFFERS_NUM;
const int next_it = it + 1;
const size_t row_base = chunk_offset_Y + it * BUFFER_DIM_Y;
if (next_it < ITERATIONS) {
if (is_master_thread) {
const int next_buff = next_it % BUFFERS_NUM;
const int chunk_it_offset_y = chunk_offset_Y + next_it * BUFFER_DIM_Y;
const int chunk_it_offset_x = chunk_offset_X;
// Initiate bulk tensor copy
if constexpr (IS_DGATED) {
// Grad
ptx::cp_async_bulk_tensor_2d_global_to_shared(
reinterpret_cast<uint64_t *>(&in_grad_sh[next_buff * buff_elems]), TMAP_grad_in,
chunk_it_offset_x, chunk_it_offset_y, &mbar[next_it]);
}
// Act
ptx::cp_async_bulk_tensor_2d_global_to_shared(
reinterpret_cast<uint64_t *>(&in_act_sh[next_buff * buff_elems]), TMAP_in_act,
chunk_it_offset_x, chunk_it_offset_y, &mbar[next_it]);
// Gate
ptx::cp_async_bulk_tensor_2d_global_to_shared(
reinterpret_cast<uint64_t *>(&in_gate_sh[next_buff * buff_elems]), TMAP_in_gate,
chunk_it_offset_x, chunk_it_offset_y, &mbar[next_it]);
// Arrive on the barrier and tell how many bytes are expected to come in.
ptx::mbarrier_arrive_expect_tx(&mbar[next_it], in_transaction_size);
for (int stage = 0; stage < STAGES; ++stage) {
const int buff = stage % BUFFS_NUM;
const int next_stage = stage + 1;
const int stage_offset_Y = stage * BUFF_DIM_Y;
if (next_stage < STAGES) {
// Wait for TMA transfer to have finished reading shared memory.
// I.e. the buffer is ready to be written to
ptx::cp_async_bulk_wait_group_read<1>();
const int next_buff = next_stage % BUFFS_NUM;
const int next_stage_offset_Y = next_stage * BUFF_DIM_Y;
const int global_offset_Y = block_offset_Y + next_stage_offset_Y;
const int global_offset_X = block_offset_X;
const int next_buff_offset = next_buff * BUFF_DIM;
if constexpr (IS_DGATED) {
copy_2d_to_sharedx3(&in_grad_sh[next_buff_offset], &tensor_map_grad, global_offset_X,
global_offset_Y, &in_act_sh[next_buff_offset], &tensor_map_input_act,
global_offset_X, global_offset_Y, &in_gate_sh[next_buff_offset],
&tensor_map_input_gate, global_offset_X, global_offset_Y,
shmem_buff_size, &mbar[next_stage], is_master_thread);
} else {
// Other threads just arrive
ptx::mbarrier_arrive(&mbar[next_it]);
copy_2d_to_sharedx2(&in_act_sh[next_buff_offset], &tensor_map_input_act, global_offset_X,
global_offset_Y, &in_gate_sh[next_buff_offset], &tensor_map_input_gate,
global_offset_X, global_offset_Y, shmem_buff_size, &mbar[next_stage],
is_master_thread);
}
}
ptx::fence_proxy_async_shared_cta();
// Wait for the data to have arrived
ptx::mbarrier_wait_parity(&mbar[it], parity);
ptx::mbarrier_wait_parity(&mbar[stage], parity);
IType *in_grad_sh_curr = in_grad_sh + buff * buff_elems;
IType *in_act_sh_curr = in_act_sh + buff * buff_elems;
IType *in_gate_sh_curr = in_gate_sh + buff * buff_elems;
OType *out_act_rowwise_sh_curr = out_act_rowwise_sh + buff * buff_elems;
OType *out_gate_rowwise_sh_curr = out_gate_rowwise_sh + buff * buff_elems;
OType *out_act_colwise_sh_curr = out_act_colwise_sh + buff * buff_elems;
OType *out_gate_colwise_sh_curr = out_gate_colwise_sh + buff * buff_elems;
// Assuming one iteration covers exactly 32 rows
const int iteration_scale_colwise_offset_Y = scales_colwise_chunk_offset_Y + it;
const int iteration_scale_rowwise_offset_Y = scales_rowwise_chunk_offset_Y + it * BUFFER_DIM_Y;
float after_dact_reg[BUFFER_STAGES_NUM];
float after_dgate_reg[BUFFER_STAGES_NUM];
float thread_Y_mx_block_amax = 0.0f;
float thread_Y_mx_block_amax_gate = 0.0f;
if constexpr (COLWISE_SCALING) {
const int shmem_offset_base_colwise =
buff * BUFF_DIM + tid_Y_colwise * BUFF_DIM_X + tid_X_colwise;
float thread_amax_act = 0.0f;
float thread_amax_gate = 0.0f;
float after_act_colwise[BUFF_DIM_Y / COLWISE_WAVEFRONT_SIZE];
float after_gate_colwise[BUFF_DIM_Y / COLWISE_WAVEFRONT_SIZE];
// 1. Read/Compute elements. Find MXFP8-block AMAX
#pragma unroll
for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) {
const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y;
const int shmem_offset_y = thread_offset_Y + stage_offset_Y;
const int shmem_offset_x = thread_offset_X;
const int shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x;
const size_t row = row_base + shmem_offset_y;
const bool row_out_of_bounds = (row >= rows);
const bool out_of_bounds = (col_out_of_bounds || row_out_of_bounds);
float act_elt = static_cast<float>(in_act_sh_curr[shmem_idx]);
float gate_elt = static_cast<float>(in_gate_sh_curr[shmem_idx]);
for (int i = 0; i < SCALE_DIM_Y / COLWISE_WAVEFRONT_SIZE; ++i) {
const int shmem_offset_colwise =
shmem_offset_base_colwise + i * COLWISE_WAVEFRONT_SIZE * BUFF_DIM_X;
if constexpr (IS_DGATED) {
float grad_elt = static_cast<float>(in_grad_sh_curr[shmem_idx]);
const float x = act_elt;
float act_x;
float dact_x;
float act_elt = static_cast<float>(in_act_sh[shmem_offset_colwise]);
float gate_elt = static_cast<float>(in_gate_sh[shmem_offset_colwise]);
float after_act_elt;
float after_gate_elt;
if constexpr ((ActOP == &silu<fp32, fp32>) && (DActOP == &dsilu<fp32, fp32>)) {
const float s = sigmoidf(x);
act_x = x * s;
dact_x = x * s * (1 - s) + s;
if constexpr (IS_DGATED) {
float grad_elt = static_cast<float>(in_grad_sh[shmem_offset_colwise]);
const float x = act_elt;
float act_x;
float dact_x;
if constexpr ((ActOP == &silu<fp32, fp32>) && (DActOP == &dsilu<fp32, fp32>)) {
const float s = sigmoidf(x);
act_x = x * s;
dact_x = x * s * (1 - s) + s;
} else {
act_x = ActOP(x, {});
dact_x = DActOP(x, {});
}
after_act_elt = dact_x * grad_elt * gate_elt;
after_gate_elt = act_x * grad_elt;
} else {
act_x = ActOP(x, {});
dact_x = DActOP(x, {});
after_act_elt = ActOP(act_elt, {}) * gate_elt;
}
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
if constexpr (!std::is_same_v<IType, float>) {
after_act_elt = static_cast<float>(static_cast<IType>(after_act_elt));
if constexpr (IS_DGATED) {
after_gate_elt = static_cast<float>(static_cast<IType>(after_gate_elt));
}
}
after_dact_reg[stage] = dact_x * grad_elt * gate_elt;
after_dgate_reg[stage] = act_x * grad_elt;
} else {
after_dact_reg[stage] = ActOP(act_elt, {}) * gate_elt;
}
if constexpr (USE_ROWWISE_SCALING) {
after_act_colwise[i] = after_act_elt;
if constexpr (IS_DGATED) {
// dgate
float amax = fabsf(after_dgate_reg[stage]);
const float mx_block_X_amax = warp_reduce_max_broadcast(amax);
const e8m0_t biased_exponent_X =
float_to_e8m0(mx_block_X_amax * Quantized_Limits<OType>::max_norm_rcp);
const float scale_reciprocal_X = exp2f_rcp(biased_exponent_X);
out_gate_rowwise_sh_curr[shmem_idx] =
static_cast<OType>(scale_reciprocal_X * after_dgate_reg[stage]);
// Only single thread writes the computed scaling factor
if ((tid_X % SCALE_DIM_X == 0) && !out_of_bounds) {
const int global_scales_offset_Y =
iteration_scale_rowwise_offset_Y + stage_offset_Y + thread_offset_Y;
const int global_scales_offset_X =
scales_rowwise_chunk_offset_X + (tid_X + cols) / SCALE_DIM_X;
const int scale_idx =
global_scales_offset_Y * scale_stride_rowwise + global_scales_offset_X;
scales_rowwise[scale_idx] = biased_exponent_X;
}
after_gate_colwise[i] = after_gate_elt;
}
float amax = fabsf(after_dact_reg[stage]);
const float mx_block_X_amax = warp_reduce_max_broadcast(amax);
const e8m0_t biased_exponent_X =
float_to_e8m0(mx_block_X_amax * Quantized_Limits<OType>::max_norm_rcp);
const float scale_reciprocal_X = exp2f_rcp(biased_exponent_X);
out_act_rowwise_sh_curr[shmem_idx] =
static_cast<OType>(scale_reciprocal_X * after_dact_reg[stage]);
// Only single thread writes the computed scaling factor
if ((tid_X % SCALE_DIM_X == 0) && !out_of_bounds) {
const int global_scales_offset_Y =
iteration_scale_rowwise_offset_Y + stage_offset_Y + thread_offset_Y;
const int global_scales_offset_X = scales_rowwise_chunk_offset_X + tid_X / SCALE_DIM_X;
const int scale_idx =
global_scales_offset_Y * scale_stride_rowwise + global_scales_offset_X;
scales_rowwise[scale_idx] = biased_exponent_X;
// Cache computed activations to avoid computing them again in the 2nd pass along another dimension
if constexpr (IS_CACHED_ACT_OP) {
cached_act_sh[shmem_offset_colwise] = static_cast<IType>(after_act_elt);
if constexpr (IS_DGATED) {
cached_gate_sh[shmem_offset_colwise] = static_cast<IType>(after_gate_elt);
}
}
}
if constexpr (USE_COLWISE_SCALING) {
__builtin_assume(thread_Y_mx_block_amax >= 0);
__builtin_assume(thread_Y_mx_block_amax_gate >= 0);
thread_Y_mx_block_amax = fmaxf(thread_Y_mx_block_amax, fabsf(after_dact_reg[stage]));
if constexpr (IS_DGATED) {
thread_Y_mx_block_amax_gate =
fmaxf(thread_Y_mx_block_amax_gate, fabsf(after_dgate_reg[stage]));
const bool row_out_of_bounds_colwise = (row_base_colwise + stage_offset_Y + i >= rows);
const bool out_of_bounds = (col_out_of_bounds_colwise || row_out_of_bounds_colwise);
if (!out_of_bounds) {
thread_amax_act = fmaxf(thread_amax_act, fabsf(after_act_elt));
if constexpr (IS_DGATED) {
thread_amax_gate = fmaxf(thread_amax_gate, fabsf(after_gate_elt));
}
}
}
}
if constexpr (USE_COLWISE_SCALING) {
const bool row_out_of_bounds = (row_base >= rows);
const bool out_of_bounds = (col_out_of_bounds || row_out_of_bounds);
if constexpr (IS_DGATED) {
// Colwise max reduction of the amax element
if (tid_Y > 0) {
stage_amax_sh[tid_Y][tid_X] = thread_Y_mx_block_amax_gate;
if constexpr (ONLY_COLWISE_SCALING) {
// Threads, whose id along Y-dim is 0, don't need to store to shared memory,
// as they manage the columwise reduction of the amax
if (tid_Y_colwise > 0) {
subamax_colwise_buff[tid_Y_colwise - 1][tid_X_colwise] = thread_amax_act;
}
__syncthreads();
if (tid_Y == 0) {
if (tid_Y_colwise == 0) {
#pragma unroll
for (int y = 1; y < THREADS_PER_CHUNK_Y; ++y) {
thread_Y_mx_block_amax_gate =
fmaxf(thread_Y_mx_block_amax_gate, stage_amax_sh[y][tid_X]);
for (int t = 0; t < SUBAMAX_BUFF_DIM_Y; ++t) {
const float other_thread_amax = subamax_colwise_buff[t][tid_X_colwise];
__builtin_assume(thread_amax_act >= 0);
__builtin_assume(other_thread_amax >= 0);
thread_amax_act = fmaxf(thread_amax_act, other_thread_amax);
}
stage_amax_sh[0][tid_X] = thread_Y_mx_block_amax_gate; // write mx column-block amax
subamax_colwise_buff[0][tid_X_colwise] = thread_amax_act;
}
__syncthreads();
const float mx_block_Y_amax = stage_amax_sh[0][tid_X]; // read the mx column-block amax
// All threads read the reduced amax (ACT)
thread_amax_act = subamax_colwise_buff[0][tid_X_colwise];
if constexpr (IS_DGATED) {
// Make sure the previous read of the ACT values has been completed,
// so the data are not rewritten
__syncthreads();
if (tid_Y_colwise > 0) {
subamax_colwise_buff[tid_Y_colwise - 1][tid_X_colwise] = thread_amax_gate;
}
__syncthreads();
if (tid_Y_colwise == 0) {
#pragma unroll
for (int t = 0; t < SUBAMAX_BUFF_DIM_Y; ++t) {
const float other_thread_amax = subamax_colwise_buff[t][tid_X_colwise];
__builtin_assume(thread_amax_gate >= 0);
__builtin_assume(other_thread_amax >= 0);
thread_amax_gate = fmaxf(thread_amax_gate, other_thread_amax);
}
subamax_colwise_buff[0][tid_X_colwise] = thread_amax_gate;
}
__syncthreads();
// For the scaling along both dimensions, the thread amax is already computed in ROWWISE section
if constexpr (!USE_ROWWISE_SCALING) {
__builtin_assume(mx_block_Y_amax >= 0);
// All threads read the reduced amax (GATE)
thread_amax_gate = subamax_colwise_buff[0][tid_X_colwise];
}
}
// 2. Compute E8M0 scaling factor
const e8m0_t biased_exponent_act =
ptx::float_to_e8m0(thread_amax_act * Quantized_Limits<OType>::max_norm_rcp);
const int global_scales_offset_Y = scales_offset_Y_colwise + stage;
const int global_scales_offset_X = scales_offset_X_colwise;
const int scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X;
const bool row_out_of_bounds_colwise = (row_base_colwise + stage_offset_Y) >= rows;
const bool out_of_bounds_colwise = row_out_of_bounds_colwise || col_out_of_bounds_colwise;
if (tid_Y_colwise == 0 && (!out_of_bounds_colwise)) {
scales_colwise[scale_idx] = biased_exponent_act;
}
const e8m0_t biased_exponent =
float_to_e8m0(mx_block_Y_amax * Quantized_Limits<OType>::max_norm_rcp);
const float scale_reciprocal = exp2f_rcp(biased_exponent);
// Only single thread writes the computed scaling factor
// Also assuming one iteration covers exactly 32 rows
if ((tid_Y == 0) && !out_of_bounds) {
const int global_scales_offset_Y = iteration_scale_colwise_offset_Y;
const int global_scales_offset_X = scales_colwise_chunk_offset_X + tid_X + cols;
const int scale_idx =
global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X;
scales_colwise[scale_idx] = biased_exponent;
float block_scale_inverse_act = ptx::exp2f_rcp(biased_exponent_act);
float block_scale_inverse_gate;
if constexpr (IS_DGATED) {
const e8m0_t biased_exponent_gate =
ptx::float_to_e8m0(thread_amax_gate * Quantized_Limits<OType>::max_norm_rcp);
// const int scale_idx_gate = scale_idx + scale_stride_colwise / 2;
const int scale_idx_gate = scale_idx + gate_scale_idx_offset_colwise;
if (tid_Y_colwise == 0 && (!out_of_bounds_colwise)) {
scales_colwise[scale_idx_gate] = biased_exponent_gate;
}
block_scale_inverse_gate = ptx::exp2f_rcp(biased_exponent_gate);
}
// 3. Scale elements
#pragma unroll
for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) {
const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y;
const int shmem_offset_y = thread_offset_Y + stage_offset_Y;
const int shmem_offset_x = thread_offset_X;
const int shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x;
out_gate_colwise_sh_curr[shmem_idx] =
static_cast<OType>(scale_reciprocal * after_dgate_reg[stage]);
for (int i = 0; i < SCALE_DIM_Y / COLWISE_WAVEFRONT_SIZE; ++i) {
const int shmem_offset_elt =
shmem_offset_base_colwise + i * COLWISE_WAVEFRONT_SIZE * BUFF_DIM_X;
if constexpr (IS_DGATED) {
OType2 out_pair;
ptx::floatx2 in_pair = {after_act_colwise[i], after_gate_colwise[i]};
const ptx::floatx2 block_scale_inverse_2x_pair = {block_scale_inverse_act,
block_scale_inverse_gate};
ptx::mul_cvt_2x(out_pair, in_pair, block_scale_inverse_2x_pair);
out_act_colwise_sh[shmem_offset_elt] = out_pair.x;
out_gate_colwise_sh[shmem_offset_elt] = out_pair.y;
} else {
const float scaled_out_act = block_scale_inverse_act * after_act_colwise[i];
out_act_colwise_sh[shmem_offset_elt] = static_cast<OType>(scaled_out_act);
}
}
// Colwise max reduction of the amax element
if (tid_Y > 0) {
stage_amax_sh[tid_Y][tid_X] = thread_Y_mx_block_amax;
}
__syncthreads();
if (tid_Y == 0) {
}
if constexpr (ROWWISE_SCALING) {
const int shmem_offset_base_rowwise = buff * BUFF_DIM + thread_offset_Y_rowwise * BUFF_DIM_X;
float thread_amax_act = 0.0f;
float thread_amax_gate = 0.0f;
Vec<IType, PACK_SIZE> in_cached_act[WAVES];
Vec<IType, PACK_SIZE> in_cached_gate[WAVES];
float after_act_rowwise[SCALE_DIM_X];
float after_gate_rowwise[SCALE_DIM_X];
// 1. Read/Compute elements. Find MXFP8-block AMAX
if constexpr (IS_CACHED_ACT_OP) {
// ensures that all writes to cache made in the section above are visible to all threads
__syncthreads();
IType2 thread_amax_2x_act = {static_cast<IType>(0.0f), static_cast<IType>(0.0f)};
IType2 thread_amax_2x_gate = {static_cast<IType>(0.0f), static_cast<IType>(0.0f)};
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const int shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx;
const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows);
const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols);
const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds);
// Load cached elements
in_cached_act[w].load_from(&cached_act_sh[shmem_offset_rowwise]);
if constexpr (IS_DGATED) {
in_cached_gate[w].load_from(&cached_gate_sh[shmem_offset_rowwise]);
}
// Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements)
// only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries
if (!out_of_bounds) {
if constexpr (std::is_same_v<IType, float>) {
#pragma unroll
for (int y = 1; y < THREADS_PER_CHUNK_Y; ++y) {
thread_Y_mx_block_amax = fmaxf(thread_Y_mx_block_amax, stage_amax_sh[y][tid_X]);
for (int e = 0; e < PACK_SIZE; ++e) {
thread_amax_act = fmaxf(thread_amax_act, fabsf(in_cached_act[w].data.elt[e]));
if constexpr (IS_DGATED) {
thread_amax_gate = fmaxf(thread_amax_gate, fabsf(in_cached_gate[w].data.elt[e]));
}
}
} else {
#pragma unroll
for (int e = 0; e < PACK_SIZE; e += 2) {
const IType2 in_cached_2x_act = {in_cached_act[w].data.elt[e],
in_cached_act[w].data.elt[e + 1]};
ptx::abs_max_2x(thread_amax_2x_act, thread_amax_2x_act, in_cached_2x_act);
if constexpr (IS_DGATED) {
const IType2 in_cached_2x_gate = {in_cached_gate[w].data.elt[e],
in_cached_gate[w].data.elt[e + 1]};
ptx::abs_max_2x(thread_amax_2x_gate, thread_amax_2x_gate, in_cached_2x_gate);
}
}
}
}
}
stage_amax_sh[0][tid_X] = thread_Y_mx_block_amax; // write mx column-block amax
}
__syncthreads();
if constexpr (!std::is_same_v<IType, float>) {
thread_amax_act = static_cast<float>(
__hmax(__habs(thread_amax_2x_act.x), __habs(thread_amax_2x_act.y)));
if constexpr (IS_DGATED) {
thread_amax_gate = static_cast<float>(
__hmax(__habs(thread_amax_2x_gate.x), __habs(thread_amax_2x_gate.y)));
}
}
} else {
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const int shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx;
Vec<IType, PACK_SIZE> in_grad;
Vec<IType, PACK_SIZE> in_act;
Vec<IType, PACK_SIZE> in_gate;
in_act.load_from(&in_act_sh[shmem_offset_rowwise]);
in_gate.load_from(&in_gate_sh[shmem_offset_rowwise]);
if constexpr (IS_DGATED) {
in_grad.load_from(&in_grad_sh[shmem_offset_rowwise]);
}
const float mx_block_Y_amax = stage_amax_sh[0][tid_X]; // read the mx column-block amax
#pragma unroll
for (int e = 0; e < PACK_SIZE; ++e) {
const int j = w * PACK_SIZE + e;
float act_elt = static_cast<float>(in_act.data.elt[e]);
float gate_elt = static_cast<float>(in_gate.data.elt[e]);
float after_act_elt;
float after_gate_elt;
if constexpr (IS_DGATED) {
float grad_elt = static_cast<float>(in_grad.data.elt[e]);
const float x = act_elt;
float act_x;
float dact_x;
if constexpr ((ActOP == &silu<fp32, fp32>) && (DActOP == &dsilu<fp32, fp32>)) {
const float s = sigmoidf(x);
act_x = x * s;
dact_x = x * s * (1 - s) + s;
} else {
act_x = ActOP(x, {});
dact_x = DActOP(x, {});
}
after_act_elt = dact_x * grad_elt * gate_elt;
after_gate_elt = act_x * grad_elt;
after_act_rowwise[j] = after_act_elt;
after_gate_rowwise[j] = after_gate_elt;
} else {
after_act_elt = ActOP(act_elt, {}) * gate_elt;
after_act_rowwise[j] = after_act_elt;
}
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
if constexpr (!std::is_same_v<IType, float>) {
after_act_elt = static_cast<float>(static_cast<IType>(after_act_elt));
if constexpr (IS_DGATED) {
after_gate_elt = static_cast<float>(static_cast<IType>(after_gate_elt));
}
}
const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows);
const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols);
const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds);
if (!out_of_bounds) {
thread_amax_act = fmaxf(thread_amax_act, fabsf(after_act_elt));
if constexpr (IS_DGATED) {
thread_amax_gate = fmaxf(thread_amax_gate, fabsf(after_gate_elt));
}
}
}
}
}
// For the scaling along both dimensions, the thread amax is already computed in ROWWISE section
if constexpr (!USE_ROWWISE_SCALING) {
__builtin_assume(mx_block_Y_amax >= 0);
// 2. Compute E8M0 scaling factor
const e8m0_t biased_exponent_act =
ptx::float_to_e8m0(thread_amax_act * Quantized_Limits<OType>::max_norm_rcp);
const int stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y;
const int stage_scales_offset_X = scales_offset_X_rowwise;
const int scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X;
const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y) >= rows;
const bool out_of_bounds_rowwise = row_out_of_bounds_rowwise || col_out_of_bounds_rowwise;
if (!out_of_bounds_rowwise) {
scales_rowwise[scale_idx] = biased_exponent_act;
}
const e8m0_t biased_exponent =
float_to_e8m0(mx_block_Y_amax * Quantized_Limits<OType>::max_norm_rcp);
const float scale_reciprocal = exp2f_rcp(biased_exponent);
// Only single thread writes the computed scaling factor
// Also assuming one iteration covers exactly 32 rows
if ((tid_Y == 0) && !out_of_bounds) {
const int global_scales_offset_Y = iteration_scale_colwise_offset_Y;
const int global_scales_offset_X = scales_colwise_chunk_offset_X + tid_X;
const int scale_idx =
global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X;
scales_colwise[scale_idx] = biased_exponent;
const float block_scale_inverse_act = ptx::exp2f_rcp(biased_exponent_act);
const ptx::floatx2 block_scale_inverse_2x_act = {block_scale_inverse_act,
block_scale_inverse_act};
float block_scale_inverse_gate;
ptx::floatx2 block_scale_inverse_2x_gate;
if constexpr (IS_DGATED) {
const e8m0_t biased_exponent_gate =
ptx::float_to_e8m0(thread_amax_gate * Quantized_Limits<OType>::max_norm_rcp);
const int scale_idx_gate = scale_idx + gate_scale_idx_offset_rowwise;
if (!out_of_bounds_rowwise) {
scales_rowwise[scale_idx_gate] = biased_exponent_gate;
}
block_scale_inverse_gate = ptx::exp2f_rcp(biased_exponent_gate);
block_scale_inverse_2x_gate = {block_scale_inverse_gate, block_scale_inverse_gate};
}
// 3. Scale elements
#pragma unroll
for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) {
const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y;
const int shmem_offset_y = thread_offset_Y + stage_offset_Y;
const int shmem_offset_x = thread_offset_X;
const int shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x;
out_act_colwise_sh_curr[shmem_idx] =
static_cast<OType>(scale_reciprocal * after_dact_reg[stage]);
for (int w = 0; w < WAVES; ++w) {
Vec<OType2, PACK_SIZE / 2> out_act;
Vec<OType2, PACK_SIZE / 2> out_gate;
#pragma unroll
for (int e = 0; e < PACK_SIZE / 2; ++e) {
IType2 in_act;
OType2 &out_act_pair = reinterpret_cast<OType2 &>(out_act.data.elt[e]);
if constexpr (IS_CACHED_ACT_OP) {
in_act.x = in_cached_act[w].data.elt[2 * e];
in_act.y = in_cached_act[w].data.elt[2 * e + 1];
} else {
const int j = w * PACK_SIZE + 2 * e;
in_act.x = after_act_rowwise[j];
in_act.y = after_act_rowwise[j + 1];
}
ptx::mul_cvt_2x(out_act_pair, in_act, block_scale_inverse_2x_act);
if constexpr (IS_DGATED) {
IType2 in_gate;
OType2 &out_gate_pair = reinterpret_cast<OType2 &>(out_gate.data.elt[e]);
if constexpr (IS_CACHED_ACT_OP) {
in_gate.x = in_cached_gate[w].data.elt[2 * e];
in_gate.y = in_cached_gate[w].data.elt[2 * e + 1];
} else {
const int j = w * PACK_SIZE + 2 * e;
in_gate.x = after_gate_rowwise[j];
in_gate.y = after_gate_rowwise[j + 1];
}
ptx::mul_cvt_2x(out_gate_pair, in_gate, block_scale_inverse_2x_gate);
}
}
const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const int swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise;
const int shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx;
out_act.store_to(&out_act_rowwise_sh[shmem_offset_rowwise]);
if constexpr (IS_DGATED) {
out_gate.store_to(&out_gate_rowwise_sh[shmem_offset_rowwise]);
}
}
} // endif USE_COLWISE_SCALING
}
// Wait for shared memory writes to be visible to TMA engine (cross-proxy fence)
// Wait for shared memory writes to be visible to TMA engine.
ptx::fence_proxy_async_shared_cta();
__syncthreads();
// After syncthreads, writes by all threads are visible to TMA engine.
// Initiate TMA transfer to copy shared memory to global memory
if (is_master_thread) {
const int chunk_it_offset_y = chunk_offset_Y + it * BUFFER_DIM_Y;
const int chunk_it_offset_x = chunk_offset_X;
const int global_offset_Y = block_offset_Y + stage_offset_Y;
const int global_offset_X = block_offset_X;
const int buff_offset = buff * BUFF_DIM;
// dGeLU
if constexpr (USE_ROWWISE_SCALING) {
if constexpr (ROWWISE_SCALING) {
ptx::cp_async_bulk_tensor_2d_shared_to_global(
TMAP_output_act_rowwise, chunk_it_offset_x, chunk_it_offset_y,
reinterpret_cast<uint64_t *>(out_act_rowwise_sh_curr));
reinterpret_cast<const uint64_t *>(&tensor_map_output_act_rowwise), global_offset_X,
global_offset_Y, reinterpret_cast<uint64_t *>(&out_act_rowwise_sh[buff_offset]));
if constexpr (IS_DGATED) {
// dGate
ptx::cp_async_bulk_tensor_2d_shared_to_global(
TMAP_output_gate_rowwise, chunk_it_offset_x, chunk_it_offset_y,
reinterpret_cast<uint64_t *>(out_gate_rowwise_sh_curr));
reinterpret_cast<const uint64_t *>(&tensor_map_output_gate_rowwise), global_offset_X,
global_offset_Y, reinterpret_cast<uint64_t *>(&out_gate_rowwise_sh[buff_offset]));
}
}
// dGeLU
if constexpr (USE_COLWISE_SCALING) {
if constexpr (COLWISE_SCALING) {
ptx::cp_async_bulk_tensor_2d_shared_to_global(
TMAP_output_act_colwise, chunk_it_offset_x, chunk_it_offset_y,
reinterpret_cast<uint64_t *>(out_act_colwise_sh_curr));
reinterpret_cast<const uint64_t *>(&tensor_map_output_act_colwise), global_offset_X,
global_offset_Y, reinterpret_cast<uint64_t *>(&out_act_colwise_sh[buff_offset]));
if constexpr (IS_DGATED) {
// dGate
ptx::cp_async_bulk_tensor_2d_shared_to_global(
TMAP_output_gate_colwise, chunk_it_offset_x, chunk_it_offset_y,
reinterpret_cast<uint64_t *>(out_gate_colwise_sh_curr));
reinterpret_cast<const uint64_t *>(&tensor_map_output_gate_colwise), global_offset_X,
global_offset_Y, reinterpret_cast<uint64_t *>(&out_gate_colwise_sh[buff_offset]));
}
}
// Create a "bulk async-group" out of the previous bulk copy operation.
ptx::cp_async_bulk_commit_group();
// Wait for TMA transfer to have finished reading shared memory.
ptx::cp_async_bulk_wait_group_read<BUFFERS_NUM - 1>();
}
}
ptx::cp_async_bulk_wait_group_read<0>();
__syncthreads();
// Destroy the barriers. This invalidates the memory region of the barrier.
// If further computations were to take place in the kernel, this allows the
// memory location of the shared memory barrier to be reused.
if (is_master_thread) {
#pragma unroll
for (int it = 0; it < ITERATIONS; ++it) {
ptx::mbarrier_invalid(&mbar[it]);
}
}
parity ^= 1;
destroy_barriers<STAGES>(mbar, is_master_thread);
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
} // namespace mxfp8_kernel
template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &)>
......@@ -771,17 +935,16 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu
const size_t buff_elems_total = BUFFERS_NUM * SHMEM_DIM_Y * SHMEM_DIM_X;
const size_t buff_size_aligned_in =
DIVUP(buff_elems_total * sizeof(IType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE;
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT);
const size_t buff_size_aligned_out =
DIVUP(buff_elems_total * sizeof(OType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE;
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT);
const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0);
const size_t in_act_mem = buff_size_aligned_in;
const size_t in_gate_mem = buff_size_aligned_in;
const size_t out_act_mem = buff_size_aligned_out;
const size_t out_gate_mem = buff_size_aligned_out;
// const size_t mbar_mem = ITERATIONS * sizeof(uint64_t);
const size_t shmem_size = ALIGNMENT_SIZE + grad_mem + (in_act_mem + in_gate_mem) +
(out_act_mem + out_gate_mem); // + mbar_mem;
const size_t shmem_size = grad_mem + (in_act_mem + in_gate_mem) +
(out_act_mem + out_gate_mem) + TMA_SHMEM_ALIGNMENT;
cudaFuncSetAttribute(
cast_fp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType, OType>,
......@@ -809,16 +972,34 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, "Scaling tensor must be allocated.");
}
// TODO: Make more general
const size_t scale_dim_X_rowwise = USE_ROWWISE_SCALING ? 32 : 1;
const size_t scale_dim_Y_colwise = USE_COLWISE_SCALING ? 32 : 1;
ScalingType scaling_type;
if (USE_ROWWISE_SCALING && (!USE_COLWISE_SCALING)) {
scaling_type = ScalingType::ROWWISE;
} else if ((!USE_ROWWISE_SCALING) && USE_COLWISE_SCALING) {
scaling_type = ScalingType::COLWISE;
} else if (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) {
scaling_type = ScalingType::BIDIMENSIONAL;
}
const size_t rows = gated_input.flat_first_dim();
const size_t cols = gated_input.flat_last_dim() / 2;
const size_t output_cols = (IS_DGATED ? 2 : 1) * cols;
const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y);
const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X);
constexpr size_t BUFF_DIM_Y = mxfp8_kernel::BUFF_DIM_Y;
constexpr size_t BUFF_DIM_X = mxfp8_kernel::BUFF_DIM_X;
constexpr size_t BUFFS_NUM = mxfp8_kernel::BUFFS_NUM;
const size_t blocks_Y = DIVUP(rows, mxfp8_kernel::CHUNK_DIM_Y);
const size_t blocks_X = DIVUP(cols, mxfp8_kernel::CHUNK_DIM_X);
constexpr size_t THREADS_PER_CHUNK_COLWISE = mxfp8_kernel::THREADS_PER_CHUNK_COLWISE;
constexpr size_t THREADS_PER_CHUNK_NON_COLWISE = mxfp8_kernel::THREADS_PER_CHUNK_NON_COLWISE;
const size_t THREADS_PER_CHUNK = (scaling_type == ScalingType::COLWISE)
? THREADS_PER_CHUNK_COLWISE
: THREADS_PER_CHUNK_NON_COLWISE;
const dim3 grid(blocks_X, blocks_Y);
const dim3 block_size(THREADS_PER_CHUNK);
size_t scale_stride_rowwise = USE_ROWWISE_SCALING ? output->scale_inv.shape[1] : 1;
size_t scale_stride_colwise = USE_COLWISE_SCALING ? output->columnwise_scale_inv.shape[1] : 1;
......@@ -828,94 +1009,122 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
e8m0_t *const scales_colwise_ptr =
USE_COLWISE_SCALING ? reinterpret_cast<e8m0_t *>(output->columnwise_scale_inv.dptr) : nullptr;
const dim3 block_dim(THREADS_PER_CHUNK);
const dim3 grid_dim(blocks_X, blocks_Y);
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
gated_input.dtype(), IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
output->dtype(), OType,
alignas(64) CUtensorMap tensor_map_grad{};
alignas(64) CUtensorMap tensor_map_input_act{};
alignas(64) CUtensorMap tensor_map_input_gate{};
alignas(64) CUtensorMap tensor_map_output_act_rowwise{};
alignas(64) CUtensorMap tensor_map_output_gate_rowwise{};
alignas(64) CUtensorMap tensor_map_output_act_colwise{};
alignas(64) CUtensorMap tensor_map_output_gate_colwise{};
constexpr size_t input_type_bit_size = TypeInfo<IType>::size;
constexpr size_t output_type_bit_size = TypeInfo<OType>::size;
if constexpr (IS_DGATED) {
create_2D_tensor_map(tensor_map_grad, grad.data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X,
cols, 0, input_type_bit_size);
}
const uint32_t tensor_stride_elems = output_cols;
create_2D_tensor_map(tensor_map_input_act, gated_input.data, rows, cols, BUFF_DIM_Y,
BUFF_DIM_X, cols * 2, 0, input_type_bit_size);
create_2D_tensor_map(tensor_map_input_gate, gated_input.data, rows, cols, BUFF_DIM_Y,
BUFF_DIM_X, cols * 2, cols, input_type_bit_size);
if (USE_ROWWISE_SCALING) {
create_2D_tensor_map(tensor_map_output_act_rowwise, output->data, rows, cols,
BUFF_DIM_Y, BUFF_DIM_X, tensor_stride_elems, 0,
output_type_bit_size);
create_2D_tensor_map(tensor_map_output_gate_rowwise, output->data, rows, cols,
BUFF_DIM_Y, BUFF_DIM_X, tensor_stride_elems, cols,
output_type_bit_size);
}
if (USE_COLWISE_SCALING) {
create_2D_tensor_map(tensor_map_output_act_colwise, output->columnwise_data, rows, cols,
BUFF_DIM_Y, BUFF_DIM_X, tensor_stride_elems, 0,
output_type_bit_size);
create_2D_tensor_map(tensor_map_output_gate_colwise, output->columnwise_data, rows,
cols, BUFF_DIM_Y, BUFF_DIM_X, tensor_stride_elems, cols,
output_type_bit_size);
}
const size_t buff_elems_total = BUFFS_NUM * BUFF_DIM_Y * BUFF_DIM_X;
const size_t input_buff_size = (buff_elems_total * input_type_bit_size) / 8;
const size_t output_buff_size = (buff_elems_total * output_type_bit_size) / 8;
const size_t buff_size_aligned_in =
DIVUP_TO_MULTIPLE(input_buff_size, TMA_SHMEM_ALIGNMENT);
const size_t buff_size_aligned_out =
DIVUP_TO_MULTIPLE(output_buff_size, TMA_SHMEM_ALIGNMENT);
const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0);
const size_t in_act_mem = buff_size_aligned_in;
const size_t in_gate_mem = buff_size_aligned_in;
const size_t in_mem = grad_mem + in_act_mem + in_gate_mem;
TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH(
scale_dim_Y_colwise, SCALE_DIM_Y,
TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH(
scale_dim_X_rowwise, SCALE_DIM_X,
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
gated_input.dtype(), IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
output->dtype(), OType,
alignas(64) CUtensorMap tensor_map_grad{};
alignas(64) CUtensorMap tensor_map_input_act{};
alignas(64) CUtensorMap tensor_map_input_gate{};
alignas(64) CUtensorMap tensor_map_output_act_rowwise{};
alignas(64) CUtensorMap tensor_map_output_gate_rowwise{};
alignas(64) CUtensorMap tensor_map_output_act_colwise{};
alignas(64) CUtensorMap tensor_map_output_gate_colwise{};
if constexpr (IS_DGATED) {
create_2D_tensor_map(tensor_map_grad, grad.data, rows, cols, SHMEM_DIM_Y,
SHMEM_DIM_X, cols, 0, typeToNumBits(gated_input.dtype()));
}
const uint32_t tensor_stride_elems = output_cols;
create_2D_tensor_map(tensor_map_input_act, gated_input.data, rows, cols,
SHMEM_DIM_Y, SHMEM_DIM_X, cols * 2, 0,
typeToNumBits(gated_input.dtype()));
create_2D_tensor_map(tensor_map_input_gate, gated_input.data, rows, cols,
SHMEM_DIM_Y, SHMEM_DIM_X, cols * 2, cols,
typeToNumBits(gated_input.dtype()));
if (USE_ROWWISE_SCALING) {
create_2D_tensor_map(tensor_map_output_act_rowwise, output->data, rows, cols,
SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems, 0,
typeToNumBits(output->dtype()));
create_2D_tensor_map(tensor_map_output_gate_rowwise, output->data, rows, cols,
SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems, cols,
typeToNumBits(output->dtype()));
}
if (USE_COLWISE_SCALING) {
create_2D_tensor_map(tensor_map_output_act_colwise, output->columnwise_data,
rows, cols, SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems,
0, typeToNumBits(output->dtype()));
create_2D_tensor_map(tensor_map_output_gate_colwise, output->columnwise_data,
rows, cols, SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems,
cols, typeToNumBits(output->dtype()));
}
const size_t buff_elems_total = BUFFERS_NUM * SHMEM_DIM_Y * SHMEM_DIM_X;
const size_t buff_size_aligned_in =
DIVUP(buff_elems_total * sizeof(IType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE;
const size_t buff_size_aligned_out =
DIVUP(buff_elems_total * sizeof(OType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE;
const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0);
const size_t in_act_mem = buff_size_aligned_in;
const size_t in_gate_mem = buff_size_aligned_in;
const size_t in_mem = grad_mem + in_act_mem + in_gate_mem;
const size_t out_act_mem = buff_size_aligned_out;
const size_t out_gate_mem = buff_size_aligned_out;
size_t out_mem = out_act_mem + out_gate_mem;
if (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) { out_mem *= 2; }
// const size_t mbar_mem = ITERATIONS * sizeof(uint64_t);
// const size_t shmem_size = ALIGNMENT_SIZE + in_mem + out_mem + mbar_mem;
const size_t shmem_size = ALIGNMENT_SIZE + in_mem + out_mem;
cudaFuncSetAttribute(
cast_mxfp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType, OType,
SCALE_DIM_Y, SCALE_DIM_X>,
cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size);
cast_mxfp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType, OType,
SCALE_DIM_Y, SCALE_DIM_X>
<<<grid_dim, block_dim, shmem_size, stream>>>(
const size_t out_act_mem = buff_size_aligned_out;
const size_t out_gate_mem = (IS_DGATED ? buff_size_aligned_out : 0);
size_t out_mem = out_act_mem + out_gate_mem;
if (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) { out_mem *= 2; }
const size_t shmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT;
switch (scaling_type) {
case ScalingType::ROWWISE:
cudaFuncSetAttribute(
mxfp8_kernel::cast_mxfp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType,
OType, true, false,
THREADS_PER_CHUNK_NON_COLWISE>,
cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size);
mxfp8_kernel::cast_mxfp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType, OType,
true, false, THREADS_PER_CHUNK_NON_COLWISE>
<<<grid, block_size, shmem_size, stream>>>(
tensor_map_grad, tensor_map_input_act, tensor_map_input_gate,
tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise,
tensor_map_output_act_colwise, tensor_map_output_gate_colwise,
scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise);
break;
case ScalingType::COLWISE:
cudaFuncSetAttribute(
mxfp8_kernel::cast_mxfp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType,
OType, false, true,
THREADS_PER_CHUNK_COLWISE>,
cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size);
mxfp8_kernel::cast_mxfp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType, OType,
false, true, THREADS_PER_CHUNK_COLWISE>
<<<grid, block_size, shmem_size, stream>>>(
tensor_map_grad, tensor_map_input_act, tensor_map_input_gate,
tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise,
tensor_map_output_act_colwise, tensor_map_output_gate_colwise,
scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise);); // NOLINT(*)
); // NOLINT(*)
); // NOLINT(*)
); // NOLINT(*)
scale_stride_colwise);
break;
case ScalingType::BIDIMENSIONAL:
cudaFuncSetAttribute(
mxfp8_kernel::cast_mxfp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType,
OType, true, true,
THREADS_PER_CHUNK_NON_COLWISE>,
cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size);
mxfp8_kernel::cast_mxfp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType, OType,
true, true, THREADS_PER_CHUNK_NON_COLWISE>
<<<grid, block_size, shmem_size, stream>>>(
tensor_map_grad, tensor_map_input_act, tensor_map_input_gate,
tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise,
tensor_map_output_act_colwise, tensor_map_output_gate_colwise,
scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise);
break;
}); // NOLINT(*)
); // NOLINT(*)
}
template <typename ParamOP, float (*ActOP)(float, const ParamOP &)>
......
......@@ -28,36 +28,25 @@
namespace transformer_engine {
constexpr size_t MXFP8_CHUNK_DIM_Y = 64;
constexpr size_t MXFP8_CHUNK_DIM_X = 64;
constexpr size_t MXFP8_CHUNKS_PER_BLOCK_Y = 1;
constexpr size_t MXFP8_CHUNKS_PER_BLOCK_X = 1;
constexpr size_t MXFP8_CHUNKS_PER_BLOCK = MXFP8_CHUNKS_PER_BLOCK_Y * MXFP8_CHUNKS_PER_BLOCK_X;
constexpr size_t MXFP8_THREADS_PER_CHUNK = 64;
constexpr size_t MXFP8_BUFFERS_NUM = 2;
constexpr size_t MXFP8_PREFETCH_BUFFERS_NUM = 1;
static_assert(MXFP8_PREFETCH_BUFFERS_NUM < MXFP8_BUFFERS_NUM);
constexpr size_t ELEMS_PER_THREAD = 16;
constexpr size_t MXFP8_BUFFER_DIM_Y = 32; // only 32 is supported
constexpr size_t MXFP8_BUFFER_DIM_X = MXFP8_CHUNK_DIM_X; // 64
constexpr size_t MXFP8_SHMEM_DIM_Y = MXFP8_BUFFER_DIM_Y; // 32
constexpr size_t MXFP8_SHMEM_DIM_X = MXFP8_BUFFER_DIM_X; // 64
constexpr size_t THREADS_PER_CHUNK_X_ROWWISE =
MXFP8_CHUNK_DIM_X / ELEMS_PER_THREAD; // 4 = 64 / 16
constexpr size_t THREADS_PER_CHUNK_Y_ROWWISE =
MXFP8_THREADS_PER_CHUNK / THREADS_PER_CHUNK_X_ROWWISE; // 16 = 64 / 4
constexpr size_t THREADS_PER_CHUNK_X_COLWISE = MXFP8_CHUNK_DIM_X; // 64
constexpr size_t MXFP8_BUFF_STAGES_NUM =
MXFP8_BUFFER_DIM_Y / THREADS_PER_CHUNK_Y_ROWWISE; // 2 = 32 / 16
constexpr size_t MXFP8_ITERATIONS = MXFP8_CHUNK_DIM_Y / MXFP8_BUFFER_DIM_Y; // 2 = 64 / 32
static_assert(MXFP8_ITERATIONS >= MXFP8_PREFETCH_BUFFERS_NUM);
namespace mxfp8_kernel {
constexpr size_t SCALE_DIM_Y = 32;
constexpr size_t SCALE_DIM_X = 32;
constexpr size_t BUFFS_NUM = 2;
constexpr size_t PACK_SIZE = 4;
constexpr size_t WAVES = SCALE_DIM_X / PACK_SIZE;
// Number of 1-byte elements that span 32 banks (4-byte each) of shared memory
constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4) / 1; // 128
// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory
constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM_X; // 4 = 128 / 32
template <bool IS_DBIAS, bool IS_DACT, bool IS_ACT, typename ParamOP,
float (*OP)(float, const ParamOP &), typename IType, typename OType, size_t SCALE_DIM_Y,
size_t SCALE_DIM_X>
__global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK)
float (*OP)(float, const ParamOP &), typename IType, typename OType, bool ROWWISE_SCALING,
bool COLWISE_SCALING, size_t CHUNK_DIM_Y, size_t CHUNK_DIM_X, size_t THREADS_PER_CHUNK>
__global__ void __launch_bounds__(THREADS_PER_CHUNK)
cast_mxfp8_2D_kernel(const __grid_constant__ CUtensorMap tensor_map_input,
const __grid_constant__ CUtensorMap tensor_map_act_input,
const __grid_constant__ CUtensorMap tensor_map_output_rowwise,
......@@ -67,201 +56,341 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK)
const size_t rows, const size_t cols, const size_t scale_stride_rowwise,
const size_t scale_stride_colwise) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
if constexpr (!IS_DBIAS && !IS_DACT && !IS_ACT) {
if (noop != nullptr && noop[0] == 1.0f) return;
constexpr bool COMPUTE_ACTIVATIONS = IS_DACT || IS_ACT;
constexpr bool NO_ACTIVATIONS = !COMPUTE_ACTIVATIONS;
using IType2 = typename ptx::FPx2<IType>;
using OType2 = typename ptx::FPx2<OType>;
if constexpr (NO_ACTIVATIONS) {
if (noop != nullptr && noop[0] == 1.0f) {
return;
}
}
constexpr size_t THREADS_X = CHUNK_DIM_X / SCALE_DIM_X;
constexpr size_t THREADS_Y = THREADS_PER_CHUNK / THREADS_X;
constexpr size_t BUFF_DIM_Y = THREADS_Y;
constexpr size_t BUFF_DIM_X = CHUNK_DIM_X;
constexpr size_t BUFF_DIM = BUFF_DIM_Y * BUFF_DIM_X;
static_assert(BUFF_DIM_Y == 32);
constexpr size_t STAGES = CHUNK_DIM_Y / BUFF_DIM_Y;
static_assert(STAGES >= 1);
constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS && ROWWISE_SCALING && COLWISE_SCALING;
const int block_offset_Y = blockIdx.y * CHUNK_DIM_Y;
const int block_offset_X = blockIdx.x * CHUNK_DIM_X;
const int scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y;
const int scales_block_offset_X_rowwise = blockIdx.x * CHUNK_DIM_X / SCALE_DIM_X;
const int scales_block_offset_Y_colwise = blockIdx.y * CHUNK_DIM_Y / SCALE_DIM_Y;
const int scales_block_offset_X_colwise = blockIdx.x * CHUNK_DIM_X;
const int tid_Y_rowwise = threadIdx.x / THREADS_X;
const int tid_X_rowwise = threadIdx.x % THREADS_X;
const int tid_Y_colwise = 0;
const int tid_X_colwise = threadIdx.x;
const int thread_offset_Y_rowwise = tid_Y_rowwise;
const int thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM_X;
const int thread_offset_Y_colwise = tid_Y_colwise;
const int thread_offset_X_colwise = tid_X_colwise;
const int row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise;
const int row_base_colwise = block_offset_Y + thread_offset_Y_colwise;
const int col_base_colwise = block_offset_X + thread_offset_X_colwise;
const bool col_out_of_bounds_colwise = (col_base_colwise >= cols);
const int scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise;
const int scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise;
const int scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise;
const int scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise;
// helps resolving bank conflicts in shmem
const int thread_lane = threadIdx.x % THREADS_PER_WARP;
const int bank_group = thread_lane / THREADS_PER_BANK;
constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X;
constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems;
constexpr size_t buff_size_aligned_in =
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_aligned_out =
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT);
constexpr size_t elt_input_mem = buff_size_aligned_in;
constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0);
constexpr size_t in_mem = elt_input_mem + act_input_mem;
constexpr size_t out_mem_rowwise = (ROWWISE_SCALING ? buff_size_aligned_out : 0);
extern __shared__ char dynamic_shmem[];
uintptr_t base_shmem_ptr = reinterpret_cast<uintptr_t>(dynamic_shmem);
// Manually align dynamic SHMEM per TMA requirements using padding
// __align__(128) Does not guarantee the pointer to be aligned!
uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) &
~(static_cast<uintptr_t>(TMA_SHMEM_ALIGNMENT - 1));
// The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned
IType *in_sh = reinterpret_cast<IType *>(dshmem);
IType *act_in_sh = reinterpret_cast<IType *>(dshmem + elt_input_mem);
OType *out_rowwise_sh = reinterpret_cast<OType *>(dshmem + in_mem);
OType *out_colwise_sh = reinterpret_cast<OType *>(dshmem + in_mem + out_mem_rowwise);
IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer
constexpr int shmem_buff_size = buff_size_aligned_in / BUFFS_NUM;
constexpr bool USE_ROWWISE_SCALING = SCALE_DIM_X > 1;
constexpr bool USE_COLWISE_SCALING = SCALE_DIM_Y > 1;
constexpr bool COMPUTE_DBIAS_IN_ROWWISE_SECTION = !USE_COLWISE_SCALING;
constexpr size_t SCALES_ROWWISE_PER_CHUNK_Y = MXFP8_CHUNK_DIM_Y; // 2 = 64 / 32
constexpr size_t SCALES_ROWWISE_PER_CHUNK_X = MXFP8_CHUNK_DIM_X / SCALE_DIM_X; // 64 = 64 / 1
constexpr size_t SCALES_ROWWISE_PER_BLOCK_Y =
SCALES_ROWWISE_PER_CHUNK_Y * MXFP8_CHUNKS_PER_BLOCK_Y; // 2 = 2 * 1
constexpr size_t SCALES_ROWWISE_PER_BLOCK_X =
SCALES_ROWWISE_PER_CHUNK_X * MXFP8_CHUNKS_PER_BLOCK_X; // 64 = 64 * 1
constexpr size_t SCALES_COLWISE_PER_CHUNK_Y = MXFP8_CHUNK_DIM_Y / SCALE_DIM_Y; // 2 = 64 / 32
constexpr size_t SCALES_COLWISE_PER_CHUNK_X = MXFP8_CHUNK_DIM_X; // 64 = 64 / 1
constexpr size_t SCALES_COLWISE_PER_BLOCK_Y =
SCALES_COLWISE_PER_CHUNK_Y * MXFP8_CHUNKS_PER_BLOCK_Y; // 2 = 2 * 1
constexpr size_t SCALES_COLWISE_PER_BLOCK_X =
SCALES_COLWISE_PER_CHUNK_X * MXFP8_CHUNKS_PER_BLOCK_X; // 64 = 64 * 1
constexpr size_t THREADS_PER_SCALE_X_ROWWISE =
DIVUP(SCALE_DIM_X, ELEMS_PER_THREAD); // 2 = 32 / 16
constexpr size_t SUBWARP_WIDTH = THREADS_PER_SCALE_X_ROWWISE; // 2
const int block_offset_Y = blockIdx.y * MXFP8_CHUNKS_PER_BLOCK_Y * MXFP8_CHUNK_DIM_Y;
const int block_offset_X = blockIdx.x * MXFP8_CHUNKS_PER_BLOCK_X * MXFP8_CHUNK_DIM_X;
const int scales_rowwise_block_offset_Y = blockIdx.y * SCALES_ROWWISE_PER_BLOCK_Y;
const int scales_rowwise_block_offset_X = blockIdx.x * SCALES_ROWWISE_PER_BLOCK_X;
const int scales_colwise_block_offset_Y = blockIdx.y * SCALES_COLWISE_PER_BLOCK_Y;
const int scales_colwise_block_offset_X = blockIdx.x * SCALES_COLWISE_PER_BLOCK_X;
const int tid_rowwise_Y = threadIdx.x / THREADS_PER_CHUNK_X_ROWWISE;
const int tid_rowwise_X = threadIdx.x % THREADS_PER_CHUNK_X_ROWWISE;
// const int tid_colwise_Y = threadIdx.x / THREADS_PER_CHUNK_X_COLWISE;
const int tid_colwise_X = threadIdx.x % THREADS_PER_CHUNK_X_COLWISE;
const int thread_offset_Y = tid_rowwise_Y;
const int thread_offset_X_rowwise = tid_rowwise_X * ELEMS_PER_THREAD;
// const int thread_offset_X_colwise = tid_colwise_X;
const int dbias_rowwise_offset_Y = blockIdx.y * MXFP8_CHUNKS_PER_BLOCK_Y + tid_rowwise_Y;
const int dbias_rowwise_block_offset_X =
blockIdx.x * MXFP8_CHUNKS_PER_BLOCK_X * MXFP8_CHUNK_DIM_X + thread_offset_X_rowwise;
const int dbias_colwise_offset_Y = blockIdx.y;
const int dbias_colwise_block_offset_X =
blockIdx.x * MXFP8_CHUNKS_PER_BLOCK_X * MXFP8_CHUNK_DIM_X + tid_colwise_X;
const int dbias_stride = cols;
const bool is_master_thread = (threadIdx.x == 0);
Vec<float, ELEMS_PER_THREAD> partial_dbias_rowwise[MXFP8_CHUNKS_PER_BLOCK_X];
float partial_dbias_colwise[MXFP8_CHUNKS_PER_BLOCK_X];
float partial_dbias_colwise = 0.0f;
float thread_dbias_rowwise[SCALE_DIM_X];
if constexpr (IS_DBIAS) {
if constexpr (COMPUTE_DBIAS_IN_ROWWISE_SECTION) {
#pragma unroll
for (int i = 0; i < MXFP8_CHUNKS_PER_BLOCK_X; ++i) {
partial_dbias_rowwise[i].clear();
}
} else {
#pragma unroll
for (int i = 0; i < MXFP8_CHUNKS_PER_BLOCK_X; ++i) {
partial_dbias_colwise[i] = 0;
}
for (int j = 0; j < SCALE_DIM_X; ++j) {
thread_dbias_rowwise[j] = 0.0f;
}
}
// The destination shared memory buffer of a bulk tensor operation should be 128 e8m0_t aligned
__shared__ alignas(128) IType in_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X];
__shared__ alignas(128) IType act_in_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X];
__shared__ alignas(128)
OType out_rowwise_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X];
__shared__ alignas(128)
OType out_colwise_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X];
constexpr int shmem_buff_size = sizeof(in_sh) / MXFP8_BUFFERS_NUM;
const bool is_master_thread = (threadIdx.x == 0);
float block_amax = 0;
float block_amax = 0.0f;
// Initialize shared memory barrier with the number of threads participating in the barrier.
#pragma nv_diag_suppress static_var_with_dynamic_init
__shared__ alignas(8) uint64_t mbar[MXFP8_ITERATIONS];
__shared__ alignas(8) uint64_t mbar[STAGES];
initialize_barriers<MXFP8_ITERATIONS, MXFP8_THREADS_PER_CHUNK>(mbar, is_master_thread);
initialize_barriers<STAGES, THREADS_PER_CHUNK>(mbar, is_master_thread);
int parity = 0;
#pragma unroll
for (int chunk = 0; chunk < MXFP8_CHUNKS_PER_BLOCK; ++chunk) {
const int chunk_Y = chunk / MXFP8_CHUNKS_PER_BLOCK_X;
const int chunk_X = chunk % MXFP8_CHUNKS_PER_BLOCK_X;
const int chunk_offset_Y = block_offset_Y + chunk_Y * MXFP8_CHUNK_DIM_Y;
const int chunk_offset_X = block_offset_X + chunk_X * MXFP8_CHUNK_DIM_X;
if constexpr (IS_DACT) {
copy_2d_to_sharedx2(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, &act_in_sh[0],
&tensor_map_act_input, block_offset_X, block_offset_Y, shmem_buff_size,
&mbar[0], is_master_thread);
} else {
copy_2d_to_shared(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, shmem_buff_size,
&mbar[0], is_master_thread);
}
const int dbias_rowwise_offset_X = dbias_rowwise_block_offset_X + chunk_X * MXFP8_CHUNK_DIM_X;
const int dbias_colwise_offset_X = dbias_colwise_block_offset_X + chunk_X * MXFP8_CHUNK_DIM_X;
#pragma unroll
for (int stage = 0; stage < STAGES; ++stage) {
const int buff = stage % BUFFS_NUM;
const int next_stage = stage + 1;
const int stage_offset_Y = stage * BUFF_DIM_Y;
const int scales_rowwise_chunk_offset_Y =
scales_rowwise_block_offset_Y + chunk_Y * SCALES_ROWWISE_PER_CHUNK_Y;
const int scales_rowwise_chunk_offset_X =
scales_rowwise_block_offset_X + chunk_X * SCALES_ROWWISE_PER_CHUNK_X;
const int scales_colwise_chunk_offset_Y =
scales_colwise_block_offset_Y + chunk_Y * SCALES_COLWISE_PER_CHUNK_Y;
const int scales_colwise_chunk_offset_X =
scales_colwise_block_offset_X + chunk_X * SCALES_COLWISE_PER_CHUNK_X;
if (next_stage < STAGES) {
// Wait for TMA transfer to have finished reading shared memory.
// I.e. the buffer is ready to be written to
ptx::cp_async_bulk_wait_group_read<1>();
#pragma unroll
for (int prefetch_buff = 0; prefetch_buff < MXFP8_PREFETCH_BUFFERS_NUM; ++prefetch_buff) {
const int chunk_stage_offset_Y = chunk_offset_Y + prefetch_buff * MXFP8_BUFFER_DIM_Y;
const int chunk_stage_offset_X = chunk_offset_X;
const int next_buff = next_stage % BUFFS_NUM;
const int next_stage_offset_Y = next_stage * BUFF_DIM_Y;
const int global_offset_Y = block_offset_Y + next_stage_offset_Y;
const int global_offset_X = block_offset_X;
const int next_buff_offset = next_buff * BUFF_DIM;
if constexpr (IS_DACT) {
copy_2d_to_sharedx2(&in_sh[prefetch_buff], &tensor_map_input, chunk_stage_offset_X,
chunk_stage_offset_Y, &act_in_sh[prefetch_buff], &tensor_map_act_input,
chunk_stage_offset_X, chunk_stage_offset_Y, shmem_buff_size,
&mbar[prefetch_buff], is_master_thread);
copy_2d_to_sharedx2(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X,
global_offset_Y, &act_in_sh[next_buff_offset], &tensor_map_act_input,
global_offset_X, global_offset_Y, shmem_buff_size, &mbar[next_stage],
is_master_thread);
} else {
copy_2d_to_shared(&in_sh[prefetch_buff], &tensor_map_input, chunk_stage_offset_X,
chunk_stage_offset_Y, shmem_buff_size, &mbar[prefetch_buff],
is_master_thread);
copy_2d_to_shared(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X,
global_offset_Y, shmem_buff_size, &mbar[next_stage], is_master_thread);
}
}
ptx::fence_proxy_async_shared_cta();
// Wait for the data to have arrived
ptx::mbarrier_wait_parity(&mbar[stage], parity);
float thread_amax = 0.0f;
if constexpr (COLWISE_SCALING) {
const int shmem_offset_base_colwise = buff * BUFF_DIM + tid_X_colwise;
thread_amax = 0.0f;
float in_compute_colwise[BUFF_DIM_Y];
IType in_colwise_IType[BUFF_DIM_Y];
// 1. Read/Compute elements. Find MXFP8-block AMAX
if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v<IType, float>)) {
IType thread_amax_f16 = static_cast<IType>(0.0f);
#pragma unroll
for (int iter = 0; iter < MXFP8_ITERATIONS; ++iter) {
const int buff = iter % MXFP8_BUFFERS_NUM;
const int next_iter = iter + MXFP8_PREFETCH_BUFFERS_NUM;
const size_t row_base = chunk_offset_Y + iter * MXFP8_BUFFER_DIM_Y;
if (next_iter < MXFP8_ITERATIONS) {
const int next_buff = next_iter % MXFP8_BUFFERS_NUM;
const int chunk_it_offset_y = chunk_offset_Y + next_iter * MXFP8_BUFFER_DIM_Y;
const int chunk_it_offset_x = chunk_offset_X;
if constexpr (IS_DACT) {
copy_2d_to_sharedx2(&in_sh[next_buff], &tensor_map_input, chunk_it_offset_x,
chunk_it_offset_y, &act_in_sh[next_buff], &tensor_map_act_input,
chunk_it_offset_x, chunk_it_offset_y, shmem_buff_size,
&mbar[next_iter], is_master_thread);
} else {
copy_2d_to_shared(&in_sh[next_buff], &tensor_map_input, chunk_it_offset_x,
chunk_it_offset_y, shmem_buff_size, &mbar[next_iter], is_master_thread);
for (int i = 0; i < BUFF_DIM_Y; ++i) {
const int shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X;
in_colwise_IType[i] = in_sh[shmem_offset_colwise];
thread_amax_f16 = __hmax(thread_amax_f16, __habs(in_colwise_IType[i]));
}
}
thread_amax = static_cast<float>(thread_amax_f16);
} else {
#pragma unroll
for (int i = 0; i < BUFF_DIM_Y; ++i) {
const int shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X;
ptx::fence_proxy_async_shared_cta();
float elt = static_cast<float>(in_sh[shmem_offset_colwise]);
if constexpr (IS_ACT) {
elt = OP(elt, {});
}
if constexpr (IS_DACT) {
float act_in_elt = static_cast<float>(act_in_sh[shmem_offset_colwise]);
elt *= OP(act_in_elt, {});
}
if constexpr (IS_DBIAS) {
partial_dbias_colwise += elt;
}
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
if constexpr (!std::is_same_v<IType, float>) {
elt = static_cast<float>(static_cast<IType>(elt));
}
// Cache computed activations to avoid computing them again in the 2nd pass along another dimension
if constexpr (IS_CACHED_ACT_OP) {
cached_act_sh[shmem_offset_colwise] = static_cast<IType>(elt);
}
// Wait for the data to have arrived
ptx::mbarrier_wait_parity(&mbar[iter], parity);
if constexpr (COMPUTE_ACTIVATIONS) {
const bool row_out_of_bounds_colwise = (row_base_colwise + stage_offset_Y + i >= rows);
const bool out_of_bounds = (col_out_of_bounds_colwise || row_out_of_bounds_colwise);
if (!out_of_bounds) {
thread_amax = fmaxf(thread_amax, fabsf(elt));
}
} else {
// If no activation, elt is 0 so we can safely do this
thread_amax = fmaxf(thread_amax, fabsf(elt));
}
in_compute_colwise[i] = elt;
}
}
// 2. Compute E8M0 scaling factor
const e8m0_t biased_exponent =
ptx::float_to_e8m0(thread_amax * Quantized_Limits<OType>::max_norm_rcp);
if constexpr (USE_ROWWISE_SCALING) {
Vec<IType, ELEMS_PER_THREAD> in;
Vec<IType, ELEMS_PER_THREAD> act_in;
Vec<OType, ELEMS_PER_THREAD> out_c;
const int global_scales_offset_Y = scales_offset_Y_colwise + stage;
const int global_scales_offset_X = scales_offset_X_colwise;
const int scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X;
scales_colwise[scale_idx] = biased_exponent;
const int iteration_scale_rowwise_offset_Y =
scales_rowwise_chunk_offset_Y + iter * MXFP8_BUFFER_DIM_Y;
const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent);
const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse};
// 3. Scale elements
#pragma unroll
for (int stage = 0; stage < MXFP8_BUFF_STAGES_NUM; ++stage) {
const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y_ROWWISE;
const int shmem_offset_y = thread_offset_Y + stage_offset_Y;
const int shmem_offset_x = thread_offset_X_rowwise;
for (int i = 0; i < SCALE_DIM_Y; ++i) {
float in;
if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v<IType, float>)) {
in = static_cast<float>(in_colwise_IType[i]);
} else {
in = in_compute_colwise[i];
}
const float scaled_out = in * block_scale_inverse;
const size_t row = row_base + shmem_offset_y;
const bool row_out_of_bounds = (row >= rows);
const int shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_DIM_X;
out_colwise_sh[shmem_offset_elt] = static_cast<OType>(scaled_out);
}
}
in.load_from(&in_sh[buff][shmem_offset_y][shmem_offset_x]);
if constexpr (IS_DACT) {
act_in.load_from(&act_in_sh[buff][shmem_offset_y][shmem_offset_x]);
}
if constexpr (ROWWISE_SCALING) {
const int shmem_offset_base_rowwise = buff * BUFF_DIM + thread_offset_Y_rowwise * BUFF_DIM_X;
thread_amax = 0.0f;
float in_compute_rowwise[SCALE_DIM_X];
Vec<IType, PACK_SIZE> in_cached[WAVES];
float thread_amax = 0;
float in_compute[ELEMS_PER_THREAD];
// used as an IType container for BF16/FP16 --> MXFP8 CAST ONLY
Vec<IType2, PACK_SIZE / 2> in_IType[WAVES];
// 1. Read/Compute elements. Find MXFP8-block AMAX
if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v<IType, float>)) {
IType2 thread_amax_2x = {static_cast<IType>(0.0f), static_cast<IType>(0.0f)};
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const int shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx;
// Load elements
in_IType[w].load_from(&in_sh[shmem_offset_rowwise]);
#pragma unroll
for (int e = 0; e < PACK_SIZE / 2; ++e) {
ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]);
}
}
thread_amax =
static_cast<float>(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y)));
} else if constexpr (IS_CACHED_ACT_OP) {
// ensures that all writes to cache made in the section above are visible to all threads
__syncthreads();
IType2 thread_amax_2x = {static_cast<IType>(0.0f), static_cast<IType>(0.0f)};
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const int shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx;
const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows);
const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols);
const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds);
// Load cached elements
in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]);
// Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements)
// only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries
if (!out_of_bounds) {
if constexpr (std::is_same_v<IType, float>) {
#pragma unroll
for (int e = 0; e < PACK_SIZE; ++e) {
thread_amax = fmaxf(thread_amax, fabsf(in_cached[w].data.elt[e]));
}
} else {
#pragma unroll
for (int j = 0; j < ELEMS_PER_THREAD; ++j) {
const bool col_out_of_bounds = (dbias_rowwise_offset_X + j >= cols);
const bool out_of_bounds = (col_out_of_bounds || row_out_of_bounds);
for (int e = 0; e < PACK_SIZE; e += 2) {
const IType2 in_cached_2x = {in_cached[w].data.elt[e],
in_cached[w].data.elt[e + 1]};
ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x);
}
}
}
}
if constexpr (!std::is_same_v<IType, float>) {
thread_amax =
static_cast<float>(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y)));
}
} else {
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const int shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx;
float elt = static_cast<float>(in.data.elt[j]);
Vec<IType, PACK_SIZE> in;
Vec<IType, PACK_SIZE> act_in;
in.load_from(&in_sh[shmem_offset_rowwise]);
if constexpr (IS_DACT) {
act_in.load_from(&act_in_sh[shmem_offset_rowwise]);
}
#pragma unroll
for (int e = 0; e < PACK_SIZE; ++e) {
const int j = w * PACK_SIZE + e;
// Compute element
float elt = static_cast<float>(in.data.elt[e]);
if constexpr (IS_ACT) {
elt = OP(elt, {});
}
if constexpr (IS_DACT) {
float act_in_elt = static_cast<float>(act_in.data.elt[j]);
float act_in_elt = static_cast<float>(act_in.data.elt[e]);
elt *= OP(act_in_elt, {});
}
if constexpr (IS_DBIAS && COMPUTE_DBIAS_IN_ROWWISE_SECTION) {
if (!out_of_bounds) {
partial_dbias_rowwise[chunk_X].data.elt[j] += elt;
}
}
in_compute[j] = elt;
if constexpr (IS_ACT || IS_DACT) {
// If DBIAS was computed in the 1st pass (COLWISE) then no need to compute it again
if constexpr (IS_DBIAS && (!COLWISE_SCALING)) {
thread_dbias_rowwise[j] += elt;
}
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
if constexpr (!std::is_same_v<IType, float>) {
elt = static_cast<float>(static_cast<IType>(elt));
}
if constexpr (COMPUTE_ACTIVATIONS) {
const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows);
const bool swizzled_col_out_of_bounds =
(block_offset_X + swizzled_thread_idx >= cols);
const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds);
if (!out_of_bounds) {
thread_amax = fmaxf(thread_amax, fabsf(elt));
}
......@@ -269,196 +398,141 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK)
// If no activation, elt is 0 so we can safely do this
thread_amax = fmaxf(thread_amax, fabsf(elt));
}
in_compute_rowwise[j] = elt;
}
__builtin_assume(block_amax >= 0);
__builtin_assume(thread_amax >= 0);
block_amax = fmaxf(block_amax, thread_amax);
const float subwarp_amax = subwarp_reduce_max_broadcast<SUBWARP_WIDTH>(thread_amax);
const e8m0_t biased_exponent =
float_to_e8m0(subwarp_amax * Quantized_Limits<OType>::max_norm_rcp);
// Only single thread writes the computed scaling factor
if (tid_rowwise_X % THREADS_PER_SCALE_X_ROWWISE == 0) {
const int global_scales_offset_Y =
iteration_scale_rowwise_offset_Y + stage_offset_Y + tid_rowwise_Y;
const int global_scales_offset_X =
scales_rowwise_chunk_offset_X + tid_rowwise_X / THREADS_PER_SCALE_X_ROWWISE;
const int scale_idx =
global_scales_offset_Y * scale_stride_rowwise + global_scales_offset_X;
scales_rowwise[scale_idx] = biased_exponent;
}
const float block_scale_inverse = exp2f_rcp(biased_exponent);
#pragma unroll
for (int j = 0; j < ELEMS_PER_THREAD; ++j) {
out_c.data.elt[j] = static_cast<OType>(in_compute[j] * block_scale_inverse);
}
out_c.store_to(&out_rowwise_sh[buff][shmem_offset_y][shmem_offset_x]);
}
}
if constexpr (USE_COLWISE_SCALING) {
const bool col_out_of_bounds = (dbias_colwise_offset_X >= cols);
float in_compute[SCALE_DIM_Y];
// 2. Compute E8M0 scaling factor
const e8m0_t biased_exponent =
ptx::float_to_e8m0(thread_amax * Quantized_Limits<OType>::max_norm_rcp);
const int stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y;
const int stage_scales_offset_X = scales_offset_X_rowwise;
const int scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X;
scales_rowwise[scale_idx] = biased_exponent;
float amax = 0;
#pragma unroll
for (int i = 0; i < SCALE_DIM_Y; ++i) {
const size_t row = row_base + i;
const bool row_out_of_bounds = (row >= rows);
const bool out_of_bounds = (col_out_of_bounds || row_out_of_bounds);
const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent);
const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse};
float elt = static_cast<float>(in_sh[buff][i][tid_colwise_X]);
if constexpr (IS_ACT) {
elt = OP(elt, {});
}
if constexpr (IS_DACT) {
float act_in_elt = static_cast<float>(act_in_sh[buff][i][tid_colwise_X]);
elt *= OP(act_in_elt, {});
}
if constexpr (IS_DBIAS) {
if (!out_of_bounds) {
partial_dbias_colwise[chunk_X] += elt;
}
}
in_compute[i] = elt;
if constexpr (IS_ACT || IS_DACT) {
if (!out_of_bounds) {
amax = fmaxf(amax, fabsf(elt));
}
// 3. Scale elements
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
Vec<OType2, PACK_SIZE / 2> out;
#pragma unroll
for (int e = 0; e < PACK_SIZE / 2; ++e) {
IType2 in;
OType2 &out_pair = reinterpret_cast<OType2 &>(out.data.elt[e]);
if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v<IType, float>)) {
in = in_IType[w].data.elt[e];
} else if constexpr (IS_CACHED_ACT_OP) {
in.x = in_cached[w].data.elt[2 * e];
in.y = in_cached[w].data.elt[2 * e + 1];
} else {
// If no activation, elt is 0 so we can safely do this
amax = fmaxf(amax, fabsf(elt));
const int j = w * PACK_SIZE + 2 * e;
in.x = in_compute_rowwise[j];
in.y = in_compute_rowwise[j + 1];
}
ptx::mul_cvt_2x(out_pair, in, block_scale_inverse_2x);
}
const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const int swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise;
const int shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx;
out.store_to(&out_rowwise_sh[shmem_offset_rowwise]);
}
}
__builtin_assume(block_amax >= 0);
__builtin_assume(amax >= 0);
block_amax = fmaxf(block_amax, amax);
const e8m0_t biased_exponent = float_to_e8m0(amax * Quantized_Limits<OType>::max_norm_rcp);
__builtin_assume(block_amax >= 0);
__builtin_assume(thread_amax >= 0);
block_amax = fmaxf(block_amax, thread_amax);
const int global_scales_offset_Y = scales_colwise_chunk_offset_Y + iter;
const int global_scales_offset_X = scales_colwise_chunk_offset_X + tid_colwise_X;
const int scale_idx =
global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X;
scales_colwise[scale_idx] = biased_exponent;
// Wait for shared memory writes to be visible to TMA engine.
ptx::fence_proxy_async_shared_cta();
__syncthreads();
// After syncthreads, writes by all threads are visible to TMA engine.
const float block_scale_inverse = exp2f_rcp(biased_exponent);
#pragma unroll
for (int i = 0; i < SCALE_DIM_Y; ++i) {
out_colwise_sh[buff][i][tid_colwise_X] =
static_cast<OType>(in_compute[i] * block_scale_inverse);
}
// Initiate TMA transfer to copy shared memory to global memory
if (is_master_thread) {
const int global_offset_Y = block_offset_Y + stage_offset_Y;
const int global_offset_X = block_offset_X;
const int buff_offset = buff * BUFF_DIM;
if constexpr (ROWWISE_SCALING) {
ptx::cp_async_bulk_tensor_2d_shared_to_global(
reinterpret_cast<const uint64_t *>(&tensor_map_output_rowwise), global_offset_X,
global_offset_Y, reinterpret_cast<uint64_t *>(&out_rowwise_sh[buff_offset]));
}
// Wait for shared memory writes to be visible to TMA engine.
ptx::fence_proxy_async_shared_cta();
__syncthreads();
// After syncthreads, writes by all threads are visible to TMA engine.
// Initiate TMA transfer to copy shared memory to global memory
if (is_master_thread) {
const int chunk_it_offset_y = chunk_offset_Y + iter * MXFP8_BUFFER_DIM_Y;
const int chunk_it_offset_x = chunk_offset_X;
if constexpr (USE_ROWWISE_SCALING) {
ptx::cp_async_bulk_tensor_2d_shared_to_global(
reinterpret_cast<const uint64_t *>(&tensor_map_output_rowwise), chunk_it_offset_x,
chunk_it_offset_y, reinterpret_cast<uint64_t *>(&out_rowwise_sh[buff]));
}
if constexpr (USE_COLWISE_SCALING) {
ptx::cp_async_bulk_tensor_2d_shared_to_global(
reinterpret_cast<const uint64_t *>(&tensor_map_output_colwise), chunk_it_offset_x,
chunk_it_offset_y, reinterpret_cast<uint64_t *>(&out_colwise_sh[buff]));
}
// Create a "bulk async-group" out of the previous bulk copy operation.
ptx::cp_async_bulk_commit_group();
// Wait for TMA transfer to have finished reading shared memory.
ptx::cp_async_bulk_wait_group_read<MXFP8_PREFETCH_BUFFERS_NUM>();
if constexpr (COLWISE_SCALING) {
ptx::cp_async_bulk_tensor_2d_shared_to_global(
reinterpret_cast<const uint64_t *>(&tensor_map_output_colwise), global_offset_X,
global_offset_Y, reinterpret_cast<uint64_t *>(&out_colwise_sh[buff_offset]));
}
}
ptx::cp_async_bulk_wait_group_read<0>();
__syncthreads();
parity ^= 1;
// Create a "bulk async-group" out of the previous bulk copy operation.
ptx::cp_async_bulk_commit_group();
}
}
if constexpr (IS_DBIAS) {
if constexpr (COMPUTE_DBIAS_IN_ROWWISE_SECTION) {
constexpr size_t CZ = MXFP8_CHUNKS_PER_BLOCK_X;
constexpr size_t Y = THREADS_PER_CHUNK_Y_ROWWISE - 1;
constexpr size_t X = THREADS_PER_CHUNK_X_ROWWISE;
__shared__ float shmem_partial_dbias_rowwise[CZ][Y][X][ELEMS_PER_THREAD];
if (tid_rowwise_Y > 0) {
#pragma unroll
for (int c = 0; c < MXFP8_CHUNKS_PER_BLOCK_X; ++c) {
partial_dbias_rowwise[c].store_to(
&shmem_partial_dbias_rowwise[c][tid_rowwise_Y - 1][tid_rowwise_X]);
}
}
__syncthreads();
parity ^= 1;
if (tid_rowwise_Y == 0) {
#pragma unroll
for (int c = 0; c < MXFP8_CHUNKS_PER_BLOCK_X; ++c) {
Vec<float, ELEMS_PER_THREAD> other_row_dbias;
const int dbias_rowwise_offset_X = dbias_rowwise_block_offset_X + c * MXFP8_CHUNK_DIM_X;
const int dbias_offset = dbias_rowwise_offset_Y * dbias_stride + dbias_rowwise_offset_X;
if constexpr (IS_DBIAS) {
float thread_partial_dbias = 0.0f;
if constexpr (COLWISE_SCALING) {
thread_partial_dbias = partial_dbias_colwise;
} else {
// Reusing dshmem (in_sh) as dbias buffer [HEIGHT x WIDTH]
// HEIGHT = THREADS_Y
// WIDTH = THREADS_X * (SCALE_DIM_X + 1)
// Added extra 1-element padding per thread_X to reduce bank conflicts
float *partial_dbias_rowwise = reinterpret_cast<float *>(dshmem);
const int left_bound = dbias_rowwise_offset_X;
const int right_bound = dbias_rowwise_offset_X + ELEMS_PER_THREAD - 1;
constexpr int DBIAS_BUFF_WIDTH = THREADS_X * (SCALE_DIM_X + 1);
const int shmem_thread_offset =
tid_Y_rowwise * DBIAS_BUFF_WIDTH + tid_X_rowwise * (SCALE_DIM_X + 1);
#pragma unroll
for (int i = 0; i < Y; ++i) {
other_row_dbias.load_from(&shmem_partial_dbias_rowwise[c][i][tid_rowwise_X]);
for (int w = 0; w < WAVES; ++w) {
const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const int swizzled_group_offset = shmem_thread_offset + swizzled_group_idx;
#pragma unroll
for (int j = 0; j < ELEMS_PER_THREAD; ++j) {
partial_dbias_rowwise[c].data.elt[j] += other_row_dbias.data.elt[j];
}
}
// Vectorized store when all elements are inside the boundaries
if (right_bound < cols) {
partial_dbias_rowwise[c].store_to(&dbias_workspace[dbias_offset]);
} else if (left_bound < cols && right_bound >= cols) {
// Element-by-element store when some elements cross the boundaries
const int in_bound_elts_count = cols - left_bound;
partial_dbias_rowwise[c].store_to_elts(&dbias_workspace[dbias_offset], 0,
in_bound_elts_count);
}
for (int e = 0; e < PACK_SIZE; ++e) {
const int j = w * PACK_SIZE + e;
const int shmem_elt_idx = swizzled_group_offset + e;
partial_dbias_rowwise[shmem_elt_idx] = thread_dbias_rowwise[j];
}
}
} else {
__syncthreads();
#pragma unroll
for (int i = 0; i < MXFP8_CHUNKS_PER_BLOCK_X; ++i) {
const int dbias_colwise_offset_X = dbias_colwise_block_offset_X + i * MXFP8_CHUNK_DIM_X;
const int dbias_offset = dbias_colwise_offset_Y * dbias_stride + dbias_colwise_offset_X;
const bool col_out_of_bounds = (dbias_colwise_offset_X >= cols);
if (!col_out_of_bounds) {
dbias_workspace[dbias_offset] = partial_dbias_colwise[i];
}
for (int i = 0; i < THREADS_Y; ++i) {
// Add extra element offset per MXFP8 scaling block [1x32]
const int scaling_block = threadIdx.x / SCALE_DIM_X;
thread_partial_dbias +=
partial_dbias_rowwise[i * DBIAS_BUFF_WIDTH + threadIdx.x + scaling_block];
}
}
const int dbias_stride = cols;
const int dbias_offset_Y = blockIdx.y;
const int dbias_offset_X = blockIdx.x * CHUNK_DIM_X + threadIdx.x;
const int dbias_idx = dbias_offset_Y * dbias_stride + dbias_offset_X;
const bool col_out_of_bounds_dbias = (dbias_offset_X >= cols);
if (!col_out_of_bounds_dbias) {
dbias_workspace[dbias_idx] = thread_partial_dbias;
}
}
if (amax_ptr != nullptr) {
const int warp_id = threadIdx.x / THREADS_PER_WARP;
// Reduce the amax over the block
block_amax = reduce_max<MXFP8_THREADS_PER_CHUNK / THREADS_PER_WARP>(block_amax, warp_id);
block_amax = reduce_max<THREADS_PER_CHUNK / THREADS_PER_WARP>(block_amax, warp_id);
}
if (is_master_thread && amax_ptr != nullptr) {
atomicMaxFloat(amax_ptr, block_amax);
}
destroy_barriers<MXFP8_ITERATIONS>(mbar, is_master_thread);
destroy_barriers<STAGES>(mbar, is_master_thread);
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
} // namespace mxfp8_kernel
constexpr size_t FP8_CHUNK_DIM_Y = 128;
constexpr size_t FP8_CHUNK_DIM_X = 128;
......@@ -507,9 +581,12 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK)
const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1;
// The destination shared memory buffer of a bulk tensor operation should be 128-byte aligned
__shared__ alignas(128) IType in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X];
__shared__ alignas(128) IType act_in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X];
__shared__ alignas(128) OType out_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X];
__shared__ alignas(TMA_SHMEM_ALIGNMENT)
IType in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X];
__shared__ alignas(TMA_SHMEM_ALIGNMENT)
IType act_in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X];
__shared__ alignas(TMA_SHMEM_ALIGNMENT)
OType out_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X];
constexpr int shmem_buff_size = sizeof(in_sh) / FP8_BUFFERS_NUM;
......@@ -678,8 +755,8 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK)
const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1;
// The destination shared memory buffer of a bulk tensor operation should be 128-byte aligned
__shared__ alignas(128) IType in_sh[SHMEM_BUFFERS][SHMEM_DIM];
__shared__ alignas(128) OType out_sh[SHMEM_BUFFERS][SHMEM_DIM];
__shared__ alignas(TMA_SHMEM_ALIGNMENT) IType in_sh[SHMEM_BUFFERS][SHMEM_DIM];
__shared__ alignas(TMA_SHMEM_ALIGNMENT) OType out_sh[SHMEM_BUFFERS][SHMEM_DIM];
constexpr int transaction_size_IN = sizeof(in_sh) / SHMEM_BUFFERS;
constexpr int transaction_size_OUT = sizeof(out_sh) / SHMEM_BUFFERS;
......@@ -921,6 +998,7 @@ template <bool IS_DBIAS, bool IS_DACT, bool IS_ACT, typename ParamOP,
void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
const Tensor *noop, // TODO (ksivamani)
Tensor *output, Tensor *dbias, Tensor *workspace, cudaStream_t stream) {
using namespace mxfp8_kernel;
bool use_rowwise_scaling = output->has_data();
bool use_colwise_scaling = output->has_columnwise_data();
checkCuDriverContext(stream);
......@@ -936,16 +1014,24 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
}
CheckNoopTensor(*noop, "cast_noop");
// TODO: Make more general
const size_t scale_dim_X_rowwise = use_rowwise_scaling ? 32 : 1;
const size_t scale_dim_Y_colwise = use_colwise_scaling ? 32 : 1;
const size_t rows = input.flat_first_dim();
const size_t cols = input.flat_last_dim();
const size_t chunks_Y = DIVUP(rows, MXFP8_CHUNK_DIM_Y);
const size_t chunks_X = DIVUP(cols, MXFP8_CHUNK_DIM_X);
const size_t blocks_Y = DIVUP(chunks_Y, MXFP8_CHUNKS_PER_BLOCK_Y);
const size_t blocks_X = DIVUP(chunks_X, MXFP8_CHUNKS_PER_BLOCK_X);
constexpr bool CAST_DBIAS_ONLY = IS_DBIAS && (!IS_DACT) && (!IS_ACT);
constexpr size_t CHUNK_DIM_Y = CAST_DBIAS_ONLY ? 128 : 64;
constexpr size_t CHUNK_DIM_X = CAST_DBIAS_ONLY ? 128 : 64;
constexpr size_t THREADS_PER_CHUNK = CAST_DBIAS_ONLY ? 128 : 64;
constexpr size_t THREADS_X = CHUNK_DIM_X / SCALE_DIM_X;
constexpr size_t THREADS_Y = THREADS_PER_CHUNK / THREADS_X;
constexpr size_t BUFF_DIM_Y = THREADS_Y;
constexpr size_t BUFF_DIM_X = CHUNK_DIM_X;
const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y);
const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X);
const dim3 grid(blocks_X, blocks_Y);
const size_t block_size = THREADS_PER_CHUNK;
const size_t scale_stride_rowwise = use_rowwise_scaling ? output->scale_inv.shape[1] : 1;
const size_t scale_stride_colwise =
......@@ -958,6 +1044,15 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
const size_t dbias_rows = blocks_Y;
const size_t dbias_cols = cols;
ScalingType scaling_type;
if (use_rowwise_scaling && (!use_colwise_scaling)) {
scaling_type = ScalingType::ROWWISE;
} else if ((!use_rowwise_scaling) && use_colwise_scaling) {
scaling_type = ScalingType::COLWISE;
} else if (use_rowwise_scaling && use_colwise_scaling) {
scaling_type = ScalingType::BIDIMENSIONAL;
}
if constexpr (IS_DBIAS) {
NVTE_CHECK(dbias->data.dtype == input.dtype(), "DBias must have the same type as input.");
NVTE_CHECK(dbias->data.shape == std::vector<size_t>{cols}, "Wrong shape of DBias.");
......@@ -972,58 +1067,107 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
float *const workspace_ptr = IS_DBIAS ? reinterpret_cast<float *>(workspace->data.dptr) : nullptr;
float *const amax_ptr = reinterpret_cast<float *>(output->amax.dptr);
const float *noop_ptr = reinterpret_cast<const float *>(noop->data.dptr);
const dim3 block(MXFP8_THREADS_PER_CHUNK);
const dim3 grid(blocks_X, blocks_Y);
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
input.dtype(), IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
output->dtype(), OType,
alignas(64) CUtensorMap tensor_map_input{};
alignas(64) CUtensorMap tensor_map_act_input{};
alignas(64) CUtensorMap tensor_map_output_rowwise{};
alignas(64) CUtensorMap tensor_map_output_colwise{};
constexpr size_t input_type_bit_size = TypeInfo<IType>::size;
constexpr size_t output_type_bit_size = TypeInfo<OType>::size;
create_2D_tensor_map(tensor_map_input, input.data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X,
cols, 0, input_type_bit_size);
if constexpr (IS_DACT) {
create_2D_tensor_map(tensor_map_act_input, act_input->data, rows, cols, BUFF_DIM_Y,
BUFF_DIM_X, cols, 0, input_type_bit_size);
}
if (use_rowwise_scaling) {
create_2D_tensor_map(tensor_map_output_rowwise, output->data, rows, cols, BUFF_DIM_Y,
BUFF_DIM_X, cols, 0, output_type_bit_size);
}
if (use_colwise_scaling) {
create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, rows, cols,
BUFF_DIM_Y, BUFF_DIM_X, cols, 0, output_type_bit_size);
}
TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH(
scale_dim_Y_colwise, SCALE_DIM_Y,
TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH(
scale_dim_X_rowwise, SCALE_DIM_X,
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.dtype(), IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
output->dtype(), OType,
alignas(64) CUtensorMap tensor_map_input{};
alignas(64) CUtensorMap tensor_map_act_input{};
alignas(64) CUtensorMap tensor_map_output_rowwise{};
alignas(64) CUtensorMap tensor_map_output_colwise{};
create_2D_tensor_map(tensor_map_input, input.data, rows, cols, MXFP8_SHMEM_DIM_Y,
MXFP8_SHMEM_DIM_X, cols, 0, typeToNumBits(input.dtype()));
if constexpr (IS_DACT) {
create_2D_tensor_map(tensor_map_act_input, act_input->data, rows, cols,
MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, cols, 0,
typeToNumBits(input.dtype()));
}
if (use_rowwise_scaling) {
create_2D_tensor_map(tensor_map_output_rowwise, output->data, rows, cols,
MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, cols, 0,
typeToNumBits(output->dtype()));
}
if (use_colwise_scaling) {
create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, rows,
cols, MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, cols, 0,
typeToNumBits(output->dtype()));
}
cast_mxfp8_2D_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType,
SCALE_DIM_Y, SCALE_DIM_X><<<grid, block, 0, stream>>>(
constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X;
constexpr size_t buff_elems_total = mxfp8_kernel::BUFFS_NUM * buff_elems;
constexpr size_t input_buff_size = (buff_elems_total * input_type_bit_size) / 8;
constexpr size_t output_buff_size = (buff_elems_total * output_type_bit_size) / 8;
constexpr size_t buff_size_aligned_in =
DIVUP_TO_MULTIPLE(input_buff_size, TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_aligned_out =
DIVUP_TO_MULTIPLE(output_buff_size, TMA_SHMEM_ALIGNMENT);
constexpr size_t elt_input_mem = buff_size_aligned_in;
constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0);
constexpr size_t in_mem = elt_input_mem + act_input_mem;
const size_t out_rowwise_mem = (use_rowwise_scaling ? buff_size_aligned_out : 0);
const size_t out_colwise_mem = (use_colwise_scaling ? buff_size_aligned_out : 0);
const size_t out_mem = out_rowwise_mem + out_colwise_mem;
const size_t dshmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT;
switch (scaling_type) {
case ScalingType::ROWWISE:
cudaFuncSetAttribute(
cast_mxfp8_2D_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, true,
false, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>,
cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size);
cast_mxfp8_2D_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, true,
false, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>
<<<grid, block_size, dshmem_size, stream>>>(
tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise,
tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr,
reinterpret_cast<const float *>(noop->data.dptr), workspace_ptr, amax_ptr,
rows, cols, scale_stride_rowwise, scale_stride_colwise);
if constexpr (IS_DBIAS) {
reduce_dbias<IType>(workspace_ptr, dbias, dbias_rows, dbias_cols, stream);
}); // NOLINT(*)
); // NOLINT(*)
); // NOLINT(*)
); // NOLINT(*)
tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr,
workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise);
break;
case ScalingType::COLWISE:
cudaFuncSetAttribute(
cast_mxfp8_2D_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, false,
true, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>,
cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size);
cast_mxfp8_2D_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, false,
true, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>
<<<grid, block_size, dshmem_size, stream>>>(
tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise,
tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr,
workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise);
break;
case ScalingType::BIDIMENSIONAL:
cudaFuncSetAttribute(
cast_mxfp8_2D_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, true,
true, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>,
cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size);
cast_mxfp8_2D_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, true, true,
CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>
<<<grid, block_size, dshmem_size, stream>>>(
tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise,
tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr,
workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise);
break;
}
if constexpr (IS_DBIAS) {
reduce_dbias<IType>(workspace_ptr, dbias, dbias_rows, dbias_cols, stream);
}); // NOLINT(*)
); // NOLINT(*)
}
namespace detail {
......@@ -1117,8 +1261,8 @@ void fp8_quantize_arch_ge_100(const Tensor &input, const Tensor *act_input, cons
case NVTE_DELAYED_TENSOR_SCALING: {
if (!IS_DBIAS && !IS_DACT) {
if (is_full_tile_1D_tensor(output) && is_fp8_dtype(output->dtype()) &&
is_aligned_tensor_data(input, TMA_gmem_alignment) &&
is_aligned_tensor_data(*output, TMA_gmem_alignment)) {
is_aligned_tensor_data(input, TMA_GMEM_ALIGNMENT) &&
is_aligned_tensor_data(*output, TMA_GMEM_ALIGNMENT)) {
// Aligned AND FP8
cast_fp8_1D<IS_ACT, ParamOP, OP>(input, output, stream);
} else {
......@@ -1127,9 +1271,9 @@ void fp8_quantize_arch_ge_100(const Tensor &input, const Tensor *act_input, cons
}
} else if (!IS_DBIAS && IS_DACT) {
if (dimensions_supported_by_TMA(output) && is_fp8_dtype(output->dtype()) &&
is_aligned_tensor_data(input, TMA_gmem_alignment) &&
is_aligned_tensor_data(*output, TMA_gmem_alignment) &&
is_aligned_tensor_data(*act_input, TMA_gmem_alignment)) {
is_aligned_tensor_data(input, TMA_GMEM_ALIGNMENT) &&
is_aligned_tensor_data(*output, TMA_GMEM_ALIGNMENT) &&
is_aligned_tensor_data(*act_input, TMA_GMEM_ALIGNMENT)) {
// Aligned AND FP8 (+dAct)
cast_fp8_2D<IS_DBIAS, IS_DACT, ParamOP, OP>(input, act_input, output, dbias, workspace,
stream);
......
......@@ -84,8 +84,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// const int thread_offset_X_colwise = tid_colwise_X;
// The destination shared memory buffer of a bulk tensor operation should be 128 e8m0_t aligned
__shared__ alignas(128) IType in_sh[BUFFERS_NUM][SHMEM_DIM_Y][SHMEM_DIM_X];
__shared__ alignas(128) OType out_sh[BUFFERS_NUM][SHMEM_DIM_Y][SHMEM_DIM_X];
__shared__ alignas(TMA_SHMEM_ALIGNMENT) IType in_sh[BUFFERS_NUM][SHMEM_DIM_Y][SHMEM_DIM_X];
__shared__ alignas(TMA_SHMEM_ALIGNMENT) OType out_sh[BUFFERS_NUM][SHMEM_DIM_Y][SHMEM_DIM_X];
constexpr int shmem_buff_size = sizeof(in_sh) / BUFFERS_NUM;
constexpr int transaction_size = shmem_buff_size;
......@@ -166,7 +166,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const int scale_idx = scale_offset_Y * scales_stride + scale_offset_X;
const e8m0_t biased_exponent = scales_ptr[scale_idx];
const float block_scale = exp2f(static_cast<float>(biased_exponent) - FP32_EXPONENT_BIAS);
const float block_scale = ptx::exp2f(biased_exponent);
if constexpr (USE_ROWWISE_SCALING) {
Vec<IType, ELEMS_PER_THREAD> in;
......
......@@ -104,6 +104,53 @@ __device__ __forceinline__ void mbarrier_wait_parity(uint64_t *mbar, const uint3
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
constexpr uint32_t FP32_MANTISSA_BITS = 23;
constexpr uint32_t FP32_EXPONENT_BIAS = 127;
__device__ __forceinline__ float exp2f_rcp(e8m0_t biased_exp) {
return (biased_exp == 0) ? 1
: __int_as_float((254 - biased_exp)
<< FP32_MANTISSA_BITS); // 127 - (biased_exp - 127)
}
__device__ __forceinline__ float exp2f(e8m0_t biased_exp) {
return __int_as_float(biased_exp << FP32_MANTISSA_BITS);
}
__device__ __forceinline__ e8m0_t float_to_e8m0(float val) {
#if ((__CUDA_ARCH_HAS_FEATURE__(SM100_ALL)) || (__CUDA_ARCH_HAS_FEATURE__(SM101_ALL)) || \
(__CUDA_ARCH_HAS_FEATURE__(SM120_ALL)))
uint16_t out;
asm volatile(
"{\n"
"cvt.rp.satfinite.ue8m0x2.f32 %0, 0.0, %1;\n"
"}"
: "=h"(out)
: "f"(val));
return *reinterpret_cast<e8m0_t *>(&out);
#else
// TODO: nan/inf needs to be set for any value
// of nan/inf in input not just amax.
if (isnan(val)) {
return 0xFF;
}
if (isinf(val)) {
return 0xFE;
}
if (val == 0.0f) {
return 0x00;
}
uint32_t val_u32 = *reinterpret_cast<uint32_t *>(&val);
e8m0_t exponent = (val_u32 >> FP32_MANTISSA_BITS);
uint32_t mantissa = val_u32 & 0x7FFFFF;
// Round up exponent and deal with satfinite.
if ((mantissa > 0 && exponent != 0xFE) && !(exponent == 0 && mantissa <= 0x400000)) {
++exponent;
}
return exponent;
#endif
}
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor
......@@ -169,6 +216,159 @@ __device__ __forceinline__ void fence_proxy_async_shared_cta() {
asm volatile("fence.proxy.async.shared::cta;");
}
template <typename T>
struct alignas(2 * sizeof(T)) FPx2 {
T x;
T y;
};
using floatx2 = FPx2<float>;
using bf16x2 = FPx2<bf16>;
using fp16x2 = FPx2<fp16>;
using fp8e4m3x2 = FPx2<fp8e4m3>;
using fp8e5m2x2 = FPx2<fp8e5m2>;
static_assert(sizeof(floatx2) == 8);
static_assert(sizeof(bf16x2) == 4);
static_assert(sizeof(fp16x2) == 4);
static_assert(sizeof(fp8e4m3x2) == 2);
static_assert(sizeof(fp8e5m2x2) == 2);
// SIMD like "Fused" cast + multiplication (x2)
__device__ __forceinline__ void mul_cvt_2x(fp8e4m3x2 &out, const floatx2 &in,
const floatx2 &scale) {
asm volatile(
"{\n"
".reg.b64 val_pair; \n\t"
".reg.b32 val1; \n\t"
".reg.b32 val2; \n\t"
"mul.f32x2 val_pair, %1, %2; \n\t"
"mov.b64 {val2,val1}, val_pair; \n\t"
"cvt.rn.satfinite.e4m3x2.f32 %0, val1, val2; \n\t"
"}"
: "=h"(reinterpret_cast<uint16_t &>(out))
: "l"(reinterpret_cast<const uint64_t &>(in)),
"l"(reinterpret_cast<const uint64_t &>(scale)));
}
__device__ __forceinline__ void mul_cvt_2x(fp8e5m2x2 &out, const floatx2 &in,
const floatx2 &scale) {
asm volatile(
"{\n"
".reg.b64 val_pair; \n\t"
".reg.b32 val1; \n\t"
".reg.b32 val2; \n\t"
"mul.f32x2 val_pair, %1, %2; \n\t"
"mov.b64 {val2,val1}, val_pair; \n\t"
"cvt.rn.satfinite.e5m2x2.f32 %0, val1, val2; \n\t"
"}"
: "=h"(reinterpret_cast<uint16_t &>(out))
: "l"(reinterpret_cast<const uint64_t &>(in)),
"l"(reinterpret_cast<const uint64_t &>(scale)));
}
__device__ __forceinline__ void mul_cvt_2x(fp8e4m3x2 &out, const bf16x2 &in, const floatx2 &scale) {
asm volatile(
"{\n"
".reg.b64 val_pair_before; \n\t"
".reg.b64 val_pair_after; \n\t"
".reg.b32 val1; \n\t"
".reg.b32 val2; \n\t"
".reg.b16 val1_bf16; \n\t"
".reg.b16 val2_bf16; \n\t"
"mov.b32 {val1_bf16, val2_bf16} , %1; \n\t"
"cvt.f32.bf16 val1, val1_bf16; \n\t"
"cvt.f32.bf16 val2, val2_bf16; \n\t"
"mov.b64 val_pair_before, {val1,val2}; \n\t"
"mul.f32x2 val_pair_after, val_pair_before, %2; \n\t"
"mov.b64 {val2,val1}, val_pair_after; \n\t"
"cvt.rn.satfinite.e4m3x2.f32 %0, val1, val2; \n\t"
"}"
: "=h"(reinterpret_cast<uint16_t &>(out))
: "r"(reinterpret_cast<const uint32_t &>(in)),
"l"(reinterpret_cast<const uint64_t &>(scale)));
}
__device__ __forceinline__ void mul_cvt_2x(fp8e5m2x2 &out, const bf16x2 &in, const floatx2 &scale) {
asm volatile(
"{\n"
".reg.b64 val_pair_before; \n\t"
".reg.b64 val_pair_after; \n\t"
".reg.b32 val1; \n\t"
".reg.b32 val2; \n\t"
".reg.b16 val1_bf16; \n\t"
".reg.b16 val2_bf16; \n\t"
"mov.b32 {val1_bf16, val2_bf16} , %1; \n\t"
"cvt.f32.bf16 val1, val1_bf16; \n\t"
"cvt.f32.bf16 val2, val2_bf16; \n\t"
"mov.b64 val_pair_before, {val1,val2}; \n\t"
"mul.f32x2 val_pair_after, val_pair_before, %2; \n\t"
"mov.b64 {val2,val1}, val_pair_after; \n\t"
"cvt.rn.satfinite.e5m2x2.f32 %0, val1, val2; \n\t"
"}"
: "=h"(reinterpret_cast<uint16_t &>(out))
: "r"(reinterpret_cast<const uint32_t &>(in)),
"l"(reinterpret_cast<const uint64_t &>(scale)));
}
__device__ __forceinline__ void mul_cvt_2x(fp8e4m3x2 &out, const fp16x2 &in, const floatx2 &scale) {
asm volatile(
"{\n"
".reg.b64 val_pair_before; \n\t"
".reg.b64 val_pair_after; \n\t"
".reg.b32 val1; \n\t"
".reg.b32 val2; \n\t"
".reg.b16 val1_fp16; \n\t"
".reg.b16 val2_fp16; \n\t"
"mov.b32 {val1_fp16, val2_fp16} , %1; \n\t"
"cvt.f32.f16 val1, val1_fp16; \n\t"
"cvt.f32.f16 val2, val2_fp16; \n\t"
"mov.b64 val_pair_before, {val1,val2}; \n\t"
"mul.f32x2 val_pair_after, val_pair_before, %2; \n\t"
"mov.b64 {val2,val1}, val_pair_after; \n\t"
"cvt.rn.satfinite.e4m3x2.f32 %0, val1, val2; \n\t"
"}"
: "=h"(reinterpret_cast<uint16_t &>(out))
: "r"(reinterpret_cast<const uint32_t &>(in)),
"l"(reinterpret_cast<const uint64_t &>(scale)));
}
__device__ __forceinline__ void mul_cvt_2x(fp8e5m2x2 &out, const fp16x2 &in, const floatx2 &scale) {
asm volatile(
"{\n"
".reg.b64 val_pair_before; \n\t"
".reg.b64 val_pair_after; \n\t"
".reg.b32 val1; \n\t"
".reg.b32 val2; \n\t"
".reg.b16 val1_fp16; \n\t"
".reg.b16 val2_fp16; \n\t"
"mov.b32 {val1_fp16, val2_fp16} , %1; \n\t"
"cvt.f32.f16 val1, val1_fp16; \n\t"
"cvt.f32.f16 val2, val2_fp16; \n\t"
"mov.b64 val_pair_before, {val1,val2}; \n\t"
"mul.f32x2 val_pair_after, val_pair_before, %2; \n\t"
"mov.b64 {val2,val1}, val_pair_after; \n\t"
"cvt.rn.satfinite.e5m2x2.f32 %0, val1, val2; \n\t"
"}"
: "=h"(reinterpret_cast<uint16_t &>(out))
: "r"(reinterpret_cast<const uint32_t &>(in)),
"l"(reinterpret_cast<const uint64_t &>(scale)));
}
__device__ __forceinline__ void abs_max_2x(bf16x2 &dst, const bf16x2 &p1, const bf16x2 &p2) {
asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;"
: "=r"(reinterpret_cast<uint32_t &>(dst))
: "r"(reinterpret_cast<const uint32_t &>(p1)),
"r"(reinterpret_cast<const uint32_t &>(p2)));
}
__device__ __forceinline__ void abs_max_2x(fp16x2 &dst, const fp16x2 &p1, const fp16x2 &p2) {
asm volatile("max.xorsign.abs.f16x2 %0, %1, %2;"
: "=r"(reinterpret_cast<uint32_t &>(dst))
: "r"(reinterpret_cast<const uint32_t &>(p1)),
"r"(reinterpret_cast<const uint32_t &>(p2)));
}
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
} // namespace ptx
......
......@@ -905,10 +905,7 @@ using fp8e4m3 = __nv_fp8_e4m3;
using fp8e5m2 = __nv_fp8_e5m2;
using e8m0_t = uint8_t;
constexpr uint32_t FP32_MANTISSA_BITS = 23;
constexpr uint32_t FP32_EXPONENT_BIAS = 127;
enum ScalingType { ROWWISE = 0, COLWISE = 1, BIDIMENTIONAL = 2 };
enum ScalingType { ROWWISE = 0, COLWISE = 1, BIDIMENSIONAL = 2 };
template <typename T>
struct Numeric_Traits;
......@@ -934,44 +931,6 @@ struct Quantized_Limits {
static constexpr float emax_rcp = 1.0 / emax;
};
__device__ __forceinline__ e8m0_t float_to_e8m0(float val) {
// TODO: nan/inf needs to be set for any value
// of nan/inf in input not just amax.
if (isnan(val)) {
return 0xFF;
}
if (isinf(val)) {
return 0xFE;
}
#if ((__CUDA_ARCH_HAS_FEATURE__(SM100_ALL)) || (__CUDA_ARCH_HAS_FEATURE__(SM101_ALL)) || \
(__CUDA_ARCH_HAS_FEATURE__(SM120_ALL)))
uint16_t out;
asm volatile(
"{\n"
"cvt.rp.satfinite.ue8m0x2.f32 %0, 0.0, %1;\n"
"}"
: "=h"(out)
: "f"(val));
return *reinterpret_cast<e8m0_t *>(&out);
#else
if (val == 0.0f) {
return 0x00;
}
uint32_t val_u32 = *reinterpret_cast<uint32_t *>(&val);
e8m0_t exponent = (val_u32 >> FP32_MANTISSA_BITS);
uint32_t mantissa = val_u32 & 0x7FFFFF;
// Round up exponent and deal with satfinite.
if ((mantissa > 0 && exponent != 0xFE) && !(exponent == 0 && mantissa <= 0x400000)) {
++exponent;
}
return exponent;
#endif
}
__device__ __forceinline__ float exp2f_rcp(e8m0_t biased_exp) {
return (biased_exp == 0) ? 1 : exp2f(FP32_EXPONENT_BIAS - static_cast<float>(biased_exp));
}
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_COMMON_UTILS_CUH_
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment