Unverified Commit 73939472 authored by Oleg Goncharov's avatar Oleg Goncharov Committed by GitHub
Browse files

[Common] MXFP8 kernel for grouped tensors (#2586)



* Rebased to main
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fixed the year to 2026
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Added compilation guards
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Added BWD pass
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Added dbias and dact tests. Refactoring.
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Added grouped MXFP8 DACT and ACT API and tests
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fixed a typo
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Fixes per the review
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* More fixes from the review
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Fixes per the review
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Relaxed requirement for last dim from mod128 to mod32
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Added alignment checks when tensor descriptors are modified
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarvthumbe1503 <vthumbe@nvidia.com>
parent 71971e33
......@@ -11,6 +11,7 @@ add_executable(test_operator
test_cast_mxfp8_gated_swiglu.cu
test_qdq.cu
test_cast_mxfp8.cu
test_cast_mxfp8_grouped.cu
test_cast_nvfp4_transpose.cu
test_cast_float8blockwise.cu
test_dequantize_mxfp8.cu
......
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/cast.h>
#include <transformer_engine/activation.h>
#include "../test_common.h"
#include "transformer_engine/transformer_engine.h"
using namespace transformer_engine;
using namespace test;
namespace {
enum ProcessingMethod {
CAST_ONLY,
CAST_DBIAS,
CAST_DBIAS_DACT,
CAST_DACT,
CAST_ACT
};
enum ActivationKind {
Identity,
GeLU,
SiLU,
ReLU,
QGeLU,
SReLU
};
enum ShapeRepresentation {
SAME_BOTH_DIMS = 0,
VARYING_FIRST_DIM = 1,
VARYING_LAST_DIM = 2,
VARYING_BOTH_DIMS = 3
};
template <typename InputType, typename OutputType>
void compute_ref(const ProcessingMethod processing_method,
float (*OP)(const float),
const bool rowwise,
const bool colwise,
const InputType* input,
const InputType* grad,
OutputType* output_rowwise,
OutputType* output_colwise,
fp8e8m0* output_scales_rowwise,
fp8e8m0* output_scales_colwise,
InputType* output_dbias,
const size_t rows,
const size_t cols,
const size_t scales_stride_rowwise,
const size_t scales_stride_colwise,
const bool is_single_tensor)
{
const size_t tile_size_Y = 32;
const size_t tile_size_X = 32;
const size_t tiles_num_Y = (rows + tile_size_Y - 1) / tile_size_Y;
const size_t tiles_num_X = (cols + tile_size_X - 1) / tile_size_X;
std::vector<float> output_dbias_fp32(cols, 0);
#pragma omp parallel proc_bind(spread)
{
// Buffers to cache intermediate computations
std::vector<float> cache_buffer(tile_size_Y * tile_size_X);
std::vector<float> thread_dbias(cols, 0);
#pragma omp for schedule(static)
for (size_t t = 0; t < tiles_num_Y * tiles_num_X; ++t) {
const size_t tile_Y = t / tiles_num_X;
const size_t tile_X = t % tiles_num_X;
const size_t tile_offset_Y = tile_Y * tile_size_Y;
const size_t tile_offset_X = tile_X * tile_size_X;
const size_t i_min = tile_offset_Y;
const size_t i_max = std::min(i_min + tile_size_Y, rows);
const size_t j_min = tile_offset_X;
const size_t j_max = std::min(j_min + tile_size_X, cols);
// Cache computations
for (size_t i = i_min; i < i_max; ++i) {
for (size_t j = j_min; j < j_max; ++j) {
const size_t idx = i * cols + j;
const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min);
float elt = static_cast<float>(input[idx]);
if (processing_method == ProcessingMethod::CAST_DBIAS) {
// grad is the input
elt = static_cast<float>(grad[idx]);
}
if (processing_method != ProcessingMethod::CAST_ONLY
&& processing_method != ProcessingMethod::CAST_DBIAS) {
elt = OP(elt);
}
if (processing_method == ProcessingMethod::CAST_DACT ||
processing_method == ProcessingMethod::CAST_DBIAS_DACT) {
elt *= static_cast<float>(grad[idx]);
}
thread_dbias[j] += elt;
// Numerical truncation: after downcast to InputType (BF16/FP16), upcast it back to FP32
elt = static_cast<float>(static_cast<InputType>(elt));
cache_buffer[cache_idx] = elt;
if (isinf(elt) || isnan(elt)) {
continue;
}
}
}
if (rowwise) {
for (size_t i = i_min; i < i_max; ++i) {
float block_amax = 0.0f;
for (size_t j = j_min; j < j_max; ++j) {
const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min);
block_amax = std::max(block_amax, std::abs(cache_buffer[cache_idx]));
}
const fp8e8m0 biased_exponent = float_to_e8m0(block_amax * Quantized_Limits<OutputType>::max_reciprocal());
const size_t scale_idx = i * scales_stride_rowwise + tile_X;
output_scales_rowwise[scale_idx] = biased_exponent;
const float scale_reciprocal = exp2f_rcp(biased_exponent);
for (size_t j = j_min; j < j_max; ++j) {
const size_t idx = i * cols + j;
const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min);
output_rowwise[idx] = static_cast<OutputType>(cache_buffer[cache_idx] * scale_reciprocal);
}
}
}
if (colwise) {
for (size_t j = j_min; j < j_max; ++j) {
float block_amax = 0.0f;
for (size_t i = i_min; i < i_max; ++i) {
const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min);
block_amax = std::max(block_amax, std::abs(cache_buffer[cache_idx]));
}
const fp8e8m0 biased_exponent = float_to_e8m0(block_amax * Quantized_Limits<OutputType>::max_reciprocal());
const size_t scale_idx = tile_Y * scales_stride_colwise + j;
output_scales_colwise[scale_idx] = biased_exponent;
const float scale_reciprocal = exp2f_rcp(biased_exponent);
for (size_t i = i_min; i < i_max; ++i) {
const size_t idx = i * cols + j;
const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min);
output_colwise[idx] = static_cast<OutputType>(cache_buffer[cache_idx] * scale_reciprocal);
}
}
}
}
#pragma omp critical
{
for (size_t j = 0; j < cols; ++j) {
output_dbias_fp32[j] += thread_dbias[j];
}
}
}
if (is_single_tensor) {
for (size_t j = 0; j < cols; ++j) {
output_dbias[j] = static_cast<InputType>(output_dbias_fp32[j]);
}
}
}
template <typename T>
void compare_scaled_elts(const std::string &name,
const T* ref_data,
const T* test_data,
const size_t rows,
const size_t cols,
const bool rowwise,
const size_t tolerable_mismatches_limit = 0,
const double atol = 1e-5,
const double rtol = 1e-8) {
size_t mismatches_num = 0;
int first_mismatch_idx = -1;
for (size_t i = 0; i < rows * cols; ++i) {
double t = static_cast<double>(test_data[i]);
double r = static_cast<double>(ref_data[i]);
bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol);
/* For Float32 the floating point comparison is enough to error out */
bool assertion = false;
if (mismatch && !assertion) {
/* Check if it is just a failure of round to nearest choosing different
side of the real value */
const double mean = (t + r) / 2;
const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6);
const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6);
const double cast_mean_p = static_cast<double>(static_cast<T>(mean_p));
const double cast_mean_m = static_cast<double>(static_cast<T>(mean_m));
assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r));
}
std::string direction = rowwise ? "rowwise" : "columnwise";
if (assertion) {
mismatches_num++;
if (first_mismatch_idx == -1) {
first_mismatch_idx = i;
}
}
if (mismatches_num > tolerable_mismatches_limit) {
const double first_mismatch_t = static_cast<double>(test_data[first_mismatch_idx]);
const double first_mismatch_r = static_cast<double>(ref_data[first_mismatch_idx]);
GTEST_FAIL() << mismatches_num << " mismatche(s) which is more than tolerable mismatch limit of "
<< tolerable_mismatches_limit << "." << std::endl
<< "Error in tensor " << name << " in "
<< direction << " direction." << std::endl
<< "First mismatch at place " << first_mismatch_idx
<< " (" << std::to_string(first_mismatch_idx) << "): "
<< first_mismatch_t << " vs " << first_mismatch_r;
}
}
}
/**
* Scaling along single dimension (either rows or columns)
* Produces one set of output data and the corresponding data of the fused operation (dbias):
* 1) Scaled rows + row-wise scaling factors
* OR
* 2) Scaled columns + column-wise scaling factors
*/
template <typename InputType, typename OutputType>
void performTest(const ProcessingMethod processing_method,
float (*OP)(const float),
const ShapeRepresentation shape_rep,
const size_t num_tensors,
const std::vector<size_t>& logical_shape_vec,
const std::vector<size_t>& first_dims_h,
const std::vector<size_t>& last_dims_h,
const std::vector<size_t>& offsets_h,
const bool rowwise,
const bool colwise) {
using namespace test;
DType itype = TypeInfo<InputType>::dtype;
DType otype = TypeInfo<OutputType>::dtype;
const size_t rows = logical_shape_vec[0];
const size_t cols = logical_shape_vec[1];
size_t elts_num = 0;
size_t rowwise_sfs_num = 0;
size_t colwise_sfs_num = 0;
std::vector<size_t> rowwise_scales_first_dim(num_tensors, 0);
std::vector<size_t> rowwise_scales_last_dim(num_tensors, 0);
std::vector<size_t> rowwise_scales_offset(num_tensors + 1, 0);
std::vector<size_t> colwise_scales_first_dim(num_tensors, 0);
std::vector<size_t> colwise_scales_last_dim(num_tensors, 0);
std::vector<size_t> colwise_scales_offset(num_tensors + 1, 0);
for (size_t t = 0; t < num_tensors; ++t) {
const size_t M = first_dims_h[t];
const size_t K = last_dims_h[t];
const size_t elts = M * K;
elts_num += elts;
const size_t unpadded_rowwise_blocks_Y = M;
const size_t unpadded_rowwise_blocks_X = divide_round_up(K, 32);
const size_t unpadded_colwise_blocks_Y = divide_round_up(M, 32);
const size_t unpadded_colwise_blocks_X = K;
rowwise_scales_first_dim[t] = round_up_to_nearest_multiple(unpadded_rowwise_blocks_Y, 128);
rowwise_scales_last_dim[t] = round_up_to_nearest_multiple(unpadded_rowwise_blocks_X, 4);
colwise_scales_first_dim[t] = round_up_to_nearest_multiple(unpadded_colwise_blocks_Y, 4);
colwise_scales_last_dim[t] = round_up_to_nearest_multiple(unpadded_colwise_blocks_X, 128);
const size_t rowwise_sfs = rowwise_scales_first_dim[t] * rowwise_scales_last_dim[t];
const size_t colwise_sfs = colwise_scales_first_dim[t] * colwise_scales_last_dim[t];
rowwise_sfs_num += rowwise_sfs;
colwise_sfs_num += colwise_sfs;
rowwise_scales_offset[t+1] = rowwise_sfs_num;
colwise_scales_offset[t+1] = colwise_sfs_num;
}
const bool is_single_tensor = (shape_rep == SAME_BOTH_DIMS) || (shape_rep == VARYING_FIRST_DIM);
std::vector<size_t> scales_rowwise_shape = {rowwise_sfs_num};
std::vector<size_t> scales_colwise_shape = {colwise_sfs_num};
std::mt19937 gen;
std::uniform_real_distribution<> dis(-2.0, 1.0);
std::vector<InputType> in_data(elts_num);
std::vector<InputType> grad_data(elts_num);
std::vector<OutputType> out_data_rowwise_h(rowwise ? elts_num : 0);
std::vector<OutputType> out_data_colwise_h(colwise ? elts_num : 0);
std::vector<fp8e8m0> out_scales_rowwise_h(rowwise ? rowwise_sfs_num : 0);
std::vector<fp8e8m0> out_scales_colwise_h(colwise ? colwise_sfs_num : 0);
std::vector<OutputType> out_data_rowwise_ref(rowwise ? elts_num : 0);
std::vector<OutputType> out_data_colwise_ref(colwise ? elts_num : 0);
std::vector<fp8e8m0> out_scales_rowwise_ref(rowwise ? rowwise_sfs_num : 0);
std::vector<fp8e8m0> out_scales_colwise_ref(colwise ? colwise_sfs_num : 0);
std::vector<InputType> ref_output_dbias(is_single_tensor ? cols : 0);
for (size_t i = 0; i < elts_num; ++i) {
const float val = dis(gen);
grad_data[i] = static_cast<InputType>(val);
in_data[i] = static_cast<InputType>(val);
}
const OutputType zero_elt = static_cast<OutputType>(0.0f);
const fp8e8m0 zero_SF = static_cast<fp8e8m0>(0.0f);
if (rowwise) {
std::fill(out_data_rowwise_h.begin(), out_data_rowwise_h.end(), zero_elt);
std::fill(out_data_rowwise_ref.begin(), out_data_rowwise_ref.end(), zero_elt);
std::fill(out_scales_rowwise_h.begin(), out_scales_rowwise_h.end(), zero_SF);
std::fill(out_scales_rowwise_ref.begin(), out_scales_rowwise_ref.end(), zero_SF);
}
if (colwise) {
std::fill(out_data_colwise_h.begin(), out_data_colwise_h.end(), zero_elt);
std::fill(out_data_colwise_ref.begin(), out_data_colwise_ref.end(), zero_elt);
std::fill(out_scales_colwise_h.begin(), out_scales_colwise_h.end(), zero_SF);
std::fill(out_scales_colwise_ref.begin(), out_scales_colwise_ref.end(), zero_SF);
}
const size_t in_data_size = elts_num * sizeof(InputType);
const size_t out_data_size = elts_num * sizeof(OutputType);
const size_t rowwise_scales_size = rowwise_sfs_num * sizeof(fp8e8m0);
const size_t colwise_scales_size = colwise_sfs_num * sizeof(fp8e8m0);
const size_t first_dims_size = num_tensors * sizeof(size_t);
const size_t last_dims_size = num_tensors * sizeof(size_t);
const size_t offsets_size = (num_tensors + 1) * sizeof(size_t);
InputType* grad_data_d;
InputType* in_data_d;
OutputType* out_data_rowwise_d;
OutputType* out_data_colwise_d;
fp8e8m0* out_scales_rowwise_d;
fp8e8m0* out_scales_colwise_d;
size_t* first_dims_d;
size_t* last_dims_d;
size_t* offsets_d;
cudaMalloc((void**)&grad_data_d, in_data_size);
cudaMalloc((void**)&in_data_d, in_data_size);
cudaMalloc((void**)&first_dims_d, first_dims_size);
cudaMalloc((void**)&last_dims_d, last_dims_size);
cudaMalloc((void**)&offsets_d, offsets_size);
cudaMemcpy(grad_data_d, grad_data.data(), in_data_size, cudaMemcpyHostToDevice);
cudaMemcpy(in_data_d, in_data.data(), in_data_size, cudaMemcpyHostToDevice);
cudaMemcpy(first_dims_d, first_dims_h.data(), first_dims_size, cudaMemcpyHostToDevice);
cudaMemcpy(last_dims_d, last_dims_h.data(), last_dims_size, cudaMemcpyHostToDevice);
cudaMemcpy(offsets_d, offsets_h.data(), offsets_size, cudaMemcpyHostToDevice);
NVTEShape logical_shape_ = nvte_make_shape(logical_shape_vec.data(), logical_shape_vec.size());
NVTEShape first_dims_shape_;
NVTEShape last_dims_shape_;
NVTEShape offsets_shape_;
first_dims_shape_.ndim = 1;
last_dims_shape_.ndim = 1;
offsets_shape_.ndim = 1;
first_dims_shape_.data[0] = num_tensors;
last_dims_shape_.data[0] = num_tensors;
offsets_shape_.data[0] = num_tensors + 1;
NVTEGroupedTensor grad_group_tensor = nvte_create_grouped_tensor(NVTE_DELAYED_TENSOR_SCALING, num_tensors, logical_shape_);
NVTEGroupedTensor in_group_tensor = nvte_create_grouped_tensor(NVTE_DELAYED_TENSOR_SCALING, num_tensors, logical_shape_);
NVTEGroupedTensor out_group_tensor = nvte_create_grouped_tensor(NVTE_MXFP8_1D_SCALING, num_tensors, logical_shape_);
NVTEBasicTensor grad_data_tensor = {grad_data_d, static_cast<NVTEDType>(itype), logical_shape_};
NVTEBasicTensor in_data_tensor = {in_data_d, static_cast<NVTEDType>(itype), logical_shape_};
nvte_set_grouped_tensor_param(&in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseData, &in_data_tensor);
nvte_set_grouped_tensor_param(&grad_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseData, &grad_data_tensor);
if ((shape_rep == VARYING_FIRST_DIM) || (shape_rep == VARYING_BOTH_DIMS)) {
NVTEBasicTensor first_dims_tensor = {first_dims_d, kNVTEInt64, first_dims_shape_};
nvte_set_grouped_tensor_param(&grad_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedFirstDims, &first_dims_tensor);
nvte_set_grouped_tensor_param(&in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedFirstDims, &first_dims_tensor);
nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedFirstDims, &first_dims_tensor);
}
if ((shape_rep == VARYING_LAST_DIM) || (shape_rep == VARYING_BOTH_DIMS)) {
NVTEBasicTensor last_dims_tensor = {last_dims_d, kNVTEInt64, last_dims_shape_};
nvte_set_grouped_tensor_param(&grad_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedLastDims, &last_dims_tensor);
nvte_set_grouped_tensor_param(&in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedLastDims, &last_dims_tensor);
nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedLastDims, &last_dims_tensor);
}
if (shape_rep != SAME_BOTH_DIMS) {
NVTEBasicTensor offsets_tensor = {offsets_d, kNVTEInt64, offsets_shape_};
nvte_set_grouped_tensor_param(&grad_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedTensorOffsets, &offsets_tensor);
nvte_set_grouped_tensor_param(&in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedTensorOffsets, &offsets_tensor);
nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedTensorOffsets, &offsets_tensor);
}
if (rowwise) {
cudaMalloc((void**)&out_data_rowwise_d, out_data_size);
cudaMalloc((void**)&out_scales_rowwise_d, rowwise_scales_size);
cudaMemset(out_data_rowwise_d, 0, out_data_size);
cudaMemset(out_scales_rowwise_d, 0, rowwise_scales_size);
NVTEBasicTensor out_data_rowwise_tensor = {out_data_rowwise_d, static_cast<NVTEDType>(otype), logical_shape_};
NVTEShape scales_rowwise_shape_ = nvte_make_shape(scales_rowwise_shape.data(), scales_rowwise_shape.size());
NVTEBasicTensor out_scales_rowwise_tensor = {out_scales_rowwise_d, NVTEDType::kNVTEFloat8E8M0, scales_rowwise_shape_};
nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseData, &out_data_rowwise_tensor);
nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseScaleInv, &out_scales_rowwise_tensor);
}
if (colwise) {
cudaMalloc((void**)&out_data_colwise_d, out_data_size);
cudaMalloc((void**)&out_scales_colwise_d, colwise_scales_size);
cudaMemset(out_data_colwise_d, 0, out_data_size);
cudaMemset(out_scales_colwise_d, 0, colwise_scales_size);
NVTEBasicTensor out_data_colwise_tensor = {out_data_colwise_d, static_cast<NVTEDType>(otype), logical_shape_};
NVTEShape scales_colwise_shape_ = nvte_make_shape(scales_colwise_shape.data(), scales_colwise_shape.size());
NVTEBasicTensor out_scales_colwise_tensor = {out_scales_colwise_d, NVTEDType::kNVTEFloat8E8M0, scales_colwise_shape_};
nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedColumnwiseData, &out_data_colwise_tensor);
nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedColumnwiseScaleInv, &out_scales_colwise_tensor);
}
Tensor output_dbias("output_dbias", std::vector<size_t>{ cols }, itype);
// Reference (CPU)
if (is_single_tensor) {
const size_t unpadded_rowwise_blocks_X = divide_round_up(cols, 32);
const size_t unpadded_colwise_blocks_X = cols;
const size_t scales_stride_rowwise = round_up_to_nearest_multiple(unpadded_rowwise_blocks_X, 4);
const size_t scales_stride_colwise = round_up_to_nearest_multiple(unpadded_colwise_blocks_X, 128);
compute_ref<InputType, OutputType>(
processing_method, OP, rowwise, colwise, in_data.data(), grad_data.data(),
out_data_rowwise_ref.data(), out_data_colwise_ref.data(),
out_scales_rowwise_ref.data(), out_scales_colwise_ref.data(),
ref_output_dbias.data(), rows, cols,
scales_stride_rowwise,
scales_stride_colwise,
is_single_tensor);
} else {
for (size_t t = 0; t < num_tensors; ++t) {
const size_t M = first_dims_h[t];
const size_t K = last_dims_h[t];
const size_t scales_stride_rowwise = rowwise_scales_last_dim[t];
const size_t scales_stride_colwise = colwise_scales_last_dim[t];
const size_t data_offset = offsets_h[t];
const size_t rowwise_sfs_offset = rowwise_scales_offset[t];
const size_t colwise_sfs_offset = colwise_scales_offset[t];
const InputType* const grad_ptr = grad_data.data() + data_offset;
const InputType* const in_ptr = in_data.data() + data_offset;
OutputType* const out_data_rowwise_ptr = out_data_rowwise_ref.data() + data_offset;
OutputType* const out_data_colwise_ptr = out_data_colwise_ref.data() + data_offset;
fp8e8m0* const out_scales_rowwise_ptr = out_scales_rowwise_ref.data() + rowwise_sfs_offset;
fp8e8m0* const out_scales_colwise_ptr = out_scales_colwise_ref.data() + colwise_sfs_offset;
compute_ref<InputType, OutputType>(
processing_method, OP, rowwise, colwise, in_ptr, grad_ptr,
out_data_rowwise_ptr, out_data_colwise_ptr,
out_scales_rowwise_ptr, out_scales_colwise_ptr,
ref_output_dbias.data(), M, K,
scales_stride_rowwise,
scales_stride_colwise,
is_single_tensor);
}
}
// GPU
Tensor workspace;
switch (processing_method) {
case ProcessingMethod::CAST_ONLY: {
nvte_group_quantize(in_group_tensor, out_group_tensor, 0);
break;
}
case ProcessingMethod::CAST_DBIAS: {
nvte_group_quantize_dbias(grad_group_tensor, out_group_tensor, output_dbias.data(), workspace.data(), 0);
workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype());
nvte_group_quantize_dbias(grad_group_tensor, out_group_tensor, output_dbias.data(), workspace.data(), 0);
break;
}
case ProcessingMethod::CAST_DBIAS_DACT: {
auto nvte_group_quantize_dbias_dact = &nvte_group_quantize_dbias_dgelu;
if (OP == &dsilu) { nvte_group_quantize_dbias_dact = &nvte_group_quantize_dbias_dsilu; }
else if (OP == &drelu) { nvte_group_quantize_dbias_dact = &nvte_group_quantize_dbias_drelu; }
else if (OP == &dqgelu) { nvte_group_quantize_dbias_dact = &nvte_group_quantize_dbias_dqgelu; }
else if (OP == &dsrelu) { nvte_group_quantize_dbias_dact = &nvte_group_quantize_dbias_dsrelu; }
nvte_group_quantize_dbias_dact(grad_group_tensor, in_group_tensor, out_group_tensor,
output_dbias.data(), workspace.data(), 0);
workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype());
nvte_group_quantize_dbias_dact(grad_group_tensor, in_group_tensor, out_group_tensor,
output_dbias.data(), workspace.data(), 0);
break;
}
case ProcessingMethod::CAST_ACT: {
auto nvte_group_act = &nvte_group_gelu;
if (OP == &silu) { nvte_group_act = &nvte_group_silu; }
else if (OP == &relu) { nvte_group_act = &nvte_group_relu; }
else if (OP == &qgelu) { nvte_group_act = &nvte_group_qgelu; }
else if (OP == &srelu) { nvte_group_act = &nvte_group_srelu; }
nvte_group_act(in_group_tensor, out_group_tensor, 0);
break;
}
case ProcessingMethod::CAST_DACT: {
auto nvte_group_dact = &nvte_group_dgelu;
if (OP == &dsilu) { nvte_group_dact = &nvte_group_dsilu; }
else if (OP == &drelu) { nvte_group_dact = &nvte_group_drelu; }
else if (OP == &dqgelu) { nvte_group_dact = &nvte_group_dqgelu; }
else if (OP == &dsrelu) { nvte_group_dact = &nvte_group_dsrelu; }
nvte_group_dact(grad_group_tensor, in_group_tensor, out_group_tensor, 0);
break;
}
}
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
auto [atol, rtol] = getTolerances(otype);
const size_t scale_diff_abs_tolerance = 0;
const double abs_tolerable_mismatches_limit = 0.0;
const double rel_tolerable_mismatches_limit = 0.0;
if (rowwise) {
cudaMemcpy(out_data_rowwise_h.data(), out_data_rowwise_d, out_data_size, cudaMemcpyDeviceToHost);
cudaMemcpy(out_scales_rowwise_h.data(), out_scales_rowwise_d, rowwise_scales_size, cudaMemcpyDeviceToHost);
size_t mismatches_scales = 0;
compare_scaling_factors("rowwise_scales", out_scales_rowwise_h.data(), out_scales_rowwise_ref.data(),
1, rowwise_sfs_num, rowwise_sfs_num, mismatches_scales, scale_diff_abs_tolerance,
abs_tolerable_mismatches_limit, rel_tolerable_mismatches_limit);
const size_t mismatches_elts = 32 * mismatches_scales;
compare_scaled_elts<OutputType>("rowwise_output", out_data_rowwise_ref.data(),
out_data_rowwise_h.data(), rows, cols, true, mismatches_elts);
}
if (colwise) {
cudaMemcpy(out_data_colwise_h.data(), out_data_colwise_d, out_data_size, cudaMemcpyDeviceToHost);
cudaMemcpy(out_scales_colwise_h.data(), out_scales_colwise_d, colwise_scales_size, cudaMemcpyDeviceToHost);
size_t mismatches_scales = 0;
compare_scaling_factors("colwise_scales", out_scales_colwise_h.data(), out_scales_colwise_ref.data(),
1, colwise_sfs_num, colwise_sfs_num, mismatches_scales, scale_diff_abs_tolerance,
abs_tolerable_mismatches_limit, rel_tolerable_mismatches_limit);
const size_t mismatches_elts = 32 * mismatches_scales;
compare_scaled_elts<OutputType>("colwise_output", out_data_colwise_ref.data(),
out_data_colwise_h.data(), rows, cols, false, mismatches_elts);
}
if (processing_method == ProcessingMethod::CAST_DBIAS
|| processing_method == ProcessingMethod::CAST_DBIAS_DACT)
{
auto [atol_dbias, rtol_dbias] = getTolerances(itype);
if (itype == DType::kFloat32) {
atol_dbias = 1e-4;
rtol_dbias *= sqrt(static_cast<double>(rows)) ;
} else {
rtol_dbias *= 4;
}
compareResults("output_dbias", output_dbias, ref_output_dbias.data(), true, atol_dbias, rtol_dbias);
}
cudaFree(grad_data_d);
cudaFree(in_data_d);
cudaFree(first_dims_d);
cudaFree(last_dims_d);
cudaFree(offsets_d);
if (rowwise) {
cudaFree(out_data_rowwise_d);
cudaFree(out_scales_rowwise_d);
}
if (colwise) {
cudaFree(out_data_colwise_d);
cudaFree(out_scales_colwise_d);
}
}
std::vector<ProcessingMethod> processing_methods = {
ProcessingMethod::CAST_ONLY,
ProcessingMethod::CAST_DBIAS,
ProcessingMethod::CAST_DBIAS_DACT,
ProcessingMethod::CAST_DACT,
ProcessingMethod::CAST_ACT,
};
std::vector<ActivationKind> activation_kinds = {
ActivationKind::Identity,
ActivationKind::GeLU,
// ActivationKind::SiLU,
// ActivationKind::ReLU,
// ActivationKind::QGeLU,
// ActivationKind::SReLU,
};
enum ScalingDirection {
ROWWISE = 0,
COLWISE = 1,
BOTH = 2
};
std::vector<ScalingDirection> scaling_directions = {
ScalingDirection::ROWWISE,
ScalingDirection::COLWISE,
ScalingDirection::BOTH,
};
// {shape_representation, num_tensors, [logical_shape_M, logical_shape_K], [M_i], [K_i]}
std::vector<std::vector<size_t>> input_config = {
{SAME_BOTH_DIMS, 1, 128,128},
{SAME_BOTH_DIMS, 2, 256,128},
{VARYING_FIRST_DIM, 2, 512,128, 128,384},
{VARYING_FIRST_DIM, 2, 384,160, 128,256},
{VARYING_FIRST_DIM, 5, 4096,512, 128,256,384,1024,2304},
{VARYING_LAST_DIM, 3, 256,896, 128,256,512},
{VARYING_BOTH_DIMS, 2, 1,(128*128)+(256*256), 128,256, 128,256},
{VARYING_BOTH_DIMS, 2, 1,(256*128)+(512*640), 256,512, 128,640},
};
} // namespace
class GroupedFusedCastMXFP8TestSuite : public ::testing::TestWithParam
<std::tuple<ProcessingMethod,
ActivationKind,
ScalingDirection,
std::vector<size_t>, // Config
transformer_engine::DType, // InputType
transformer_engine::DType // OutputType
>> {};
TEST_P(GroupedFusedCastMXFP8TestSuite, Test) {
// Skip tests for pre-Blackwell architectures
if (getDeviceComputeCapability() < blackwellComputeCapability) {
GTEST_SKIP();
}
using namespace transformer_engine;
using namespace test;
const ProcessingMethod processing_method = std::get<0>(GetParam());
const ActivationKind activation = std::get<1>(GetParam());
const ScalingDirection scaling_direction = std::get<2>(GetParam());
const std::vector<size_t> input_config = std::get<3>(GetParam());
const DType input_type = std::get<4>(GetParam());
const DType output_type = std::get<5>(GetParam());
const ShapeRepresentation shape_rep = static_cast<ShapeRepresentation>(input_config[0]);
const bool is_single_tensor = (shape_rep == SAME_BOTH_DIMS) || (shape_rep == VARYING_FIRST_DIM);
const size_t num_tensors = input_config[1];
const std::vector<size_t> logical_shape = {input_config[2], input_config[3]};
std::vector<size_t> first_dims(num_tensors);
std::vector<size_t> last_dims(num_tensors);
std::vector<size_t> offsets(num_tensors + 1, 0);
for (size_t t = 0; t < num_tensors; ++t) {
switch (shape_rep) {
case SAME_BOTH_DIMS: {
first_dims[t] = logical_shape[0] / num_tensors;
last_dims[t] = logical_shape[1];
break;
}
case VARYING_FIRST_DIM: {
first_dims[t] = input_config[t + 4];
last_dims[t] = logical_shape[1];
break;
}
case VARYING_LAST_DIM: {
first_dims[t] = logical_shape[0];
last_dims[t] = input_config[t + 4];
break;
}
case VARYING_BOTH_DIMS: {
first_dims[t] = input_config[t + 4];
last_dims[t] = input_config[t + (4 + num_tensors)];
break;
}
}
offsets[t+1] = offsets[t] + first_dims[t] * last_dims[t];
// Skips tests if tensor shape is not as required by the kernel
if ((first_dims[t] % 128 != 0) || (last_dims[t] % 32 != 0)) {
GTEST_SKIP();
}
}
// Skips DBias tests if last dimension of tensors variates
if ((processing_method == ProcessingMethod::CAST_DBIAS || processing_method == ProcessingMethod::CAST_DBIAS_DACT)
&& !is_single_tensor) {
GTEST_SKIP();
}
// Skips non Act tests if the Activation type is not an identity
if ((processing_method == ProcessingMethod::CAST_ONLY || processing_method == ProcessingMethod::CAST_DBIAS)
&& activation != ActivationKind::Identity) {
GTEST_SKIP();
}
// Skips Act tests if the Activation is an identity
if ((processing_method == ProcessingMethod::CAST_DBIAS_DACT
|| processing_method == ProcessingMethod::CAST_DACT
|| processing_method == ProcessingMethod::CAST_ACT) && (activation == ActivationKind::Identity)) {
GTEST_SKIP();
}
bool rowwise = false;
bool colwise = false;
switch (scaling_direction) {
case ScalingDirection::ROWWISE: rowwise = true; break;
case ScalingDirection::COLWISE: colwise = true; break;
case ScalingDirection::BOTH: rowwise = true; colwise = true; break;
}
auto OP = &identity;
if (processing_method == ProcessingMethod::CAST_ACT) {
switch (activation) {
case ActivationKind::GeLU: OP = &gelu; break;
case ActivationKind::SiLU: OP = &silu; break;
case ActivationKind::ReLU: OP = &relu; break;
case ActivationKind::QGeLU: OP = &qgelu; break;
case ActivationKind::SReLU: OP = &srelu; break;
}
} else if (processing_method == ProcessingMethod::CAST_DACT
|| processing_method == ProcessingMethod::CAST_DBIAS_DACT) {
switch (activation) {
case ActivationKind::GeLU: OP = &dgelu; break;
case ActivationKind::SiLU: OP = &dsilu; break;
case ActivationKind::ReLU: OP = &drelu; break;
case ActivationKind::QGeLU: OP = &dqgelu; break;
case ActivationKind::SReLU: OP = &dsrelu; break;
}
}
TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType,
performTest<InputType, OutputType>(processing_method, OP, shape_rep, num_tensors,
logical_shape, first_dims, last_dims, offsets,
rowwise, colwise);
);
);
}
std::string to_string(const ProcessingMethod method) {
switch (method) {
case ProcessingMethod::CAST_ONLY: return "CAST_ONLY";
case ProcessingMethod::CAST_DBIAS: return "CAST_DBIAS";
case ProcessingMethod::CAST_DBIAS_DACT: return "CAST_DBIAS_DACT";
case ProcessingMethod::CAST_DACT: return "CAST_DACT";
case ProcessingMethod::CAST_ACT: return "CAST_ACT";
default: return "";
}
}
std::string to_string(const ActivationKind activation) {
switch (activation) {
case ActivationKind::Identity: return "Identity";
case ActivationKind::GeLU: return "GeLU";
case ActivationKind::SiLU: return "SiLU";
case ActivationKind::ReLU: return "ReLU";
case ActivationKind::QGeLU: return "QGeLU";
case ActivationKind::SReLU: return "SReLU";
default: return "";
}
}
INSTANTIATE_TEST_SUITE_P(
OperatorTest,
GroupedFusedCastMXFP8TestSuite,
::testing::Combine(
::testing::ValuesIn(processing_methods),
::testing::ValuesIn(activation_kinds),
::testing::ValuesIn(scaling_directions),
::testing::ValuesIn(input_config),
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2)),
[](const testing::TestParamInfo<GroupedFusedCastMXFP8TestSuite::ParamType>& info) {
const ProcessingMethod method = std::get<0>(info.param);
std::string name = to_string(method);
name += "X" + to_string(std::get<1>(info.param));
switch (std::get<2>(info.param)) {
case ScalingDirection::ROWWISE: name += "_ROWWISE_"; break;
case ScalingDirection::COLWISE: name += "_COLWISE_"; break;
case ScalingDirection::BOTH: name += "_BIDIMENSIONAL_"; break;
}
const std::vector<size_t> input = std::get<3>(info.param);
switch(static_cast<ShapeRepresentation>(input[0])) {
case ShapeRepresentation::SAME_BOTH_DIMS: name += "SAME_BOTH_DIMS"; break;
case ShapeRepresentation::VARYING_FIRST_DIM: name += "VARYING_FIRST_DIM"; break;
case ShapeRepresentation::VARYING_LAST_DIM: name += "VARYING_LAST_DIM"; break;
case ShapeRepresentation::VARYING_BOTH_DIMS: name += "VARYING_BOTH_DIMS"; break;
};
name += "_N_" + std::to_string(input[1]);
name += "_SHAPE_" +
std::to_string(input[2]) +
"X" + std::to_string(input[3]);
name += "_" + test::typeName(std::get<4>(info.param)) +
"_" + test::typeName(std::get<5>(info.param));
return name;
});
......@@ -13,6 +13,14 @@ void nvte_gelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
act_fn<fp32, Empty, gelu<fp32, fp32>>(input, output, stream);
}
void nvte_group_gelu(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_gelu);
using namespace transformer_engine;
constexpr bool IS_ACT = true;
dispatch::group_quantize_fwd_helper<IS_ACT, Empty, gelu<fp32, fp32>>(input, output, nullptr,
stream);
}
void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dgelu);
......@@ -20,6 +28,20 @@ void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output
dact_fn<fp32, Empty, dgelu<fp32, fp32>>(grad, input, output, stream);
}
void nvte_group_dgelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input,
NVTEGroupedTensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_dgelu);
using namespace transformer_engine;
NVTETensor dbias = nullptr;
NVTETensor workspace = nullptr;
constexpr bool IS_DBIAS = false;
constexpr bool IS_DACT = true;
dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dgelu<fp32, fp32>>(
grad, input, output, dbias, workspace, nullptr, stream);
}
void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activation_input,
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) {
......@@ -33,6 +55,20 @@ void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activati
input, activation_input, output, dbias, workspace, nullptr, stream);
}
void nvte_group_quantize_dbias_dgelu(const NVTEGroupedTensor input,
const NVTEGroupedTensor activation_input,
NVTEGroupedTensor output, NVTETensor dbias,
NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_quantize_dbias_dgelu);
using namespace transformer_engine;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dgelu<fp32, fp32>>(
input, activation_input, output, dbias, workspace, nullptr, stream);
}
void nvte_geglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_geglu);
using namespace transformer_engine;
......@@ -54,6 +90,15 @@ void nvte_qgelu(const NVTETensor input, NVTETensor output, cudaStream_t stream)
act_fn<fp32, Empty, qgelu<fp32, fp32>>(input, output, stream);
}
void nvte_group_qgelu(const NVTEGroupedTensor input, NVTEGroupedTensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_group_qgelu);
using namespace transformer_engine;
constexpr bool IS_ACT = true;
dispatch::group_quantize_fwd_helper<IS_ACT, Empty, qgelu<fp32, fp32>>(input, output, nullptr,
stream);
}
void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dqgelu);
......@@ -61,6 +106,20 @@ void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu
dact_fn<fp32, Empty, dqgelu<fp32, fp32>>(grad, input, output, stream);
}
void nvte_group_dqgelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input,
NVTEGroupedTensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_dqgelu);
using namespace transformer_engine;
NVTETensor dbias = nullptr;
NVTETensor workspace = nullptr;
constexpr bool IS_DBIAS = false;
constexpr bool IS_DACT = true;
dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dqgelu<fp32, fp32>>(
grad, input, output, dbias, workspace, nullptr, stream);
}
void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activation_input,
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) {
......@@ -74,6 +133,20 @@ void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activat
input, activation_input, output, dbias, workspace, nullptr, stream);
}
void nvte_group_quantize_dbias_dqgelu(const NVTEGroupedTensor input,
const NVTEGroupedTensor activation_input,
NVTEGroupedTensor output, NVTETensor dbias,
NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_quantize_dbias_dqgelu);
using namespace transformer_engine;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dqgelu<fp32, fp32>>(
input, activation_input, output, dbias, workspace, nullptr, stream);
}
void nvte_qgeglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_qgeglu);
using namespace transformer_engine;
......
......@@ -13,6 +13,14 @@ void nvte_relu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
act_fn<fp32, Empty, relu<fp32, fp32>>(input, output, stream);
}
void nvte_group_relu(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_relu);
using namespace transformer_engine;
constexpr bool IS_ACT = true;
dispatch::group_quantize_fwd_helper<IS_ACT, Empty, relu<fp32, fp32>>(input, output, nullptr,
stream);
}
void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_drelu);
......@@ -20,6 +28,20 @@ void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output
dact_fn<fp32, Empty, drelu<fp32, fp32>>(grad, input, output, stream);
}
void nvte_group_drelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input,
NVTEGroupedTensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_drelu);
using namespace transformer_engine;
NVTETensor dbias = nullptr;
NVTETensor workspace = nullptr;
constexpr bool IS_DBIAS = false;
constexpr bool IS_DACT = true;
dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, drelu<fp32, fp32>>(
grad, input, output, dbias, workspace, nullptr, stream);
}
void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activation_input,
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) {
......@@ -33,6 +55,20 @@ void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activati
input, activation_input, output, dbias, workspace, nullptr, stream);
}
void nvte_group_quantize_dbias_drelu(const NVTEGroupedTensor input,
const NVTEGroupedTensor activation_input,
NVTEGroupedTensor output, NVTETensor dbias,
NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_quantize_dbias_drelu);
using namespace transformer_engine;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, drelu<fp32, fp32>>(
input, activation_input, output, dbias, workspace, nullptr, stream);
}
void nvte_reglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_reglu);
using namespace transformer_engine;
......@@ -54,6 +90,15 @@ void nvte_srelu(const NVTETensor input, NVTETensor output, cudaStream_t stream)
act_fn<fp32, Empty, srelu<fp32, fp32>>(input, output, stream);
}
void nvte_group_srelu(const NVTEGroupedTensor input, NVTEGroupedTensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_group_srelu);
using namespace transformer_engine;
constexpr bool IS_ACT = true;
dispatch::group_quantize_fwd_helper<IS_ACT, Empty, srelu<fp32, fp32>>(input, output, nullptr,
stream);
}
void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dsrelu);
......@@ -61,6 +106,20 @@ void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu
dact_fn<fp32, Empty, dsrelu<fp32, fp32>>(grad, input, output, stream);
}
void nvte_group_dsrelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input,
NVTEGroupedTensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_dsrelu);
using namespace transformer_engine;
NVTETensor dbias = nullptr;
NVTETensor workspace = nullptr;
constexpr bool IS_DBIAS = false;
constexpr bool IS_DACT = true;
dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dsrelu<fp32, fp32>>(
grad, input, output, dbias, workspace, nullptr, stream);
}
void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activation_input,
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) {
......@@ -74,6 +133,20 @@ void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activat
input, activation_input, output, dbias, workspace, nullptr, stream);
}
void nvte_group_quantize_dbias_dsrelu(const NVTEGroupedTensor input,
const NVTEGroupedTensor activation_input,
NVTEGroupedTensor output, NVTETensor dbias,
NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_quantize_dbias_dsrelu);
using namespace transformer_engine;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dsrelu<fp32, fp32>>(
input, activation_input, output, dbias, workspace, nullptr, stream);
}
void nvte_sreglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_sreglu);
using namespace transformer_engine;
......
......@@ -13,6 +13,14 @@ void nvte_silu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
act_fn<fp32, Empty, silu<fp32, fp32>>(input, output, stream);
}
void nvte_group_silu(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_silu);
using namespace transformer_engine;
constexpr bool IS_ACT = true;
dispatch::group_quantize_fwd_helper<IS_ACT, Empty, silu<fp32, fp32>>(input, output, nullptr,
stream);
}
void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dsilu);
......@@ -20,6 +28,20 @@ void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output
dact_fn<fp32, Empty, dsilu<fp32, fp32>>(grad, input, output, stream);
}
void nvte_group_dsilu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input,
NVTEGroupedTensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_dsilu);
using namespace transformer_engine;
NVTETensor dbias = nullptr;
NVTETensor workspace = nullptr;
constexpr bool IS_DBIAS = false;
constexpr bool IS_DACT = true;
dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dsilu<fp32, fp32>>(
grad, input, output, dbias, workspace, nullptr, stream);
}
void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activation_input,
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) {
......@@ -33,6 +55,20 @@ void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activati
input, activation_input, output, dbias, workspace, nullptr, stream);
}
void nvte_group_quantize_dbias_dsilu(const NVTEGroupedTensor input,
const NVTEGroupedTensor activation_input,
NVTEGroupedTensor output, NVTETensor dbias,
NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_quantize_dbias_dsilu);
using namespace transformer_engine;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dsilu<fp32, fp32>>(
input, activation_input, output, dbias, workspace, nullptr, stream);
}
void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_swiglu);
using namespace transformer_engine;
......
......@@ -26,6 +26,15 @@ void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t strea
dispatch::quantize_fwd_helper<IS_ACT, Empty, nullptr>(input, output, nullptr, stream);
}
void nvte_group_quantize(const NVTEGroupedTensor input, NVTEGroupedTensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_group_quantize);
using namespace transformer_engine;
constexpr bool IS_ACT = false;
dispatch::group_quantize_fwd_helper<IS_ACT, Empty, nullptr>(input, output, nullptr, stream);
}
void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor noop,
cudaStream_t stream) {
NVTE_API_CALL(nvte_quantize_noop);
......@@ -60,6 +69,19 @@ void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor d
input, activation_input, output, dbias, workspace, nullptr, stream);
}
void nvte_group_quantize_dbias(const NVTEGroupedTensor input, NVTEGroupedTensor output,
NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_quantize_dbias);
using namespace transformer_engine;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = false;
constexpr const NVTEGroupedTensor activation_input = nullptr;
dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, nullptr>(
input, activation_input, output, dbias, workspace, nullptr, stream);
}
void nvte_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_dequantize);
using namespace transformer_engine;
......
......@@ -18,6 +18,7 @@
#include "../../util/vectorized_pointwise.h"
#include "../core/common.cuh"
#include "../fp8/quantize_fp8.cuh"
#include "../mxfp8/group_quantize_mxfp8.cuh"
#include "../mxfp8/quantize_mxfp8.cuh"
#include "../nvfp4/group_quantize_transpose_nvfp4.cuh"
#include "../nvfp4/quantize_nvfp4.cuh"
......@@ -371,6 +372,89 @@ void group_quantize_fwd_helper(const NVTETensor input, NVTETensor *outputs,
}
}
template <bool IS_ACT, typename ParamOP, float (*OP)(float, const ParamOP &)>
void group_quantize_fwd_helper(const NVTEGroupedTensor input, NVTEGroupedTensor output,
const NVTEQuantizationConfig quant_config, cudaStream_t stream) {
using namespace detail;
NVTEScalingMode scaling_mode = nvte_grouped_tensor_scaling_mode(output);
const NVTEGroupedTensor activation = nullptr;
NVTETensor dbias = nullptr;
NVTETensor workspace = nullptr;
const GroupedTensor *input_tensor = convertNVTEGroupedTensorCheck(input);
GroupedTensor *output_tensor = convertNVTEGroupedTensorCheck(output);
const GroupedTensor *activations_tensor = convertNVTEGroupedTensor(activation);
Tensor *dbias_tensor = convertNVTETensor(dbias);
Tensor *workspace_tensor = convertNVTETensor(workspace);
// Quantization config
QuantizationConfig quant_config_cpp;
if (quant_config != nullptr) {
quant_config_cpp = *reinterpret_cast<QuantizationConfig *>(quant_config);
}
// Noop flag
Tensor dummy_tensor;
Tensor *noop_tensor = &dummy_tensor;
if (quant_config_cpp.noop_tensor != nullptr) {
noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor);
}
// Dispatch to quantization kernel depending on data format
switch (scaling_mode) {
case NVTE_MXFP8_1D_SCALING: {
mxfp8::group_quantize</*IS_DBIAS=*/false, /*IS_DACT=*/false, IS_ACT, ParamOP, OP>(
input_tensor, activations_tensor, noop_tensor, output_tensor, dbias_tensor,
workspace_tensor, stream);
break;
}
default:
NVTE_ERROR("Not implemented scaling mode: " + to_string(scaling_mode) + ".");
}
}
template <bool IS_DBIAS, bool IS_DACT, typename ParamOP, float (*OP)(float, const ParamOP &)>
void group_quantize_bwd_helper(const NVTEGroupedTensor grad, const NVTEGroupedTensor input,
NVTEGroupedTensor output, NVTETensor dbias, NVTETensor workspace,
const NVTEQuantizationConfig quant_config, cudaStream_t stream) {
using namespace detail;
NVTEScalingMode scaling_mode = nvte_grouped_tensor_scaling_mode(output);
const GroupedTensor *grad_tensor = convertNVTEGroupedTensorCheck(grad);
const GroupedTensor *input_tensor = convertNVTEGroupedTensor(input);
GroupedTensor *output_tensor = convertNVTEGroupedTensorCheck(output);
Tensor *dbias_tensor = convertNVTETensor(dbias);
Tensor *workspace_tensor = convertNVTETensor(workspace);
// Quantization config
QuantizationConfig quant_config_cpp;
if (quant_config != nullptr) {
quant_config_cpp = *reinterpret_cast<QuantizationConfig *>(quant_config);
}
// Noop flag
Tensor dummy_tensor;
Tensor *noop_tensor = &dummy_tensor;
if (quant_config_cpp.noop_tensor != nullptr) {
noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor);
}
// Dispatch to quantization kernel depending on data format
switch (scaling_mode) {
case NVTE_MXFP8_1D_SCALING: {
mxfp8::group_quantize<IS_DBIAS, IS_DACT, /*IS_ACT=*/false, ParamOP, OP>(
grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor,
stream);
break;
}
default:
NVTE_ERROR("Not implemented scaling mode: " + to_string(scaling_mode) + ".");
}
}
} // namespace dispatch
} // namespace transformer_engine
......
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file group_quantize_mxfp8.cuh
* \brief CUDA kernels to quantize grouped tensors to MXFP8.
*/
#ifndef TRANSFORMER_ENGINE_GROUP_QUANTIZE_MXFP8_CUH_
#define TRANSFORMER_ENGINE_GROUP_QUANTIZE_MXFP8_CUH_
#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_runtime.h>
#include <transformer_engine/transformer_engine.h>
#include "../../common.h"
#include "../../util/math.h"
#include "../../util/ptx.cuh"
#include "../../utils.cuh"
#include "../core/common.cuh"
namespace transformer_engine {
namespace dispatch {
namespace mxfp8 {
namespace group_quantize_kernel {
constexpr int MAX_SUPPORTED_TENSOR_DESCRIPTORS = 64;
__device__ alignas(128) CUtensorMap g_tensor_maps_input[MAX_SUPPORTED_TENSOR_DESCRIPTORS];
__device__ alignas(128) CUtensorMap g_tensor_maps_act_input[MAX_SUPPORTED_TENSOR_DESCRIPTORS];
__device__ alignas(128) CUtensorMap g_tensor_maps_output_rowwise[MAX_SUPPORTED_TENSOR_DESCRIPTORS];
__device__ alignas(128) CUtensorMap g_tensor_maps_output_colwise[MAX_SUPPORTED_TENSOR_DESCRIPTORS];
enum ShapeRepresentation {
SAME_BOTH_DIMS = 0,
VARYING_FIRST_DIM = 1,
VARYING_LAST_DIM = 2,
VARYING_BOTH_DIMS = 3
};
constexpr size_t SCALE_DIM_Y = 32;
constexpr size_t SCALE_DIM_X = 32;
constexpr size_t BUFFS_NUM = 2;
constexpr size_t PACK_SIZE = 4;
constexpr size_t WAVES = SCALE_DIM_X / PACK_SIZE;
constexpr size_t CHUNK_DIM_Y = 128;
constexpr size_t CHUNK_DIM_X = 128;
constexpr size_t THREADS_PER_CHUNK = 128;
constexpr size_t ELTS_PER_CHUNK = CHUNK_DIM_Y * CHUNK_DIM_X;
constexpr size_t THREADS_X = CHUNK_DIM_X / SCALE_DIM_X;
constexpr size_t THREADS_Y = THREADS_PER_CHUNK / THREADS_X;
constexpr size_t BUFF_DIM_Y = THREADS_Y;
constexpr size_t BUFF_DIM_X = CHUNK_DIM_X;
constexpr size_t BUFF_DIM = BUFF_DIM_Y * BUFF_DIM_X;
static_assert(BUFF_DIM_Y == 32);
constexpr size_t STAGES = CHUNK_DIM_Y / BUFF_DIM_Y;
static_assert(STAGES >= 1);
// Number of 1-byte elements that span 32 banks (4-byte each) of shared memory
constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4) / 1; // 128
// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory
constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM_X; // 4 = 128 / 32
__device__ __forceinline__ size_t get_current_tensor_id(
const ShapeRepresentation shape_rep, const size_t num_tensors, const size_t current_offset,
const size_t first_logical_dim, const size_t last_logical_dim,
const int64_t *const __restrict__ offsets_ptr) {
if (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS) {
const size_t current_row = current_offset / last_logical_dim;
const size_t rows_per_tensor = first_logical_dim / num_tensors;
return current_row / rows_per_tensor;
} else {
size_t low = 1;
size_t hi = num_tensors; // [low, hi]
while (low < hi) {
const size_t mid = low + (hi - low) / 2;
const size_t mid_offset = static_cast<size_t>(offsets_ptr[mid]);
if (mid_offset <= current_offset) {
low = mid + 1;
} else {
hi = mid;
}
}
return low - 1;
}
}
__device__ __forceinline__ size_t get_tensor_rows_num(
const size_t tensor_id, const ShapeRepresentation shape_rep, const size_t first_logical_dim,
const int64_t *const __restrict__ first_dims_ptr, const size_t num_tensors) {
size_t rows_num = 0;
switch (shape_rep) {
case ShapeRepresentation::SAME_BOTH_DIMS:
case ShapeRepresentation::VARYING_LAST_DIM:
rows_num = first_logical_dim;
break;
case ShapeRepresentation::VARYING_FIRST_DIM:
case ShapeRepresentation::VARYING_BOTH_DIMS:
rows_num = static_cast<size_t>(first_dims_ptr[tensor_id]);
break;
}
return rows_num;
}
__device__ __forceinline__ size_t get_tensor_cols_num(
const size_t tensor_id, const ShapeRepresentation shape_rep, const size_t last_logical_dim,
const int64_t *const __restrict__ last_dims_ptr) {
size_t cols_num = 0;
switch (shape_rep) {
case ShapeRepresentation::SAME_BOTH_DIMS:
case ShapeRepresentation::VARYING_FIRST_DIM:
cols_num = last_logical_dim;
break;
case ShapeRepresentation::VARYING_LAST_DIM:
case ShapeRepresentation::VARYING_BOTH_DIMS:
cols_num = static_cast<size_t>(last_dims_ptr[tensor_id]);
break;
}
return cols_num;
}
// Copies the base tensor map to shmem, modifies the copy, stores the modified tensor map at index
__device__ __forceinline__ void modify_base_tensor_map(const CUtensorMap base_tensor_map,
CUtensorMap *global_tensor_map,
const uintptr_t global_data_ptr,
const size_t global_dim_Y,
const size_t global_dim_X,
const size_t data_type_size_bytes) {
__shared__ CUtensorMap shared_tensor_map;
shared_tensor_map = base_tensor_map; // Copy the base tensor map into shmem
constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY;
if constexpr (is_blackwell) {
const size_t global_stride_bytes = global_dim_X * data_type_size_bytes;
if (global_stride_bytes % TMA_GMEM_ALIGNMENT != 0) {
NVTE_DEVICE_ERROR("Shape not supported, as data stride must be 16B aligned.");
}
if (global_data_ptr % TMA_GMEM_ALIGNMENT != 0) {
NVTE_DEVICE_ERROR("Tensor data pointer must be 16B aligned");
}
asm volatile(
"{\n\t"
".reg.b64 tensor_map_ptr; \n\t"
"mov.b64 tensor_map_ptr, %0; \n\t"
"tensormap.replace.tile.global_address.b1024.b64 [tensor_map_ptr], %1; \n\t"
"tensormap.replace.tile.global_dim.b1024.b32 [tensor_map_ptr], 1, %2; \n\t" // DIM Y
"tensormap.replace.tile.global_dim.b1024.b32 [tensor_map_ptr], 0, %3; \n\t" // DIM X
"tensormap.replace.tile.global_stride.b1024.b64 [tensor_map_ptr], 0, %4; \n"
"}\n" ::"l"(reinterpret_cast<uintptr_t>(&shared_tensor_map)),
"l"(global_data_ptr), "r"(static_cast<uint32_t>(global_dim_Y)),
"r"(static_cast<uint32_t>(global_dim_X)), "l"(static_cast<uint64_t>(global_stride_bytes))
: "memory");
*global_tensor_map = shared_tensor_map;
} else {
NVTE_DEVICE_ERROR(
"tensormap.replace is architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX.");
}
}
template <typename IType, typename OType>
__global__ void update_tma_descriptors(
const __grid_constant__ CUtensorMap base_tensor_map_input,
const __grid_constant__ CUtensorMap base_tensor_map_act_input,
const __grid_constant__ CUtensorMap base_tensor_map_output_rowwise,
const __grid_constant__ CUtensorMap base_tensor_map_output_colwise,
const IType *const __restrict__ input_data_ptr,
const IType *const __restrict__ act_input_data_ptr,
const OType *const __restrict__ output_rowwise_data_ptr,
const OType *const __restrict__ output_colwise_data_ptr, const ShapeRepresentation shape_rep,
const size_t num_tensors, const size_t first_logical_dim, const size_t last_logical_dim,
const int64_t *const __restrict__ offsets_ptr, const int64_t *const __restrict__ first_dims_ptr,
const int64_t *const __restrict__ last_dims_ptr, const bool rowwise, const bool colwise,
const bool compute_dactivations) {
const bool leading_thread = (threadIdx.x == 0);
const size_t tensor_id = blockIdx.x;
const size_t rows =
get_tensor_rows_num(tensor_id, shape_rep, first_logical_dim, first_dims_ptr, num_tensors);
const size_t cols = get_tensor_cols_num(tensor_id, shape_rep, last_logical_dim, last_dims_ptr);
const size_t offset_elts = offsets_ptr[tensor_id];
if (leading_thread && (tensor_id < num_tensors)) {
{
const uintptr_t global_data_ptr = reinterpret_cast<uintptr_t>(input_data_ptr + offset_elts);
modify_base_tensor_map(base_tensor_map_input, &g_tensor_maps_input[tensor_id],
global_data_ptr, rows, cols, sizeof(IType));
}
if (compute_dactivations) {
const uintptr_t global_data_ptr =
reinterpret_cast<uintptr_t>(act_input_data_ptr + offset_elts);
modify_base_tensor_map(base_tensor_map_act_input, &g_tensor_maps_act_input[tensor_id],
global_data_ptr, rows, cols, sizeof(IType));
}
if (rowwise) {
const uintptr_t global_data_ptr =
reinterpret_cast<uintptr_t>(output_rowwise_data_ptr + offset_elts);
modify_base_tensor_map(base_tensor_map_output_rowwise,
&g_tensor_maps_output_rowwise[tensor_id], global_data_ptr, rows, cols,
sizeof(OType));
}
if (colwise) {
const uintptr_t global_data_ptr =
reinterpret_cast<uintptr_t>(output_colwise_data_ptr + offset_elts);
modify_base_tensor_map(base_tensor_map_output_colwise,
&g_tensor_maps_output_colwise[tensor_id], global_data_ptr, rows, cols,
sizeof(OType));
}
}
}
__device__ __forceinline__ void fence_acquire_tensormap(const CUtensorMap *tensor_map) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
asm volatile("fence.proxy.tensormap::generic.acquire.cta [%0], 128;" ::"l"(tensor_map));
#else
NVTE_DEVICE_ERROR("fence_acquire_tensormap is only supported on SM 9.0+.");
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
}
template <bool IS_DBIAS, bool IS_DACT, bool IS_ACT, typename ParamOP,
float (*OP)(float, const ParamOP &), typename IType, typename OType, bool ROWWISE_SCALING,
bool COLWISE_SCALING>
__global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel(
const __grid_constant__ CUtensorMap tensor_map_input_static,
const __grid_constant__ CUtensorMap tensor_map_act_input_static,
const __grid_constant__ CUtensorMap tensor_map_output_rowwise_static,
const __grid_constant__ CUtensorMap tensor_map_output_colwise_static,
const ShapeRepresentation shape_rep, const size_t num_tensors, const size_t first_logical_dim,
const size_t last_logical_dim, const int64_t *const __restrict__ offsets_ptr,
const int64_t *const __restrict__ first_dims_ptr,
const int64_t *const __restrict__ last_dims_ptr, e8m0_t *const __restrict__ scales_rowwise_ptr,
e8m0_t *const __restrict__ scales_colwise_ptr, const float *__restrict__ noop,
float *const __restrict__ dbias_workspace, float *const __restrict__ amax_ptr) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
constexpr bool COMPUTE_ACTIVATIONS = IS_DACT || IS_ACT;
constexpr bool NO_ACTIVATIONS = !COMPUTE_ACTIVATIONS;
using IType2 = typename ptx::FPx2<IType>;
using OType2 = typename ptx::FPx2<OType>;
if constexpr (NO_ACTIVATIONS) {
if (noop != nullptr && noop[0] == 1.0f) {
return;
}
}
constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS && ROWWISE_SCALING && COLWISE_SCALING;
const size_t block_global_offset = blockIdx.x * ELTS_PER_CHUNK;
const size_t tensor_id = get_current_tensor_id(shape_rep, num_tensors, block_global_offset,
first_logical_dim, last_logical_dim, offsets_ptr);
const size_t rows =
get_tensor_rows_num(tensor_id, shape_rep, first_logical_dim, first_dims_ptr, num_tensors);
const size_t cols = get_tensor_cols_num(tensor_id, shape_rep, last_logical_dim, last_dims_ptr);
const size_t scale_stride_rowwise = DIVUP_TO_MULTIPLE(DIVUP(cols, static_cast<size_t>(32)), 4);
const size_t scale_stride_colwise = DIVUP_TO_MULTIPLE(cols, 128);
const bool is_single_tensor = (shape_rep == SAME_BOTH_DIMS || shape_rep == VARYING_FIRST_DIM);
// grouped tensor can be treated as continuous tensor for MXFP8
const size_t tensor_base = is_single_tensor ? 0 : static_cast<size_t>(offsets_ptr[tensor_id]);
const CUtensorMap &tensor_map_input =
is_single_tensor ? tensor_map_input_static : g_tensor_maps_input[tensor_id];
const CUtensorMap &tensor_map_act_input =
is_single_tensor ? tensor_map_act_input_static : g_tensor_maps_act_input[tensor_id];
const CUtensorMap &tensor_map_output_rowwise =
is_single_tensor ? tensor_map_output_rowwise_static : g_tensor_maps_output_rowwise[tensor_id];
const CUtensorMap &tensor_map_output_colwise =
is_single_tensor ? tensor_map_output_colwise_static : g_tensor_maps_output_colwise[tensor_id];
const bool leading_thread = (threadIdx.x == 0);
if (leading_thread && (!is_single_tensor)) {
fence_acquire_tensormap(&tensor_map_input);
if constexpr (COMPUTE_ACTIVATIONS) {
fence_acquire_tensormap(&tensor_map_act_input);
}
if constexpr (ROWWISE_SCALING) {
fence_acquire_tensormap(&tensor_map_output_rowwise);
}
if constexpr (COLWISE_SCALING) {
fence_acquire_tensormap(&tensor_map_output_colwise);
}
}
const size_t blocks_X_num_in_current_tensor = DIVUP(cols, static_cast<size_t>(128));
const size_t block_id_in_current_tensor =
is_single_tensor ? blockIdx.x : (blockIdx.x - tensor_base / ELTS_PER_CHUNK);
const size_t block_id_Y = block_id_in_current_tensor / blocks_X_num_in_current_tensor;
const size_t block_id_X = block_id_in_current_tensor % blocks_X_num_in_current_tensor;
const size_t block_offset_Y = block_id_Y * CHUNK_DIM_Y;
const size_t block_offset_X = block_id_X * CHUNK_DIM_X;
e8m0_t *const scales_rowwise =
scales_rowwise_ptr + (is_single_tensor ? 0 : tensor_base / SCALE_DIM_X);
e8m0_t *const scales_colwise =
scales_colwise_ptr + (is_single_tensor ? 0 : tensor_base / SCALE_DIM_Y);
const size_t scales_block_offset_Y_rowwise = block_id_Y * CHUNK_DIM_Y;
const size_t scales_block_offset_X_rowwise = block_id_X * CHUNK_DIM_X / SCALE_DIM_X;
const size_t scales_block_offset_Y_colwise = block_id_Y * CHUNK_DIM_Y / SCALE_DIM_Y;
const size_t scales_block_offset_X_colwise = block_id_X * CHUNK_DIM_X;
const size_t tid_Y_rowwise = threadIdx.x / THREADS_X;
const size_t tid_X_rowwise = threadIdx.x % THREADS_X;
const size_t tid_Y_colwise = 0;
const size_t tid_X_colwise = threadIdx.x;
const size_t thread_offset_Y_rowwise = tid_Y_rowwise;
const size_t thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM_X;
const size_t scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise;
const size_t scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise;
const size_t scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise;
const size_t scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise;
// helps resolving bank conflicts in shmem
const int thread_lane = threadIdx.x % THREADS_PER_WARP;
const int bank_group = thread_lane / THREADS_PER_BANK;
constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X;
constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems;
constexpr size_t buff_size_aligned_in =
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_aligned_out =
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT);
constexpr size_t elt_input_mem = buff_size_aligned_in;
constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0);
constexpr size_t in_mem = elt_input_mem + act_input_mem;
constexpr size_t out_mem_rowwise = (ROWWISE_SCALING ? buff_size_aligned_out : 0);
// The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned
extern __shared__ unsigned char dynamic_shmem[];
unsigned char *dshmem = common::align_smem_ptr_per_TMA_requirements(dynamic_shmem);
// The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned
IType *in_sh = reinterpret_cast<IType *>(dshmem);
IType *act_in_sh = reinterpret_cast<IType *>(dshmem + elt_input_mem);
OType *out_rowwise_data_sh = reinterpret_cast<OType *>(dshmem + in_mem);
OType *out_colwise_data_sh = reinterpret_cast<OType *>(dshmem + in_mem + out_mem_rowwise);
IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer
constexpr size_t shmem_buff_size = buff_size_aligned_in / BUFFS_NUM;
float partial_dbias_colwise = 0.0f;
float thread_dbias_rowwise[SCALE_DIM_X];
if constexpr (IS_DBIAS) {
#pragma unroll
for (int j = 0; j < SCALE_DIM_X; ++j) {
thread_dbias_rowwise[j] = 0.0f;
}
}
float block_amax = 0.0f;
// Initialize shared memory barrier with the number of threads participating in the barrier.
#pragma nv_diag_suppress static_var_with_dynamic_init
__shared__ alignas(8) uint64_t mbar[STAGES];
initialize_barriers<STAGES, THREADS_PER_CHUNK>(mbar, leading_thread);
int parity = 0;
if constexpr (IS_DACT) {
copy_2d_to_sharedx2(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, &act_in_sh[0],
&tensor_map_act_input, block_offset_X, block_offset_Y, shmem_buff_size,
&mbar[0], leading_thread);
} else {
copy_2d_to_shared(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, shmem_buff_size,
&mbar[0], leading_thread);
}
#pragma unroll
for (int stage = 0; stage < STAGES; ++stage) {
const size_t buff = stage % BUFFS_NUM;
const size_t next_stage = stage + 1;
const size_t stage_offset_Y = stage * BUFF_DIM_Y;
if (next_stage < STAGES) {
// Wait for TMA transfer to have finished reading shared memory.
// I.e. the buffer is ready to be written to
ptx::cp_async_bulk_wait_group_read<1>();
const size_t next_buff = next_stage % BUFFS_NUM;
const size_t next_stage_offset_Y = next_stage * BUFF_DIM_Y;
const size_t global_offset_Y = block_offset_Y + next_stage_offset_Y;
const size_t global_offset_X = block_offset_X;
const size_t next_buff_offset = next_buff * BUFF_DIM;
if constexpr (IS_DACT) {
copy_2d_to_sharedx2(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X,
global_offset_Y, &act_in_sh[next_buff_offset], &tensor_map_act_input,
global_offset_X, global_offset_Y, shmem_buff_size, &mbar[next_stage],
leading_thread);
} else {
copy_2d_to_shared(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X,
global_offset_Y, shmem_buff_size, &mbar[next_stage], leading_thread);
}
}
ptx::fence_proxy_async_shared_cta();
// Wait for the data to have arrived
ptx::mbarrier_wait_parity(&mbar[stage], parity);
float thread_amax = 0.0f;
if constexpr (COLWISE_SCALING) {
const size_t shmem_offset_base_colwise = buff * BUFF_DIM + tid_X_colwise;
thread_amax = 0.0f;
float in_compute_colwise[BUFF_DIM_Y];
IType in_colwise_IType[BUFF_DIM_Y];
// 1. Read/Compute elements. Find MXFP8-block AMAX
if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v<IType, float>)) {
IType thread_amax_f16 = static_cast<IType>(0.0f);
#pragma unroll
for (int i = 0; i < BUFF_DIM_Y; ++i) {
const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X;
in_colwise_IType[i] = in_sh[shmem_offset_colwise];
thread_amax_f16 = __hmax(thread_amax_f16, __habs(in_colwise_IType[i]));
}
thread_amax = static_cast<float>(thread_amax_f16);
} else {
#pragma unroll
for (int i = 0; i < BUFF_DIM_Y; ++i) {
const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X;
float elt = static_cast<float>(in_sh[shmem_offset_colwise]);
if constexpr (IS_ACT) {
elt = OP(elt, {});
}
if constexpr (IS_DACT) {
float act_in_elt = static_cast<float>(act_in_sh[shmem_offset_colwise]);
elt *= OP(act_in_elt, {});
}
if constexpr (IS_DBIAS) {
partial_dbias_colwise += elt;
}
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
if constexpr (!std::is_same_v<IType, float>) {
elt = static_cast<float>(static_cast<IType>(elt));
}
// Cache computed activations to avoid computing them again in the 2nd pass along another dimension
if constexpr (IS_CACHED_ACT_OP) {
cached_act_sh[shmem_offset_colwise] = static_cast<IType>(elt);
}
thread_amax = fmaxf(thread_amax, fabsf(elt));
in_compute_colwise[i] = elt;
}
}
// 2. Compute E8M0 scaling factor
const e8m0_t biased_exponent =
ptx::float_to_e8m0(thread_amax * Quantized_Limits<OType>::max_norm_rcp);
const size_t global_scales_offset_Y = scales_offset_Y_colwise + stage;
const size_t global_scales_offset_X = scales_offset_X_colwise;
const size_t scale_idx =
global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X;
scales_colwise[scale_idx] = biased_exponent;
const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent);
const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse};
// 3. Scale elements
#pragma unroll
for (int i = 0; i < SCALE_DIM_Y; ++i) {
float in;
if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v<IType, float>)) {
in = static_cast<float>(in_colwise_IType[i]);
} else {
in = in_compute_colwise[i];
}
const float scaled_out = in * block_scale_inverse;
const size_t shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_DIM_X;
out_colwise_data_sh[shmem_offset_elt] = static_cast<OType>(scaled_out);
}
}
if constexpr (ROWWISE_SCALING) {
const size_t shmem_offset_base_rowwise =
buff * BUFF_DIM + thread_offset_Y_rowwise * BUFF_DIM_X;
thread_amax = 0.0f;
float in_compute_rowwise[SCALE_DIM_X];
Vec<IType, PACK_SIZE> in_cached[WAVES];
// used as an IType container for BF16/FP16 --> MXFP8 CAST ONLY
Vec<IType2, PACK_SIZE / 2> in_IType[WAVES];
// 1. Read/Compute elements. Find MXFP8-block AMAX
if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v<IType, float>)) {
IType2 thread_amax_2x = {static_cast<IType>(0.0f), static_cast<IType>(0.0f)};
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx;
// Load elements
in_IType[w].load_from(&in_sh[shmem_offset_rowwise]);
#pragma unroll
for (int e = 0; e < PACK_SIZE / 2; ++e) {
ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]);
}
}
thread_amax =
static_cast<float>(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y)));
} else if constexpr (IS_CACHED_ACT_OP) {
// ensures that all writes to cache made in the section above are visible to all threads
__syncthreads();
IType2 thread_amax_2x = {static_cast<IType>(0.0f), static_cast<IType>(0.0f)};
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx;
// Load cached elements
in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]);
// Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements)
// only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries
if constexpr (std::is_same_v<IType, float>) {
#pragma unroll
for (int e = 0; e < PACK_SIZE; ++e) {
thread_amax = fmaxf(thread_amax, fabsf(in_cached[w].data.elt[e]));
}
} else {
#pragma unroll
for (int e = 0; e < PACK_SIZE; e += 2) {
const IType2 in_cached_2x = {in_cached[w].data.elt[e], in_cached[w].data.elt[e + 1]};
ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x);
}
}
}
if constexpr (!std::is_same_v<IType, float>) {
thread_amax =
static_cast<float>(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y)));
}
} else {
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx;
Vec<IType, PACK_SIZE> in;
Vec<IType, PACK_SIZE> act_in;
in.load_from(&in_sh[shmem_offset_rowwise]);
if constexpr (IS_DACT) {
act_in.load_from(&act_in_sh[shmem_offset_rowwise]);
}
#pragma unroll
for (int e = 0; e < PACK_SIZE; ++e) {
const int j = w * PACK_SIZE + e;
// Compute element
float elt = static_cast<float>(in.data.elt[e]);
if constexpr (IS_ACT) {
elt = OP(elt, {});
}
if constexpr (IS_DACT) {
float act_in_elt = static_cast<float>(act_in.data.elt[e]);
elt *= OP(act_in_elt, {});
}
// If DBIAS was computed in the 1st pass (COLWISE) then no need to compute it again
if constexpr (IS_DBIAS && (!COLWISE_SCALING)) {
thread_dbias_rowwise[j] += elt;
}
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
if constexpr (!std::is_same_v<IType, float>) {
elt = static_cast<float>(static_cast<IType>(elt));
}
thread_amax = fmaxf(thread_amax, fabsf(elt));
in_compute_rowwise[j] = elt;
}
}
}
// 2. Compute E8M0 scaling factor
const e8m0_t biased_exponent =
ptx::float_to_e8m0(thread_amax * Quantized_Limits<OType>::max_norm_rcp);
const int stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y;
const int stage_scales_offset_X = scales_offset_X_rowwise;
const int scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X;
scales_rowwise[scale_idx] = biased_exponent;
const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent);
const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse};
// 3. Scale elements
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
Vec<OType2, PACK_SIZE / 2> out;
#pragma unroll
for (int e = 0; e < PACK_SIZE / 2; ++e) {
IType2 in;
OType2 &out_pair = reinterpret_cast<OType2 &>(out.data.elt[e]);
if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v<IType, float>)) {
in = in_IType[w].data.elt[e];
} else if constexpr (IS_CACHED_ACT_OP) {
in.x = in_cached[w].data.elt[2 * e];
in.y = in_cached[w].data.elt[2 * e + 1];
} else {
const int j = w * PACK_SIZE + 2 * e;
in.x = in_compute_rowwise[j];
in.y = in_compute_rowwise[j + 1];
}
ptx::mul_cvt_2x(out_pair, in, block_scale_inverse_2x);
}
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx;
out.store_to(&out_rowwise_data_sh[shmem_offset_rowwise]);
}
}
__builtin_assume(block_amax >= 0);
__builtin_assume(thread_amax >= 0);
block_amax = fmaxf(block_amax, thread_amax);
// Wait for shared memory writes to be visible to TMA engine.
ptx::fence_proxy_async_shared_cta();
__syncthreads();
// After syncthreads, writes by all threads are visible to TMA engine.
// Initiate TMA transfer to copy shared memory to global memory
if (leading_thread) {
const int global_offset_Y = block_offset_Y + stage_offset_Y;
const int global_offset_X = block_offset_X;
const int buff_offset = buff * BUFF_DIM;
if constexpr (ROWWISE_SCALING) {
ptx::cp_async_bulk_tensor_2d_shared_to_global(
reinterpret_cast<const uint64_t *>(&tensor_map_output_rowwise), global_offset_X,
global_offset_Y, reinterpret_cast<uint64_t *>(&out_rowwise_data_sh[buff_offset]));
}
if constexpr (COLWISE_SCALING) {
ptx::cp_async_bulk_tensor_2d_shared_to_global(
reinterpret_cast<const uint64_t *>(&tensor_map_output_colwise), global_offset_X,
global_offset_Y, reinterpret_cast<uint64_t *>(&out_colwise_data_sh[buff_offset]));
}
// Create a "bulk async-group" out of the previous bulk copy operation.
ptx::cp_async_bulk_commit_group();
}
}
parity ^= 1;
if constexpr (IS_DBIAS) {
if (is_single_tensor) {
float thread_partial_dbias = 0.0f;
if constexpr (COLWISE_SCALING) {
thread_partial_dbias = partial_dbias_colwise;
} else {
// Reusing dshmem (in_sh) as dbias buffer [HEIGHT x WIDTH]
// HEIGHT = THREADS_Y
// WIDTH = THREADS_X * (SCALE_DIM_X + 1)
// Added extra 1-element padding per thread_X to reduce bank conflicts
float *partial_dbias_rowwise = reinterpret_cast<float *>(dshmem);
constexpr int DBIAS_BUFF_WIDTH = THREADS_X * (SCALE_DIM_X + 1);
const int shmem_thread_offset =
tid_Y_rowwise * DBIAS_BUFF_WIDTH + tid_X_rowwise * (SCALE_DIM_X + 1);
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const int swizzled_group_offset = shmem_thread_offset + swizzled_group_idx;
#pragma unroll
for (int e = 0; e < PACK_SIZE; ++e) {
const int j = w * PACK_SIZE + e;
const int shmem_elt_idx = swizzled_group_offset + e;
partial_dbias_rowwise[shmem_elt_idx] = thread_dbias_rowwise[j];
}
}
__syncthreads();
#pragma unroll
for (int i = 0; i < THREADS_Y; ++i) {
// Add extra element offset per MXFP8 scaling block [1x32]
const int scaling_block = threadIdx.x / SCALE_DIM_X;
thread_partial_dbias +=
partial_dbias_rowwise[i * DBIAS_BUFF_WIDTH + threadIdx.x + scaling_block];
}
}
const int dbias_stride = cols;
const int dbias_offset_Y = block_id_Y;
const int dbias_offset_X = block_id_X * CHUNK_DIM_X + threadIdx.x;
const int dbias_idx = dbias_offset_Y * dbias_stride + dbias_offset_X;
const bool col_out_of_bounds_dbias = (dbias_offset_X >= cols);
if (!col_out_of_bounds_dbias) {
dbias_workspace[dbias_idx] = thread_partial_dbias;
}
}
}
if (amax_ptr != nullptr) {
const int warp_id = threadIdx.x / THREADS_PER_WARP;
// Reduce the amax over the block
block_amax = reduce_max<THREADS_PER_CHUNK / THREADS_PER_WARP>(block_amax, warp_id);
}
if (leading_thread && amax_ptr != nullptr) {
atomicMaxFloat(amax_ptr, block_amax);
}
destroy_barriers<STAGES>(mbar, leading_thread);
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
} // namespace group_quantize_kernel
template <bool IS_DBIAS, bool IS_DACT, bool IS_ACT, typename ParamOP,
float (*OP)(float, const ParamOP &)>
void group_quantize(const GroupedTensor *input, const GroupedTensor *activations,
const Tensor *noop, GroupedTensor *output, Tensor *dbias, Tensor *workspace,
cudaStream_t stream) {
using namespace group_quantize_kernel;
checkCuDriverContext(stream);
CheckNoopTensor(*noop, "cast_noop");
const bool use_rowwise_scaling = output->has_data();
const bool use_colwise_scaling = output->has_columnwise_data();
NVTE_CHECK(use_rowwise_scaling || use_colwise_scaling,
"Either rowwise or columnwise output data need to be allocated.");
ScalingType scaling_type = ScalingType::BIDIMENSIONAL;
if (!use_colwise_scaling) {
scaling_type = ScalingType::ROWWISE;
} else if (!use_rowwise_scaling) {
scaling_type = ScalingType::COLWISE;
}
ShapeRepresentation shape_rep = ShapeRepresentation::SAME_BOTH_DIMS;
if (output->all_same_shape()) {
shape_rep = ShapeRepresentation::SAME_BOTH_DIMS;
} else if (output->all_same_first_dim()) {
shape_rep = ShapeRepresentation::VARYING_LAST_DIM;
} else if (output->all_same_last_dim()) {
shape_rep = ShapeRepresentation::VARYING_FIRST_DIM;
} else if (output->varying_both_dims()) {
shape_rep = ShapeRepresentation::VARYING_BOTH_DIMS;
}
// Treat a grouped tensor with const last dims as a single tensor
const bool is_single_tensor = (shape_rep == SAME_BOTH_DIMS || shape_rep == VARYING_FIRST_DIM);
NVTE_CHECK(input->num_tensors == output->num_tensors,
"Number of input and output tensors must be same.");
NVTE_CHECK(input->has_data(), "Cannot quantize tensor without rowwise data.");
NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type.");
if (IS_DACT) {
NVTE_CHECK(activations->has_data(), "Activations tensor must have data.");
NVTE_CHECK(input->num_tensors == activations->num_tensors,
"Number of grad and activations tensors must be same.");
NVTE_CHECK(input->dtype() == activations->dtype(),
"Grad and activations tensors must have the same type.");
}
const size_t first_logical_dim = input->logical_shape.data[0];
const size_t last_logical_dim = input->logical_shape.data[1];
const size_t elts_total = first_logical_dim * last_logical_dim;
const size_t num_tensors = input->num_tensors;
size_t blocks = 0;
if (is_single_tensor) {
const size_t blocks_Y = DIVUP(first_logical_dim, CHUNK_DIM_Y);
const size_t blocks_X = DIVUP(last_logical_dim, CHUNK_DIM_X);
blocks = blocks_Y * blocks_X;
} else {
NVTE_CHECK(num_tensors < MAX_SUPPORTED_TENSOR_DESCRIPTORS,
"Number of tensors in a group is larger than "
"the MAX number of supported descriptors (64).");
// Only full tiles supported
NVTE_CHECK(last_logical_dim % CHUNK_DIM_X == 0,
"Last dimension of a grouped tensor should be divisible by 128.");
blocks = DIVUP(elts_total, CHUNK_DIM_Y * CHUNK_DIM_X);
}
const dim3 grid(blocks);
const size_t block_size = THREADS_PER_CHUNK;
// Logical shape of a tensor with varying all dims is [1, M*K]
if (shape_rep != ShapeRepresentation::VARYING_BOTH_DIMS) {
NVTE_CHECK(first_logical_dim % 128 == 0,
"First dimension of a grouped tensor should be divisible by 128.");
}
const int64_t *const offsets_ptr = reinterpret_cast<const int64_t *>(input->tensor_offsets.dptr);
const int64_t *const first_dims_ptr = reinterpret_cast<const int64_t *>(input->first_dims.dptr);
const int64_t *const last_dims_ptr = reinterpret_cast<const int64_t *>(input->last_dims.dptr);
float *const workspace_ptr = IS_DBIAS ? reinterpret_cast<float *>(workspace->data.dptr) : nullptr;
float *const amax_ptr = reinterpret_cast<float *>(output->amax.dptr);
const float *noop_ptr = reinterpret_cast<const float *>(noop->data.dptr);
e8m0_t *const scales_rowwise_ptr = reinterpret_cast<e8m0_t *>(output->scale_inv.dptr);
e8m0_t *const scales_colwise_ptr = reinterpret_cast<e8m0_t *>(output->columnwise_scale_inv.dptr);
if (use_rowwise_scaling) {
NVTE_CHECK(scales_rowwise_ptr != nullptr, "Scaling tensor must be allocated");
}
if (use_colwise_scaling) {
NVTE_CHECK(scales_colwise_ptr != nullptr, "Columnwise scaling tensor must be allocated");
}
const size_t dbias_rows = DIVUP(first_logical_dim, CHUNK_DIM_Y);
const size_t dbias_cols = last_logical_dim;
if constexpr (IS_DBIAS) {
NVTE_CHECK(is_single_tensor,
"DBias is only supported for tensors with the const last dimension.");
NVTE_CHECK(dbias->data.dtype == input->dtype(),
"DBias must have the same type as input_tensor.");
NVTE_CHECK(dbias->data.shape == std::vector<size_t>{last_logical_dim}, "Wrong shape of DBias.");
NVTE_CHECK(workspace != nullptr, "Workspace must be a tensor.");
if (workspace->data.dptr == nullptr) {
workspace->data.shape = {dbias_rows, dbias_cols};
workspace->data.dtype = DType::kFloat32;
return;
}
}
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
input->dtype(), IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
output->dtype(), OType,
alignas(64) CUtensorMap tensor_map_input{};
alignas(64) CUtensorMap tensor_map_act_input{};
alignas(64) CUtensorMap tensor_map_output_rowwise{};
alignas(64) CUtensorMap tensor_map_output_colwise{};
constexpr size_t input_type_bit_size = TypeInfo<IType>::size;
constexpr size_t output_type_bit_size = TypeInfo<OType>::size;
create_2D_tensor_map(tensor_map_input, input->data, first_logical_dim, last_logical_dim,
BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0, input_type_bit_size);
if constexpr (IS_DACT) {
create_2D_tensor_map(tensor_map_act_input, activations->data, first_logical_dim,
last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0,
input_type_bit_size);
}
if (use_rowwise_scaling) {
create_2D_tensor_map(tensor_map_output_rowwise, output->data, first_logical_dim,
last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0,
output_type_bit_size);
}
if (use_colwise_scaling) {
create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data,
first_logical_dim, last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X,
last_logical_dim, 0, output_type_bit_size);
}
constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X;
constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems;
constexpr size_t input_buff_size = (buff_elems_total * input_type_bit_size) / 8;
constexpr size_t output_buff_size = (buff_elems_total * output_type_bit_size) / 8;
constexpr size_t buff_size_aligned_in =
DIVUP_TO_MULTIPLE(input_buff_size, TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_aligned_out =
DIVUP_TO_MULTIPLE(output_buff_size, TMA_SHMEM_ALIGNMENT);
constexpr size_t elt_input_mem = buff_size_aligned_in;
constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0);
constexpr size_t in_mem = elt_input_mem + act_input_mem;
const size_t out_rowwise_mem = (use_rowwise_scaling ? buff_size_aligned_out : 0);
const size_t out_colwise_mem = (use_colwise_scaling ? buff_size_aligned_out : 0);
const size_t out_mem = out_rowwise_mem + out_colwise_mem;
const size_t dshmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT;
auto kernel = group_quantize_mxfp8_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType,
OType, true, true>;
switch (scaling_type) {
case ScalingType::ROWWISE: {
kernel = group_quantize_mxfp8_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType,
OType, true, false>;
break;
}
case ScalingType::COLWISE: {
kernel = group_quantize_mxfp8_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType,
OType, false, true>;
break;
}
case ScalingType::BIDIMENSIONAL: {
kernel = group_quantize_mxfp8_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType,
OType, true, true>;
break;
}
}
// Update tensor descriptors before launching the kernel
if (!is_single_tensor) {
const IType *const input_dptr = reinterpret_cast<const IType *>(input->data.dptr);
const IType *const act_input_dptr =
IS_DACT ? reinterpret_cast<const IType *>(activations->data.dptr) : nullptr;
OType *const output_rowwise_dptr =
use_rowwise_scaling ? reinterpret_cast<OType *>(output->data.dptr) : nullptr;
OType *const output_colwise_dptr =
use_colwise_scaling ? reinterpret_cast<OType *>(output->columnwise_data.dptr)
: nullptr;
update_tma_descriptors<IType, OType><<<num_tensors, 32, 0, stream>>>(
tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise,
tensor_map_output_colwise, input_dptr, act_input_dptr, output_rowwise_dptr,
output_colwise_dptr, shape_rep, num_tensors, first_logical_dim, last_logical_dim,
offsets_ptr, first_dims_ptr, last_dims_ptr, use_rowwise_scaling,
use_colwise_scaling, IS_DACT);
}
NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
dshmem_size));
kernel<<<grid, block_size, dshmem_size, stream>>>(
tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise,
tensor_map_output_colwise, shape_rep, num_tensors, first_logical_dim,
last_logical_dim, offsets_ptr, first_dims_ptr, last_dims_ptr, scales_rowwise_ptr,
scales_colwise_ptr, noop_ptr, workspace_ptr, amax_ptr);
if constexpr (IS_DBIAS) {
common::reduce_dbias<IType>(workspace_ptr, dbias, dbias_rows, dbias_cols, stream);
}
NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*)
); // NOLINT(*)
}
} // namespace mxfp8
} // namespace dispatch
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_GROUP_QUANTIZE_MXFP8_CUH_
......@@ -52,6 +52,16 @@ enum class NVTE_Activation_Type {
*/
void nvte_gelu(const NVTETensor input, NVTETensor output, cudaStream_t stream);
/*! \brief Computes the GeLU activation of the grouped input.
* If the scaling mode of the grouped output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] input Input grouped tensor for activation.
* \param[in,out] output Output grouped tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_gelu(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream);
/*! \brief Computes the SiLU activation of the input.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
......@@ -62,6 +72,16 @@ void nvte_gelu(const NVTETensor input, NVTETensor output, cudaStream_t stream);
*/
void nvte_silu(const NVTETensor input, NVTETensor output, cudaStream_t stream);
/*! \brief Computes the SiLU activation of the grouped input.
* If the scaling mode of the grouped output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] input Input grouped tensor for activation.
* \param[in,out] output Output grouped tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_silu(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream);
/*! \brief Computes the ReLU activation of the input.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
......@@ -72,6 +92,16 @@ void nvte_silu(const NVTETensor input, NVTETensor output, cudaStream_t stream);
*/
void nvte_relu(const NVTETensor input, NVTETensor output, cudaStream_t stream);
/*! \brief Computes the ReLU activation of the grouped input.
* If the scaling mode of the grouped output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] input Input grouped tensor for activation.
* \param[in,out] output Output grouped tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_relu(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream);
/*! \brief Computes the Quick GeLU activation of the input.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
......@@ -82,6 +112,16 @@ void nvte_relu(const NVTETensor input, NVTETensor output, cudaStream_t stream);
*/
void nvte_qgelu(const NVTETensor input, NVTETensor output, cudaStream_t stream);
/*! \brief Computes the Quick GeLU activation of the grouped input.
* If the scaling mode of the grouped output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] input Input grouped tensor for activation.
* \param[in,out] output Output grouped tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_qgelu(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream);
/*! \brief Computes the Squared ReLU activation of the input.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
......@@ -92,6 +132,16 @@ void nvte_qgelu(const NVTETensor input, NVTETensor output, cudaStream_t stream);
*/
void nvte_srelu(const NVTETensor input, NVTETensor output, cudaStream_t stream);
/*! \brief Computes the Squared ReLU activation of the grouped input.
* If the scaling mode of the grouped output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] input Input grouped tensor for activation.
* \param[in,out] output Output grouped tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_srelu(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream);
/*! \brief Computes the GeLU activation gradient.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
......@@ -104,6 +154,18 @@ void nvte_srelu(const NVTETensor input, NVTETensor output, cudaStream_t stream);
void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream);
/*! \brief Computes the GeLU activation gradient of the grouped input.
* If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] grad Incoming grouped gradient.
* \param[in] input Input grouped tensor for activation.
* \param[in,out] output Output grouped tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_dgelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input,
NVTETensor output, cudaStream_t stream);
/*! \brief Computes the SiLU activation gradient.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
......@@ -116,6 +178,18 @@ void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output
void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream);
/*! \brief Computes the SiLU activation gradient of the grouped input.
* If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] grad Incoming grouped gradient.
* \param[in] input Input grouped tensor for activation.
* \param[in,out] output Output grouped tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_dsilu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input,
NVTETensor output, cudaStream_t stream);
/*! \brief Computes the ReLU activation gradient.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
......@@ -128,6 +202,18 @@ void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output
void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream);
/*! \brief Computes the ReLU activation gradient of the grouped input.
* If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] grad Incoming grouped gradient.
* \param[in] input Input grouped tensor for activation.
* \param[in,out] output Output grouped tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_drelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input,
NVTETensor output, cudaStream_t stream);
/*! \brief Computes the Quick GeLU activation gradient.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
......@@ -140,6 +226,18 @@ void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output
void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream);
/*! \brief Computes the Quick GeLU activation gradient of the grouped input.
* If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] grad Incoming grouped gradient.
* \param[in] input Input grouped tensor for activation.
* \param[in,out] output Output grouped tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_dqgelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input,
NVTETensor output, cudaStream_t stream);
/*! \brief Computes the Squared ReLU activation gradient.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
......@@ -152,6 +250,18 @@ void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu
void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream);
/*! \brief Computes the Squared ReLU activation gradient of the grouped input.
* If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] grad Incoming grouped gradient.
* \param[in] input Input grouped tensor for activation.
* \param[in,out] output Output grouped tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_dsrelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input,
NVTETensor output, cudaStream_t stream);
/*! \brief Computes the gated GeLU activation of the input.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
......
......@@ -89,6 +89,17 @@ extern "C" {
*/
void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t stream);
/*! \brief Casts input grouped tensor to MXFP8.
* The type of quantized tensor in the output depends on the scaling mode of the output
* tensor. See file level comments.
*
* \param[in] input Input grouped tensor to be cast.
* \param[in,out] output Output grouped MXFP8 tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_quantize(const NVTEGroupedTensor input, NVTEGroupedTensor output,
cudaStream_t stream);
/*! \brief Casts input tensor to FP8/MXFP8/BlockwiseFP8, providing the option to immediately exit the kernel
* based on the value of the 'noop' tensor.
* The type of quantized tensor in the output depends on the scaling mode of the output
......@@ -132,6 +143,26 @@ void nvte_quantize_v2(const NVTETensor input, NVTETensor output,
void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor dbias,
NVTETensor workspace, cudaStream_t stream);
/*! \brief Casts input grouped tensor to MXFP8. Additionally, reduces the input along columns.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* This function produces 2 results:
* - `output` is equal to `cast(dact(input))`
* - `dbias` is equal to `reduce(dact(input), dim=1)`
*
* Calling this function with the workspace being an empty tensor will not perform the operation,
* but instead set the shape and type of the workspace tensor to the required values.
*
* \param[in] input Input grouped tensor to be cast.
* \param[in,out] output Output grouped FP8/MXFP8 tensor.
* \param[out] dbias Result of the reduction of the input along columns.
* \param[out] workspace Workspace tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_quantize_dbias(const NVTEGroupedTensor input, NVTEGroupedTensor output,
NVTETensor dbias, NVTETensor workspace, cudaStream_t stream);
/*! \brief Computes backward of GeLU operation on the input, then casts to FP8/MXFP8.
* Additionally, reduces the result of the GeLU backward along columns.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
......@@ -155,6 +186,29 @@ void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor act_inpu
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream);
/*! \brief Computes backward of GeLU operation on the grouped input, then casts to FP8/MXFP8.
* Additionally, reduces the result of the GeLU backward along columns.
* If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* This function produces 2 results:
* - `output` is equal to `cast(dact(input))`
* - `dbias` is equal to `reduce(dact(input), dim=1)`
*
* Calling this function with the workspace being an empty tensor will not perform the operation,
* but instead set the shape and type of the workspace tensor to the required values.
*
* \param[in] input Input grouped tensor to be cast.
* \param[in] act_input Activation input grouped tensor.
* \param[in,out] output Output grouped FP8/MXFP8 tensor.
* \param[out] dbias Result of the reduction of the input along columns.
* \param[out] workspace Workspace tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_quantize_dbias_dgelu(const NVTEGroupedTensor input,
const NVTEGroupedTensor act_input, NVTEGroupedTensor output,
NVTETensor dbias, NVTETensor workspace, cudaStream_t stream);
/*! \brief Computes backward of SiLU operation on the input, then casts to FP8/MXFP8.
* Additionally, reduces the result of the SiLU backward along columns.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
......@@ -178,6 +232,29 @@ void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor act_inpu
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream);
/*! \brief Computes backward of SiLU operation on the grouped input, then casts to FP8/MXFP8.
* Additionally, reduces the result of the SiLU backward along columns.
* If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* This function produces 2 results:
* - `output` is equal to `cast(dact(input))`
* - `dbias` is equal to `reduce(dact(input), dim=1)`
*
* Calling this function with the workspace being an empty tensor will not perform the operation,
* but instead set the shape and type of the workspace tensor to the required values.
*
* \param[in] input Input grouped tensor to be cast.
* \param[in] act_input Activation input grouped tensor.
* \param[in,out] output Output grouped FP8/MXFP8 tensor.
* \param[out] dbias Result of the reduction of the input along columns.
* \param[out] workspace Workspace tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_quantize_dbias_dsilu(const NVTEGroupedTensor input,
const NVTEGroupedTensor act_input, NVTEGroupedTensor output,
NVTETensor dbias, NVTETensor workspace, cudaStream_t stream);
/*! \brief Computes backward of ReLU operation on the input, then casts to FP8/MXFP8.
* Additionally, reduces the result of the ReLU backward along columns.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
......@@ -201,6 +278,29 @@ void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor act_inpu
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream);
/*! \brief Computes backward of ReLU operation on the grouped input, then casts to FP8/MXFP8.
* Additionally, reduces the result of the ReLU backward along columns.
* If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* This function produces 2 results:
* - `output` is equal to `cast(dact(input))`
* - `dbias` is equal to `reduce(dact(input), dim=1)`
*
* Calling this function with the workspace being an empty tensor will not perform the operation,
* but instead set the shape and type of the workspace tensor to the required values.
*
* \param[in] input Input grouped tensor to be cast.
* \param[in] act_input Activation input grouped tensor.
* \param[in,out] output Output grouped FP8/MXFP8 tensor.
* \param[out] dbias Result of the reduction of the input along columns.
* \param[out] workspace Workspace tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_quantize_dbias_drelu(const NVTEGroupedTensor input,
const NVTEGroupedTensor act_input, NVTEGroupedTensor output,
NVTETensor dbias, NVTETensor workspace, cudaStream_t stream);
/*! \brief Computes backward of Quick GeLU operation on the input, then casts to FP8/MXFP8.
* Additionally, reduces the result of the Quick GeLU backward along columns.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
......@@ -224,6 +324,29 @@ void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor act_inp
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream);
/*! \brief Computes backward of Quick GeLU operation on the grouped input, then casts to FP8/MXFP8.
* Additionally, reduces the result of the Quick GeLU backward along columns.
* If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* This function produces 2 results:
* - `output` is equal to `cast(dact(input))`
* - `dbias` is equal to `reduce(dact(input), dim=1)`
*
* Calling this function with the workspace being an empty tensor will not perform the operation,
* but instead set the shape and type of the workspace tensor to the required values.
*
* \param[in] input Input grouped tensor to be cast.
* \param[in] act_input Activation input grouped tensor.
* \param[in,out] output Output grouped FP8/MXFP8 tensor.
* \param[out] dbias Result of the reduction of the input along columns.
* \param[out] workspace Workspace tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_quantize_dbias_dqgelu(const NVTEGroupedTensor input,
const NVTEGroupedTensor act_input, NVTEGroupedTensor output,
NVTETensor dbias, NVTETensor workspace, cudaStream_t stream);
/*! \brief Computes backward of Squared ReLU operation on the input, then casts to FP8/MXFP8.
* Additionally, reduces the result of the Squared ReLU backward along columns.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
......@@ -247,6 +370,29 @@ void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor act_inp
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream);
/*! \brief Computes backward of Squared ReLU operation on the grouped input, then casts to FP8/MXFP8.
* Additionally, reduces the result of the Squared ReLU backward along columns.
* If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* This function produces 2 results:
* - `output` is equal to `cast(dact(input))`
* - `dbias` is equal to `reduce(dact(input), dim=1)`
*
* Calling this function with the workspace being an empty tensor will not perform the operation,
* but instead set the shape and type of the workspace tensor to the required values.
*
* \param[in] input Input grouped tensor to be cast.
* \param[in] act_input Activation input grouped tensor.
* \param[in,out] output Output grouped FP8/MXFP8 tensor.
* \param[out] dbias Result of the reduction of the input along columns.
* \param[out] workspace Workspace tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_quantize_dbias_dsrelu(const NVTEGroupedTensor input,
const NVTEGroupedTensor act_input, NVTEGroupedTensor output,
NVTETensor dbias, NVTETensor workspace, cudaStream_t stream);
/*! \brief Casts input tensor from reduced to higher precision.
* If the scaling mode of the input tensor is set to NVTE_MXFP8_1D_SCALING,
* the block dequantization (MXFP8) of the specified shape of the block will be used.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment