Unverified Commit cb504cda authored by Oleg Goncharov's avatar Oleg Goncharov Committed by GitHub
Browse files

[Common] Improved performance of mxfp8 cast kernels (#1628)



* Fixed conflicts
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Minor code refactoring to avoid unnecessary checks
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fixed typo
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Fixed dBias accumulation error due to initialization. Minor code refactoring
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Test case to reproduce the init error
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fixed rowwise dbias error
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Changed ptx API
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Added a struct for two packed FP8 values
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Rolled back to scalar code for columnwise scaling due to its better performance
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Minor corrections
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Rebased on main
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fixes per code review
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Removed constexpr in C++ test suite to build faster
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Computed activations are now numerically truncated to InputType before scaling. Improved test suite.
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Minor refactoring
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Minor refactoring
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Modified mismatches checks of MXFP8 to address FP8 numerics
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Implemented Jeremy's fixes to JAX test suite with an intermediate downcast
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Reduced the dims of the test tensors to improve CI runtime
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Fixed memory alignment issue. Compute dbias without downcast.
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fixed misaligned memory issue also in gated kernels. Reduced size of MXFP8 gated tests
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 315b47db
...@@ -36,25 +36,54 @@ enum ActivationType { ...@@ -36,25 +36,54 @@ enum ActivationType {
SReLU SReLU
}; };
template <typename InputType, typename OutputType, float (*OP)(const float)> template <typename InputType, typename OutputType>
void scale_block(const ProcessingMethod processing_method, void compute_ref(const ProcessingMethod processing_method,
float (*OP)(const float),
const bool rowwise,
const bool colwise,
const InputType* input, const InputType* input,
const InputType* grad, const InputType* grad,
OutputType* output_c, OutputType* output_rowwise,
float* dbias, OutputType* output_colwise,
fp8e8m0* output_scales, fp8e8m0* output_scales_rowwise,
const size_t scale_idx, fp8e8m0* output_scales_colwise,
const size_t i_min, InputType* output_dbias,
const size_t i_max, const size_t rows,
const size_t j_min, const size_t cols,
const size_t j_max, const size_t scales_stride_rowwise,
const size_t cols) { const size_t scales_stride_colwise)
float amax = 0.0f; {
const size_t tile_size_Y = 32;
// Find the absolute maximum value in the block const size_t tile_size_X = 32;
const size_t tiles_num_Y = (rows + tile_size_Y - 1) / tile_size_Y;
const size_t tiles_num_X = (cols + tile_size_X - 1) / tile_size_X;
std::vector<float> output_dbias_fp32(cols, 0);
#pragma omp parallel proc_bind(spread)
{
// Buffers to cache intermediate computations
std::vector<float> cache_buffer(tile_size_Y * tile_size_X);
std::vector<float> thread_dbias(cols, 0);
#pragma omp for schedule(static)
for (size_t t = 0; t < tiles_num_Y * tiles_num_X; ++t) {
const size_t tile_Y = t / tiles_num_X;
const size_t tile_X = t % tiles_num_X;
const size_t tile_offset_Y = tile_Y * tile_size_Y;
const size_t tile_offset_X = tile_X * tile_size_X;
const size_t i_min = tile_offset_Y;
const size_t i_max = std::min(i_min + tile_size_Y, rows);
const size_t j_min = tile_offset_X;
const size_t j_max = std::min(j_min + tile_size_X, cols);
// Cache computations
for (size_t i = i_min; i < i_max; ++i) { for (size_t i = i_min; i < i_max; ++i) {
for (size_t j = j_min; j < j_max; ++j) { for (size_t j = j_min; j < j_max; ++j) {
const size_t idx = i * cols + j; const int idx = i * cols + j;
const int cache_idx = (i - i_min) * tile_size_X + (j - j_min);
float elt = static_cast<float>(input[idx]); float elt = static_cast<float>(input[idx]);
if (processing_method == ProcessingMethod::CAST_DBIAS) { if (processing_method == ProcessingMethod::CAST_DBIAS) {
// grad is the input // grad is the input
...@@ -68,89 +97,58 @@ void scale_block(const ProcessingMethod processing_method, ...@@ -68,89 +97,58 @@ void scale_block(const ProcessingMethod processing_method,
processing_method == ProcessingMethod::CAST_DBIAS_DACT) { processing_method == ProcessingMethod::CAST_DBIAS_DACT) {
elt *= static_cast<float>(grad[idx]); elt *= static_cast<float>(grad[idx]);
} }
dbias[j] += elt; thread_dbias[j] += elt;
// Numerical truncation: after downcast to InputType (BF16/FP16), upcast it back to FP32
elt = static_cast<float>(static_cast<InputType>(elt));
cache_buffer[cache_idx] = elt;
if (isinf(elt) || isnan(elt)) { if (isinf(elt) || isnan(elt)) {
continue; continue;
} }
amax = std::max(amax, std::abs(elt));
} }
} }
const fp8e8m0 biased_exponent = float_to_e8m0(amax * Quantized_Limits<OutputType>::max_reciprocal()); if (rowwise) {
const float scale_reciprocal = exp2f_rcp(biased_exponent);
output_scales[scale_idx] = biased_exponent;
// Quantize elements in the block
for (size_t i = i_min; i < i_max; ++i) { for (size_t i = i_min; i < i_max; ++i) {
float block_amax = 0.0f;
for (size_t j = j_min; j < j_max; ++j) { for (size_t j = j_min; j < j_max; ++j) {
const size_t idx = i * cols + j; const int cache_idx = (i - i_min) * tile_size_X + (j - j_min);
float elt = static_cast<float>(input[idx]); block_amax = std::max(block_amax, std::abs(cache_buffer[cache_idx]));
if (processing_method == ProcessingMethod::CAST_DBIAS) {
// grad is the input
elt = static_cast<float>(grad[idx]);
} }
if (processing_method != ProcessingMethod::CAST_ONLY
&& processing_method != ProcessingMethod::CAST_DBIAS) { const fp8e8m0 biased_exponent = float_to_e8m0(block_amax * Quantized_Limits<OutputType>::max_reciprocal());
elt = OP(elt); const int scale_idx = i * scales_stride_rowwise + tile_X;
} output_scales_rowwise[scale_idx] = biased_exponent;
if (processing_method == ProcessingMethod::CAST_DACT || const float scale_reciprocal = exp2f_rcp(biased_exponent);
processing_method == ProcessingMethod::CAST_DBIAS_DACT) {
elt *= static_cast<float>(grad[idx]); for (size_t j = j_min; j < j_max; ++j) {
const int idx = i * cols + j;
const int cache_idx = (i - i_min) * tile_size_X + (j - j_min);
output_rowwise[idx] = static_cast<OutputType>(cache_buffer[cache_idx] * scale_reciprocal);
} }
output_c[idx] = static_cast<OutputType>(elt * scale_reciprocal);
} }
} }
} if (colwise) {
for (size_t j = j_min; j < j_max; ++j) {
template <typename InputType, typename OutputType, float (*OP)(const float)> float block_amax = 0.0f;
void compute_ref_x1(const ProcessingMethod processing_method,
const InputType* input,
const InputType* grad,
OutputType* output_c,
fp8e8m0* output_scales,
InputType* output_dbias,
const size_t rows,
const size_t cols,
const size_t block_size_Y,
const size_t block_size_X,
const size_t scales_stride)
{
const size_t tile_size_Y = std::max(32lu, block_size_Y);
const size_t tile_size_X = std::max(64lu, block_size_X);
const size_t tiles_num_Y = (rows + tile_size_Y - 1) / tile_size_Y;
const size_t tiles_num_X = (cols + tile_size_X - 1) / tile_size_X;
const size_t blocks_per_tile_Y = tile_size_Y / block_size_Y;
const size_t blocks_per_tile_X = tile_size_X / block_size_X;
std::vector<float> output_dbias_fp32(cols, 0);
#pragma omp parallel proc_bind(spread)
{
std::vector<float> thread_dbias(cols, 0);
#pragma omp for schedule(static)
for (size_t t = 0; t < tiles_num_Y * tiles_num_X; ++t) {
const size_t tile_Y = t / tiles_num_X;
const size_t tile_X = t % tiles_num_X;
const size_t tile_offset_Y = tile_Y * tile_size_Y;
const size_t tile_offset_X = tile_X * tile_size_X;
for (size_t ii = 0; ii < blocks_per_tile_Y; ++ii) { for (size_t i = i_min; i < i_max; ++i) {
const size_t block_idx_Y = tile_Y * blocks_per_tile_Y + ii; const int cache_idx = (i - i_min) * tile_size_X + (j - j_min);
const size_t block_offset_Y = ii * block_size_Y; block_amax = std::max(block_amax, std::abs(cache_buffer[cache_idx]));
const size_t i_min = tile_offset_Y + block_offset_Y; }
if (i_min >= rows) continue;
const size_t i_max = std::min(i_min + block_size_Y, rows);
for (size_t jj = 0; jj < blocks_per_tile_X; ++jj) { const fp8e8m0 biased_exponent = float_to_e8m0(block_amax * Quantized_Limits<OutputType>::max_reciprocal());
const size_t block_idx_X = tile_X * blocks_per_tile_X + jj; const int scale_idx = tile_Y * scales_stride_colwise + j;
const size_t block_offset_X = jj * block_size_X; output_scales_colwise[scale_idx] = biased_exponent;
const size_t j_min = tile_offset_X + block_offset_X; const float scale_reciprocal = exp2f_rcp(biased_exponent);
if (j_min >= cols) continue;
const size_t j_max = std::min(j_min + block_size_X, cols);
const size_t scale_idx = block_idx_Y * scales_stride + block_idx_X; for (size_t i = i_min; i < i_max; ++i) {
scale_block<InputType, OutputType, OP>( const int idx = i * cols + j;
processing_method, input, grad, output_c, thread_dbias.data(), const int cache_idx = (i - i_min) * tile_size_X + (j - j_min);
output_scales, scale_idx, i_min, i_max, j_min, j_max, cols); output_colwise[idx] = static_cast<OutputType>(cache_buffer[cache_idx] * scale_reciprocal);
}
} }
} }
} }
...@@ -166,29 +164,6 @@ void compute_ref_x1(const ProcessingMethod processing_method, ...@@ -166,29 +164,6 @@ void compute_ref_x1(const ProcessingMethod processing_method,
} }
} }
template <typename InputType, typename OutputType, float (*OP)(const float)>
void compute_ref_x2(const ProcessingMethod processing_method,
const InputType* input,
const InputType* grad,
OutputType* output_rowwise,
OutputType* output_colwise,
fp8e8m0* scales_rowwise,
fp8e8m0* scales_colwise,
InputType* output_dbias,
const size_t rows,
const size_t cols,
const size_t block_size_Y,
const size_t block_size_X,
const size_t scales_stride_rowwise,
const size_t scales_stride_colwise) {
compute_ref_x1<InputType, OutputType, OP>(
processing_method, input, grad, output_rowwise, scales_rowwise, output_dbias,
rows, cols, 1, block_size_X, scales_stride_rowwise);
compute_ref_x1<InputType, OutputType, OP>(
processing_method, input, grad, output_colwise, scales_colwise, output_dbias,
rows, cols, block_size_Y, 1, scales_stride_colwise);
}
/** /**
* Scaling along single dimension (either rows or columns) * Scaling along single dimension (either rows or columns)
* Produces one set of output data and the corresponding data of the fused operation (dbias): * Produces one set of output data and the corresponding data of the fused operation (dbias):
...@@ -197,8 +172,9 @@ void compute_ref_x2(const ProcessingMethod processing_method, ...@@ -197,8 +172,9 @@ void compute_ref_x2(const ProcessingMethod processing_method,
* 2) Scaled columns + column-wise scaling factors * 2) Scaled columns + column-wise scaling factors
*/ */
template <typename InputType, typename OutputType, float (*OP)(const float)> template <typename InputType, typename OutputType>
void performTest_x1(const ProcessingMethod processing_method, void performTest_x1(const ProcessingMethod processing_method,
float (*OP)(const float),
const std::vector<size_t>& shape, const std::vector<size_t>& shape,
const bool rowwise, const bool rowwise,
const bool colwise, const bool colwise,
...@@ -261,7 +237,13 @@ void performTest_x1(const ProcessingMethod processing_method, ...@@ -261,7 +237,13 @@ void performTest_x1(const ProcessingMethod processing_method,
break; break;
} }
case ProcessingMethod::CAST_DBIAS_DACT: { case ProcessingMethod::CAST_DBIAS_DACT: {
nvte_quantize_dbias_dgelu(grad.data(), auto nvte_quantize_dbias_dact = &nvte_quantize_dbias_dgelu;
if (OP == &dsilu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_dsilu; }
else if (OP == &drelu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_drelu; }
else if (OP == &dqgelu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_dqgelu; }
else if (OP == &dsrelu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_dsrelu; }
nvte_quantize_dbias_dact(grad.data(),
input.data(), input.data(),
output_c.data(), output_c.data(),
output_dbias.data(), output_dbias.data(),
...@@ -269,7 +251,7 @@ void performTest_x1(const ProcessingMethod processing_method, ...@@ -269,7 +251,7 @@ void performTest_x1(const ProcessingMethod processing_method,
0); 0);
workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype());
nvte_quantize_dbias_dgelu(grad.data(), nvte_quantize_dbias_dact(grad.data(),
input.data(), input.data(),
output_c.data(), output_c.data(),
output_dbias.data(), output_dbias.data(),
...@@ -278,11 +260,23 @@ void performTest_x1(const ProcessingMethod processing_method, ...@@ -278,11 +260,23 @@ void performTest_x1(const ProcessingMethod processing_method,
break; break;
} }
case ProcessingMethod::CAST_DACT: { case ProcessingMethod::CAST_DACT: {
nvte_dgelu(grad.data(), input.data(), output_c.data(), 0); auto nvte_dact = &nvte_dgelu;
if (OP == &dsilu) { nvte_dact = &nvte_dsilu; }
else if (OP == &drelu) { nvte_dact = &nvte_drelu; }
else if (OP == &dqgelu) { nvte_dact = &nvte_dqgelu; }
else if (OP == &dsrelu) { nvte_dact = &nvte_dsrelu; }
nvte_dact(grad.data(), input.data(), output_c.data(), 0);
break; break;
} }
case ProcessingMethod::CAST_ACT: { case ProcessingMethod::CAST_ACT: {
nvte_gelu(input.data(), output_c.data(), 0); auto nvte_act = &nvte_gelu;
if (OP == &silu) { nvte_act = &nvte_silu; }
else if (OP == &relu) { nvte_act = &nvte_relu; }
else if (OP == &qgelu) { nvte_act = &nvte_qgelu; }
else if (OP == &srelu) { nvte_act = &nvte_srelu; }
nvte_act(input.data(), output_c.data(), 0);
break; break;
} }
} }
...@@ -291,29 +285,45 @@ void performTest_x1(const ProcessingMethod processing_method, ...@@ -291,29 +285,45 @@ void performTest_x1(const ProcessingMethod processing_method,
auto err = cudaGetLastError(); auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
compute_ref_x1<InputType, OutputType, OP>(processing_method, compute_ref<InputType, OutputType>(processing_method,
OP,
rowwise,
colwise,
input.rowwise_cpu_dptr<InputType>(), input.rowwise_cpu_dptr<InputType>(),
grad.rowwise_cpu_dptr<InputType>(), grad.rowwise_cpu_dptr<InputType>(),
ref_output_c.get(), ref_output_c.get(),
ref_output_c.get(),
ref_output_scales.get(),
ref_output_scales.get(), ref_output_scales.get(),
ref_output_dbias.get(), ref_output_dbias.get(),
rows, rows,
cols, cols,
block_size_rows, scales_stride,
block_size_cols,
scales_stride); scales_stride);
auto [atol, rtol] = getTolerances(otype);
compareResults("output_c", output_c, ref_output_c.get(), rowwise, atol, rtol);
const uint8_t * const gpu_scales_ptr = rowwise const uint8_t * const gpu_scales_ptr = rowwise
? output_c.rowwise_cpu_scale_inv_ptr<fp8e8m0>() ? output_c.rowwise_cpu_scale_inv_ptr<fp8e8m0>()
: output_c.columnwise_cpu_scale_inv_ptr<fp8e8m0>(); : output_c.columnwise_cpu_scale_inv_ptr<fp8e8m0>();
const size_t scale_diff_abs_tolerance = 0;
const double abs_tolerable_mismatches_limit = 0.0;
const double rel_tolerable_mismatches_limit = 0.0;
size_t mismatches_scales = 0;
compare_e8m0_scaling_factors("scales", gpu_scales_ptr, ref_output_scales.get(), compare_e8m0_scaling_factors("scales", gpu_scales_ptr, ref_output_scales.get(),
unpadded_blocks_Y, unpadded_blocks_X, scales_stride); unpadded_blocks_Y, unpadded_blocks_X, scales_stride,
mismatches_scales,
scale_diff_abs_tolerance,
abs_tolerable_mismatches_limit,
rel_tolerable_mismatches_limit);
const size_t mismatches_elts = 32 * mismatches_scales;
auto [atol, rtol] = getTolerances(otype);
compareResults("output_c", output_c, ref_output_c.get(), rowwise, atol, rtol, true, mismatches_elts);
if (processing_method == ProcessingMethod::CAST_DBIAS || processing_method == ProcessingMethod::CAST_DBIAS_DACT) { if (processing_method == ProcessingMethod::CAST_DBIAS
|| processing_method == ProcessingMethod::CAST_DBIAS_DACT)
{
auto [atol_dbias, rtol_dbias] = getTolerances(itype); auto [atol_dbias, rtol_dbias] = getTolerances(itype);
if (itype == DType::kFloat32) { if (itype == DType::kFloat32) {
atol_dbias = 1e-4; atol_dbias = 1e-4;
...@@ -332,8 +342,9 @@ void performTest_x1(const ProcessingMethod processing_method, ...@@ -332,8 +342,9 @@ void performTest_x1(const ProcessingMethod processing_method,
* AND * AND
* 2) Scaled columns + column-wise scaling factors * 2) Scaled columns + column-wise scaling factors
*/ */
template <typename InputType, typename OutputType, float (*OP)(const float)> template <typename InputType, typename OutputType>
void performTest_x2(const ProcessingMethod processing_method, void performTest_x2(const ProcessingMethod processing_method,
float (*OP)(const float),
const std::vector<size_t>& shape, const std::vector<size_t>& shape,
const size_t block_size_rows, const size_t block_size_rows,
const size_t block_size_cols, const size_t block_size_cols,
...@@ -401,7 +412,13 @@ void performTest_x2(const ProcessingMethod processing_method, ...@@ -401,7 +412,13 @@ void performTest_x2(const ProcessingMethod processing_method,
break; break;
} }
case ProcessingMethod::CAST_DBIAS_DACT: { case ProcessingMethod::CAST_DBIAS_DACT: {
nvte_quantize_dbias_dgelu(grad.data(), auto nvte_quantize_dbias_dact = &nvte_quantize_dbias_dgelu;
if (OP == &dsilu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_dsilu; }
else if (OP == &drelu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_drelu; }
else if (OP == &dqgelu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_dqgelu; }
else if (OP == &dsrelu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_dsrelu; }
nvte_quantize_dbias_dact(grad.data(),
input.data(), input.data(),
output.data(), output.data(),
output_dbias.data(), output_dbias.data(),
...@@ -409,7 +426,7 @@ void performTest_x2(const ProcessingMethod processing_method, ...@@ -409,7 +426,7 @@ void performTest_x2(const ProcessingMethod processing_method,
0); 0);
workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype());
nvte_quantize_dbias_dgelu(grad.data(), nvte_quantize_dbias_dact(grad.data(),
input.data(), input.data(),
output.data(), output.data(),
output_dbias.data(), output_dbias.data(),
...@@ -418,11 +435,23 @@ void performTest_x2(const ProcessingMethod processing_method, ...@@ -418,11 +435,23 @@ void performTest_x2(const ProcessingMethod processing_method,
break; break;
} }
case ProcessingMethod::CAST_DACT: { case ProcessingMethod::CAST_DACT: {
nvte_dgelu(grad.data(), input.data(), output.data(), 0); auto nvte_dact = &nvte_dgelu;
if (OP == &dsilu) { nvte_dact = &nvte_dsilu; }
else if (OP == &drelu) { nvte_dact = &nvte_drelu; }
else if (OP == &dqgelu) { nvte_dact = &nvte_dqgelu; }
else if (OP == &dsrelu) { nvte_dact = &nvte_dsrelu; }
nvte_dact(grad.data(), input.data(), output.data(), 0);
break; break;
} }
case ProcessingMethod::CAST_ACT: { case ProcessingMethod::CAST_ACT: {
nvte_gelu(input.data(), output.data(), 0); auto nvte_act = &nvte_gelu;
if (OP == &silu) { nvte_act = &nvte_silu; }
else if (OP == &relu) { nvte_act = &nvte_relu; }
else if (OP == &qgelu) { nvte_act = &nvte_qgelu; }
else if (OP == &srelu) { nvte_act = &nvte_srelu; }
nvte_act(input.data(), output.data(), 0);
break; break;
} }
} }
...@@ -431,7 +460,10 @@ void performTest_x2(const ProcessingMethod processing_method, ...@@ -431,7 +460,10 @@ void performTest_x2(const ProcessingMethod processing_method,
auto err = cudaGetLastError(); auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
compute_ref_x2<InputType, OutputType, OP>(processing_method, compute_ref<InputType, OutputType>(processing_method,
OP,
true,
true,
input.rowwise_cpu_dptr<InputType>(), input.rowwise_cpu_dptr<InputType>(),
grad.rowwise_cpu_dptr<InputType>(), grad.rowwise_cpu_dptr<InputType>(),
ref_output_c_rowwise.get(), ref_output_c_rowwise.get(),
...@@ -441,22 +473,41 @@ void performTest_x2(const ProcessingMethod processing_method, ...@@ -441,22 +473,41 @@ void performTest_x2(const ProcessingMethod processing_method,
ref_output_dbias.get(), ref_output_dbias.get(),
rows, rows,
cols, cols,
block_size_rows,
block_size_cols,
scales_stride_rowwise, scales_stride_rowwise,
scales_stride_colwise); scales_stride_colwise);
auto [atol, rtol] = getTolerances(otype); const size_t scale_diff_abs_tolerance = 0;
compareResults("output_c_rowwise", output, ref_output_c_rowwise.get(), true, atol, rtol); const double abs_tolerable_mismatches_limit = 0.0;
compareResults("output_c_colwise", output, ref_output_c_colwise.get(), false, atol, rtol); const double rel_tolerable_mismatches_limit = 0.0;
size_t mismatches_scales_rowwise = 0;
compare_e8m0_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr<fp8e8m0>(), compare_e8m0_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr<fp8e8m0>(),
ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise, ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise,
unpadded_blocks_X_rowwise, scales_stride_rowwise); unpadded_blocks_X_rowwise, scales_stride_rowwise,
mismatches_scales_rowwise,
scale_diff_abs_tolerance,
abs_tolerable_mismatches_limit,
rel_tolerable_mismatches_limit);
size_t mismatches_scales_colwise = 0;
compare_e8m0_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr<fp8e8m0>(), compare_e8m0_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr<fp8e8m0>(),
ref_scales_colwise.get(), unpadded_blocks_Y_colwise, ref_scales_colwise.get(), unpadded_blocks_Y_colwise,
unpadded_blocks_X_colwise, scales_stride_colwise); unpadded_blocks_X_colwise, scales_stride_colwise,
mismatches_scales_colwise,
scale_diff_abs_tolerance,
abs_tolerable_mismatches_limit,
rel_tolerable_mismatches_limit);
const size_t mismatches_elts_rowwise = 32 * mismatches_scales_rowwise;
const size_t mismatches_elts_colwise = 32 * mismatches_scales_colwise;
if (processing_method == ProcessingMethod::CAST_DBIAS || processing_method == ProcessingMethod::CAST_DBIAS_DACT) { auto [atol, rtol] = getTolerances(otype);
compareResults("output_c_rowwise", output, ref_output_c_rowwise.get(), true, atol, rtol, true, mismatches_elts_rowwise);
compareResults("output_c_colwise", output, ref_output_c_colwise.get(), false, atol, rtol, true, mismatches_elts_colwise);
if (processing_method == ProcessingMethod::CAST_DBIAS
|| processing_method == ProcessingMethod::CAST_DBIAS_DACT)
{
auto [atol_dbias, rtol_dbias] = getTolerances(itype); auto [atol_dbias, rtol_dbias] = getTolerances(itype);
if (itype == DType::kFloat32) { if (itype == DType::kFloat32) {
atol_dbias = 1e-4; atol_dbias = 1e-4;
...@@ -475,11 +526,10 @@ std::vector<std::vector<size_t>> matrix_sizes = { ...@@ -475,11 +526,10 @@ std::vector<std::vector<size_t>> matrix_sizes = {
{128, 128}, {128, 128},
{256, 256}, {256, 256},
{993, 512}, {993, 512},
{256, 65536}, {511, 6144},
{2048, 6144}, {8192, 128},
{16384, 128}, {2048, 160},
{32768, 160}, {577, 1632},
{4096, 1632},
{1024}, {1024},
{8, 32, 1024}, {8, 32, 1024},
{16, 8, 4, 512}, {16, 8, 4, 512},
...@@ -528,26 +578,6 @@ class FusedCastMXFP8TestSuite : public ::testing::TestWithParam ...@@ -528,26 +578,6 @@ class FusedCastMXFP8TestSuite : public ::testing::TestWithParam
transformer_engine::DType, transformer_engine::DType,
InputsFillCase>> {}; InputsFillCase>> {};
#define DACT_FUNC_SWITCH(OP_FUNC_TYPE, OP, ...) \
switch (OP_FUNC_TYPE) { \
case ActivationType::Identity: { constexpr auto OP = &identity; { __VA_ARGS__ } } break; \
case ActivationType::GeLU: { constexpr auto OP = &dgelu; { __VA_ARGS__ } } break; \
case ActivationType::SiLU: { constexpr auto OP = &dsilu; { __VA_ARGS__ } } break; \
case ActivationType::ReLU: { constexpr auto OP = &drelu; { __VA_ARGS__ } } break; \
case ActivationType::QGeLU: { constexpr auto OP = &dqgelu; { __VA_ARGS__ } } break; \
case ActivationType::SReLU: { constexpr auto OP = &dsrelu; { __VA_ARGS__ } } break; \
}
#define ACT_FUNC_SWITCH(OP_FUNC_TYPE, OP, ...) \
switch (OP_FUNC_TYPE) { \
case ActivationType::Identity: { constexpr auto OP = &identity; { __VA_ARGS__ } } break; \
case ActivationType::GeLU: { constexpr auto OP = &gelu; { __VA_ARGS__ } } break; \
case ActivationType::SiLU: { constexpr auto OP = &silu; { __VA_ARGS__ } } break; \
case ActivationType::ReLU: { constexpr auto OP = &relu; { __VA_ARGS__ } } break; \
case ActivationType::QGeLU: { constexpr auto OP = &qgelu; { __VA_ARGS__ } } break; \
case ActivationType::SReLU: { constexpr auto OP = &srelu; { __VA_ARGS__ } } break; \
}
TEST_P(FusedCastMXFP8TestSuite, TestFusedCastMXFP8) { TEST_P(FusedCastMXFP8TestSuite, TestFusedCastMXFP8) {
// Skip tests for pre-Blackwell architectures // Skip tests for pre-Blackwell architectures
if (getDeviceComputeCapability() < blackwellComputeCapability) { if (getDeviceComputeCapability() < blackwellComputeCapability) {
...@@ -581,37 +611,50 @@ TEST_P(FusedCastMXFP8TestSuite, TestFusedCastMXFP8) { ...@@ -581,37 +611,50 @@ TEST_P(FusedCastMXFP8TestSuite, TestFusedCastMXFP8) {
const bool colwise = block_size.first != 1; const bool colwise = block_size.first != 1;
if (processing_method == ProcessingMethod::CAST_ACT) { if (processing_method == ProcessingMethod::CAST_ACT) {
// Forward activations // Forward activations
ACT_FUNC_SWITCH(Act_type, OP, auto OP = &identity;
switch (Act_type) {
case ActivationType::GeLU: OP = &gelu; break;
case ActivationType::SiLU: OP = &silu; break;
case ActivationType::ReLU: OP = &relu; break;
case ActivationType::QGeLU: OP = &qgelu; break;
case ActivationType::SReLU: OP = &srelu; break;
}
TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType, TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType, TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType,
if (block_size.first == 1 || block_size.second == 1) { if (block_size.first == 1 || block_size.second == 1) {
performTest_x1<InputType, OutputType, OP>( performTest_x1<InputType, OutputType>(
processing_method, matrix_size, processing_method, OP, matrix_size,
rowwise, colwise, fill_case); rowwise, colwise, fill_case);
} else { } else {
performTest_x2<InputType, OutputType, OP>( performTest_x2<InputType, OutputType>(
processing_method, matrix_size, processing_method, OP, matrix_size,
block_size.first, block_size.second, fill_case); block_size.first, block_size.second, fill_case);
} }
); );
); );
);
} else { } else {
DACT_FUNC_SWITCH(Act_type, OP, auto OP = &identity;
switch (Act_type) {
case ActivationType::GeLU: OP = &dgelu; break;
case ActivationType::SiLU: OP = &dsilu; break;
case ActivationType::ReLU: OP = &drelu; break;
case ActivationType::QGeLU: OP = &dqgelu; break;
case ActivationType::SReLU: OP = &dsrelu; break;
}
TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType, TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType, TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType,
if (block_size.first == 1 || block_size.second == 1) { if (block_size.first == 1 || block_size.second == 1) {
performTest_x1<InputType, OutputType, OP>( performTest_x1<InputType, OutputType>(
processing_method, matrix_size, processing_method, OP, matrix_size,
rowwise, colwise, fill_case); rowwise, colwise, fill_case);
} else { } else {
performTest_x2<InputType, OutputType, OP>( performTest_x2<InputType, OutputType>(
processing_method, matrix_size, processing_method, OP, matrix_size,
block_size.first, block_size.second, fill_case); block_size.first, block_size.second, fill_case);
} }
); );
); );
);
} }
} }
......
...@@ -18,134 +18,157 @@ using namespace test; ...@@ -18,134 +18,157 @@ using namespace test;
namespace { namespace {
template <bool IS_DGATED, typename IType, typename OType> template <typename IType, typename OType>
void scale_block(const IType* grad, void compute_ref(const IType* grad,
const IType* input, const IType* input,
OType* output, OType* output_rowwise,
fp8e8m0* output_scales, OType* output_colwise,
const size_t scale_idx, fp8e8m0* output_scales_rowwise,
const size_t scale_idx_gate, fp8e8m0* output_scales_colwise,
float& thread_amax, float& ref_amax,
const size_t i_min, const bool IS_DGATED,
const size_t i_max, const size_t rows,
const size_t j_min, const size_t cols,
const size_t j_max, const size_t scales_stride_rowwise,
const size_t cols) { const size_t scales_stride_colwise,
const bool is_rowwise,
float block_amax = 0.0f; const bool is_colwise) {
float block_amax_gate = 0.0f; constexpr size_t tile_size_Y = 32;
constexpr size_t tile_size_X = 32;
const size_t tiles_num_Y = (rows + tile_size_Y - 1) / tile_size_Y;
const size_t tiles_num_X = (cols + tile_size_X - 1) / tile_size_X;
float amax = 0;
#pragma omp parallel reduction(max: amax) proc_bind(spread)
{
// Buffers to cache intermediate computations
std::vector<float> cache_buffer_act(tile_size_Y * tile_size_X);
std::vector<float> cache_buffer_gate(tile_size_Y * tile_size_X);
float thread_amax = 0.0f;
#pragma omp for schedule(static)
for (size_t t = 0; t < tiles_num_Y * tiles_num_X; ++t) {
const size_t tile_Y = t / tiles_num_X;
const size_t tile_X = t % tiles_num_X;
const size_t tile_offset_Y = tile_Y * tile_size_Y;
const size_t tile_offset_X = tile_X * tile_size_X;
const size_t stride = cols * 2; const size_t stride = cols * 2;
// Find the absolute maximum value in the block const size_t i_min = tile_offset_Y;
const size_t i_max = std::min(rows, tile_offset_Y + tile_size_Y);
const size_t j_min = tile_offset_X;
const size_t j_max = std::min(cols, tile_offset_X + tile_size_X);
// Compute and cache activations for the entire tile
for (size_t i = i_min; i < i_max; ++i) { for (size_t i = i_min; i < i_max; ++i) {
for (size_t j = j_min; j < j_max; ++j) { for (size_t j = j_min; j < j_max; ++j) {
float silu_elt = static_cast<float>(input[i * stride + j]); float silu_elt = static_cast<float>(input[i * stride + j]);
float gate_elt = static_cast<float>(input[i * stride + cols + j]); float gate_elt = static_cast<float>(input[i * stride + cols + j]);
float gated_amax_act = 0;
float gated_amax_gate = 0;
if constexpr (IS_DGATED) { const int cached_idx = (i - i_min) * tile_size_X + (j - j_min);
if (IS_DGATED) {
const float x = silu_elt;
const float s = sigmoid(x);
const float act_x = x * s;
const float dact_x = x * s * (1 - s) + s;
const float grad_elt = static_cast<float>(grad[i * cols + j]); const float grad_elt = static_cast<float>(grad[i * cols + j]);
const float after_dsilu = dsilu(silu_elt) * grad_elt * gate_elt; float after_dsilu = dact_x * grad_elt * gate_elt;
const float after_dgate = silu(silu_elt) * grad_elt; float after_dgate = act_x * grad_elt;
gated_amax_act = abs(after_dsilu);
gated_amax_gate = abs(after_dgate); // Numerical truncation: after downcast to IType (BF16/FP16), upcast it back to FP32
after_dsilu = static_cast<float>(static_cast<IType>(after_dsilu));
after_dgate = static_cast<float>(static_cast<IType>(after_dgate));
cache_buffer_act[cached_idx] = after_dsilu;
cache_buffer_gate[cached_idx] = after_dgate;
thread_amax = std::max(thread_amax, std::abs(after_dsilu));
thread_amax = std::max(thread_amax, std::abs(after_dgate));
} else { } else {
const float after_silu = silu(silu_elt) * gate_elt; float after_silu = silu(silu_elt) * gate_elt;
gated_amax_act = abs(after_silu);
} // Numerical truncation: after downcast to IType (BF16/FP16), upcast it back to FP32
after_silu = static_cast<float>(static_cast<IType>(after_silu));
if (gated_amax_act > block_amax) { block_amax = gated_amax_act; } cache_buffer_act[cached_idx] = after_silu;
if (gated_amax_gate > block_amax_gate) { block_amax_gate = gated_amax_gate; } thread_amax = std::max(thread_amax, std::abs(after_silu));
} }
} }
const fp8e8m0 biased_exponent = float_to_e8m0(block_amax *
Quantized_Limits<OType>::max_reciprocal());
const float scale_reciprocal = exp2f_rcp(biased_exponent);
output_scales[scale_idx] = biased_exponent;
float scale_reciprocal_gate = 1;
if constexpr (IS_DGATED) {
const fp8e8m0 biased_exponent = float_to_e8m0(block_amax_gate *
Quantized_Limits<OType>::max_reciprocal());
scale_reciprocal_gate = exp2f_rcp(biased_exponent);
output_scales[scale_idx_gate] = biased_exponent;
} }
if (is_rowwise) {
// Quantize elements in the block
for (size_t i = i_min; i < i_max; ++i) { for (size_t i = i_min; i < i_max; ++i) {
float block_amax_act = 0.0f;
float block_amax_gate = 0.0f;
for (size_t j = j_min; j < j_max; ++j) { for (size_t j = j_min; j < j_max; ++j) {
float silu_elt = static_cast<float>(input[i * stride + j]); const int cached_idx = (i - i_min) * tile_size_X + (j - j_min);
float gate_elt = static_cast<float>(input[i * stride + cols + j]); block_amax_act = std::max(block_amax_act, std::abs(cache_buffer_act[cached_idx]));
if (IS_DGATED) {
block_amax_gate = std::max(block_amax_gate, std::abs(cache_buffer_gate[cached_idx]));
}
}
const fp8e8m0 biased_exponent_act = float_to_e8m0(block_amax_act * Quantized_Limits<OType>::max_reciprocal());
const float scale_reciprocal_act = exp2f_rcp(biased_exponent_act);
const int scale_idx_act = i * scales_stride_rowwise + tile_X;
output_scales_rowwise[scale_idx_act] = biased_exponent_act;
if constexpr (IS_DGATED) { float scale_reciprocal_gate;
const float grad_elt = static_cast<float>(grad[i * cols + j]); if (IS_DGATED) {
const float after_dsilu = dsilu(silu_elt) * grad_elt * gate_elt; const fp8e8m0 biased_exponent_gate = float_to_e8m0(block_amax_gate * Quantized_Limits<OType>::max_reciprocal());
const float after_dgate = silu(silu_elt) * grad_elt; scale_reciprocal_gate = exp2f_rcp(biased_exponent_gate);
output[i * stride + j] = static_cast<OType>(after_dsilu * scale_reciprocal); const int scale_idx_gate = scale_idx_act + (cols + 32 - 1) / 32;
output[i * stride + cols + j] = static_cast<OType>(after_dgate * output_scales_rowwise[scale_idx_gate] = biased_exponent_gate;
scale_reciprocal_gate);
} else {
const float after_silu = silu(silu_elt) * gate_elt;
output[i * cols + j] = static_cast<OType>(after_silu * scale_reciprocal);
} }
for (size_t j = j_min; j < j_max; ++j) {
const int cached_idx = (i - i_min) * tile_size_X + (j - j_min);
const float after_act = cache_buffer_act[cached_idx] * scale_reciprocal_act;
if (IS_DGATED) {
const float after_gate = cache_buffer_gate[cached_idx] * scale_reciprocal_gate;
output_rowwise[i * stride + j] = static_cast<OType>(after_act);
output_rowwise[i * stride + cols + j] = static_cast<OType>(after_gate);
} else {
output_rowwise[i * cols + j] = static_cast<OType>(after_act);
}
}
} }
} }
thread_amax = std::max(thread_amax, block_amax);
thread_amax = std::max(thread_amax, block_amax_gate);
}
template <bool IS_DGATED, typename IType, typename OType> if (is_colwise) {
void compute_ref_x1(const IType* grad, for (size_t j = j_min; j < j_max; ++j) {
const IType* input, float block_amax_act = 0.0f;
OType* output, float block_amax_gate = 0.0f;
fp8e8m0* output_scales, for (size_t i = i_min; i < i_max; ++i) {
float& ref_amax, const int cached_idx = (i - i_min) * tile_size_X + (j - j_min);
const size_t rows, block_amax_act = std::max(block_amax_act, std::abs(cache_buffer_act[cached_idx]));
const size_t cols, if (IS_DGATED) {
const size_t block_size_Y, block_amax_gate = std::max(block_amax_gate, std::abs(cache_buffer_gate[cached_idx]));
const size_t block_size_X, }
const size_t scales_stride) { }
const size_t tile_size_Y = std::max(32lu, block_size_Y); const fp8e8m0 biased_exponent_act = float_to_e8m0(block_amax_act * Quantized_Limits<OType>::max_reciprocal());
const size_t tile_size_X = std::max(64lu, block_size_X); const float scale_reciprocal_act = exp2f_rcp(biased_exponent_act);
const size_t tiles_num_Y = (rows + tile_size_Y - 1) / tile_size_Y; const int scale_idx_act = tile_Y * scales_stride_colwise + j;
const size_t tiles_num_X = (cols + tile_size_X - 1) / tile_size_X; output_scales_colwise[scale_idx_act] = biased_exponent_act;
const size_t blocks_per_tile_Y = tile_size_Y / block_size_Y;
const size_t blocks_per_tile_X = tile_size_X / block_size_X;
float amax = 0; float scale_reciprocal_gate;
#pragma omp parallel reduction(max: amax) proc_bind(spread) if (IS_DGATED) {
{ const fp8e8m0 biased_exponent_gate = float_to_e8m0(block_amax_gate * Quantized_Limits<OType>::max_reciprocal());
float thread_amax = 0; const int scale_idx_gate = scale_idx_act + cols;
#pragma omp for schedule(static) scale_reciprocal_gate = exp2f_rcp(biased_exponent_gate);
for (size_t t = 0; t < tiles_num_Y * tiles_num_X; ++t) { output_scales_colwise[scale_idx_gate] = biased_exponent_gate;
const size_t tile_Y = t / tiles_num_X; }
const size_t tile_X = t % tiles_num_X; for (size_t i = i_min; i < i_max; ++i) {
const size_t tile_offset_Y = tile_Y * tile_size_Y; const int cached_idx = (i - i_min) * tile_size_X + (j - j_min);
const size_t tile_offset_X = tile_X * tile_size_X; const float after_act = cache_buffer_act[cached_idx] * scale_reciprocal_act;
for (size_t ii = 0; ii < blocks_per_tile_Y; ++ii) { if (IS_DGATED) {
const size_t block_idx_Y = tile_Y * blocks_per_tile_Y + ii; const float after_gate = cache_buffer_gate[cached_idx] * scale_reciprocal_gate;
const size_t block_offset_Y = ii * block_size_Y; output_colwise[i * stride + j] = static_cast<OType>(after_act);
const size_t i_min = tile_offset_Y + block_offset_Y; output_colwise[i * stride + cols + j] = static_cast<OType>(after_gate);
if (i_min >= rows) continue; } else {
const size_t i_max = std::min(i_min + block_size_Y, rows); output_colwise[i * cols + j] = static_cast<OType>(after_act);
}
for (size_t jj = 0; jj < blocks_per_tile_X; ++jj) { }
const size_t block_idx_X = tile_X * blocks_per_tile_X + jj;
const size_t block_offset_X = jj * block_size_X;
const size_t j_min = tile_offset_X + block_offset_X;
if (j_min >= cols) continue;
const size_t j_max = std::min(j_min + block_size_X, cols);
const size_t mx_scale_idx = block_idx_Y * scales_stride + block_idx_X;
const size_t mx_scale_idx_gate = block_idx_Y * scales_stride + block_idx_X +
cols / block_size_X;
scale_block<IS_DGATED, IType, OType>(
grad, input, output, output_scales, mx_scale_idx, mx_scale_idx_gate,
thread_amax, i_min, i_max, j_min, j_max, cols);
} }
} }
} }
...@@ -156,26 +179,6 @@ void compute_ref_x1(const IType* grad, ...@@ -156,26 +179,6 @@ void compute_ref_x1(const IType* grad,
ref_amax = amax; ref_amax = amax;
} }
template <bool IS_DGATED, typename IType, typename OType>
void compute_ref_x2(const IType* grad,
const IType* input,
OType* output_rowwise,
OType* output_colwise,
fp8e8m0* scales_rowwise,
fp8e8m0* scales_colwise,
float& ref_amax,
const size_t rows,
const size_t cols,
const size_t block_size_Y,
const size_t block_size_X,
const size_t scales_stride_rowwise,
const size_t scales_stride_colwise) {
compute_ref_x1<IS_DGATED, IType, OType>(
grad, input, output_rowwise, scales_rowwise, ref_amax, rows, cols, 1, block_size_X, scales_stride_rowwise);
compute_ref_x1<IS_DGATED, IType, OType>(
grad, input, output_colwise, scales_colwise, ref_amax, rows, cols, block_size_Y, 1, scales_stride_colwise);
}
/** /**
* Scaling along single dimension (either rows or columns) * Scaling along single dimension (either rows or columns)
* Produces one set of output data and the corresponding data of the fused operation (dbias): * Produces one set of output data and the corresponding data of the fused operation (dbias):
...@@ -183,12 +186,13 @@ void compute_ref_x2(const IType* grad, ...@@ -183,12 +186,13 @@ void compute_ref_x2(const IType* grad,
* OR * OR
* 2) Scaled columns + column-wise scaling factors * 2) Scaled columns + column-wise scaling factors
*/ */
template <bool IS_DGATED, typename IType, typename OType> template <typename IType, typename OType>
void performTest_x1(const size_t rows, void performTest_x1(const size_t rows,
const size_t cols, const size_t cols,
const size_t block_size_rows, const size_t block_size_rows,
const size_t block_size_cols, const size_t block_size_cols,
InputsFillCase fill_case) { InputsFillCase fill_case,
const bool IS_DGATED) {
using namespace test; using namespace test;
using EncodingType = fp32; using EncodingType = fp32;
DType itype = TypeInfo<IType>::dtype; DType itype = TypeInfo<IType>::dtype;
...@@ -198,12 +202,6 @@ void performTest_x1(const size_t rows, ...@@ -198,12 +202,6 @@ void performTest_x1(const size_t rows,
const bool colwise = (block_size_rows == 32) && (block_size_cols == 1); const bool colwise = (block_size_rows == 32) && (block_size_cols == 1);
NVTE_CHECK(rowwise || colwise); NVTE_CHECK(rowwise || colwise);
// std::cout << "unpadded_blocks_Y: " << unpadded_blocks_Y << std::endl;
// std::cout << "unpadded_blocks_X: " << unpadded_blocks_X << std::endl;
// std::cout << "blocks_Y: " << blocks_Y << std::endl;
// std::cout << "blocks_X: " << blocks_X << std::endl;
// std::cout << "scales_stride: " << scales_stride << std::endl;
Tensor grad("grad", std::vector<size_t>{ rows, cols }, itype); Tensor grad("grad", std::vector<size_t>{ rows, cols }, itype);
Tensor input("input", std::vector<size_t>{ rows, cols * 2 }, itype); Tensor input("input", std::vector<size_t>{ rows, cols * 2 }, itype);
...@@ -229,12 +227,12 @@ void performTest_x1(const size_t rows, ...@@ -229,12 +227,12 @@ void performTest_x1(const size_t rows,
} }
// fillCase<EncodingType>(&grad, fill_case); // fillCase<EncodingType>(&grad, fill_case);
if constexpr (IS_DGATED) { if (IS_DGATED) {
fillUniform(&grad); fillUniform(&grad);
} }
fillUniform(&input); fillUniform(&input);
if constexpr (IS_DGATED) { if (IS_DGATED) {
nvte_dswiglu(grad.data(), input.data(), output.data(), 0); nvte_dswiglu(grad.data(), input.data(), output.data(), 0);
} else { } else {
nvte_swiglu(input.data(), output.data(), 0); nvte_swiglu(input.data(), output.data(), 0);
...@@ -245,30 +243,48 @@ void performTest_x1(const size_t rows, ...@@ -245,30 +243,48 @@ void performTest_x1(const size_t rows,
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
float ref_amax = 0; float ref_amax = 0;
compute_ref_x1<IS_DGATED, IType, OType>(grad.rowwise_cpu_dptr<IType>(), compute_ref<IType, OType>(grad.rowwise_cpu_dptr<IType>(),
input.rowwise_cpu_dptr<IType>(), input.rowwise_cpu_dptr<IType>(),
ref_output.get(), ref_output.get(),
ref_output.get(),
ref_output_scales.get(),
ref_output_scales.get(), ref_output_scales.get(),
ref_amax, ref_amax,
IS_DGATED,
rows, rows,
cols, cols,
block_size_rows, scales_stride,
block_size_cols, scales_stride,
scales_stride); rowwise,
colwise);
auto [atol, rtol] = getTolerances(otype); size_t mismatches_scales = 0;
compareResults("output", output, ref_output.get(), rowwise, atol, rtol); const size_t scale_diff_abs_tolerance = 0;
const double abs_tolerable_mismatches_limit = 1.0;
const double rel_tolerable_mismatches_limit = 1.0e-4;
const uint8_t * const gpu_scales_ptr = rowwise const uint8_t * const gpu_scales_ptr = rowwise
? output.rowwise_cpu_scale_inv_ptr<fp8e8m0>() ? output.rowwise_cpu_scale_inv_ptr<fp8e8m0>()
: output.columnwise_cpu_scale_inv_ptr<fp8e8m0>(); : output.columnwise_cpu_scale_inv_ptr<fp8e8m0>();
if (rowwise) { if (rowwise) {
compare_e8m0_scaling_factors("rowwise scales", gpu_scales_ptr, ref_output_scales.get(), compare_e8m0_scaling_factors("rowwise scales", gpu_scales_ptr, ref_output_scales.get(),
unpadded_blocks_Y, unpadded_blocks_X, scales_stride); unpadded_blocks_Y, unpadded_blocks_X, scales_stride,
mismatches_scales,
scale_diff_abs_tolerance,
abs_tolerable_mismatches_limit,
rel_tolerable_mismatches_limit);
} else { } else {
compare_e8m0_scaling_factors("colwise scales", gpu_scales_ptr, ref_output_scales.get(), compare_e8m0_scaling_factors("colwise scales", gpu_scales_ptr, ref_output_scales.get(),
unpadded_blocks_Y, unpadded_blocks_X, scales_stride); unpadded_blocks_Y, unpadded_blocks_X, scales_stride,
mismatches_scales,
scale_diff_abs_tolerance,
abs_tolerable_mismatches_limit,
rel_tolerable_mismatches_limit);
} }
const size_t mismatches_elts = 32 * mismatches_scales;
auto [atol, rtol] = getTolerances(otype);
compareResults("output", output, ref_output.get(), rowwise, atol, rtol, true, mismatches_elts);
} }
/** /**
...@@ -278,12 +294,13 @@ void performTest_x1(const size_t rows, ...@@ -278,12 +294,13 @@ void performTest_x1(const size_t rows,
* AND * AND
* 2) Scaled columns + column-wise scaling factors * 2) Scaled columns + column-wise scaling factors
*/ */
template <bool IS_DGATED, typename IType, typename OType> template <typename IType, typename OType>
void performTest_x2(const size_t rows, void performTest_x2(const size_t rows,
const size_t cols, const size_t cols,
const size_t block_size_rows, const size_t block_size_rows,
const size_t block_size_cols, const size_t block_size_cols,
InputsFillCase fill_case) { InputsFillCase fill_case,
const bool IS_DGATED) {
using namespace test; using namespace test;
using EncodingType = fp32; using EncodingType = fp32;
DType itype = TypeInfo<IType>::dtype; DType itype = TypeInfo<IType>::dtype;
...@@ -325,12 +342,12 @@ void performTest_x2(const size_t rows, ...@@ -325,12 +342,12 @@ void performTest_x2(const size_t rows,
} }
// fillCase<EncodingType>(&grad, fill_case); // fillCase<EncodingType>(&grad, fill_case);
if constexpr (IS_DGATED) { if (IS_DGATED) {
fillUniform(&grad); fillUniform(&grad);
} }
fillUniform(&input); fillUniform(&input);
if constexpr (IS_DGATED) { if (IS_DGATED) {
nvte_dswiglu(grad.data(), input.data(), output.data(), 0); nvte_dswiglu(grad.data(), input.data(), output.data(), 0);
} else { } else {
nvte_swiglu(input.data(), output.data(), 0); nvte_swiglu(input.data(), output.data(), 0);
...@@ -341,30 +358,49 @@ void performTest_x2(const size_t rows, ...@@ -341,30 +358,49 @@ void performTest_x2(const size_t rows,
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
float ref_amax = 0; float ref_amax = 0;
compute_ref_x2<IS_DGATED, IType, OType>(grad.rowwise_cpu_dptr<IType>(), compute_ref<IType, OType>(grad.rowwise_cpu_dptr<IType>(),
input.rowwise_cpu_dptr<IType>(), input.rowwise_cpu_dptr<IType>(),
ref_output_rowwise.get(), ref_output_rowwise.get(),
ref_output_colwise.get(), ref_output_colwise.get(),
ref_scales_rowwise.get(), ref_scales_rowwise.get(),
ref_scales_colwise.get(), ref_scales_colwise.get(),
ref_amax, ref_amax,
IS_DGATED,
rows, rows,
cols, cols,
block_size_rows,
block_size_cols,
scales_stride_rowwise, scales_stride_rowwise,
scales_stride_colwise); scales_stride_colwise,
true,
true);
auto [atol, rtol] = getTolerances(otype); const size_t scale_diff_abs_tolerance = 0;
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); const double abs_tolerable_mismatches_limit = 1.0;
compareResults("output_c_rowwise", output, ref_output_rowwise.get(), true, atol, rtol); const double rel_tolerable_mismatches_limit = 1.0e-4;
compareResults("output_c_colwise", output, ref_output_colwise.get(), false, atol, rtol);
size_t mismatches_scales_rowwise = 0;
compare_e8m0_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr<fp8e8m0>(), compare_e8m0_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr<fp8e8m0>(),
ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise, ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise,
unpadded_blocks_X_rowwise, scales_stride_rowwise); unpadded_blocks_X_rowwise, scales_stride_rowwise,
mismatches_scales_rowwise,
scale_diff_abs_tolerance,
abs_tolerable_mismatches_limit,
rel_tolerable_mismatches_limit);
size_t mismatches_scales_colwise = 0;
compare_e8m0_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr<fp8e8m0>(), compare_e8m0_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr<fp8e8m0>(),
ref_scales_colwise.get(), unpadded_blocks_Y_colwise, ref_scales_colwise.get(), unpadded_blocks_Y_colwise,
unpadded_blocks_X_colwise, scales_stride_colwise); unpadded_blocks_X_colwise, scales_stride_colwise,
mismatches_scales_colwise,
scale_diff_abs_tolerance,
abs_tolerable_mismatches_limit,
rel_tolerable_mismatches_limit);
const size_t mismatches_elts_rowwise = 32 * mismatches_scales_rowwise;
const size_t mismatches_elts_colwise = 32 * mismatches_scales_colwise;
auto [atol, rtol] = getTolerances(otype);
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("output_c_rowwise", output, ref_output_rowwise.get(), true, atol, rtol, true, mismatches_elts_rowwise);
compareResults("output_c_colwise", output, ref_output_colwise.get(), false, atol, rtol, true, mismatches_elts_colwise);
} }
std::vector<std::pair<size_t, size_t>> matrix_sizes = { std::vector<std::pair<size_t, size_t>> matrix_sizes = {
...@@ -375,8 +411,8 @@ std::vector<std::pair<size_t, size_t>> matrix_sizes = { ...@@ -375,8 +411,8 @@ std::vector<std::pair<size_t, size_t>> matrix_sizes = {
{256, 256}, {256, 256},
{993, 512}, {993, 512},
{768, 1024}, {768, 1024},
{65504, 128}, {8192, 128},
{16384, 1632}, {577, 1632},
}; };
std::vector<std::pair<size_t, size_t>> block_sizes = { std::vector<std::pair<size_t, size_t>> block_sizes = {
...@@ -393,9 +429,9 @@ std::vector<InputsFillCase> input_scenarios = { ...@@ -393,9 +429,9 @@ std::vector<InputsFillCase> input_scenarios = {
// InputsFillCase::maxNorm_to_inf // InputsFillCase::maxNorm_to_inf
}; };
std::vector<bool> is_dgated_op = { std::vector<bool> is_bwd_op = {
true, false,
false true
}; };
} // namespace } // namespace
...@@ -427,21 +463,11 @@ TEST_P(CastMXFP8_GatedActTestSuite, TestCastMXFP8Swiglu) { ...@@ -427,21 +463,11 @@ TEST_P(CastMXFP8_GatedActTestSuite, TestCastMXFP8Swiglu) {
TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, IType, TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OType, TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OType,
if (block_size.first == 1 || block_size.second == 1) { if (block_size.first == 1 || block_size.second == 1) {
if (IS_DGATED) { performTest_x1<IType, OType>(matrix_size.first, matrix_size.second,
performTest_x1<true, IType, OType>(matrix_size.first, matrix_size.second, block_size.first, block_size.second, fill_case, IS_DGATED);
block_size.first, block_size.second, fill_case);
} else {
performTest_x1<false, IType, OType>(matrix_size.first, matrix_size.second,
block_size.first, block_size.second, fill_case);
}
} else { } else {
if (IS_DGATED) { performTest_x2<IType, OType>(matrix_size.first, matrix_size.second,
performTest_x2<true, IType, OType>(matrix_size.first, matrix_size.second, block_size.first, block_size.second, fill_case, IS_DGATED);
block_size.first, block_size.second, fill_case);
} else {
performTest_x2<false, IType, OType>(matrix_size.first, matrix_size.second,
block_size.first, block_size.second, fill_case);
}
} }
); );
); );
...@@ -456,7 +482,7 @@ INSTANTIATE_TEST_SUITE_P( ...@@ -456,7 +482,7 @@ INSTANTIATE_TEST_SUITE_P(
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2), ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2),
::testing::ValuesIn(input_scenarios), ::testing::ValuesIn(input_scenarios),
::testing::ValuesIn(is_dgated_op)), ::testing::ValuesIn(is_bwd_op)),
[](const testing::TestParamInfo<CastMXFP8_GatedActTestSuite::ParamType>& info) { [](const testing::TestParamInfo<CastMXFP8_GatedActTestSuite::ParamType>& info) {
std::string name = std::to_string(std::get<0>(info.param).first) + "X" + std::string name = std::to_string(std::get<0>(info.param).first) + "X" +
std::to_string(std::get<0>(info.param).second) + "X" + std::to_string(std::get<0>(info.param).second) + "X" +
...@@ -465,6 +491,6 @@ INSTANTIATE_TEST_SUITE_P( ...@@ -465,6 +491,6 @@ INSTANTIATE_TEST_SUITE_P(
test::typeName(std::get<2>(info.param)) + "X" + test::typeName(std::get<2>(info.param)) + "X" +
test::typeName(std::get<3>(info.param)) + "X" + test::typeName(std::get<3>(info.param)) + "X" +
test::caseName(std::get<4>(info.param)) + "X" + test::caseName(std::get<4>(info.param)) + "X" +
(std::get<5>(info.param) ? "DGATED" : "GATED"); (std::get<5>(info.param) ? "BWD" : "FWD");
return name; return name;
}); });
...@@ -523,10 +523,13 @@ std::vector<size_t> unravel(const size_t i, const NVTEShape &shape) { ...@@ -523,10 +523,13 @@ std::vector<size_t> unravel(const size_t i, const NVTEShape &shape) {
void compareResults_sequential(const std::string &name, const Tensor &test, void compareResults_sequential(const std::string &name, const Tensor &test,
const void *ref, const bool rowwise, const void *ref, const bool rowwise,
double atol, double rtol, bool if_on_gpus) { double atol, double rtol, bool if_on_gpus,
const size_t tolerable_mismatches_limit) {
if (if_on_gpus) test.to_cpu(); if (if_on_gpus) test.to_cpu();
const auto& shape = rowwise ? test.rowwise_shape() : test.columnwise_shape(); const auto& shape = rowwise ? test.rowwise_shape() : test.columnwise_shape();
const size_t N = product(shape); const size_t N = product(shape);
size_t mismatches_num = 0;
int first_mismatch_idx = -1;
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(test.dtype(), T, TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(test.dtype(), T,
const T *test_data = rowwise ? test.rowwise_cpu_dptr<T>() : test.columnwise_cpu_dptr<T>(); const T *test_data = rowwise ? test.rowwise_cpu_dptr<T>() : test.columnwise_cpu_dptr<T>();
const T *ref_data = reinterpret_cast<const T*>(ref); const T *ref_data = reinterpret_cast<const T*>(ref);
...@@ -547,27 +550,39 @@ void compareResults_sequential(const std::string &name, const Tensor &test, ...@@ -547,27 +550,39 @@ void compareResults_sequential(const std::string &name, const Tensor &test,
assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r)); assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r));
} }
std::string direction = rowwise ? "rowwise" : "columnwise"; std::string direction = rowwise ? "rowwise" : "columnwise";
ASSERT_FALSE(assertion) << "Error in tensor " << name << " in " if (assertion) {
mismatches_num++;
if (first_mismatch_idx == -1) {
first_mismatch_idx = i;
}
}
if (mismatches_num > tolerable_mismatches_limit) {
const double first_mismatch_t = static_cast<double>(test_data[first_mismatch_idx]);
const double first_mismatch_r = static_cast<double>(ref_data[first_mismatch_idx]);
GTEST_FAIL() << mismatches_num << " mismatche(s) which is more than tolerable mismatch limit of "
<< tolerable_mismatches_limit << "." << std::endl
<< "Error in tensor " << name << " in "
<< direction << " direction." << std::endl << direction << " direction." << std::endl
<< "Mismatch at place " << to_string(unravel(i, shape)) << "First mismatch at place " << to_string(unravel(first_mismatch_idx, shape))
<< " (" << std::to_string(i) << "): " << t << " vs " << r; << " (" << std::to_string(first_mismatch_idx) << "): "
<< first_mismatch_t << " vs " << first_mismatch_r;
}
} }
); );
} }
template <typename T> template <typename T>
static size_t getFirstMismatchIdx(const DType data_type, const T* test_data, const T* ref_data, static size_t getFirstMismatchIdx(const DType data_type, const T* test_data, const T* ref_data,
const size_t N, const double atol, const double rtol) { const size_t N, const double atol, const double rtol,
size_t& mismatches) {
int first_mismatch_idx = N; int first_mismatch_idx = N;
bool is_mismatch_found = false; #pragma omp parallel reduction(min: first_mismatch_idx) reduction(+: mismatches) proc_bind(spread)
#pragma omp parallel for schedule(static) firstprivate(is_mismatch_found) \ {
reduction(min: first_mismatch_idx) proc_bind(spread) size_t thread_mismatches = 0;
#pragma omp for schedule(static)
for (size_t i = 0; i < N; ++i) { for (size_t i = 0; i < N; ++i) {
if (is_mismatch_found) { // early escape of the omp thread
continue;
}
double t = static_cast<double>(test_data[i]); double t = static_cast<double>(test_data[i]);
double r = static_cast<double>(ref_data[i]); double r = static_cast<double>(ref_data[i]);
...@@ -584,29 +599,38 @@ static size_t getFirstMismatchIdx(const DType data_type, const T* test_data, con ...@@ -584,29 +599,38 @@ static size_t getFirstMismatchIdx(const DType data_type, const T* test_data, con
const double cast_mean_m = static_cast<double>(static_cast<T>(mean_m)); const double cast_mean_m = static_cast<double>(static_cast<T>(mean_m));
assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r)); assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r));
} }
if (assertion && i < first_mismatch_idx) { if (assertion) {
if (i < first_mismatch_idx) {
first_mismatch_idx = i; first_mismatch_idx = i;
is_mismatch_found = true; }
thread_mismatches++;
} }
} }
mismatches += thread_mismatches;
}
return first_mismatch_idx; return first_mismatch_idx;
} }
void compareResults_parallel(const std::string &name, const Tensor &test, const void *ref, void compareResults_parallel(const std::string &name, const Tensor &test, const void *ref,
const bool rowwise, double atol, double rtol, bool if_on_gpus) { const bool rowwise, double atol, double rtol, bool if_on_gpus,
const size_t tolerable_mismatches_limit) {
if (if_on_gpus) test.to_cpu(); if (if_on_gpus) test.to_cpu();
const auto& shape = rowwise ? test.rowwise_shape() : test.columnwise_shape(); const auto& shape = rowwise ? test.rowwise_shape() : test.columnwise_shape();
const size_t N = product(shape); const size_t N = product(shape);
size_t mismatches = 0;
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(test.dtype(), T, TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(test.dtype(), T,
const T *test_data = rowwise ? test.rowwise_cpu_dptr<T>() : test.columnwise_cpu_dptr<T>(); const T *test_data = rowwise ? test.rowwise_cpu_dptr<T>() : test.columnwise_cpu_dptr<T>();
const T *ref_data = reinterpret_cast<const T*>(ref); const T *ref_data = reinterpret_cast<const T*>(ref);
const size_t i = getFirstMismatchIdx<T>(test.dtype(), test_data, ref_data, N, atol, rtol); const size_t i = getFirstMismatchIdx<T>(test.dtype(), test_data, ref_data, N, atol, rtol, mismatches);
if (i != N) { if ((i != N) && (mismatches > tolerable_mismatches_limit)) {
const double t = static_cast<double>(test_data[i]); const double t = static_cast<double>(test_data[i]);
const double r = static_cast<double>(ref_data[i]); const double r = static_cast<double>(ref_data[i]);
std::string direction = rowwise ? "rowwise" : "columnwise"; std::string direction = rowwise ? "rowwise" : "columnwise";
ASSERT_FALSE(true) << "Error in tensor " << name << " in "
GTEST_FAIL() << mismatches << " mismatche(s) which is more than tolerable mismatch limit of "
<< tolerable_mismatches_limit << "." << std::endl
<< "Error in tensor " << name << " in "
<< direction << " direction." << std::endl << direction << " direction." << std::endl
<< "Mismatch at place " << to_string(unravel(i, shape)) << "Mismatch at place " << to_string(unravel(i, shape))
<< " (" << std::to_string(i) << "): " << t << " vs " << r; << " (" << std::to_string(i) << "): " << t << " vs " << r;
...@@ -615,12 +639,13 @@ void compareResults_parallel(const std::string &name, const Tensor &test, const ...@@ -615,12 +639,13 @@ void compareResults_parallel(const std::string &name, const Tensor &test, const
} }
void compareResults(const std::string &name, const Tensor &test, const void *ref, void compareResults(const std::string &name, const Tensor &test, const void *ref,
const bool rowwise, double atol, double rtol, bool if_on_gpus) { const bool rowwise, double atol, double rtol, bool if_on_gpus,
const size_t tolerable_mismatches_limit) {
constexpr bool sequential = false; constexpr bool sequential = false;
if constexpr (sequential) { if constexpr (sequential) {
compareResults_sequential(name, test, ref, rowwise, atol, rtol, if_on_gpus); compareResults_sequential(name, test, ref, rowwise, atol, rtol, if_on_gpus, tolerable_mismatches_limit);
} else { } else {
compareResults_parallel(name, test, ref, rowwise, atol, rtol, if_on_gpus); compareResults_parallel(name, test, ref, rowwise, atol, rtol, if_on_gpus, tolerable_mismatches_limit);
} }
} }
...@@ -657,25 +682,39 @@ void compareResults(const std::string &name, const uint8_t *test, const uint8_t ...@@ -657,25 +682,39 @@ void compareResults(const std::string &name, const uint8_t *test, const uint8_t
} }
void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref, void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref,
const size_t row_blocks, const size_t col_blocks, const size_t stride) const size_t row_blocks, const size_t col_blocks, const size_t stride,
size_t& mismatches_num, const size_t atol,
const double abs_tolerable_mismatches_limit,
const double rel_tolerable_mismatches_limit)
{ {
const size_t N = row_blocks * col_blocks;
const size_t tolerable_mismatches_limit = std::min(abs_tolerable_mismatches_limit,
std::floor(N * rel_tolerable_mismatches_limit));
mismatches_num = 0;
std::vector<int> mismatch_indices;
for (int i = 0; i < row_blocks; ++i) { for (int i = 0; i < row_blocks; ++i) {
for (int j = 0; j < col_blocks; ++j) { for (int j = 0; j < col_blocks; ++j) {
const int idx = i * stride + j; const int idx = i * stride + j;
ASSERT_FALSE(test[idx] != ref[idx]) << "Error in " << name << std::endl const int test_val = static_cast<int>(test[idx]);
<< "Mismatch: " << static_cast<int>(test[idx]) << " vs " const int ref_val = static_cast<int>(ref[idx]);
<< static_cast<int>(ref[idx]) << " at index " << idx; const int abs_delta = std::abs(test_val - ref_val);
if (abs_delta > atol) {
mismatches_num++;
mismatch_indices.push_back(idx);
}
if (mismatches_num > tolerable_mismatches_limit) {
std::cout << "Error in " << name << std::endl;
for (const int index : mismatch_indices) {
std::cout << "Mismatch at (" << index << "):"
<< static_cast<int>(test[index]) << " vs "
<< static_cast<int>(ref[index]) << std::endl;
}
GTEST_FAIL() << mismatches_num << " mismatche(s) which is more than tolerable mismatch limit of "
<< tolerable_mismatches_limit << ".";
} }
} }
}
void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref,
const size_t N)
{
for (int i = 0; i < N; i++) {
ASSERT_FALSE(test[i] != ref[i]) << "Error in " << name << std::endl
<< "Mismatch: " << static_cast<int>(test[i]) << " vs "
<< static_cast<int>(ref[i]) << " at index " << i;
} }
} }
......
...@@ -413,7 +413,12 @@ inline fp8e8m0 float_to_e8m0(float val) { ...@@ -413,7 +413,12 @@ inline fp8e8m0 float_to_e8m0(float val) {
} }
inline float exp2f_rcp(fp8e8m0 biased_exp) { inline float exp2f_rcp(fp8e8m0 biased_exp) {
return (biased_exp == 0) ? 1 : exp2f(FP32_EXPONENT_BIAS - static_cast<float>(biased_exp)); if (biased_exp == 0) {
return 1.0f;
}
int32_t int_val = (254 - biased_exp) << FP32_MANTISSA_BITS; // 127 - (biased_exp - 127)
float fp32_val = *reinterpret_cast<float*>(&int_val);
return fp32_val;
} }
inline float identity(const float x) { return x; } inline float identity(const float x) { return x; }
...@@ -445,15 +450,18 @@ size_t last_dimension(const std::vector<size_t> &shape); ...@@ -445,15 +450,18 @@ size_t last_dimension(const std::vector<size_t> &shape);
bool areShapesEqual(const NVTEShape &s1, const NVTEShape &s2); bool areShapesEqual(const NVTEShape &s1, const NVTEShape &s2);
void compareResults(const std::string &name, const Tensor &test, const void *ref, void compareResults(const std::string &name, const Tensor &test, const void *ref,
bool rowwise, double atol = 1e-5, double rtol = 1e-8, bool if_on_gpus = true); bool rowwise, double atol = 1e-5, double rtol = 1e-8, bool if_on_gpus = true,
const size_t tolerable_mismatches_limit = 0);
void compareResults(const std::string &name, const float test, const float ref, void compareResults(const std::string &name, const float test, const float ref,
double atol = 1e-5, double rtol = 1e-8); double atol = 1e-5, double rtol = 1e-8);
void compareResults(const std::string &name, const uint8_t *test, const uint8_t *ref, void compareResults(const std::string &name, const uint8_t *test, const uint8_t *ref,
size_t N, float mismatch_rate_tol = 0.); size_t N, float mismatch_rate_tol = 0.);
void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref, void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref,
const size_t row_blocks, const size_t col_blocks, const size_t stride); const size_t row_blocks, const size_t col_blocks, const size_t stride,
void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref, size_t& mismatches_num,
const size_t N); const size_t scale_diff_abs_tolerance = 0,
const double abs_tolerable_mismatches_limit = 0,
const double rel_tolerable_mismatches_limit = 0);
std::array<size_t, 4> get_scale_tensor_dims(const size_t rows, const size_t cols, std::array<size_t, 4> get_scale_tensor_dims(const size_t rows, const size_t cols,
const size_t block_size_rows, const size_t block_size_cols); const size_t block_size_rows, const size_t block_size_cols);
......
...@@ -78,8 +78,14 @@ def is_shape_supported_by_mxfp8(input_shape): ...@@ -78,8 +78,14 @@ def is_shape_supported_by_mxfp8(input_shape):
return False return False
def assert_bitwise_scaled_tensors(a: ScaledTensor, b: ScaledTensor): def assert_bitwise_scaled_tensors(
a: ScaledTensor, b: ScaledTensor, precise_comparison: bool = True
):
if isinstance(a, ScaledTensor1x) and isinstance(b, ScaledTensor1x): if isinstance(a, ScaledTensor1x) and isinstance(b, ScaledTensor1x):
if not precise_comparison:
assert_allclose(a.dequantize(), b.dequantize(), dtype=a.data.dtype)
return
assert a.scaling_mode == b.scaling_mode assert a.scaling_mode == b.scaling_mode
assert a.scale_inv.dtype == b.scale_inv.dtype assert a.scale_inv.dtype == b.scale_inv.dtype
if a.scaling_mode.is_tensor_scaling(): if a.scaling_mode.is_tensor_scaling():
...@@ -94,8 +100,12 @@ def assert_bitwise_scaled_tensors(a: ScaledTensor, b: ScaledTensor): ...@@ -94,8 +100,12 @@ def assert_bitwise_scaled_tensors(a: ScaledTensor, b: ScaledTensor):
assert_allclose(a.data, b.data) assert_allclose(a.data, b.data)
elif isinstance(a, ScaledTensor2x) and isinstance(b, ScaledTensor2x): elif isinstance(a, ScaledTensor2x) and isinstance(b, ScaledTensor2x):
assert_bitwise_scaled_tensors(a.rowwise_tensor, b.rowwise_tensor) assert_bitwise_scaled_tensors(
assert_bitwise_scaled_tensors(a.colwise_tensor, b.colwise_tensor) a.rowwise_tensor, b.rowwise_tensor, precise_comparison=precise_comparison
)
assert_bitwise_scaled_tensors(
a.colwise_tensor, b.colwise_tensor, precise_comparison=precise_comparison
)
else: else:
pytest.fail("Unsupported input types") pytest.fail("Unsupported input types")
...@@ -481,24 +491,7 @@ class TestNorm: ...@@ -481,24 +491,7 @@ class TestNorm:
# if the input dtype is not float32 # if the input dtype is not float32
precise_comparison = False precise_comparison = False
if precise_comparison: assert_bitwise_scaled_tensors(output, ref_out, precise_comparison=precise_comparison)
assert_bitwise_scaled_tensors(output, ref_out)
else:
if isinstance(ref_out, ScaledTensor1x):
assert_allclose(output.dequantize(), ref_out.dequantize(), dtype=out_dtype)
elif isinstance(ref_out, ScaledTensor2x):
assert_allclose(
output.rowwise_tensor.dequantize(),
ref_out.rowwise_tensor.dequantize(),
dtype=out_dtype,
)
assert_allclose(
output.colwise_tensor.dequantize(),
ref_out.colwise_tensor.dequantize(),
dtype=out_dtype,
)
else:
pytest.fail("Unsupported output type")
assert_allclose(rsigma, ref_rsigma, dtype=inp_dtype) assert_allclose(rsigma, ref_rsigma, dtype=inp_dtype)
if norm_type == "layernorm": if norm_type == "layernorm":
...@@ -768,12 +761,24 @@ class TestFusedQuantize: ...@@ -768,12 +761,24 @@ class TestFusedQuantize:
)(dz, x) )(dz, x)
if is_casted_output: if is_casted_output:
assert_bitwise_scaled_tensors(te_output, jax_output) # TE kernels cast the intermediate results to the input dtype which reduces precision compared to the JAX implementation
precise_comparison = not (
in_dtype != jnp.float32 and scaling_mode.is_1d_block_scaling()
)
assert_bitwise_scaled_tensors(
te_output, jax_output, precise_comparison=precise_comparison
)
else: else:
assert_allclose(te_output, jax_output) assert_allclose(te_output, jax_output)
if is_dbias: if is_dbias:
assert_allclose(te_dbias, jax_dbias) # TE kernels cast the intermediate results to the input dtype which reduces precision compared to the JAX implementation, for dbias this typically only affects bfloat16.
precise_comparison = not (
in_dtype == jnp.bfloat16 and scaling_mode.is_1d_block_scaling()
)
assert_allclose(
te_dbias, jax_dbias, dtype=in_dtype if precise_comparison else out_dtype
)
@pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES) @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
@pytest_parametrize_wrapper("input_shape", ALL_ACTIVATION_SHAPES) @pytest_parametrize_wrapper("input_shape", ALL_ACTIVATION_SHAPES)
......
...@@ -192,6 +192,7 @@ if (NVTE_BUILD_ACTIVATION_WITH_FAST_MATH) ...@@ -192,6 +192,7 @@ if (NVTE_BUILD_ACTIVATION_WITH_FAST_MATH)
set_source_files_properties(activation/gelu.cu set_source_files_properties(activation/gelu.cu
activation/relu.cu activation/relu.cu
activation/swiglu.cu activation/swiglu.cu
util/cast.cu
PROPERTIES PROPERTIES
COMPILE_OPTIONS "--use_fast_math") COMPILE_OPTIONS "--use_fast_math")
endif() endif()
......
...@@ -162,10 +162,10 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor, ...@@ -162,10 +162,10 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor,
void *dataPtr = reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(tensor.dptr) + void *dataPtr = reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(tensor.dptr) +
(offset_elems * type_num_bits) / 8); (offset_elems * type_num_bits) / 8);
NVTE_CHECK(is_aligned_ptr(dataPtr, TMA_gmem_alignment), NVTE_CHECK(is_aligned_ptr(dataPtr, TMA_GMEM_ALIGNMENT),
"Tensor data pointer must be 16B aligned"); "Tensor data pointer must be 16B aligned");
const int TMA_needed_size = (TMA_gmem_alignment * 8) / type_num_bits; const int TMA_needed_size = (TMA_GMEM_ALIGNMENT * 8) / type_num_bits;
NVTE_CHECK(globalX % TMA_needed_size == 0, "Shape not supported. For ", type_num_bits, NVTE_CHECK(globalX % TMA_needed_size == 0, "Shape not supported. For ", type_num_bits,
"-bit data type, expected multiple of ", TMA_needed_size, ", got ", globalX); "-bit data type, expected multiple of ", TMA_needed_size, ", got ", globalX);
......
...@@ -668,7 +668,8 @@ constexpr size_t scale_tensor_alignment_X_colwise = 128; ...@@ -668,7 +668,8 @@ constexpr size_t scale_tensor_alignment_X_colwise = 128;
constexpr size_t scale_tensor_alignment_Y_colwise = 4; constexpr size_t scale_tensor_alignment_Y_colwise = 4;
// Alignment requirements for the Tensor Memory Accelerator (TMA) // Alignment requirements for the Tensor Memory Accelerator (TMA)
constexpr int TMA_gmem_alignment = 16; // global memory address alignment constexpr size_t TMA_GMEM_ALIGNMENT = 16; // global memory address alignment
constexpr size_t TMA_SHMEM_ALIGNMENT = 128; // shared memory address alignment
inline bool is_aligned_ptr(const void *ptr, size_t alignment) { inline bool is_aligned_ptr(const void *ptr, size_t alignment) {
return reinterpret_cast<uintptr_t>(ptr) % alignment == 0; return reinterpret_cast<uintptr_t>(ptr) % alignment == 0;
......
...@@ -27,14 +27,8 @@ ...@@ -27,14 +27,8 @@
namespace transformer_engine { namespace transformer_engine {
template <typename T1, typename T2>
__device__ __host__ __forceinline__ uint64_t DIVUP_TO_MULTIPLE(T1 N, T2 M) {
return DIVUP(static_cast<uint64_t>(N), static_cast<uint64_t>(M)) * M;
}
namespace gated_kernels { namespace gated_kernels {
constexpr size_t ALIGNMENT_SIZE = 128;
constexpr size_t CHUNK_DIM_Y = 128; constexpr size_t CHUNK_DIM_Y = 128;
constexpr size_t CHUNK_DIM_X = 128; constexpr size_t CHUNK_DIM_X = 128;
constexpr size_t THREADS_PER_CHUNK = 512; constexpr size_t THREADS_PER_CHUNK = 512;
...@@ -76,18 +70,19 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -76,18 +70,19 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
float amax = 0; float amax = 0;
const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1; const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1;
extern __shared__ char dshmem_unaligned[]; extern __shared__ char dynamic_shmem[];
const uint64_t dshmem_unaligned_as_uint = reinterpret_cast<uint64_t>(dshmem_unaligned); uintptr_t base_shmem_ptr = reinterpret_cast<uintptr_t>(dynamic_shmem);
const uint64_t dshmem_aligned_as_uint = // Manually align dynamic SHMEM per TMA requirements using padding
DIVUP(dshmem_unaligned_as_uint, static_cast<uint64_t>(ALIGNMENT_SIZE)) * ALIGNMENT_SIZE; // __align__(128) Does not guarantee the pointer to be aligned!
char *dshmem = reinterpret_cast<char *>(dshmem_aligned_as_uint); uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) &
~(static_cast<uintptr_t>(TMA_SHMEM_ALIGNMENT - 1));
constexpr size_t buff_elems = SHMEM_DIM_Y * SHMEM_DIM_X; constexpr size_t buff_elems = SHMEM_DIM_Y * SHMEM_DIM_X;
constexpr size_t buff_elems_total = BUFFERS_NUM * buff_elems; constexpr size_t buff_elems_total = BUFFERS_NUM * buff_elems;
constexpr size_t buff_size_aligned_in = constexpr size_t buff_size_aligned_in =
DIVUP(buff_elems_total * sizeof(IType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_aligned_out = constexpr size_t buff_size_aligned_out =
DIVUP(buff_elems_total * sizeof(OType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT);
constexpr size_t grad_mem = IS_DGATED ? buff_size_aligned_in : 0; constexpr size_t grad_mem = IS_DGATED ? buff_size_aligned_in : 0;
...@@ -96,8 +91,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -96,8 +91,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
constexpr size_t in_mem = in_act_mem + in_gate_mem; constexpr size_t in_mem = in_act_mem + in_gate_mem;
constexpr size_t out_act_mem = buff_size_aligned_out; constexpr size_t out_act_mem = buff_size_aligned_out;
// const size_t in_transaction_size = grad_mem + in_mem;
constexpr size_t in_transaction_size = buff_elems * sizeof(IType); constexpr size_t in_transaction_size = buff_elems * sizeof(IType);
// The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned
...@@ -269,9 +262,34 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -269,9 +262,34 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
} }
namespace mxfp8_kernel {
constexpr size_t CHUNK_DIM_Y = 64;
constexpr size_t CHUNK_DIM_X = 64;
constexpr size_t THREADS_PER_CHUNK_COLWISE = 128;
constexpr size_t THREADS_PER_CHUNK_NON_COLWISE = CHUNK_DIM_X;
constexpr size_t SCALE_DIM_Y = 32;
constexpr size_t SCALE_DIM_X = 32;
constexpr size_t BUFFS_NUM = 2;
constexpr size_t BUFF_DIM_Y = 32;
constexpr size_t BUFF_DIM_X = CHUNK_DIM_X;
constexpr size_t BUFF_DIM = BUFF_DIM_Y * BUFF_DIM_X;
static_assert(BUFF_DIM_Y == 32);
constexpr size_t PACK_SIZE = 4;
constexpr size_t WAVES = SCALE_DIM_X / PACK_SIZE;
// Number of 1-byte elements that span 32 banks (4-byte each) of shared memory
constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4) / 1; // 128
// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory
constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM_X; // 4 = 128 / 32
template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &), template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &), typename IType, typename OType, float (*DActOP)(float, const ParamOP &), typename IType, typename OType,
size_t SCALE_DIM_Y, size_t SCALE_DIM_X> bool ROWWISE_SCALING, bool COLWISE_SCALING, size_t THREADS_PER_CHUNK>
__global__ void __launch_bounds__(THREADS_PER_CHUNK) __global__ void __launch_bounds__(THREADS_PER_CHUNK)
cast_mxfp8_gated_kernel(const __grid_constant__ CUtensorMap tensor_map_grad, cast_mxfp8_gated_kernel(const __grid_constant__ CUtensorMap tensor_map_grad,
const __grid_constant__ CUtensorMap tensor_map_input_act, const __grid_constant__ CUtensorMap tensor_map_input_act,
...@@ -284,43 +302,73 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -284,43 +302,73 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const size_t rows, const size_t cols, const size_t scale_stride_rowwise, const size_t rows, const size_t cols, const size_t scale_stride_rowwise,
const size_t scale_stride_colwise) { const size_t scale_stride_colwise) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
constexpr bool USE_ROWWISE_SCALING = SCALE_DIM_X > 1; using IType2 = typename ptx::FPx2<IType>;
constexpr bool USE_COLWISE_SCALING = SCALE_DIM_Y > 1; using OType2 = typename ptx::FPx2<OType>;
constexpr size_t SCALES_ROWWISE_PER_CHUNK_Y = CHUNK_DIM_Y; // 128 constexpr size_t STAGES = CHUNK_DIM_Y / BUFF_DIM_Y;
constexpr size_t SCALES_ROWWISE_PER_CHUNK_X = CHUNK_DIM_X / SCALE_DIM_X; // 4 = 128 / 32 static_assert(STAGES >= 1);
constexpr size_t SCALES_COLWISE_PER_CHUNK_Y = CHUNK_DIM_Y / SCALE_DIM_Y; // 4 = 128 / 32 constexpr bool IS_CACHED_ACT_OP = ROWWISE_SCALING && COLWISE_SCALING;
constexpr size_t SCALES_COLWISE_PER_CHUNK_X = CHUNK_DIM_X; // 128 constexpr bool ONLY_COLWISE_SCALING = COLWISE_SCALING && (!ROWWISE_SCALING);
const int scales_rowwise_chunk_offset_Y = blockIdx.y * SCALES_ROWWISE_PER_CHUNK_Y; // # of rows covered by one wave. Equal to the # of columnwise threads in Y dimension.
const int scales_rowwise_chunk_offset_X = blockIdx.x * SCALES_ROWWISE_PER_CHUNK_X; constexpr int COLWISE_WAVEFRONT_SIZE = DIVUP(THREADS_PER_CHUNK, CHUNK_DIM_X);
const int scales_colwise_chunk_offset_Y = blockIdx.y * SCALES_COLWISE_PER_CHUNK_Y;
const int scales_colwise_chunk_offset_X = blockIdx.x * SCALES_COLWISE_PER_CHUNK_X;
const int chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y; const int block_offset_Y = blockIdx.y * CHUNK_DIM_Y;
const int chunk_offset_X = blockIdx.x * CHUNK_DIM_X; const int block_offset_X = blockIdx.x * CHUNK_DIM_X;
const int scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y;
const int scales_block_offset_X_rowwise = blockIdx.x * CHUNK_DIM_X / SCALE_DIM_X;
const int scales_block_offset_Y_colwise = blockIdx.y * CHUNK_DIM_Y / SCALE_DIM_Y;
const int scales_block_offset_X_colwise = blockIdx.x * CHUNK_DIM_X;
const int tid_Y = threadIdx.x / THREADS_PER_CHUNK_X; constexpr size_t THREADS_X_ROWWISE = CHUNK_DIM_X / SCALE_DIM_X;
const int tid_X = threadIdx.x % THREADS_PER_CHUNK_X;
const int thread_offset_Y = tid_Y; const int tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE;
const int thread_offset_X = tid_X; const int tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE;
const int tid_Y_colwise = threadIdx.x / CHUNK_DIM_X;
const int tid_X_colwise = threadIdx.x % CHUNK_DIM_X;
const bool col_out_of_bounds = (chunk_offset_X + thread_offset_X >= cols); const int thread_offset_Y_rowwise = tid_Y_rowwise;
const int thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM_X;
const int thread_offset_Y_colwise = tid_Y_colwise;
const int thread_offset_X_colwise = tid_X_colwise;
extern __shared__ char dshmem_unaligned[]; const int row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise;
const uint64_t dshmem_unaligned_as_uint = reinterpret_cast<uint64_t>(dshmem_unaligned); const int col_base_rowwise = block_offset_X + thread_offset_X_rowwise;
const uint64_t dshmem_aligned_as_uint = const int row_base_colwise = block_offset_Y + thread_offset_Y_colwise;
DIVUP(dshmem_unaligned_as_uint, static_cast<uint64_t>(ALIGNMENT_SIZE)) * ALIGNMENT_SIZE; const int col_base_colwise = block_offset_X + thread_offset_X_colwise;
char *dshmem = reinterpret_cast<char *>(dshmem_aligned_as_uint);
const size_t buff_elems = SHMEM_DIM_Y * SHMEM_DIM_X; const bool col_out_of_bounds_rowwise = (col_base_rowwise >= cols);
const size_t buff_elems_total = BUFFERS_NUM * buff_elems; const bool col_out_of_bounds_colwise = (col_base_colwise >= cols);
const size_t buff_size_aligned_in =
DIVUP(buff_elems_total * sizeof(IType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; const int scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise;
const size_t buff_size_aligned_out = const int scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise;
DIVUP(buff_elems_total * sizeof(OType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; const int scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise;
const int scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise;
const int gate_scale_idx_offset_rowwise = (cols + SCALE_DIM_X - 1) / SCALE_DIM_X;
const int gate_scale_idx_offset_colwise = cols;
// helps resolving bank conflicts in shmem
const int thread_lane = threadIdx.x % THREADS_PER_WARP;
const int bank_group = thread_lane / THREADS_PER_BANK;
constexpr int SUBAMAX_BUFF_DIM_Y = ONLY_COLWISE_SCALING ? COLWISE_WAVEFRONT_SIZE - 1 : 1;
__shared__ float subamax_colwise_buff[SUBAMAX_BUFF_DIM_Y][CHUNK_DIM_X];
extern __shared__ char dynamic_shmem[];
uintptr_t base_shmem_ptr = reinterpret_cast<uintptr_t>(dynamic_shmem);
// Manually align dynamic SHMEM per TMA requirements using padding
// __align__(128) Does not guarantee the pointer to be aligned!
uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) &
~(static_cast<uintptr_t>(TMA_SHMEM_ALIGNMENT - 1));
constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X;
constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems;
constexpr size_t buff_size_aligned_in =
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_aligned_out =
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT);
const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0); const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0);
...@@ -329,12 +377,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -329,12 +377,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const size_t in_mem = in_act_mem + in_gate_mem; const size_t in_mem = in_act_mem + in_gate_mem;
const size_t out_act_mem = buff_size_aligned_out; const size_t out_act_mem = buff_size_aligned_out;
const size_t out_gate_mem = buff_size_aligned_out; const size_t out_gate_mem = (IS_DGATED ? buff_size_aligned_out : 0);
const size_t out_mem = out_act_mem + out_gate_mem; const size_t out_mem = out_act_mem + out_gate_mem;
// const size_t in_transaction_size = grad_mem + in_mem;
const size_t in_transaction_size = (IS_DGATED ? 3 : 2) * buff_elems * sizeof(IType);
// The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned
IType *in_grad_sh = reinterpret_cast<IType *>(dshmem); IType *in_grad_sh = reinterpret_cast<IType *>(dshmem);
IType *in_act_sh = reinterpret_cast<IType *>(dshmem + grad_mem); IType *in_act_sh = reinterpret_cast<IType *>(dshmem + grad_mem);
...@@ -346,144 +391,94 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -346,144 +391,94 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
OType *out_act_colwise_sh = out_act_rowwise_sh; OType *out_act_colwise_sh = out_act_rowwise_sh;
OType *out_gate_colwise_sh = out_gate_rowwise_sh; OType *out_gate_colwise_sh = out_gate_rowwise_sh;
if constexpr (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) { if constexpr (ROWWISE_SCALING && COLWISE_SCALING) {
out_act_colwise_sh = reinterpret_cast<OType *>(dshmem + grad_mem + in_mem + out_mem); out_act_colwise_sh = reinterpret_cast<OType *>(dshmem + grad_mem + in_mem + out_mem);
out_gate_colwise_sh = out_gate_colwise_sh =
reinterpret_cast<OType *>(dshmem + grad_mem + in_mem + out_mem + out_act_mem); reinterpret_cast<OType *>(dshmem + grad_mem + in_mem + out_mem + out_act_mem);
} }
const uint64_t *TMAP_grad_in = reinterpret_cast<const uint64_t *>(&tensor_map_grad); IType *cached_act_sh = in_act_sh; // in_act_sh is used as a cache buffer for activations
const uint64_t *TMAP_in_act = reinterpret_cast<const uint64_t *>(&tensor_map_input_act); IType *cached_gate_sh = in_gate_sh; // in_gate_sh is used as a cache buffer for gated values
const uint64_t *TMAP_in_gate = reinterpret_cast<const uint64_t *>(&tensor_map_input_gate);
const uint64_t *TMAP_output_act_rowwise =
reinterpret_cast<const uint64_t *>(&tensor_map_output_act_rowwise);
const uint64_t *TMAP_output_gate_rowwise =
reinterpret_cast<const uint64_t *>(&tensor_map_output_gate_rowwise);
const uint64_t *TMAP_output_act_colwise =
reinterpret_cast<const uint64_t *>(&tensor_map_output_act_colwise);
const uint64_t *TMAP_output_gate_colwise =
reinterpret_cast<const uint64_t *>(&tensor_map_output_gate_colwise);
__shared__ float stage_amax_sh[THREADS_PER_CHUNK_Y][CHUNK_DIM_X]; constexpr int shmem_buff_size = buff_size_aligned_in / BUFFS_NUM;
const bool is_master_thread = (threadIdx.x == 0);
// Initialize shared memory barrier with the number of threads participating in the barrier. // Initialize shared memory barrier with the number of threads participating in the barrier.
#pragma nv_diag_suppress static_var_with_dynamic_init #pragma nv_diag_suppress static_var_with_dynamic_init
__shared__ alignas(8) uint64_t mbar[ITERATIONS]; __shared__ alignas(8) uint64_t mbar[STAGES];
const bool is_master_thread = (threadIdx.x == 0); initialize_barriers<STAGES, THREADS_PER_CHUNK>(mbar, is_master_thread);
if (is_master_thread) {
// Initialize barrier. All `blockDim.x * blockDim.y` threads in block participate.
#pragma unroll
for (int it = 0; it < ITERATIONS; ++it) {
ptx::mbarrier_init(&mbar[it], THREADS_PER_CHUNK);
}
ptx::fence_proxy_async_shared_cta();
}
// Syncthreads so initialized barrier is visible to all threads.
__syncthreads();
int parity = 0; int parity = 0;
// Prefetch data of the first stage
if (is_master_thread) {
// Initiate bulk tensor copy
// Grad
if constexpr (IS_DGATED) { if constexpr (IS_DGATED) {
ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast<uint64_t *>(&in_grad_sh[0]), copy_2d_to_sharedx3(&in_grad_sh[0], &tensor_map_grad, block_offset_X, block_offset_Y,
TMAP_grad_in, chunk_offset_X, chunk_offset_Y, &in_act_sh[0], &tensor_map_input_act, block_offset_X, block_offset_Y,
&mbar[0]); &in_gate_sh[0], &tensor_map_input_gate, block_offset_X, block_offset_Y,
} shmem_buff_size, &mbar[0], is_master_thread);
// Act
ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast<uint64_t *>(&in_act_sh[0]),
TMAP_in_act, chunk_offset_X, chunk_offset_Y,
&mbar[0]);
// Gate
ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast<uint64_t *>(&in_gate_sh[0]),
TMAP_in_gate, chunk_offset_X, chunk_offset_Y,
&mbar[0]);
// Arrive on the barrier and tell how many bytes are expected to come in.
ptx::mbarrier_arrive_expect_tx(&mbar[0], in_transaction_size);
} else { } else {
// Other threads just arrive copy_2d_to_sharedx2(&in_act_sh[0], &tensor_map_input_act, block_offset_X, block_offset_Y,
ptx::mbarrier_arrive(&mbar[0]); &in_gate_sh[0], &tensor_map_input_gate, block_offset_X, block_offset_Y,
shmem_buff_size, &mbar[0], is_master_thread);
} }
#pragma unroll #pragma unroll
for (int it = 0; it < ITERATIONS; ++it) { for (int stage = 0; stage < STAGES; ++stage) {
const int buff = it % BUFFERS_NUM; const int buff = stage % BUFFS_NUM;
const int next_it = it + 1; const int next_stage = stage + 1;
const size_t row_base = chunk_offset_Y + it * BUFFER_DIM_Y; const int stage_offset_Y = stage * BUFF_DIM_Y;
if (next_it < ITERATIONS) {
if (is_master_thread) { if (next_stage < STAGES) {
const int next_buff = next_it % BUFFERS_NUM; // Wait for TMA transfer to have finished reading shared memory.
const int chunk_it_offset_y = chunk_offset_Y + next_it * BUFFER_DIM_Y; // I.e. the buffer is ready to be written to
const int chunk_it_offset_x = chunk_offset_X; ptx::cp_async_bulk_wait_group_read<1>();
// Initiate bulk tensor copy
const int next_buff = next_stage % BUFFS_NUM;
const int next_stage_offset_Y = next_stage * BUFF_DIM_Y;
const int global_offset_Y = block_offset_Y + next_stage_offset_Y;
const int global_offset_X = block_offset_X;
const int next_buff_offset = next_buff * BUFF_DIM;
if constexpr (IS_DGATED) { if constexpr (IS_DGATED) {
// Grad copy_2d_to_sharedx3(&in_grad_sh[next_buff_offset], &tensor_map_grad, global_offset_X,
ptx::cp_async_bulk_tensor_2d_global_to_shared( global_offset_Y, &in_act_sh[next_buff_offset], &tensor_map_input_act,
reinterpret_cast<uint64_t *>(&in_grad_sh[next_buff * buff_elems]), TMAP_grad_in, global_offset_X, global_offset_Y, &in_gate_sh[next_buff_offset],
chunk_it_offset_x, chunk_it_offset_y, &mbar[next_it]); &tensor_map_input_gate, global_offset_X, global_offset_Y,
} shmem_buff_size, &mbar[next_stage], is_master_thread);
// Act
ptx::cp_async_bulk_tensor_2d_global_to_shared(
reinterpret_cast<uint64_t *>(&in_act_sh[next_buff * buff_elems]), TMAP_in_act,
chunk_it_offset_x, chunk_it_offset_y, &mbar[next_it]);
// Gate
ptx::cp_async_bulk_tensor_2d_global_to_shared(
reinterpret_cast<uint64_t *>(&in_gate_sh[next_buff * buff_elems]), TMAP_in_gate,
chunk_it_offset_x, chunk_it_offset_y, &mbar[next_it]);
// Arrive on the barrier and tell how many bytes are expected to come in.
ptx::mbarrier_arrive_expect_tx(&mbar[next_it], in_transaction_size);
} else { } else {
// Other threads just arrive copy_2d_to_sharedx2(&in_act_sh[next_buff_offset], &tensor_map_input_act, global_offset_X,
ptx::mbarrier_arrive(&mbar[next_it]); global_offset_Y, &in_gate_sh[next_buff_offset], &tensor_map_input_gate,
global_offset_X, global_offset_Y, shmem_buff_size, &mbar[next_stage],
is_master_thread);
} }
} }
ptx::fence_proxy_async_shared_cta(); ptx::fence_proxy_async_shared_cta();
// Wait for the data to have arrived // Wait for the data to have arrived
ptx::mbarrier_wait_parity(&mbar[it], parity); ptx::mbarrier_wait_parity(&mbar[stage], parity);
IType *in_grad_sh_curr = in_grad_sh + buff * buff_elems; if constexpr (COLWISE_SCALING) {
IType *in_act_sh_curr = in_act_sh + buff * buff_elems; const int shmem_offset_base_colwise =
IType *in_gate_sh_curr = in_gate_sh + buff * buff_elems; buff * BUFF_DIM + tid_Y_colwise * BUFF_DIM_X + tid_X_colwise;
OType *out_act_rowwise_sh_curr = out_act_rowwise_sh + buff * buff_elems; float thread_amax_act = 0.0f;
OType *out_gate_rowwise_sh_curr = out_gate_rowwise_sh + buff * buff_elems; float thread_amax_gate = 0.0f;
OType *out_act_colwise_sh_curr = out_act_colwise_sh + buff * buff_elems; float after_act_colwise[BUFF_DIM_Y / COLWISE_WAVEFRONT_SIZE];
OType *out_gate_colwise_sh_curr = out_gate_colwise_sh + buff * buff_elems; float after_gate_colwise[BUFF_DIM_Y / COLWISE_WAVEFRONT_SIZE];
// Assuming one iteration covers exactly 32 rows
const int iteration_scale_colwise_offset_Y = scales_colwise_chunk_offset_Y + it;
const int iteration_scale_rowwise_offset_Y = scales_rowwise_chunk_offset_Y + it * BUFFER_DIM_Y;
float after_dact_reg[BUFFER_STAGES_NUM];
float after_dgate_reg[BUFFER_STAGES_NUM];
float thread_Y_mx_block_amax = 0.0f;
float thread_Y_mx_block_amax_gate = 0.0f;
// 1. Read/Compute elements. Find MXFP8-block AMAX
#pragma unroll #pragma unroll
for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) { for (int i = 0; i < SCALE_DIM_Y / COLWISE_WAVEFRONT_SIZE; ++i) {
const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y; const int shmem_offset_colwise =
const int shmem_offset_y = thread_offset_Y + stage_offset_Y; shmem_offset_base_colwise + i * COLWISE_WAVEFRONT_SIZE * BUFF_DIM_X;
const int shmem_offset_x = thread_offset_X;
const int shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x;
const size_t row = row_base + shmem_offset_y; float act_elt = static_cast<float>(in_act_sh[shmem_offset_colwise]);
const bool row_out_of_bounds = (row >= rows); float gate_elt = static_cast<float>(in_gate_sh[shmem_offset_colwise]);
const bool out_of_bounds = (col_out_of_bounds || row_out_of_bounds); float after_act_elt;
float after_gate_elt;
float act_elt = static_cast<float>(in_act_sh_curr[shmem_idx]);
float gate_elt = static_cast<float>(in_gate_sh_curr[shmem_idx]);
if constexpr (IS_DGATED) { if constexpr (IS_DGATED) {
float grad_elt = static_cast<float>(in_grad_sh_curr[shmem_idx]); float grad_elt = static_cast<float>(in_grad_sh[shmem_offset_colwise]);
const float x = act_elt; const float x = act_elt;
float act_x; float act_x;
float dact_x; float dact_x;
...@@ -496,224 +491,393 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -496,224 +491,393 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
act_x = ActOP(x, {}); act_x = ActOP(x, {});
dact_x = DActOP(x, {}); dact_x = DActOP(x, {});
} }
after_dact_reg[stage] = dact_x * grad_elt * gate_elt; after_act_elt = dact_x * grad_elt * gate_elt;
after_dgate_reg[stage] = act_x * grad_elt; after_gate_elt = act_x * grad_elt;
} else { } else {
after_dact_reg[stage] = ActOP(act_elt, {}) * gate_elt; after_act_elt = ActOP(act_elt, {}) * gate_elt;
}
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
if constexpr (!std::is_same_v<IType, float>) {
after_act_elt = static_cast<float>(static_cast<IType>(after_act_elt));
if constexpr (IS_DGATED) {
after_gate_elt = static_cast<float>(static_cast<IType>(after_gate_elt));
}
}
after_act_colwise[i] = after_act_elt;
if constexpr (IS_DGATED) {
after_gate_colwise[i] = after_gate_elt;
} }
if constexpr (USE_ROWWISE_SCALING) { // Cache computed activations to avoid computing them again in the 2nd pass along another dimension
if constexpr (IS_CACHED_ACT_OP) {
cached_act_sh[shmem_offset_colwise] = static_cast<IType>(after_act_elt);
if constexpr (IS_DGATED) { if constexpr (IS_DGATED) {
// dgate cached_gate_sh[shmem_offset_colwise] = static_cast<IType>(after_gate_elt);
float amax = fabsf(after_dgate_reg[stage]); }
const float mx_block_X_amax = warp_reduce_max_broadcast(amax); }
const e8m0_t biased_exponent_X =
float_to_e8m0(mx_block_X_amax * Quantized_Limits<OType>::max_norm_rcp); const bool row_out_of_bounds_colwise = (row_base_colwise + stage_offset_Y + i >= rows);
const float scale_reciprocal_X = exp2f_rcp(biased_exponent_X); const bool out_of_bounds = (col_out_of_bounds_colwise || row_out_of_bounds_colwise);
out_gate_rowwise_sh_curr[shmem_idx] = if (!out_of_bounds) {
static_cast<OType>(scale_reciprocal_X * after_dgate_reg[stage]); thread_amax_act = fmaxf(thread_amax_act, fabsf(after_act_elt));
// Only single thread writes the computed scaling factor
if ((tid_X % SCALE_DIM_X == 0) && !out_of_bounds) {
const int global_scales_offset_Y =
iteration_scale_rowwise_offset_Y + stage_offset_Y + thread_offset_Y;
const int global_scales_offset_X =
scales_rowwise_chunk_offset_X + (tid_X + cols) / SCALE_DIM_X;
const int scale_idx =
global_scales_offset_Y * scale_stride_rowwise + global_scales_offset_X;
scales_rowwise[scale_idx] = biased_exponent_X;
}
}
float amax = fabsf(after_dact_reg[stage]);
const float mx_block_X_amax = warp_reduce_max_broadcast(amax);
const e8m0_t biased_exponent_X =
float_to_e8m0(mx_block_X_amax * Quantized_Limits<OType>::max_norm_rcp);
const float scale_reciprocal_X = exp2f_rcp(biased_exponent_X);
out_act_rowwise_sh_curr[shmem_idx] =
static_cast<OType>(scale_reciprocal_X * after_dact_reg[stage]);
// Only single thread writes the computed scaling factor
if ((tid_X % SCALE_DIM_X == 0) && !out_of_bounds) {
const int global_scales_offset_Y =
iteration_scale_rowwise_offset_Y + stage_offset_Y + thread_offset_Y;
const int global_scales_offset_X = scales_rowwise_chunk_offset_X + tid_X / SCALE_DIM_X;
const int scale_idx =
global_scales_offset_Y * scale_stride_rowwise + global_scales_offset_X;
scales_rowwise[scale_idx] = biased_exponent_X;
}
}
if constexpr (USE_COLWISE_SCALING) {
__builtin_assume(thread_Y_mx_block_amax >= 0);
__builtin_assume(thread_Y_mx_block_amax_gate >= 0);
thread_Y_mx_block_amax = fmaxf(thread_Y_mx_block_amax, fabsf(after_dact_reg[stage]));
if constexpr (IS_DGATED) { if constexpr (IS_DGATED) {
thread_Y_mx_block_amax_gate = thread_amax_gate = fmaxf(thread_amax_gate, fabsf(after_gate_elt));
fmaxf(thread_Y_mx_block_amax_gate, fabsf(after_dgate_reg[stage]));
} }
} }
} }
if constexpr (USE_COLWISE_SCALING) { if constexpr (ONLY_COLWISE_SCALING) {
const bool row_out_of_bounds = (row_base >= rows); // Threads, whose id along Y-dim is 0, don't need to store to shared memory,
const bool out_of_bounds = (col_out_of_bounds || row_out_of_bounds); // as they manage the columwise reduction of the amax
if (tid_Y_colwise > 0) {
subamax_colwise_buff[tid_Y_colwise - 1][tid_X_colwise] = thread_amax_act;
}
__syncthreads();
if (tid_Y_colwise == 0) {
#pragma unroll
for (int t = 0; t < SUBAMAX_BUFF_DIM_Y; ++t) {
const float other_thread_amax = subamax_colwise_buff[t][tid_X_colwise];
__builtin_assume(thread_amax_act >= 0);
__builtin_assume(other_thread_amax >= 0);
thread_amax_act = fmaxf(thread_amax_act, other_thread_amax);
}
subamax_colwise_buff[0][tid_X_colwise] = thread_amax_act;
}
__syncthreads();
// All threads read the reduced amax (ACT)
thread_amax_act = subamax_colwise_buff[0][tid_X_colwise];
if constexpr (IS_DGATED) { if constexpr (IS_DGATED) {
// Colwise max reduction of the amax element // Make sure the previous read of the ACT values has been completed,
if (tid_Y > 0) { // so the data are not rewritten
stage_amax_sh[tid_Y][tid_X] = thread_Y_mx_block_amax_gate; __syncthreads();
if (tid_Y_colwise > 0) {
subamax_colwise_buff[tid_Y_colwise - 1][tid_X_colwise] = thread_amax_gate;
} }
__syncthreads(); __syncthreads();
if (tid_Y == 0) { if (tid_Y_colwise == 0) {
#pragma unroll #pragma unroll
for (int y = 1; y < THREADS_PER_CHUNK_Y; ++y) { for (int t = 0; t < SUBAMAX_BUFF_DIM_Y; ++t) {
thread_Y_mx_block_amax_gate = const float other_thread_amax = subamax_colwise_buff[t][tid_X_colwise];
fmaxf(thread_Y_mx_block_amax_gate, stage_amax_sh[y][tid_X]); __builtin_assume(thread_amax_gate >= 0);
__builtin_assume(other_thread_amax >= 0);
thread_amax_gate = fmaxf(thread_amax_gate, other_thread_amax);
} }
stage_amax_sh[0][tid_X] = thread_Y_mx_block_amax_gate; // write mx column-block amax subamax_colwise_buff[0][tid_X_colwise] = thread_amax_gate;
} }
__syncthreads(); __syncthreads();
const float mx_block_Y_amax = stage_amax_sh[0][tid_X]; // read the mx column-block amax // All threads read the reduced amax (GATE)
thread_amax_gate = subamax_colwise_buff[0][tid_X_colwise];
}
}
// 2. Compute E8M0 scaling factor
const e8m0_t biased_exponent_act =
ptx::float_to_e8m0(thread_amax_act * Quantized_Limits<OType>::max_norm_rcp);
// For the scaling along both dimensions, the thread amax is already computed in ROWWISE section const int global_scales_offset_Y = scales_offset_Y_colwise + stage;
if constexpr (!USE_ROWWISE_SCALING) { const int global_scales_offset_X = scales_offset_X_colwise;
__builtin_assume(mx_block_Y_amax >= 0); const int scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X;
const bool row_out_of_bounds_colwise = (row_base_colwise + stage_offset_Y) >= rows;
const bool out_of_bounds_colwise = row_out_of_bounds_colwise || col_out_of_bounds_colwise;
if (tid_Y_colwise == 0 && (!out_of_bounds_colwise)) {
scales_colwise[scale_idx] = biased_exponent_act;
} }
const e8m0_t biased_exponent = float block_scale_inverse_act = ptx::exp2f_rcp(biased_exponent_act);
float_to_e8m0(mx_block_Y_amax * Quantized_Limits<OType>::max_norm_rcp); float block_scale_inverse_gate;
const float scale_reciprocal = exp2f_rcp(biased_exponent);
// Only single thread writes the computed scaling factor if constexpr (IS_DGATED) {
// Also assuming one iteration covers exactly 32 rows const e8m0_t biased_exponent_gate =
if ((tid_Y == 0) && !out_of_bounds) { ptx::float_to_e8m0(thread_amax_gate * Quantized_Limits<OType>::max_norm_rcp);
const int global_scales_offset_Y = iteration_scale_colwise_offset_Y; // const int scale_idx_gate = scale_idx + scale_stride_colwise / 2;
const int global_scales_offset_X = scales_colwise_chunk_offset_X + tid_X + cols; const int scale_idx_gate = scale_idx + gate_scale_idx_offset_colwise;
const int scale_idx = if (tid_Y_colwise == 0 && (!out_of_bounds_colwise)) {
global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; scales_colwise[scale_idx_gate] = biased_exponent_gate;
scales_colwise[scale_idx] = biased_exponent; }
block_scale_inverse_gate = ptx::exp2f_rcp(biased_exponent_gate);
} }
// 3. Scale elements
#pragma unroll #pragma unroll
for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) { for (int i = 0; i < SCALE_DIM_Y / COLWISE_WAVEFRONT_SIZE; ++i) {
const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y; const int shmem_offset_elt =
const int shmem_offset_y = thread_offset_Y + stage_offset_Y; shmem_offset_base_colwise + i * COLWISE_WAVEFRONT_SIZE * BUFF_DIM_X;
const int shmem_offset_x = thread_offset_X; if constexpr (IS_DGATED) {
const int shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x; OType2 out_pair;
ptx::floatx2 in_pair = {after_act_colwise[i], after_gate_colwise[i]};
out_gate_colwise_sh_curr[shmem_idx] = const ptx::floatx2 block_scale_inverse_2x_pair = {block_scale_inverse_act,
static_cast<OType>(scale_reciprocal * after_dgate_reg[stage]); block_scale_inverse_gate};
ptx::mul_cvt_2x(out_pair, in_pair, block_scale_inverse_2x_pair);
out_act_colwise_sh[shmem_offset_elt] = out_pair.x;
out_gate_colwise_sh[shmem_offset_elt] = out_pair.y;
} else {
const float scaled_out_act = block_scale_inverse_act * after_act_colwise[i];
out_act_colwise_sh[shmem_offset_elt] = static_cast<OType>(scaled_out_act);
} }
} }
// Colwise max reduction of the amax element
if (tid_Y > 0) {
stage_amax_sh[tid_Y][tid_X] = thread_Y_mx_block_amax;
} }
if constexpr (ROWWISE_SCALING) {
const int shmem_offset_base_rowwise = buff * BUFF_DIM + thread_offset_Y_rowwise * BUFF_DIM_X;
float thread_amax_act = 0.0f;
float thread_amax_gate = 0.0f;
Vec<IType, PACK_SIZE> in_cached_act[WAVES];
Vec<IType, PACK_SIZE> in_cached_gate[WAVES];
float after_act_rowwise[SCALE_DIM_X];
float after_gate_rowwise[SCALE_DIM_X];
// 1. Read/Compute elements. Find MXFP8-block AMAX
if constexpr (IS_CACHED_ACT_OP) {
// ensures that all writes to cache made in the section above are visible to all threads
__syncthreads(); __syncthreads();
if (tid_Y == 0) { IType2 thread_amax_2x_act = {static_cast<IType>(0.0f), static_cast<IType>(0.0f)};
IType2 thread_amax_2x_gate = {static_cast<IType>(0.0f), static_cast<IType>(0.0f)};
#pragma unroll #pragma unroll
for (int y = 1; y < THREADS_PER_CHUNK_Y; ++y) { for (int w = 0; w < WAVES; ++w) {
thread_Y_mx_block_amax = fmaxf(thread_Y_mx_block_amax, stage_amax_sh[y][tid_X]); const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const int shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx;
const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows);
const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols);
const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds);
// Load cached elements
in_cached_act[w].load_from(&cached_act_sh[shmem_offset_rowwise]);
if constexpr (IS_DGATED) {
in_cached_gate[w].load_from(&cached_gate_sh[shmem_offset_rowwise]);
} }
stage_amax_sh[0][tid_X] = thread_Y_mx_block_amax; // write mx column-block amax // Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements)
// only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries
if (!out_of_bounds) {
if constexpr (std::is_same_v<IType, float>) {
#pragma unroll
for (int e = 0; e < PACK_SIZE; ++e) {
thread_amax_act = fmaxf(thread_amax_act, fabsf(in_cached_act[w].data.elt[e]));
if constexpr (IS_DGATED) {
thread_amax_gate = fmaxf(thread_amax_gate, fabsf(in_cached_gate[w].data.elt[e]));
} }
__syncthreads(); }
} else {
#pragma unroll
for (int e = 0; e < PACK_SIZE; e += 2) {
const IType2 in_cached_2x_act = {in_cached_act[w].data.elt[e],
in_cached_act[w].data.elt[e + 1]};
ptx::abs_max_2x(thread_amax_2x_act, thread_amax_2x_act, in_cached_2x_act);
if constexpr (IS_DGATED) {
const IType2 in_cached_2x_gate = {in_cached_gate[w].data.elt[e],
in_cached_gate[w].data.elt[e + 1]};
ptx::abs_max_2x(thread_amax_2x_gate, thread_amax_2x_gate, in_cached_2x_gate);
}
}
}
}
}
if constexpr (!std::is_same_v<IType, float>) {
thread_amax_act = static_cast<float>(
__hmax(__habs(thread_amax_2x_act.x), __habs(thread_amax_2x_act.y)));
if constexpr (IS_DGATED) {
thread_amax_gate = static_cast<float>(
__hmax(__habs(thread_amax_2x_gate.x), __habs(thread_amax_2x_gate.y)));
}
}
} else {
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const int shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx;
const float mx_block_Y_amax = stage_amax_sh[0][tid_X]; // read the mx column-block amax Vec<IType, PACK_SIZE> in_grad;
Vec<IType, PACK_SIZE> in_act;
Vec<IType, PACK_SIZE> in_gate;
// For the scaling along both dimensions, the thread amax is already computed in ROWWISE section in_act.load_from(&in_act_sh[shmem_offset_rowwise]);
if constexpr (!USE_ROWWISE_SCALING) { in_gate.load_from(&in_gate_sh[shmem_offset_rowwise]);
__builtin_assume(mx_block_Y_amax >= 0); if constexpr (IS_DGATED) {
in_grad.load_from(&in_grad_sh[shmem_offset_rowwise]);
} }
const e8m0_t biased_exponent = #pragma unroll
float_to_e8m0(mx_block_Y_amax * Quantized_Limits<OType>::max_norm_rcp); for (int e = 0; e < PACK_SIZE; ++e) {
const float scale_reciprocal = exp2f_rcp(biased_exponent); const int j = w * PACK_SIZE + e;
float act_elt = static_cast<float>(in_act.data.elt[e]);
float gate_elt = static_cast<float>(in_gate.data.elt[e]);
float after_act_elt;
float after_gate_elt;
// Only single thread writes the computed scaling factor if constexpr (IS_DGATED) {
// Also assuming one iteration covers exactly 32 rows float grad_elt = static_cast<float>(in_grad.data.elt[e]);
if ((tid_Y == 0) && !out_of_bounds) { const float x = act_elt;
const int global_scales_offset_Y = iteration_scale_colwise_offset_Y; float act_x;
const int global_scales_offset_X = scales_colwise_chunk_offset_X + tid_X; float dact_x;
const int scale_idx =
global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; if constexpr ((ActOP == &silu<fp32, fp32>) && (DActOP == &dsilu<fp32, fp32>)) {
scales_colwise[scale_idx] = biased_exponent; const float s = sigmoidf(x);
act_x = x * s;
dact_x = x * s * (1 - s) + s;
} else {
act_x = ActOP(x, {});
dact_x = DActOP(x, {});
}
after_act_elt = dact_x * grad_elt * gate_elt;
after_gate_elt = act_x * grad_elt;
after_act_rowwise[j] = after_act_elt;
after_gate_rowwise[j] = after_gate_elt;
} else {
after_act_elt = ActOP(act_elt, {}) * gate_elt;
after_act_rowwise[j] = after_act_elt;
} }
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
if constexpr (!std::is_same_v<IType, float>) {
after_act_elt = static_cast<float>(static_cast<IType>(after_act_elt));
if constexpr (IS_DGATED) {
after_gate_elt = static_cast<float>(static_cast<IType>(after_gate_elt));
}
}
const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows);
const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols);
const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds);
if (!out_of_bounds) {
thread_amax_act = fmaxf(thread_amax_act, fabsf(after_act_elt));
if constexpr (IS_DGATED) {
thread_amax_gate = fmaxf(thread_amax_gate, fabsf(after_gate_elt));
}
}
}
}
}
// 2. Compute E8M0 scaling factor
const e8m0_t biased_exponent_act =
ptx::float_to_e8m0(thread_amax_act * Quantized_Limits<OType>::max_norm_rcp);
const int stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y;
const int stage_scales_offset_X = scales_offset_X_rowwise;
const int scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X;
const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y) >= rows;
const bool out_of_bounds_rowwise = row_out_of_bounds_rowwise || col_out_of_bounds_rowwise;
if (!out_of_bounds_rowwise) {
scales_rowwise[scale_idx] = biased_exponent_act;
}
const float block_scale_inverse_act = ptx::exp2f_rcp(biased_exponent_act);
const ptx::floatx2 block_scale_inverse_2x_act = {block_scale_inverse_act,
block_scale_inverse_act};
float block_scale_inverse_gate;
ptx::floatx2 block_scale_inverse_2x_gate;
if constexpr (IS_DGATED) {
const e8m0_t biased_exponent_gate =
ptx::float_to_e8m0(thread_amax_gate * Quantized_Limits<OType>::max_norm_rcp);
const int scale_idx_gate = scale_idx + gate_scale_idx_offset_rowwise;
if (!out_of_bounds_rowwise) {
scales_rowwise[scale_idx_gate] = biased_exponent_gate;
}
block_scale_inverse_gate = ptx::exp2f_rcp(biased_exponent_gate);
block_scale_inverse_2x_gate = {block_scale_inverse_gate, block_scale_inverse_gate};
}
// 3. Scale elements
#pragma unroll #pragma unroll
for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) { for (int w = 0; w < WAVES; ++w) {
const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y; Vec<OType2, PACK_SIZE / 2> out_act;
const int shmem_offset_y = thread_offset_Y + stage_offset_Y; Vec<OType2, PACK_SIZE / 2> out_gate;
const int shmem_offset_x = thread_offset_X; #pragma unroll
const int shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x; for (int e = 0; e < PACK_SIZE / 2; ++e) {
IType2 in_act;
OType2 &out_act_pair = reinterpret_cast<OType2 &>(out_act.data.elt[e]);
out_act_colwise_sh_curr[shmem_idx] = if constexpr (IS_CACHED_ACT_OP) {
static_cast<OType>(scale_reciprocal * after_dact_reg[stage]); in_act.x = in_cached_act[w].data.elt[2 * e];
in_act.y = in_cached_act[w].data.elt[2 * e + 1];
} else {
const int j = w * PACK_SIZE + 2 * e;
in_act.x = after_act_rowwise[j];
in_act.y = after_act_rowwise[j + 1];
} }
} // endif USE_COLWISE_SCALING ptx::mul_cvt_2x(out_act_pair, in_act, block_scale_inverse_2x_act);
// Wait for shared memory writes to be visible to TMA engine (cross-proxy fence) if constexpr (IS_DGATED) {
IType2 in_gate;
OType2 &out_gate_pair = reinterpret_cast<OType2 &>(out_gate.data.elt[e]);
if constexpr (IS_CACHED_ACT_OP) {
in_gate.x = in_cached_gate[w].data.elt[2 * e];
in_gate.y = in_cached_gate[w].data.elt[2 * e + 1];
} else {
const int j = w * PACK_SIZE + 2 * e;
in_gate.x = after_gate_rowwise[j];
in_gate.y = after_gate_rowwise[j + 1];
}
ptx::mul_cvt_2x(out_gate_pair, in_gate, block_scale_inverse_2x_gate);
}
}
const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const int swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise;
const int shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx;
out_act.store_to(&out_act_rowwise_sh[shmem_offset_rowwise]);
if constexpr (IS_DGATED) {
out_gate.store_to(&out_gate_rowwise_sh[shmem_offset_rowwise]);
}
}
}
// Wait for shared memory writes to be visible to TMA engine.
ptx::fence_proxy_async_shared_cta(); ptx::fence_proxy_async_shared_cta();
__syncthreads(); __syncthreads();
// After syncthreads, writes by all threads are visible to TMA engine. // After syncthreads, writes by all threads are visible to TMA engine.
// Initiate TMA transfer to copy shared memory to global memory // Initiate TMA transfer to copy shared memory to global memory
if (is_master_thread) { if (is_master_thread) {
const int chunk_it_offset_y = chunk_offset_Y + it * BUFFER_DIM_Y; const int global_offset_Y = block_offset_Y + stage_offset_Y;
const int chunk_it_offset_x = chunk_offset_X; const int global_offset_X = block_offset_X;
const int buff_offset = buff * BUFF_DIM;
// dGeLU if constexpr (ROWWISE_SCALING) {
if constexpr (USE_ROWWISE_SCALING) {
ptx::cp_async_bulk_tensor_2d_shared_to_global( ptx::cp_async_bulk_tensor_2d_shared_to_global(
TMAP_output_act_rowwise, chunk_it_offset_x, chunk_it_offset_y, reinterpret_cast<const uint64_t *>(&tensor_map_output_act_rowwise), global_offset_X,
reinterpret_cast<uint64_t *>(out_act_rowwise_sh_curr)); global_offset_Y, reinterpret_cast<uint64_t *>(&out_act_rowwise_sh[buff_offset]));
if constexpr (IS_DGATED) { if constexpr (IS_DGATED) {
// dGate
ptx::cp_async_bulk_tensor_2d_shared_to_global( ptx::cp_async_bulk_tensor_2d_shared_to_global(
TMAP_output_gate_rowwise, chunk_it_offset_x, chunk_it_offset_y, reinterpret_cast<const uint64_t *>(&tensor_map_output_gate_rowwise), global_offset_X,
reinterpret_cast<uint64_t *>(out_gate_rowwise_sh_curr)); global_offset_Y, reinterpret_cast<uint64_t *>(&out_gate_rowwise_sh[buff_offset]));
} }
} }
if constexpr (COLWISE_SCALING) {
// dGeLU
if constexpr (USE_COLWISE_SCALING) {
ptx::cp_async_bulk_tensor_2d_shared_to_global( ptx::cp_async_bulk_tensor_2d_shared_to_global(
TMAP_output_act_colwise, chunk_it_offset_x, chunk_it_offset_y, reinterpret_cast<const uint64_t *>(&tensor_map_output_act_colwise), global_offset_X,
reinterpret_cast<uint64_t *>(out_act_colwise_sh_curr)); global_offset_Y, reinterpret_cast<uint64_t *>(&out_act_colwise_sh[buff_offset]));
if constexpr (IS_DGATED) { if constexpr (IS_DGATED) {
// dGate
ptx::cp_async_bulk_tensor_2d_shared_to_global( ptx::cp_async_bulk_tensor_2d_shared_to_global(
TMAP_output_gate_colwise, chunk_it_offset_x, chunk_it_offset_y, reinterpret_cast<const uint64_t *>(&tensor_map_output_gate_colwise), global_offset_X,
reinterpret_cast<uint64_t *>(out_gate_colwise_sh_curr)); global_offset_Y, reinterpret_cast<uint64_t *>(&out_gate_colwise_sh[buff_offset]));
} }
} }
// Create a "bulk async-group" out of the previous bulk copy operation. // Create a "bulk async-group" out of the previous bulk copy operation.
ptx::cp_async_bulk_commit_group(); ptx::cp_async_bulk_commit_group();
// Wait for TMA transfer to have finished reading shared memory.
ptx::cp_async_bulk_wait_group_read<BUFFERS_NUM - 1>();
} }
} }
ptx::cp_async_bulk_wait_group_read<0>();
__syncthreads();
// Destroy the barriers. This invalidates the memory region of the barrier. parity ^= 1;
// If further computations were to take place in the kernel, this allows the destroy_barriers<STAGES>(mbar, is_master_thread);
// memory location of the shared memory barrier to be reused.
if (is_master_thread) {
#pragma unroll
for (int it = 0; it < ITERATIONS; ++it) {
ptx::mbarrier_invalid(&mbar[it]);
}
}
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
} }
} // namespace mxfp8_kernel
template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &), template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &)> float (*DActOP)(float, const ParamOP &)>
...@@ -771,17 +935,16 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu ...@@ -771,17 +935,16 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu
const size_t buff_elems_total = BUFFERS_NUM * SHMEM_DIM_Y * SHMEM_DIM_X; const size_t buff_elems_total = BUFFERS_NUM * SHMEM_DIM_Y * SHMEM_DIM_X;
const size_t buff_size_aligned_in = const size_t buff_size_aligned_in =
DIVUP(buff_elems_total * sizeof(IType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT);
const size_t buff_size_aligned_out = const size_t buff_size_aligned_out =
DIVUP(buff_elems_total * sizeof(OType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT);
const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0); const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0);
const size_t in_act_mem = buff_size_aligned_in; const size_t in_act_mem = buff_size_aligned_in;
const size_t in_gate_mem = buff_size_aligned_in; const size_t in_gate_mem = buff_size_aligned_in;
const size_t out_act_mem = buff_size_aligned_out; const size_t out_act_mem = buff_size_aligned_out;
const size_t out_gate_mem = buff_size_aligned_out; const size_t out_gate_mem = buff_size_aligned_out;
// const size_t mbar_mem = ITERATIONS * sizeof(uint64_t); const size_t shmem_size = grad_mem + (in_act_mem + in_gate_mem) +
const size_t shmem_size = ALIGNMENT_SIZE + grad_mem + (in_act_mem + in_gate_mem) + (out_act_mem + out_gate_mem) + TMA_SHMEM_ALIGNMENT;
(out_act_mem + out_gate_mem); // + mbar_mem;
cudaFuncSetAttribute( cudaFuncSetAttribute(
cast_fp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType, OType>, cast_fp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType, OType>,
...@@ -809,16 +972,34 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out ...@@ -809,16 +972,34 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, "Scaling tensor must be allocated."); NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, "Scaling tensor must be allocated.");
} }
// TODO: Make more general ScalingType scaling_type;
const size_t scale_dim_X_rowwise = USE_ROWWISE_SCALING ? 32 : 1; if (USE_ROWWISE_SCALING && (!USE_COLWISE_SCALING)) {
const size_t scale_dim_Y_colwise = USE_COLWISE_SCALING ? 32 : 1; scaling_type = ScalingType::ROWWISE;
} else if ((!USE_ROWWISE_SCALING) && USE_COLWISE_SCALING) {
scaling_type = ScalingType::COLWISE;
} else if (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) {
scaling_type = ScalingType::BIDIMENSIONAL;
}
const size_t rows = gated_input.flat_first_dim(); const size_t rows = gated_input.flat_first_dim();
const size_t cols = gated_input.flat_last_dim() / 2; const size_t cols = gated_input.flat_last_dim() / 2;
const size_t output_cols = (IS_DGATED ? 2 : 1) * cols; const size_t output_cols = (IS_DGATED ? 2 : 1) * cols;
const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); constexpr size_t BUFF_DIM_Y = mxfp8_kernel::BUFF_DIM_Y;
const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); constexpr size_t BUFF_DIM_X = mxfp8_kernel::BUFF_DIM_X;
constexpr size_t BUFFS_NUM = mxfp8_kernel::BUFFS_NUM;
const size_t blocks_Y = DIVUP(rows, mxfp8_kernel::CHUNK_DIM_Y);
const size_t blocks_X = DIVUP(cols, mxfp8_kernel::CHUNK_DIM_X);
constexpr size_t THREADS_PER_CHUNK_COLWISE = mxfp8_kernel::THREADS_PER_CHUNK_COLWISE;
constexpr size_t THREADS_PER_CHUNK_NON_COLWISE = mxfp8_kernel::THREADS_PER_CHUNK_NON_COLWISE;
const size_t THREADS_PER_CHUNK = (scaling_type == ScalingType::COLWISE)
? THREADS_PER_CHUNK_COLWISE
: THREADS_PER_CHUNK_NON_COLWISE;
const dim3 grid(blocks_X, blocks_Y);
const dim3 block_size(THREADS_PER_CHUNK);
size_t scale_stride_rowwise = USE_ROWWISE_SCALING ? output->scale_inv.shape[1] : 1; size_t scale_stride_rowwise = USE_ROWWISE_SCALING ? output->scale_inv.shape[1] : 1;
size_t scale_stride_colwise = USE_COLWISE_SCALING ? output->columnwise_scale_inv.shape[1] : 1; size_t scale_stride_colwise = USE_COLWISE_SCALING ? output->columnwise_scale_inv.shape[1] : 1;
...@@ -828,14 +1009,7 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out ...@@ -828,14 +1009,7 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
e8m0_t *const scales_colwise_ptr = e8m0_t *const scales_colwise_ptr =
USE_COLWISE_SCALING ? reinterpret_cast<e8m0_t *>(output->columnwise_scale_inv.dptr) : nullptr; USE_COLWISE_SCALING ? reinterpret_cast<e8m0_t *>(output->columnwise_scale_inv.dptr) : nullptr;
const dim3 block_dim(THREADS_PER_CHUNK); TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
const dim3 grid_dim(blocks_X, blocks_Y);
TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH(
scale_dim_Y_colwise, SCALE_DIM_Y,
TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH(
scale_dim_X_rowwise, SCALE_DIM_X,
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
gated_input.dtype(), IType, gated_input.dtype(), IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
output->dtype(), OType, output->dtype(), OType,
...@@ -848,42 +1022,45 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out ...@@ -848,42 +1022,45 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
alignas(64) CUtensorMap tensor_map_output_act_colwise{}; alignas(64) CUtensorMap tensor_map_output_act_colwise{};
alignas(64) CUtensorMap tensor_map_output_gate_colwise{}; alignas(64) CUtensorMap tensor_map_output_gate_colwise{};
constexpr size_t input_type_bit_size = TypeInfo<IType>::size;
constexpr size_t output_type_bit_size = TypeInfo<OType>::size;
if constexpr (IS_DGATED) { if constexpr (IS_DGATED) {
create_2D_tensor_map(tensor_map_grad, grad.data, rows, cols, SHMEM_DIM_Y, create_2D_tensor_map(tensor_map_grad, grad.data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X,
SHMEM_DIM_X, cols, 0, typeToNumBits(gated_input.dtype())); cols, 0, input_type_bit_size);
} }
const uint32_t tensor_stride_elems = output_cols; const uint32_t tensor_stride_elems = output_cols;
create_2D_tensor_map(tensor_map_input_act, gated_input.data, rows, cols, create_2D_tensor_map(tensor_map_input_act, gated_input.data, rows, cols, BUFF_DIM_Y,
SHMEM_DIM_Y, SHMEM_DIM_X, cols * 2, 0, BUFF_DIM_X, cols * 2, 0, input_type_bit_size);
typeToNumBits(gated_input.dtype())); create_2D_tensor_map(tensor_map_input_gate, gated_input.data, rows, cols, BUFF_DIM_Y,
create_2D_tensor_map(tensor_map_input_gate, gated_input.data, rows, cols, BUFF_DIM_X, cols * 2, cols, input_type_bit_size);
SHMEM_DIM_Y, SHMEM_DIM_X, cols * 2, cols,
typeToNumBits(gated_input.dtype()));
if (USE_ROWWISE_SCALING) { if (USE_ROWWISE_SCALING) {
create_2D_tensor_map(tensor_map_output_act_rowwise, output->data, rows, cols, create_2D_tensor_map(tensor_map_output_act_rowwise, output->data, rows, cols,
SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems, 0, BUFF_DIM_Y, BUFF_DIM_X, tensor_stride_elems, 0,
typeToNumBits(output->dtype())); output_type_bit_size);
create_2D_tensor_map(tensor_map_output_gate_rowwise, output->data, rows, cols, create_2D_tensor_map(tensor_map_output_gate_rowwise, output->data, rows, cols,
SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems, cols, BUFF_DIM_Y, BUFF_DIM_X, tensor_stride_elems, cols,
typeToNumBits(output->dtype())); output_type_bit_size);
} }
if (USE_COLWISE_SCALING) { if (USE_COLWISE_SCALING) {
create_2D_tensor_map(tensor_map_output_act_colwise, output->columnwise_data, create_2D_tensor_map(tensor_map_output_act_colwise, output->columnwise_data, rows, cols,
rows, cols, SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems, BUFF_DIM_Y, BUFF_DIM_X, tensor_stride_elems, 0,
0, typeToNumBits(output->dtype())); output_type_bit_size);
create_2D_tensor_map(tensor_map_output_gate_colwise, output->columnwise_data, create_2D_tensor_map(tensor_map_output_gate_colwise, output->columnwise_data, rows,
rows, cols, SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems, cols, BUFF_DIM_Y, BUFF_DIM_X, tensor_stride_elems, cols,
cols, typeToNumBits(output->dtype())); output_type_bit_size);
} }
const size_t buff_elems_total = BUFFERS_NUM * SHMEM_DIM_Y * SHMEM_DIM_X; const size_t buff_elems_total = BUFFS_NUM * BUFF_DIM_Y * BUFF_DIM_X;
const size_t input_buff_size = (buff_elems_total * input_type_bit_size) / 8;
const size_t output_buff_size = (buff_elems_total * output_type_bit_size) / 8;
const size_t buff_size_aligned_in = const size_t buff_size_aligned_in =
DIVUP(buff_elems_total * sizeof(IType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; DIVUP_TO_MULTIPLE(input_buff_size, TMA_SHMEM_ALIGNMENT);
const size_t buff_size_aligned_out = const size_t buff_size_aligned_out =
DIVUP(buff_elems_total * sizeof(OType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; DIVUP_TO_MULTIPLE(output_buff_size, TMA_SHMEM_ALIGNMENT);
const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0); const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0);
const size_t in_act_mem = buff_size_aligned_in; const size_t in_act_mem = buff_size_aligned_in;
...@@ -891,30 +1068,62 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out ...@@ -891,30 +1068,62 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
const size_t in_mem = grad_mem + in_act_mem + in_gate_mem; const size_t in_mem = grad_mem + in_act_mem + in_gate_mem;
const size_t out_act_mem = buff_size_aligned_out; const size_t out_act_mem = buff_size_aligned_out;
const size_t out_gate_mem = buff_size_aligned_out; const size_t out_gate_mem = (IS_DGATED ? buff_size_aligned_out : 0);
size_t out_mem = out_act_mem + out_gate_mem; size_t out_mem = out_act_mem + out_gate_mem;
if (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) { out_mem *= 2; } if (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) { out_mem *= 2; }
// const size_t mbar_mem = ITERATIONS * sizeof(uint64_t); const size_t shmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT;
// const size_t shmem_size = ALIGNMENT_SIZE + in_mem + out_mem + mbar_mem;
switch (scaling_type) {
case ScalingType::ROWWISE:
cudaFuncSetAttribute(
mxfp8_kernel::cast_mxfp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType,
OType, true, false,
THREADS_PER_CHUNK_NON_COLWISE>,
cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size);
const size_t shmem_size = ALIGNMENT_SIZE + in_mem + out_mem; mxfp8_kernel::cast_mxfp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType, OType,
true, false, THREADS_PER_CHUNK_NON_COLWISE>
<<<grid, block_size, shmem_size, stream>>>(
tensor_map_grad, tensor_map_input_act, tensor_map_input_gate,
tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise,
tensor_map_output_act_colwise, tensor_map_output_gate_colwise,
scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise);
break;
case ScalingType::COLWISE:
cudaFuncSetAttribute(
mxfp8_kernel::cast_mxfp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType,
OType, false, true,
THREADS_PER_CHUNK_COLWISE>,
cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size);
mxfp8_kernel::cast_mxfp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType, OType,
false, true, THREADS_PER_CHUNK_COLWISE>
<<<grid, block_size, shmem_size, stream>>>(
tensor_map_grad, tensor_map_input_act, tensor_map_input_gate,
tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise,
tensor_map_output_act_colwise, tensor_map_output_gate_colwise,
scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise);
break;
case ScalingType::BIDIMENSIONAL:
cudaFuncSetAttribute( cudaFuncSetAttribute(
cast_mxfp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType, OType, mxfp8_kernel::cast_mxfp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType,
SCALE_DIM_Y, SCALE_DIM_X>, OType, true, true,
THREADS_PER_CHUNK_NON_COLWISE>,
cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size); cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size);
cast_mxfp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType, OType, mxfp8_kernel::cast_mxfp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType, OType,
SCALE_DIM_Y, SCALE_DIM_X> true, true, THREADS_PER_CHUNK_NON_COLWISE>
<<<grid_dim, block_dim, shmem_size, stream>>>( <<<grid, block_size, shmem_size, stream>>>(
tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, tensor_map_grad, tensor_map_input_act, tensor_map_input_gate,
tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise,
tensor_map_output_act_colwise, tensor_map_output_gate_colwise, tensor_map_output_act_colwise, tensor_map_output_gate_colwise,
scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise);); // NOLINT(*) scale_stride_colwise);
); // NOLINT(*) break;
); // NOLINT(*) }); // NOLINT(*)
); // NOLINT(*) ); // NOLINT(*)
} }
......
...@@ -28,36 +28,25 @@ ...@@ -28,36 +28,25 @@
namespace transformer_engine { namespace transformer_engine {
constexpr size_t MXFP8_CHUNK_DIM_Y = 64; namespace mxfp8_kernel {
constexpr size_t MXFP8_CHUNK_DIM_X = 64;
constexpr size_t MXFP8_CHUNKS_PER_BLOCK_Y = 1; constexpr size_t SCALE_DIM_Y = 32;
constexpr size_t MXFP8_CHUNKS_PER_BLOCK_X = 1; constexpr size_t SCALE_DIM_X = 32;
constexpr size_t MXFP8_CHUNKS_PER_BLOCK = MXFP8_CHUNKS_PER_BLOCK_Y * MXFP8_CHUNKS_PER_BLOCK_X;
constexpr size_t MXFP8_THREADS_PER_CHUNK = 64; constexpr size_t BUFFS_NUM = 2;
constexpr size_t MXFP8_BUFFERS_NUM = 2; constexpr size_t PACK_SIZE = 4;
constexpr size_t MXFP8_PREFETCH_BUFFERS_NUM = 1; constexpr size_t WAVES = SCALE_DIM_X / PACK_SIZE;
static_assert(MXFP8_PREFETCH_BUFFERS_NUM < MXFP8_BUFFERS_NUM);
// Number of 1-byte elements that span 32 banks (4-byte each) of shared memory
constexpr size_t ELEMS_PER_THREAD = 16; constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4) / 1; // 128
constexpr size_t MXFP8_BUFFER_DIM_Y = 32; // only 32 is supported
constexpr size_t MXFP8_BUFFER_DIM_X = MXFP8_CHUNK_DIM_X; // 64 // Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory
constexpr size_t MXFP8_SHMEM_DIM_Y = MXFP8_BUFFER_DIM_Y; // 32 constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM_X; // 4 = 128 / 32
constexpr size_t MXFP8_SHMEM_DIM_X = MXFP8_BUFFER_DIM_X; // 64
constexpr size_t THREADS_PER_CHUNK_X_ROWWISE =
MXFP8_CHUNK_DIM_X / ELEMS_PER_THREAD; // 4 = 64 / 16
constexpr size_t THREADS_PER_CHUNK_Y_ROWWISE =
MXFP8_THREADS_PER_CHUNK / THREADS_PER_CHUNK_X_ROWWISE; // 16 = 64 / 4
constexpr size_t THREADS_PER_CHUNK_X_COLWISE = MXFP8_CHUNK_DIM_X; // 64
constexpr size_t MXFP8_BUFF_STAGES_NUM =
MXFP8_BUFFER_DIM_Y / THREADS_PER_CHUNK_Y_ROWWISE; // 2 = 32 / 16
constexpr size_t MXFP8_ITERATIONS = MXFP8_CHUNK_DIM_Y / MXFP8_BUFFER_DIM_Y; // 2 = 64 / 32
static_assert(MXFP8_ITERATIONS >= MXFP8_PREFETCH_BUFFERS_NUM);
template <bool IS_DBIAS, bool IS_DACT, bool IS_ACT, typename ParamOP, template <bool IS_DBIAS, bool IS_DACT, bool IS_ACT, typename ParamOP,
float (*OP)(float, const ParamOP &), typename IType, typename OType, size_t SCALE_DIM_Y, float (*OP)(float, const ParamOP &), typename IType, typename OType, bool ROWWISE_SCALING,
size_t SCALE_DIM_X> bool COLWISE_SCALING, size_t CHUNK_DIM_Y, size_t CHUNK_DIM_X, size_t THREADS_PER_CHUNK>
__global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) __global__ void __launch_bounds__(THREADS_PER_CHUNK)
cast_mxfp8_2D_kernel(const __grid_constant__ CUtensorMap tensor_map_input, cast_mxfp8_2D_kernel(const __grid_constant__ CUtensorMap tensor_map_input,
const __grid_constant__ CUtensorMap tensor_map_act_input, const __grid_constant__ CUtensorMap tensor_map_act_input,
const __grid_constant__ CUtensorMap tensor_map_output_rowwise, const __grid_constant__ CUtensorMap tensor_map_output_rowwise,
...@@ -67,201 +56,198 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) ...@@ -67,201 +56,198 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK)
const size_t rows, const size_t cols, const size_t scale_stride_rowwise, const size_t rows, const size_t cols, const size_t scale_stride_rowwise,
const size_t scale_stride_colwise) { const size_t scale_stride_colwise) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
if constexpr (!IS_DBIAS && !IS_DACT && !IS_ACT) { constexpr bool COMPUTE_ACTIVATIONS = IS_DACT || IS_ACT;
if (noop != nullptr && noop[0] == 1.0f) return; constexpr bool NO_ACTIVATIONS = !COMPUTE_ACTIVATIONS;
}
constexpr bool USE_ROWWISE_SCALING = SCALE_DIM_X > 1;
constexpr bool USE_COLWISE_SCALING = SCALE_DIM_Y > 1;
constexpr bool COMPUTE_DBIAS_IN_ROWWISE_SECTION = !USE_COLWISE_SCALING;
constexpr size_t SCALES_ROWWISE_PER_CHUNK_Y = MXFP8_CHUNK_DIM_Y; // 2 = 64 / 32
constexpr size_t SCALES_ROWWISE_PER_CHUNK_X = MXFP8_CHUNK_DIM_X / SCALE_DIM_X; // 64 = 64 / 1
constexpr size_t SCALES_ROWWISE_PER_BLOCK_Y =
SCALES_ROWWISE_PER_CHUNK_Y * MXFP8_CHUNKS_PER_BLOCK_Y; // 2 = 2 * 1
constexpr size_t SCALES_ROWWISE_PER_BLOCK_X =
SCALES_ROWWISE_PER_CHUNK_X * MXFP8_CHUNKS_PER_BLOCK_X; // 64 = 64 * 1
constexpr size_t SCALES_COLWISE_PER_CHUNK_Y = MXFP8_CHUNK_DIM_Y / SCALE_DIM_Y; // 2 = 64 / 32
constexpr size_t SCALES_COLWISE_PER_CHUNK_X = MXFP8_CHUNK_DIM_X; // 64 = 64 / 1
constexpr size_t SCALES_COLWISE_PER_BLOCK_Y =
SCALES_COLWISE_PER_CHUNK_Y * MXFP8_CHUNKS_PER_BLOCK_Y; // 2 = 2 * 1
constexpr size_t SCALES_COLWISE_PER_BLOCK_X =
SCALES_COLWISE_PER_CHUNK_X * MXFP8_CHUNKS_PER_BLOCK_X; // 64 = 64 * 1
constexpr size_t THREADS_PER_SCALE_X_ROWWISE =
DIVUP(SCALE_DIM_X, ELEMS_PER_THREAD); // 2 = 32 / 16
constexpr size_t SUBWARP_WIDTH = THREADS_PER_SCALE_X_ROWWISE; // 2
const int block_offset_Y = blockIdx.y * MXFP8_CHUNKS_PER_BLOCK_Y * MXFP8_CHUNK_DIM_Y;
const int block_offset_X = blockIdx.x * MXFP8_CHUNKS_PER_BLOCK_X * MXFP8_CHUNK_DIM_X;
const int scales_rowwise_block_offset_Y = blockIdx.y * SCALES_ROWWISE_PER_BLOCK_Y;
const int scales_rowwise_block_offset_X = blockIdx.x * SCALES_ROWWISE_PER_BLOCK_X;
const int scales_colwise_block_offset_Y = blockIdx.y * SCALES_COLWISE_PER_BLOCK_Y;
const int scales_colwise_block_offset_X = blockIdx.x * SCALES_COLWISE_PER_BLOCK_X;
const int tid_rowwise_Y = threadIdx.x / THREADS_PER_CHUNK_X_ROWWISE;
const int tid_rowwise_X = threadIdx.x % THREADS_PER_CHUNK_X_ROWWISE;
// const int tid_colwise_Y = threadIdx.x / THREADS_PER_CHUNK_X_COLWISE;
const int tid_colwise_X = threadIdx.x % THREADS_PER_CHUNK_X_COLWISE;
const int thread_offset_Y = tid_rowwise_Y;
const int thread_offset_X_rowwise = tid_rowwise_X * ELEMS_PER_THREAD;
// const int thread_offset_X_colwise = tid_colwise_X;
const int dbias_rowwise_offset_Y = blockIdx.y * MXFP8_CHUNKS_PER_BLOCK_Y + tid_rowwise_Y;
const int dbias_rowwise_block_offset_X =
blockIdx.x * MXFP8_CHUNKS_PER_BLOCK_X * MXFP8_CHUNK_DIM_X + thread_offset_X_rowwise;
const int dbias_colwise_offset_Y = blockIdx.y;
const int dbias_colwise_block_offset_X =
blockIdx.x * MXFP8_CHUNKS_PER_BLOCK_X * MXFP8_CHUNK_DIM_X + tid_colwise_X;
const int dbias_stride = cols;
Vec<float, ELEMS_PER_THREAD> partial_dbias_rowwise[MXFP8_CHUNKS_PER_BLOCK_X]; using IType2 = typename ptx::FPx2<IType>;
float partial_dbias_colwise[MXFP8_CHUNKS_PER_BLOCK_X]; using OType2 = typename ptx::FPx2<OType>;
if constexpr (IS_DBIAS) {
if constexpr (COMPUTE_DBIAS_IN_ROWWISE_SECTION) { if constexpr (NO_ACTIVATIONS) {
#pragma unroll if (noop != nullptr && noop[0] == 1.0f) {
for (int i = 0; i < MXFP8_CHUNKS_PER_BLOCK_X; ++i) { return;
partial_dbias_rowwise[i].clear();
}
} else {
#pragma unroll
for (int i = 0; i < MXFP8_CHUNKS_PER_BLOCK_X; ++i) {
partial_dbias_colwise[i] = 0;
}
} }
} }
constexpr size_t THREADS_X = CHUNK_DIM_X / SCALE_DIM_X;
constexpr size_t THREADS_Y = THREADS_PER_CHUNK / THREADS_X;
// The destination shared memory buffer of a bulk tensor operation should be 128 e8m0_t aligned constexpr size_t BUFF_DIM_Y = THREADS_Y;
__shared__ alignas(128) IType in_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; constexpr size_t BUFF_DIM_X = CHUNK_DIM_X;
__shared__ alignas(128) IType act_in_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; constexpr size_t BUFF_DIM = BUFF_DIM_Y * BUFF_DIM_X;
__shared__ alignas(128) static_assert(BUFF_DIM_Y == 32);
OType out_rowwise_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X];
__shared__ alignas(128)
OType out_colwise_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X];
constexpr int shmem_buff_size = sizeof(in_sh) / MXFP8_BUFFERS_NUM; constexpr size_t STAGES = CHUNK_DIM_Y / BUFF_DIM_Y;
static_assert(STAGES >= 1);
const bool is_master_thread = (threadIdx.x == 0); constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS && ROWWISE_SCALING && COLWISE_SCALING;
float block_amax = 0; const int block_offset_Y = blockIdx.y * CHUNK_DIM_Y;
const int block_offset_X = blockIdx.x * CHUNK_DIM_X;
const int scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y;
const int scales_block_offset_X_rowwise = blockIdx.x * CHUNK_DIM_X / SCALE_DIM_X;
const int scales_block_offset_Y_colwise = blockIdx.y * CHUNK_DIM_Y / SCALE_DIM_Y;
const int scales_block_offset_X_colwise = blockIdx.x * CHUNK_DIM_X;
// Initialize shared memory barrier with the number of threads participating in the barrier. const int tid_Y_rowwise = threadIdx.x / THREADS_X;
#pragma nv_diag_suppress static_var_with_dynamic_init const int tid_X_rowwise = threadIdx.x % THREADS_X;
__shared__ alignas(8) uint64_t mbar[MXFP8_ITERATIONS]; const int tid_Y_colwise = 0;
const int tid_X_colwise = threadIdx.x;
initialize_barriers<MXFP8_ITERATIONS, MXFP8_THREADS_PER_CHUNK>(mbar, is_master_thread); const int thread_offset_Y_rowwise = tid_Y_rowwise;
const int thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM_X;
const int thread_offset_Y_colwise = tid_Y_colwise;
const int thread_offset_X_colwise = tid_X_colwise;
int parity = 0; const int row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise;
#pragma unroll const int row_base_colwise = block_offset_Y + thread_offset_Y_colwise;
for (int chunk = 0; chunk < MXFP8_CHUNKS_PER_BLOCK; ++chunk) { const int col_base_colwise = block_offset_X + thread_offset_X_colwise;
const int chunk_Y = chunk / MXFP8_CHUNKS_PER_BLOCK_X;
const int chunk_X = chunk % MXFP8_CHUNKS_PER_BLOCK_X;
const int chunk_offset_Y = block_offset_Y + chunk_Y * MXFP8_CHUNK_DIM_Y; const bool col_out_of_bounds_colwise = (col_base_colwise >= cols);
const int chunk_offset_X = block_offset_X + chunk_X * MXFP8_CHUNK_DIM_X;
const int dbias_rowwise_offset_X = dbias_rowwise_block_offset_X + chunk_X * MXFP8_CHUNK_DIM_X; const int scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise;
const int dbias_colwise_offset_X = dbias_colwise_block_offset_X + chunk_X * MXFP8_CHUNK_DIM_X; const int scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise;
const int scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise;
const int scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise;
const int scales_rowwise_chunk_offset_Y = // helps resolving bank conflicts in shmem
scales_rowwise_block_offset_Y + chunk_Y * SCALES_ROWWISE_PER_CHUNK_Y; const int thread_lane = threadIdx.x % THREADS_PER_WARP;
const int scales_rowwise_chunk_offset_X = const int bank_group = thread_lane / THREADS_PER_BANK;
scales_rowwise_block_offset_X + chunk_X * SCALES_ROWWISE_PER_CHUNK_X;
const int scales_colwise_chunk_offset_Y =
scales_colwise_block_offset_Y + chunk_Y * SCALES_COLWISE_PER_CHUNK_Y;
const int scales_colwise_chunk_offset_X =
scales_colwise_block_offset_X + chunk_X * SCALES_COLWISE_PER_CHUNK_X;
#pragma unroll constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X;
for (int prefetch_buff = 0; prefetch_buff < MXFP8_PREFETCH_BUFFERS_NUM; ++prefetch_buff) { constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems;
const int chunk_stage_offset_Y = chunk_offset_Y + prefetch_buff * MXFP8_BUFFER_DIM_Y; constexpr size_t buff_size_aligned_in =
const int chunk_stage_offset_X = chunk_offset_X; DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT);
if constexpr (IS_DACT) { constexpr size_t buff_size_aligned_out =
copy_2d_to_sharedx2(&in_sh[prefetch_buff], &tensor_map_input, chunk_stage_offset_X, DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT);
chunk_stage_offset_Y, &act_in_sh[prefetch_buff], &tensor_map_act_input,
chunk_stage_offset_X, chunk_stage_offset_Y, shmem_buff_size, constexpr size_t elt_input_mem = buff_size_aligned_in;
&mbar[prefetch_buff], is_master_thread); constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0);
} else { constexpr size_t in_mem = elt_input_mem + act_input_mem;
copy_2d_to_shared(&in_sh[prefetch_buff], &tensor_map_input, chunk_stage_offset_X,
chunk_stage_offset_Y, shmem_buff_size, &mbar[prefetch_buff], constexpr size_t out_mem_rowwise = (ROWWISE_SCALING ? buff_size_aligned_out : 0);
is_master_thread);
}
}
extern __shared__ char dynamic_shmem[];
uintptr_t base_shmem_ptr = reinterpret_cast<uintptr_t>(dynamic_shmem);
// Manually align dynamic SHMEM per TMA requirements using padding
// __align__(128) Does not guarantee the pointer to be aligned!
uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) &
~(static_cast<uintptr_t>(TMA_SHMEM_ALIGNMENT - 1));
// The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned
IType *in_sh = reinterpret_cast<IType *>(dshmem);
IType *act_in_sh = reinterpret_cast<IType *>(dshmem + elt_input_mem);
OType *out_rowwise_sh = reinterpret_cast<OType *>(dshmem + in_mem);
OType *out_colwise_sh = reinterpret_cast<OType *>(dshmem + in_mem + out_mem_rowwise);
IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer
constexpr int shmem_buff_size = buff_size_aligned_in / BUFFS_NUM;
const bool is_master_thread = (threadIdx.x == 0);
float partial_dbias_colwise = 0.0f;
float thread_dbias_rowwise[SCALE_DIM_X];
if constexpr (IS_DBIAS) {
#pragma unroll #pragma unroll
for (int iter = 0; iter < MXFP8_ITERATIONS; ++iter) { for (int j = 0; j < SCALE_DIM_X; ++j) {
const int buff = iter % MXFP8_BUFFERS_NUM; thread_dbias_rowwise[j] = 0.0f;
const int next_iter = iter + MXFP8_PREFETCH_BUFFERS_NUM;
const size_t row_base = chunk_offset_Y + iter * MXFP8_BUFFER_DIM_Y;
if (next_iter < MXFP8_ITERATIONS) {
const int next_buff = next_iter % MXFP8_BUFFERS_NUM;
const int chunk_it_offset_y = chunk_offset_Y + next_iter * MXFP8_BUFFER_DIM_Y;
const int chunk_it_offset_x = chunk_offset_X;
if constexpr (IS_DACT) {
copy_2d_to_sharedx2(&in_sh[next_buff], &tensor_map_input, chunk_it_offset_x,
chunk_it_offset_y, &act_in_sh[next_buff], &tensor_map_act_input,
chunk_it_offset_x, chunk_it_offset_y, shmem_buff_size,
&mbar[next_iter], is_master_thread);
} else {
copy_2d_to_shared(&in_sh[next_buff], &tensor_map_input, chunk_it_offset_x,
chunk_it_offset_y, shmem_buff_size, &mbar[next_iter], is_master_thread);
} }
} }
ptx::fence_proxy_async_shared_cta(); float block_amax = 0.0f;
// Wait for the data to have arrived // Initialize shared memory barrier with the number of threads participating in the barrier.
ptx::mbarrier_wait_parity(&mbar[iter], parity); #pragma nv_diag_suppress static_var_with_dynamic_init
__shared__ alignas(8) uint64_t mbar[STAGES];
initialize_barriers<STAGES, THREADS_PER_CHUNK>(mbar, is_master_thread);
if constexpr (USE_ROWWISE_SCALING) { int parity = 0;
Vec<IType, ELEMS_PER_THREAD> in;
Vec<IType, ELEMS_PER_THREAD> act_in;
Vec<OType, ELEMS_PER_THREAD> out_c;
const int iteration_scale_rowwise_offset_Y = if constexpr (IS_DACT) {
scales_rowwise_chunk_offset_Y + iter * MXFP8_BUFFER_DIM_Y; copy_2d_to_sharedx2(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, &act_in_sh[0],
&tensor_map_act_input, block_offset_X, block_offset_Y, shmem_buff_size,
&mbar[0], is_master_thread);
} else {
copy_2d_to_shared(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, shmem_buff_size,
&mbar[0], is_master_thread);
}
#pragma unroll #pragma unroll
for (int stage = 0; stage < MXFP8_BUFF_STAGES_NUM; ++stage) { for (int stage = 0; stage < STAGES; ++stage) {
const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y_ROWWISE; const int buff = stage % BUFFS_NUM;
const int shmem_offset_y = thread_offset_Y + stage_offset_Y; const int next_stage = stage + 1;
const int shmem_offset_x = thread_offset_X_rowwise; const int stage_offset_Y = stage * BUFF_DIM_Y;
const size_t row = row_base + shmem_offset_y; if (next_stage < STAGES) {
const bool row_out_of_bounds = (row >= rows); // Wait for TMA transfer to have finished reading shared memory.
// I.e. the buffer is ready to be written to
ptx::cp_async_bulk_wait_group_read<1>();
in.load_from(&in_sh[buff][shmem_offset_y][shmem_offset_x]); const int next_buff = next_stage % BUFFS_NUM;
const int next_stage_offset_Y = next_stage * BUFF_DIM_Y;
const int global_offset_Y = block_offset_Y + next_stage_offset_Y;
const int global_offset_X = block_offset_X;
const int next_buff_offset = next_buff * BUFF_DIM;
if constexpr (IS_DACT) { if constexpr (IS_DACT) {
act_in.load_from(&act_in_sh[buff][shmem_offset_y][shmem_offset_x]); copy_2d_to_sharedx2(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X,
global_offset_Y, &act_in_sh[next_buff_offset], &tensor_map_act_input,
global_offset_X, global_offset_Y, shmem_buff_size, &mbar[next_stage],
is_master_thread);
} else {
copy_2d_to_shared(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X,
global_offset_Y, shmem_buff_size, &mbar[next_stage], is_master_thread);
}
} }
float thread_amax = 0; ptx::fence_proxy_async_shared_cta();
float in_compute[ELEMS_PER_THREAD];
// Wait for the data to have arrived
ptx::mbarrier_wait_parity(&mbar[stage], parity);
float thread_amax = 0.0f;
if constexpr (COLWISE_SCALING) {
const int shmem_offset_base_colwise = buff * BUFF_DIM + tid_X_colwise;
thread_amax = 0.0f;
float in_compute_colwise[BUFF_DIM_Y];
IType in_colwise_IType[BUFF_DIM_Y];
// 1. Read/Compute elements. Find MXFP8-block AMAX
if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v<IType, float>)) {
IType thread_amax_f16 = static_cast<IType>(0.0f);
#pragma unroll
for (int i = 0; i < BUFF_DIM_Y; ++i) {
const int shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X;
in_colwise_IType[i] = in_sh[shmem_offset_colwise];
thread_amax_f16 = __hmax(thread_amax_f16, __habs(in_colwise_IType[i]));
}
thread_amax = static_cast<float>(thread_amax_f16);
} else {
#pragma unroll #pragma unroll
for (int j = 0; j < ELEMS_PER_THREAD; ++j) { for (int i = 0; i < BUFF_DIM_Y; ++i) {
const bool col_out_of_bounds = (dbias_rowwise_offset_X + j >= cols); const int shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X;
const bool out_of_bounds = (col_out_of_bounds || row_out_of_bounds);
float elt = static_cast<float>(in.data.elt[j]); float elt = static_cast<float>(in_sh[shmem_offset_colwise]);
if constexpr (IS_ACT) { if constexpr (IS_ACT) {
elt = OP(elt, {}); elt = OP(elt, {});
} }
if constexpr (IS_DACT) { if constexpr (IS_DACT) {
float act_in_elt = static_cast<float>(act_in.data.elt[j]); float act_in_elt = static_cast<float>(act_in_sh[shmem_offset_colwise]);
elt *= OP(act_in_elt, {}); elt *= OP(act_in_elt, {});
} }
if constexpr (IS_DBIAS && COMPUTE_DBIAS_IN_ROWWISE_SECTION) { if constexpr (IS_DBIAS) {
if (!out_of_bounds) { partial_dbias_colwise += elt;
partial_dbias_rowwise[chunk_X].data.elt[j] += elt; }
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
if constexpr (!std::is_same_v<IType, float>) {
elt = static_cast<float>(static_cast<IType>(elt));
} }
// Cache computed activations to avoid computing them again in the 2nd pass along another dimension
if constexpr (IS_CACHED_ACT_OP) {
cached_act_sh[shmem_offset_colwise] = static_cast<IType>(elt);
} }
in_compute[j] = elt;
if constexpr (IS_ACT || IS_DACT) { if constexpr (COMPUTE_ACTIVATIONS) {
const bool row_out_of_bounds_colwise = (row_base_colwise + stage_offset_Y + i >= rows);
const bool out_of_bounds = (col_out_of_bounds_colwise || row_out_of_bounds_colwise);
if (!out_of_bounds) { if (!out_of_bounds) {
thread_amax = fmaxf(thread_amax, fabsf(elt)); thread_amax = fmaxf(thread_amax, fabsf(elt));
} }
...@@ -269,92 +255,196 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) ...@@ -269,92 +255,196 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK)
// If no activation, elt is 0 so we can safely do this // If no activation, elt is 0 so we can safely do this
thread_amax = fmaxf(thread_amax, fabsf(elt)); thread_amax = fmaxf(thread_amax, fabsf(elt));
} }
in_compute_colwise[i] = elt;
}
} }
__builtin_assume(block_amax >= 0); // 2. Compute E8M0 scaling factor
__builtin_assume(thread_amax >= 0);
block_amax = fmaxf(block_amax, thread_amax);
const float subwarp_amax = subwarp_reduce_max_broadcast<SUBWARP_WIDTH>(thread_amax);
const e8m0_t biased_exponent = const e8m0_t biased_exponent =
float_to_e8m0(subwarp_amax * Quantized_Limits<OType>::max_norm_rcp); ptx::float_to_e8m0(thread_amax * Quantized_Limits<OType>::max_norm_rcp);
// Only single thread writes the computed scaling factor const int global_scales_offset_Y = scales_offset_Y_colwise + stage;
if (tid_rowwise_X % THREADS_PER_SCALE_X_ROWWISE == 0) { const int global_scales_offset_X = scales_offset_X_colwise;
const int global_scales_offset_Y = const int scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X;
iteration_scale_rowwise_offset_Y + stage_offset_Y + tid_rowwise_Y; scales_colwise[scale_idx] = biased_exponent;
const int global_scales_offset_X =
scales_rowwise_chunk_offset_X + tid_rowwise_X / THREADS_PER_SCALE_X_ROWWISE;
const int scale_idx =
global_scales_offset_Y * scale_stride_rowwise + global_scales_offset_X;
scales_rowwise[scale_idx] = biased_exponent;
}
const float block_scale_inverse = exp2f_rcp(biased_exponent); const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent);
const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse};
// 3. Scale elements
#pragma unroll #pragma unroll
for (int j = 0; j < ELEMS_PER_THREAD; ++j) { for (int i = 0; i < SCALE_DIM_Y; ++i) {
out_c.data.elt[j] = static_cast<OType>(in_compute[j] * block_scale_inverse); float in;
if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v<IType, float>)) {
in = static_cast<float>(in_colwise_IType[i]);
} else {
in = in_compute_colwise[i];
} }
out_c.store_to(&out_rowwise_sh[buff][shmem_offset_y][shmem_offset_x]); const float scaled_out = in * block_scale_inverse;
const int shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_DIM_X;
out_colwise_sh[shmem_offset_elt] = static_cast<OType>(scaled_out);
} }
} }
if constexpr (USE_COLWISE_SCALING) { if constexpr (ROWWISE_SCALING) {
const bool col_out_of_bounds = (dbias_colwise_offset_X >= cols); const int shmem_offset_base_rowwise = buff * BUFF_DIM + thread_offset_Y_rowwise * BUFF_DIM_X;
float in_compute[SCALE_DIM_Y]; thread_amax = 0.0f;
float in_compute_rowwise[SCALE_DIM_X];
Vec<IType, PACK_SIZE> in_cached[WAVES];
float amax = 0; // used as an IType container for BF16/FP16 --> MXFP8 CAST ONLY
Vec<IType2, PACK_SIZE / 2> in_IType[WAVES];
// 1. Read/Compute elements. Find MXFP8-block AMAX
if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v<IType, float>)) {
IType2 thread_amax_2x = {static_cast<IType>(0.0f), static_cast<IType>(0.0f)};
#pragma unroll #pragma unroll
for (int i = 0; i < SCALE_DIM_Y; ++i) { for (int w = 0; w < WAVES; ++w) {
const size_t row = row_base + i; const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const bool row_out_of_bounds = (row >= rows); const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const bool out_of_bounds = (col_out_of_bounds || row_out_of_bounds); const int shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx;
// Load elements
in_IType[w].load_from(&in_sh[shmem_offset_rowwise]);
#pragma unroll
for (int e = 0; e < PACK_SIZE / 2; ++e) {
ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]);
}
}
thread_amax =
static_cast<float>(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y)));
} else if constexpr (IS_CACHED_ACT_OP) {
// ensures that all writes to cache made in the section above are visible to all threads
__syncthreads();
IType2 thread_amax_2x = {static_cast<IType>(0.0f), static_cast<IType>(0.0f)};
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const int shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx;
const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows);
const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols);
const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds);
// Load cached elements
in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]);
// Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements)
// only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries
if (!out_of_bounds) {
if constexpr (std::is_same_v<IType, float>) {
#pragma unroll
for (int e = 0; e < PACK_SIZE; ++e) {
thread_amax = fmaxf(thread_amax, fabsf(in_cached[w].data.elt[e]));
}
} else {
#pragma unroll
for (int e = 0; e < PACK_SIZE; e += 2) {
const IType2 in_cached_2x = {in_cached[w].data.elt[e],
in_cached[w].data.elt[e + 1]};
ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x);
}
}
}
}
if constexpr (!std::is_same_v<IType, float>) {
thread_amax =
static_cast<float>(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y)));
}
} else {
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const int shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx;
float elt = static_cast<float>(in_sh[buff][i][tid_colwise_X]); Vec<IType, PACK_SIZE> in;
Vec<IType, PACK_SIZE> act_in;
in.load_from(&in_sh[shmem_offset_rowwise]);
if constexpr (IS_DACT) {
act_in.load_from(&act_in_sh[shmem_offset_rowwise]);
}
#pragma unroll
for (int e = 0; e < PACK_SIZE; ++e) {
const int j = w * PACK_SIZE + e;
// Compute element
float elt = static_cast<float>(in.data.elt[e]);
if constexpr (IS_ACT) { if constexpr (IS_ACT) {
elt = OP(elt, {}); elt = OP(elt, {});
} }
if constexpr (IS_DACT) { if constexpr (IS_DACT) {
float act_in_elt = static_cast<float>(act_in_sh[buff][i][tid_colwise_X]); float act_in_elt = static_cast<float>(act_in.data.elt[e]);
elt *= OP(act_in_elt, {}); elt *= OP(act_in_elt, {});
} }
if constexpr (IS_DBIAS) {
if (!out_of_bounds) { // If DBIAS was computed in the 1st pass (COLWISE) then no need to compute it again
partial_dbias_colwise[chunk_X] += elt; if constexpr (IS_DBIAS && (!COLWISE_SCALING)) {
thread_dbias_rowwise[j] += elt;
} }
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
if constexpr (!std::is_same_v<IType, float>) {
elt = static_cast<float>(static_cast<IType>(elt));
} }
in_compute[i] = elt; if constexpr (COMPUTE_ACTIVATIONS) {
if constexpr (IS_ACT || IS_DACT) { const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows);
const bool swizzled_col_out_of_bounds =
(block_offset_X + swizzled_thread_idx >= cols);
const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds);
if (!out_of_bounds) { if (!out_of_bounds) {
amax = fmaxf(amax, fabsf(elt)); thread_amax = fmaxf(thread_amax, fabsf(elt));
} }
} else { } else {
// If no activation, elt is 0 so we can safely do this // If no activation, elt is 0 so we can safely do this
amax = fmaxf(amax, fabsf(elt)); thread_amax = fmaxf(thread_amax, fabsf(elt));
}
in_compute_rowwise[j] = elt;
}
} }
} }
__builtin_assume(block_amax >= 0); // 2. Compute E8M0 scaling factor
__builtin_assume(amax >= 0); const e8m0_t biased_exponent =
block_amax = fmaxf(block_amax, amax); ptx::float_to_e8m0(thread_amax * Quantized_Limits<OType>::max_norm_rcp);
const int stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y;
const e8m0_t biased_exponent = float_to_e8m0(amax * Quantized_Limits<OType>::max_norm_rcp); const int stage_scales_offset_X = scales_offset_X_rowwise;
const int scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X;
scales_rowwise[scale_idx] = biased_exponent;
const int global_scales_offset_Y = scales_colwise_chunk_offset_Y + iter; const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent);
const int global_scales_offset_X = scales_colwise_chunk_offset_X + tid_colwise_X; const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse};
const int scale_idx =
global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X;
scales_colwise[scale_idx] = biased_exponent;
const float block_scale_inverse = exp2f_rcp(biased_exponent); // 3. Scale elements
#pragma unroll #pragma unroll
for (int i = 0; i < SCALE_DIM_Y; ++i) { for (int w = 0; w < WAVES; ++w) {
out_colwise_sh[buff][i][tid_colwise_X] = Vec<OType2, PACK_SIZE / 2> out;
static_cast<OType>(in_compute[i] * block_scale_inverse); #pragma unroll
for (int e = 0; e < PACK_SIZE / 2; ++e) {
IType2 in;
OType2 &out_pair = reinterpret_cast<OType2 &>(out.data.elt[e]);
if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v<IType, float>)) {
in = in_IType[w].data.elt[e];
} else if constexpr (IS_CACHED_ACT_OP) {
in.x = in_cached[w].data.elt[2 * e];
in.y = in_cached[w].data.elt[2 * e + 1];
} else {
const int j = w * PACK_SIZE + 2 * e;
in.x = in_compute_rowwise[j];
in.y = in_compute_rowwise[j + 1];
}
ptx::mul_cvt_2x(out_pair, in, block_scale_inverse_2x);
}
const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const int swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise;
const int shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx;
out.store_to(&out_rowwise_sh[shmem_offset_rowwise]);
} }
} }
__builtin_assume(block_amax >= 0);
__builtin_assume(thread_amax >= 0);
block_amax = fmaxf(block_amax, thread_amax);
// Wait for shared memory writes to be visible to TMA engine. // Wait for shared memory writes to be visible to TMA engine.
ptx::fence_proxy_async_shared_cta(); ptx::fence_proxy_async_shared_cta();
__syncthreads(); __syncthreads();
...@@ -362,103 +452,87 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) ...@@ -362,103 +452,87 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK)
// Initiate TMA transfer to copy shared memory to global memory // Initiate TMA transfer to copy shared memory to global memory
if (is_master_thread) { if (is_master_thread) {
const int chunk_it_offset_y = chunk_offset_Y + iter * MXFP8_BUFFER_DIM_Y; const int global_offset_Y = block_offset_Y + stage_offset_Y;
const int chunk_it_offset_x = chunk_offset_X; const int global_offset_X = block_offset_X;
if constexpr (USE_ROWWISE_SCALING) { const int buff_offset = buff * BUFF_DIM;
if constexpr (ROWWISE_SCALING) {
ptx::cp_async_bulk_tensor_2d_shared_to_global( ptx::cp_async_bulk_tensor_2d_shared_to_global(
reinterpret_cast<const uint64_t *>(&tensor_map_output_rowwise), chunk_it_offset_x, reinterpret_cast<const uint64_t *>(&tensor_map_output_rowwise), global_offset_X,
chunk_it_offset_y, reinterpret_cast<uint64_t *>(&out_rowwise_sh[buff])); global_offset_Y, reinterpret_cast<uint64_t *>(&out_rowwise_sh[buff_offset]));
} }
if constexpr (USE_COLWISE_SCALING) { if constexpr (COLWISE_SCALING) {
ptx::cp_async_bulk_tensor_2d_shared_to_global( ptx::cp_async_bulk_tensor_2d_shared_to_global(
reinterpret_cast<const uint64_t *>(&tensor_map_output_colwise), chunk_it_offset_x, reinterpret_cast<const uint64_t *>(&tensor_map_output_colwise), global_offset_X,
chunk_it_offset_y, reinterpret_cast<uint64_t *>(&out_colwise_sh[buff])); global_offset_Y, reinterpret_cast<uint64_t *>(&out_colwise_sh[buff_offset]));
} }
// Create a "bulk async-group" out of the previous bulk copy operation. // Create a "bulk async-group" out of the previous bulk copy operation.
ptx::cp_async_bulk_commit_group(); ptx::cp_async_bulk_commit_group();
// Wait for TMA transfer to have finished reading shared memory.
ptx::cp_async_bulk_wait_group_read<MXFP8_PREFETCH_BUFFERS_NUM>();
} }
} }
ptx::cp_async_bulk_wait_group_read<0>();
__syncthreads();
parity ^= 1; parity ^= 1;
}
if constexpr (IS_DBIAS) { if constexpr (IS_DBIAS) {
if constexpr (COMPUTE_DBIAS_IN_ROWWISE_SECTION) { float thread_partial_dbias = 0.0f;
constexpr size_t CZ = MXFP8_CHUNKS_PER_BLOCK_X; if constexpr (COLWISE_SCALING) {
constexpr size_t Y = THREADS_PER_CHUNK_Y_ROWWISE - 1; thread_partial_dbias = partial_dbias_colwise;
constexpr size_t X = THREADS_PER_CHUNK_X_ROWWISE; } else {
__shared__ float shmem_partial_dbias_rowwise[CZ][Y][X][ELEMS_PER_THREAD]; // Reusing dshmem (in_sh) as dbias buffer [HEIGHT x WIDTH]
// HEIGHT = THREADS_Y
if (tid_rowwise_Y > 0) { // WIDTH = THREADS_X * (SCALE_DIM_X + 1)
#pragma unroll // Added extra 1-element padding per thread_X to reduce bank conflicts
for (int c = 0; c < MXFP8_CHUNKS_PER_BLOCK_X; ++c) { float *partial_dbias_rowwise = reinterpret_cast<float *>(dshmem);
partial_dbias_rowwise[c].store_to(
&shmem_partial_dbias_rowwise[c][tid_rowwise_Y - 1][tid_rowwise_X]);
}
}
__syncthreads();
if (tid_rowwise_Y == 0) {
#pragma unroll
for (int c = 0; c < MXFP8_CHUNKS_PER_BLOCK_X; ++c) {
Vec<float, ELEMS_PER_THREAD> other_row_dbias;
const int dbias_rowwise_offset_X = dbias_rowwise_block_offset_X + c * MXFP8_CHUNK_DIM_X;
const int dbias_offset = dbias_rowwise_offset_Y * dbias_stride + dbias_rowwise_offset_X;
const int left_bound = dbias_rowwise_offset_X; constexpr int DBIAS_BUFF_WIDTH = THREADS_X * (SCALE_DIM_X + 1);
const int right_bound = dbias_rowwise_offset_X + ELEMS_PER_THREAD - 1;
const int shmem_thread_offset =
tid_Y_rowwise * DBIAS_BUFF_WIDTH + tid_X_rowwise * (SCALE_DIM_X + 1);
#pragma unroll #pragma unroll
for (int i = 0; i < Y; ++i) { for (int w = 0; w < WAVES; ++w) {
other_row_dbias.load_from(&shmem_partial_dbias_rowwise[c][i][tid_rowwise_X]); const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const int swizzled_group_offset = shmem_thread_offset + swizzled_group_idx;
#pragma unroll #pragma unroll
for (int j = 0; j < ELEMS_PER_THREAD; ++j) { for (int e = 0; e < PACK_SIZE; ++e) {
partial_dbias_rowwise[c].data.elt[j] += other_row_dbias.data.elt[j]; const int j = w * PACK_SIZE + e;
} const int shmem_elt_idx = swizzled_group_offset + e;
} partial_dbias_rowwise[shmem_elt_idx] = thread_dbias_rowwise[j];
// Vectorized store when all elements are inside the boundaries
if (right_bound < cols) {
partial_dbias_rowwise[c].store_to(&dbias_workspace[dbias_offset]);
} else if (left_bound < cols && right_bound >= cols) {
// Element-by-element store when some elements cross the boundaries
const int in_bound_elts_count = cols - left_bound;
partial_dbias_rowwise[c].store_to_elts(&dbias_workspace[dbias_offset], 0,
in_bound_elts_count);
} }
} }
} __syncthreads();
} else {
#pragma unroll #pragma unroll
for (int i = 0; i < MXFP8_CHUNKS_PER_BLOCK_X; ++i) { for (int i = 0; i < THREADS_Y; ++i) {
const int dbias_colwise_offset_X = dbias_colwise_block_offset_X + i * MXFP8_CHUNK_DIM_X; // Add extra element offset per MXFP8 scaling block [1x32]
const int dbias_offset = dbias_colwise_offset_Y * dbias_stride + dbias_colwise_offset_X; const int scaling_block = threadIdx.x / SCALE_DIM_X;
const bool col_out_of_bounds = (dbias_colwise_offset_X >= cols); thread_partial_dbias +=
if (!col_out_of_bounds) { partial_dbias_rowwise[i * DBIAS_BUFF_WIDTH + threadIdx.x + scaling_block];
dbias_workspace[dbias_offset] = partial_dbias_colwise[i];
} }
} }
const int dbias_stride = cols;
const int dbias_offset_Y = blockIdx.y;
const int dbias_offset_X = blockIdx.x * CHUNK_DIM_X + threadIdx.x;
const int dbias_idx = dbias_offset_Y * dbias_stride + dbias_offset_X;
const bool col_out_of_bounds_dbias = (dbias_offset_X >= cols);
if (!col_out_of_bounds_dbias) {
dbias_workspace[dbias_idx] = thread_partial_dbias;
} }
} }
if (amax_ptr != nullptr) { if (amax_ptr != nullptr) {
const int warp_id = threadIdx.x / THREADS_PER_WARP; const int warp_id = threadIdx.x / THREADS_PER_WARP;
// Reduce the amax over the block // Reduce the amax over the block
block_amax = reduce_max<MXFP8_THREADS_PER_CHUNK / THREADS_PER_WARP>(block_amax, warp_id); block_amax = reduce_max<THREADS_PER_CHUNK / THREADS_PER_WARP>(block_amax, warp_id);
} }
if (is_master_thread && amax_ptr != nullptr) { if (is_master_thread && amax_ptr != nullptr) {
atomicMaxFloat(amax_ptr, block_amax); atomicMaxFloat(amax_ptr, block_amax);
} }
destroy_barriers<MXFP8_ITERATIONS>(mbar, is_master_thread); destroy_barriers<STAGES>(mbar, is_master_thread);
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
} }
} // namespace mxfp8_kernel
constexpr size_t FP8_CHUNK_DIM_Y = 128; constexpr size_t FP8_CHUNK_DIM_Y = 128;
constexpr size_t FP8_CHUNK_DIM_X = 128; constexpr size_t FP8_CHUNK_DIM_X = 128;
...@@ -507,9 +581,12 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) ...@@ -507,9 +581,12 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK)
const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1; const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1;
// The destination shared memory buffer of a bulk tensor operation should be 128-byte aligned // The destination shared memory buffer of a bulk tensor operation should be 128-byte aligned
__shared__ alignas(128) IType in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; __shared__ alignas(TMA_SHMEM_ALIGNMENT)
__shared__ alignas(128) IType act_in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; IType in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X];
__shared__ alignas(128) OType out_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; __shared__ alignas(TMA_SHMEM_ALIGNMENT)
IType act_in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X];
__shared__ alignas(TMA_SHMEM_ALIGNMENT)
OType out_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X];
constexpr int shmem_buff_size = sizeof(in_sh) / FP8_BUFFERS_NUM; constexpr int shmem_buff_size = sizeof(in_sh) / FP8_BUFFERS_NUM;
...@@ -678,8 +755,8 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) ...@@ -678,8 +755,8 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK)
const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1; const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1;
// The destination shared memory buffer of a bulk tensor operation should be 128-byte aligned // The destination shared memory buffer of a bulk tensor operation should be 128-byte aligned
__shared__ alignas(128) IType in_sh[SHMEM_BUFFERS][SHMEM_DIM]; __shared__ alignas(TMA_SHMEM_ALIGNMENT) IType in_sh[SHMEM_BUFFERS][SHMEM_DIM];
__shared__ alignas(128) OType out_sh[SHMEM_BUFFERS][SHMEM_DIM]; __shared__ alignas(TMA_SHMEM_ALIGNMENT) OType out_sh[SHMEM_BUFFERS][SHMEM_DIM];
constexpr int transaction_size_IN = sizeof(in_sh) / SHMEM_BUFFERS; constexpr int transaction_size_IN = sizeof(in_sh) / SHMEM_BUFFERS;
constexpr int transaction_size_OUT = sizeof(out_sh) / SHMEM_BUFFERS; constexpr int transaction_size_OUT = sizeof(out_sh) / SHMEM_BUFFERS;
...@@ -921,6 +998,7 @@ template <bool IS_DBIAS, bool IS_DACT, bool IS_ACT, typename ParamOP, ...@@ -921,6 +998,7 @@ template <bool IS_DBIAS, bool IS_DACT, bool IS_ACT, typename ParamOP,
void mxfp8_quantize(const Tensor &input, const Tensor *act_input, void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
const Tensor *noop, // TODO (ksivamani) const Tensor *noop, // TODO (ksivamani)
Tensor *output, Tensor *dbias, Tensor *workspace, cudaStream_t stream) { Tensor *output, Tensor *dbias, Tensor *workspace, cudaStream_t stream) {
using namespace mxfp8_kernel;
bool use_rowwise_scaling = output->has_data(); bool use_rowwise_scaling = output->has_data();
bool use_colwise_scaling = output->has_columnwise_data(); bool use_colwise_scaling = output->has_columnwise_data();
checkCuDriverContext(stream); checkCuDriverContext(stream);
...@@ -936,16 +1014,24 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, ...@@ -936,16 +1014,24 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
} }
CheckNoopTensor(*noop, "cast_noop"); CheckNoopTensor(*noop, "cast_noop");
// TODO: Make more general
const size_t scale_dim_X_rowwise = use_rowwise_scaling ? 32 : 1;
const size_t scale_dim_Y_colwise = use_colwise_scaling ? 32 : 1;
const size_t rows = input.flat_first_dim(); const size_t rows = input.flat_first_dim();
const size_t cols = input.flat_last_dim(); const size_t cols = input.flat_last_dim();
const size_t chunks_Y = DIVUP(rows, MXFP8_CHUNK_DIM_Y);
const size_t chunks_X = DIVUP(cols, MXFP8_CHUNK_DIM_X); constexpr bool CAST_DBIAS_ONLY = IS_DBIAS && (!IS_DACT) && (!IS_ACT);
const size_t blocks_Y = DIVUP(chunks_Y, MXFP8_CHUNKS_PER_BLOCK_Y);
const size_t blocks_X = DIVUP(chunks_X, MXFP8_CHUNKS_PER_BLOCK_X); constexpr size_t CHUNK_DIM_Y = CAST_DBIAS_ONLY ? 128 : 64;
constexpr size_t CHUNK_DIM_X = CAST_DBIAS_ONLY ? 128 : 64;
constexpr size_t THREADS_PER_CHUNK = CAST_DBIAS_ONLY ? 128 : 64;
constexpr size_t THREADS_X = CHUNK_DIM_X / SCALE_DIM_X;
constexpr size_t THREADS_Y = THREADS_PER_CHUNK / THREADS_X;
constexpr size_t BUFF_DIM_Y = THREADS_Y;
constexpr size_t BUFF_DIM_X = CHUNK_DIM_X;
const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y);
const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X);
const dim3 grid(blocks_X, blocks_Y);
const size_t block_size = THREADS_PER_CHUNK;
const size_t scale_stride_rowwise = use_rowwise_scaling ? output->scale_inv.shape[1] : 1; const size_t scale_stride_rowwise = use_rowwise_scaling ? output->scale_inv.shape[1] : 1;
const size_t scale_stride_colwise = const size_t scale_stride_colwise =
...@@ -958,6 +1044,15 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, ...@@ -958,6 +1044,15 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
const size_t dbias_rows = blocks_Y; const size_t dbias_rows = blocks_Y;
const size_t dbias_cols = cols; const size_t dbias_cols = cols;
ScalingType scaling_type;
if (use_rowwise_scaling && (!use_colwise_scaling)) {
scaling_type = ScalingType::ROWWISE;
} else if ((!use_rowwise_scaling) && use_colwise_scaling) {
scaling_type = ScalingType::COLWISE;
} else if (use_rowwise_scaling && use_colwise_scaling) {
scaling_type = ScalingType::BIDIMENSIONAL;
}
if constexpr (IS_DBIAS) { if constexpr (IS_DBIAS) {
NVTE_CHECK(dbias->data.dtype == input.dtype(), "DBias must have the same type as input."); NVTE_CHECK(dbias->data.dtype == input.dtype(), "DBias must have the same type as input.");
NVTE_CHECK(dbias->data.shape == std::vector<size_t>{cols}, "Wrong shape of DBias."); NVTE_CHECK(dbias->data.shape == std::vector<size_t>{cols}, "Wrong shape of DBias.");
...@@ -972,15 +1067,9 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, ...@@ -972,15 +1067,9 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
float *const workspace_ptr = IS_DBIAS ? reinterpret_cast<float *>(workspace->data.dptr) : nullptr; float *const workspace_ptr = IS_DBIAS ? reinterpret_cast<float *>(workspace->data.dptr) : nullptr;
float *const amax_ptr = reinterpret_cast<float *>(output->amax.dptr); float *const amax_ptr = reinterpret_cast<float *>(output->amax.dptr);
const float *noop_ptr = reinterpret_cast<const float *>(noop->data.dptr);
const dim3 block(MXFP8_THREADS_PER_CHUNK); TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
const dim3 grid(blocks_X, blocks_Y);
TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH(
scale_dim_Y_colwise, SCALE_DIM_Y,
TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH(
scale_dim_X_rowwise, SCALE_DIM_X,
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.dtype(), IType, input.dtype(), IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
output->dtype(), OType, output->dtype(), OType,
...@@ -990,40 +1079,95 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, ...@@ -990,40 +1079,95 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
alignas(64) CUtensorMap tensor_map_output_rowwise{}; alignas(64) CUtensorMap tensor_map_output_rowwise{};
alignas(64) CUtensorMap tensor_map_output_colwise{}; alignas(64) CUtensorMap tensor_map_output_colwise{};
create_2D_tensor_map(tensor_map_input, input.data, rows, cols, MXFP8_SHMEM_DIM_Y, constexpr size_t input_type_bit_size = TypeInfo<IType>::size;
MXFP8_SHMEM_DIM_X, cols, 0, typeToNumBits(input.dtype())); constexpr size_t output_type_bit_size = TypeInfo<OType>::size;
create_2D_tensor_map(tensor_map_input, input.data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X,
cols, 0, input_type_bit_size);
if constexpr (IS_DACT) { if constexpr (IS_DACT) {
create_2D_tensor_map(tensor_map_act_input, act_input->data, rows, cols, create_2D_tensor_map(tensor_map_act_input, act_input->data, rows, cols, BUFF_DIM_Y,
MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, cols, 0, BUFF_DIM_X, cols, 0, input_type_bit_size);
typeToNumBits(input.dtype()));
} }
if (use_rowwise_scaling) { if (use_rowwise_scaling) {
create_2D_tensor_map(tensor_map_output_rowwise, output->data, rows, cols, create_2D_tensor_map(tensor_map_output_rowwise, output->data, rows, cols, BUFF_DIM_Y,
MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, cols, 0, BUFF_DIM_X, cols, 0, output_type_bit_size);
typeToNumBits(output->dtype()));
} }
if (use_colwise_scaling) { if (use_colwise_scaling) {
create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, rows, create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, rows, cols,
cols, MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, cols, 0, BUFF_DIM_Y, BUFF_DIM_X, cols, 0, output_type_bit_size);
typeToNumBits(output->dtype())); }
}
constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X;
cast_mxfp8_2D_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, constexpr size_t buff_elems_total = mxfp8_kernel::BUFFS_NUM * buff_elems;
SCALE_DIM_Y, SCALE_DIM_X><<<grid, block, 0, stream>>>( constexpr size_t input_buff_size = (buff_elems_total * input_type_bit_size) / 8;
constexpr size_t output_buff_size = (buff_elems_total * output_type_bit_size) / 8;
constexpr size_t buff_size_aligned_in =
DIVUP_TO_MULTIPLE(input_buff_size, TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_aligned_out =
DIVUP_TO_MULTIPLE(output_buff_size, TMA_SHMEM_ALIGNMENT);
constexpr size_t elt_input_mem = buff_size_aligned_in;
constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0);
constexpr size_t in_mem = elt_input_mem + act_input_mem;
const size_t out_rowwise_mem = (use_rowwise_scaling ? buff_size_aligned_out : 0);
const size_t out_colwise_mem = (use_colwise_scaling ? buff_size_aligned_out : 0);
const size_t out_mem = out_rowwise_mem + out_colwise_mem;
const size_t dshmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT;
switch (scaling_type) {
case ScalingType::ROWWISE:
cudaFuncSetAttribute(
cast_mxfp8_2D_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, true,
false, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>,
cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size);
cast_mxfp8_2D_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, true,
false, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>
<<<grid, block_size, dshmem_size, stream>>>(
tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise,
tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr,
reinterpret_cast<const float *>(noop->data.dptr), workspace_ptr, amax_ptr, workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise,
rows, cols, scale_stride_rowwise, scale_stride_colwise); scale_stride_colwise);
break;
case ScalingType::COLWISE:
cudaFuncSetAttribute(
cast_mxfp8_2D_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, false,
true, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>,
cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size);
cast_mxfp8_2D_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, false,
true, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>
<<<grid, block_size, dshmem_size, stream>>>(
tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise,
tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr,
workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise);
break;
case ScalingType::BIDIMENSIONAL:
cudaFuncSetAttribute(
cast_mxfp8_2D_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, true,
true, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>,
cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size);
cast_mxfp8_2D_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, true, true,
CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>
<<<grid, block_size, dshmem_size, stream>>>(
tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise,
tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr,
workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise);
break;
}
if constexpr (IS_DBIAS) { if constexpr (IS_DBIAS) {
reduce_dbias<IType>(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); reduce_dbias<IType>(workspace_ptr, dbias, dbias_rows, dbias_cols, stream);
}); // NOLINT(*) }); // NOLINT(*)
); // NOLINT(*) ); // NOLINT(*)
); // NOLINT(*)
); // NOLINT(*)
} }
namespace detail { namespace detail {
...@@ -1117,8 +1261,8 @@ void fp8_quantize_arch_ge_100(const Tensor &input, const Tensor *act_input, cons ...@@ -1117,8 +1261,8 @@ void fp8_quantize_arch_ge_100(const Tensor &input, const Tensor *act_input, cons
case NVTE_DELAYED_TENSOR_SCALING: { case NVTE_DELAYED_TENSOR_SCALING: {
if (!IS_DBIAS && !IS_DACT) { if (!IS_DBIAS && !IS_DACT) {
if (is_full_tile_1D_tensor(output) && is_fp8_dtype(output->dtype()) && if (is_full_tile_1D_tensor(output) && is_fp8_dtype(output->dtype()) &&
is_aligned_tensor_data(input, TMA_gmem_alignment) && is_aligned_tensor_data(input, TMA_GMEM_ALIGNMENT) &&
is_aligned_tensor_data(*output, TMA_gmem_alignment)) { is_aligned_tensor_data(*output, TMA_GMEM_ALIGNMENT)) {
// Aligned AND FP8 // Aligned AND FP8
cast_fp8_1D<IS_ACT, ParamOP, OP>(input, output, stream); cast_fp8_1D<IS_ACT, ParamOP, OP>(input, output, stream);
} else { } else {
...@@ -1127,9 +1271,9 @@ void fp8_quantize_arch_ge_100(const Tensor &input, const Tensor *act_input, cons ...@@ -1127,9 +1271,9 @@ void fp8_quantize_arch_ge_100(const Tensor &input, const Tensor *act_input, cons
} }
} else if (!IS_DBIAS && IS_DACT) { } else if (!IS_DBIAS && IS_DACT) {
if (dimensions_supported_by_TMA(output) && is_fp8_dtype(output->dtype()) && if (dimensions_supported_by_TMA(output) && is_fp8_dtype(output->dtype()) &&
is_aligned_tensor_data(input, TMA_gmem_alignment) && is_aligned_tensor_data(input, TMA_GMEM_ALIGNMENT) &&
is_aligned_tensor_data(*output, TMA_gmem_alignment) && is_aligned_tensor_data(*output, TMA_GMEM_ALIGNMENT) &&
is_aligned_tensor_data(*act_input, TMA_gmem_alignment)) { is_aligned_tensor_data(*act_input, TMA_GMEM_ALIGNMENT)) {
// Aligned AND FP8 (+dAct) // Aligned AND FP8 (+dAct)
cast_fp8_2D<IS_DBIAS, IS_DACT, ParamOP, OP>(input, act_input, output, dbias, workspace, cast_fp8_2D<IS_DBIAS, IS_DACT, ParamOP, OP>(input, act_input, output, dbias, workspace,
stream); stream);
......
...@@ -84,8 +84,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -84,8 +84,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// const int thread_offset_X_colwise = tid_colwise_X; // const int thread_offset_X_colwise = tid_colwise_X;
// The destination shared memory buffer of a bulk tensor operation should be 128 e8m0_t aligned // The destination shared memory buffer of a bulk tensor operation should be 128 e8m0_t aligned
__shared__ alignas(128) IType in_sh[BUFFERS_NUM][SHMEM_DIM_Y][SHMEM_DIM_X]; __shared__ alignas(TMA_SHMEM_ALIGNMENT) IType in_sh[BUFFERS_NUM][SHMEM_DIM_Y][SHMEM_DIM_X];
__shared__ alignas(128) OType out_sh[BUFFERS_NUM][SHMEM_DIM_Y][SHMEM_DIM_X]; __shared__ alignas(TMA_SHMEM_ALIGNMENT) OType out_sh[BUFFERS_NUM][SHMEM_DIM_Y][SHMEM_DIM_X];
constexpr int shmem_buff_size = sizeof(in_sh) / BUFFERS_NUM; constexpr int shmem_buff_size = sizeof(in_sh) / BUFFERS_NUM;
constexpr int transaction_size = shmem_buff_size; constexpr int transaction_size = shmem_buff_size;
...@@ -166,7 +166,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -166,7 +166,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const int scale_idx = scale_offset_Y * scales_stride + scale_offset_X; const int scale_idx = scale_offset_Y * scales_stride + scale_offset_X;
const e8m0_t biased_exponent = scales_ptr[scale_idx]; const e8m0_t biased_exponent = scales_ptr[scale_idx];
const float block_scale = exp2f(static_cast<float>(biased_exponent) - FP32_EXPONENT_BIAS); const float block_scale = ptx::exp2f(biased_exponent);
if constexpr (USE_ROWWISE_SCALING) { if constexpr (USE_ROWWISE_SCALING) {
Vec<IType, ELEMS_PER_THREAD> in; Vec<IType, ELEMS_PER_THREAD> in;
......
...@@ -104,6 +104,53 @@ __device__ __forceinline__ void mbarrier_wait_parity(uint64_t *mbar, const uint3 ...@@ -104,6 +104,53 @@ __device__ __forceinline__ void mbarrier_wait_parity(uint64_t *mbar, const uint3
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
constexpr uint32_t FP32_MANTISSA_BITS = 23;
constexpr uint32_t FP32_EXPONENT_BIAS = 127;
__device__ __forceinline__ float exp2f_rcp(e8m0_t biased_exp) {
return (biased_exp == 0) ? 1
: __int_as_float((254 - biased_exp)
<< FP32_MANTISSA_BITS); // 127 - (biased_exp - 127)
}
__device__ __forceinline__ float exp2f(e8m0_t biased_exp) {
return __int_as_float(biased_exp << FP32_MANTISSA_BITS);
}
__device__ __forceinline__ e8m0_t float_to_e8m0(float val) {
#if ((__CUDA_ARCH_HAS_FEATURE__(SM100_ALL)) || (__CUDA_ARCH_HAS_FEATURE__(SM101_ALL)) || \
(__CUDA_ARCH_HAS_FEATURE__(SM120_ALL)))
uint16_t out;
asm volatile(
"{\n"
"cvt.rp.satfinite.ue8m0x2.f32 %0, 0.0, %1;\n"
"}"
: "=h"(out)
: "f"(val));
return *reinterpret_cast<e8m0_t *>(&out);
#else
// TODO: nan/inf needs to be set for any value
// of nan/inf in input not just amax.
if (isnan(val)) {
return 0xFF;
}
if (isinf(val)) {
return 0xFE;
}
if (val == 0.0f) {
return 0x00;
}
uint32_t val_u32 = *reinterpret_cast<uint32_t *>(&val);
e8m0_t exponent = (val_u32 >> FP32_MANTISSA_BITS);
uint32_t mantissa = val_u32 & 0x7FFFFF;
// Round up exponent and deal with satfinite.
if ((mantissa > 0 && exponent != 0xFE) && !(exponent == 0 && mantissa <= 0x400000)) {
++exponent;
}
return exponent;
#endif
}
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor
...@@ -169,6 +216,159 @@ __device__ __forceinline__ void fence_proxy_async_shared_cta() { ...@@ -169,6 +216,159 @@ __device__ __forceinline__ void fence_proxy_async_shared_cta() {
asm volatile("fence.proxy.async.shared::cta;"); asm volatile("fence.proxy.async.shared::cta;");
} }
template <typename T>
struct alignas(2 * sizeof(T)) FPx2 {
T x;
T y;
};
using floatx2 = FPx2<float>;
using bf16x2 = FPx2<bf16>;
using fp16x2 = FPx2<fp16>;
using fp8e4m3x2 = FPx2<fp8e4m3>;
using fp8e5m2x2 = FPx2<fp8e5m2>;
static_assert(sizeof(floatx2) == 8);
static_assert(sizeof(bf16x2) == 4);
static_assert(sizeof(fp16x2) == 4);
static_assert(sizeof(fp8e4m3x2) == 2);
static_assert(sizeof(fp8e5m2x2) == 2);
// SIMD like "Fused" cast + multiplication (x2)
__device__ __forceinline__ void mul_cvt_2x(fp8e4m3x2 &out, const floatx2 &in,
const floatx2 &scale) {
asm volatile(
"{\n"
".reg.b64 val_pair; \n\t"
".reg.b32 val1; \n\t"
".reg.b32 val2; \n\t"
"mul.f32x2 val_pair, %1, %2; \n\t"
"mov.b64 {val2,val1}, val_pair; \n\t"
"cvt.rn.satfinite.e4m3x2.f32 %0, val1, val2; \n\t"
"}"
: "=h"(reinterpret_cast<uint16_t &>(out))
: "l"(reinterpret_cast<const uint64_t &>(in)),
"l"(reinterpret_cast<const uint64_t &>(scale)));
}
__device__ __forceinline__ void mul_cvt_2x(fp8e5m2x2 &out, const floatx2 &in,
const floatx2 &scale) {
asm volatile(
"{\n"
".reg.b64 val_pair; \n\t"
".reg.b32 val1; \n\t"
".reg.b32 val2; \n\t"
"mul.f32x2 val_pair, %1, %2; \n\t"
"mov.b64 {val2,val1}, val_pair; \n\t"
"cvt.rn.satfinite.e5m2x2.f32 %0, val1, val2; \n\t"
"}"
: "=h"(reinterpret_cast<uint16_t &>(out))
: "l"(reinterpret_cast<const uint64_t &>(in)),
"l"(reinterpret_cast<const uint64_t &>(scale)));
}
__device__ __forceinline__ void mul_cvt_2x(fp8e4m3x2 &out, const bf16x2 &in, const floatx2 &scale) {
asm volatile(
"{\n"
".reg.b64 val_pair_before; \n\t"
".reg.b64 val_pair_after; \n\t"
".reg.b32 val1; \n\t"
".reg.b32 val2; \n\t"
".reg.b16 val1_bf16; \n\t"
".reg.b16 val2_bf16; \n\t"
"mov.b32 {val1_bf16, val2_bf16} , %1; \n\t"
"cvt.f32.bf16 val1, val1_bf16; \n\t"
"cvt.f32.bf16 val2, val2_bf16; \n\t"
"mov.b64 val_pair_before, {val1,val2}; \n\t"
"mul.f32x2 val_pair_after, val_pair_before, %2; \n\t"
"mov.b64 {val2,val1}, val_pair_after; \n\t"
"cvt.rn.satfinite.e4m3x2.f32 %0, val1, val2; \n\t"
"}"
: "=h"(reinterpret_cast<uint16_t &>(out))
: "r"(reinterpret_cast<const uint32_t &>(in)),
"l"(reinterpret_cast<const uint64_t &>(scale)));
}
__device__ __forceinline__ void mul_cvt_2x(fp8e5m2x2 &out, const bf16x2 &in, const floatx2 &scale) {
asm volatile(
"{\n"
".reg.b64 val_pair_before; \n\t"
".reg.b64 val_pair_after; \n\t"
".reg.b32 val1; \n\t"
".reg.b32 val2; \n\t"
".reg.b16 val1_bf16; \n\t"
".reg.b16 val2_bf16; \n\t"
"mov.b32 {val1_bf16, val2_bf16} , %1; \n\t"
"cvt.f32.bf16 val1, val1_bf16; \n\t"
"cvt.f32.bf16 val2, val2_bf16; \n\t"
"mov.b64 val_pair_before, {val1,val2}; \n\t"
"mul.f32x2 val_pair_after, val_pair_before, %2; \n\t"
"mov.b64 {val2,val1}, val_pair_after; \n\t"
"cvt.rn.satfinite.e5m2x2.f32 %0, val1, val2; \n\t"
"}"
: "=h"(reinterpret_cast<uint16_t &>(out))
: "r"(reinterpret_cast<const uint32_t &>(in)),
"l"(reinterpret_cast<const uint64_t &>(scale)));
}
__device__ __forceinline__ void mul_cvt_2x(fp8e4m3x2 &out, const fp16x2 &in, const floatx2 &scale) {
asm volatile(
"{\n"
".reg.b64 val_pair_before; \n\t"
".reg.b64 val_pair_after; \n\t"
".reg.b32 val1; \n\t"
".reg.b32 val2; \n\t"
".reg.b16 val1_fp16; \n\t"
".reg.b16 val2_fp16; \n\t"
"mov.b32 {val1_fp16, val2_fp16} , %1; \n\t"
"cvt.f32.f16 val1, val1_fp16; \n\t"
"cvt.f32.f16 val2, val2_fp16; \n\t"
"mov.b64 val_pair_before, {val1,val2}; \n\t"
"mul.f32x2 val_pair_after, val_pair_before, %2; \n\t"
"mov.b64 {val2,val1}, val_pair_after; \n\t"
"cvt.rn.satfinite.e4m3x2.f32 %0, val1, val2; \n\t"
"}"
: "=h"(reinterpret_cast<uint16_t &>(out))
: "r"(reinterpret_cast<const uint32_t &>(in)),
"l"(reinterpret_cast<const uint64_t &>(scale)));
}
__device__ __forceinline__ void mul_cvt_2x(fp8e5m2x2 &out, const fp16x2 &in, const floatx2 &scale) {
asm volatile(
"{\n"
".reg.b64 val_pair_before; \n\t"
".reg.b64 val_pair_after; \n\t"
".reg.b32 val1; \n\t"
".reg.b32 val2; \n\t"
".reg.b16 val1_fp16; \n\t"
".reg.b16 val2_fp16; \n\t"
"mov.b32 {val1_fp16, val2_fp16} , %1; \n\t"
"cvt.f32.f16 val1, val1_fp16; \n\t"
"cvt.f32.f16 val2, val2_fp16; \n\t"
"mov.b64 val_pair_before, {val1,val2}; \n\t"
"mul.f32x2 val_pair_after, val_pair_before, %2; \n\t"
"mov.b64 {val2,val1}, val_pair_after; \n\t"
"cvt.rn.satfinite.e5m2x2.f32 %0, val1, val2; \n\t"
"}"
: "=h"(reinterpret_cast<uint16_t &>(out))
: "r"(reinterpret_cast<const uint32_t &>(in)),
"l"(reinterpret_cast<const uint64_t &>(scale)));
}
__device__ __forceinline__ void abs_max_2x(bf16x2 &dst, const bf16x2 &p1, const bf16x2 &p2) {
asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;"
: "=r"(reinterpret_cast<uint32_t &>(dst))
: "r"(reinterpret_cast<const uint32_t &>(p1)),
"r"(reinterpret_cast<const uint32_t &>(p2)));
}
__device__ __forceinline__ void abs_max_2x(fp16x2 &dst, const fp16x2 &p1, const fp16x2 &p2) {
asm volatile("max.xorsign.abs.f16x2 %0, %1, %2;"
: "=r"(reinterpret_cast<uint32_t &>(dst))
: "r"(reinterpret_cast<const uint32_t &>(p1)),
"r"(reinterpret_cast<const uint32_t &>(p2)));
}
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
} // namespace ptx } // namespace ptx
......
...@@ -905,10 +905,7 @@ using fp8e4m3 = __nv_fp8_e4m3; ...@@ -905,10 +905,7 @@ using fp8e4m3 = __nv_fp8_e4m3;
using fp8e5m2 = __nv_fp8_e5m2; using fp8e5m2 = __nv_fp8_e5m2;
using e8m0_t = uint8_t; using e8m0_t = uint8_t;
constexpr uint32_t FP32_MANTISSA_BITS = 23; enum ScalingType { ROWWISE = 0, COLWISE = 1, BIDIMENSIONAL = 2 };
constexpr uint32_t FP32_EXPONENT_BIAS = 127;
enum ScalingType { ROWWISE = 0, COLWISE = 1, BIDIMENTIONAL = 2 };
template <typename T> template <typename T>
struct Numeric_Traits; struct Numeric_Traits;
...@@ -934,44 +931,6 @@ struct Quantized_Limits { ...@@ -934,44 +931,6 @@ struct Quantized_Limits {
static constexpr float emax_rcp = 1.0 / emax; static constexpr float emax_rcp = 1.0 / emax;
}; };
__device__ __forceinline__ e8m0_t float_to_e8m0(float val) {
// TODO: nan/inf needs to be set for any value
// of nan/inf in input not just amax.
if (isnan(val)) {
return 0xFF;
}
if (isinf(val)) {
return 0xFE;
}
#if ((__CUDA_ARCH_HAS_FEATURE__(SM100_ALL)) || (__CUDA_ARCH_HAS_FEATURE__(SM101_ALL)) || \
(__CUDA_ARCH_HAS_FEATURE__(SM120_ALL)))
uint16_t out;
asm volatile(
"{\n"
"cvt.rp.satfinite.ue8m0x2.f32 %0, 0.0, %1;\n"
"}"
: "=h"(out)
: "f"(val));
return *reinterpret_cast<e8m0_t *>(&out);
#else
if (val == 0.0f) {
return 0x00;
}
uint32_t val_u32 = *reinterpret_cast<uint32_t *>(&val);
e8m0_t exponent = (val_u32 >> FP32_MANTISSA_BITS);
uint32_t mantissa = val_u32 & 0x7FFFFF;
// Round up exponent and deal with satfinite.
if ((mantissa > 0 && exponent != 0xFE) && !(exponent == 0 && mantissa <= 0x400000)) {
++exponent;
}
return exponent;
#endif
}
__device__ __forceinline__ float exp2f_rcp(e8m0_t biased_exp) {
return (biased_exp == 0) ? 1 : exp2f(FP32_EXPONENT_BIAS - static_cast<float>(biased_exp));
}
} // namespace transformer_engine } // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_COMMON_UTILS_CUH_ #endif // TRANSFORMER_ENGINE_COMMON_UTILS_CUH_
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment