Unverified Commit fbb16f4a authored by Oleg Goncharov's avatar Oleg Goncharov Committed by GitHub
Browse files

[Common] Tuned NVFP4 cast kernel (#2412)



* Implemented persistent nvfp4 kernel
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix FP4 guard in ptx
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Fix
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Fix in ptx. reduxf32 guard
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Fixes per PR review
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fixes per PR review. Added parameter to turn off the persistency
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Modified reference CPU implementation in C++ unit tests to match GPU (numerical truncation). Tightened the numerical tolerance
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Disabled persistency by default, as non-persistent kernel is more performant when inputs are large
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Use the tuned kernel also for the rowwise only quantization
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Fixed typo
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Addressed comments from the PR review
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Resolved conflicts
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Macros renaming
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

---------
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 27fc168e
...@@ -54,12 +54,16 @@ std::vector<InputType> create_transpose(const InputType* const input, const size ...@@ -54,12 +54,16 @@ std::vector<InputType> create_transpose(const InputType* const input, const size
} }
// Compute the global encode scale factor for a given global amax // Compute the global encode scale factor for a given global amax
float compute_global_encode_scaling_factor_FP4(const float global_amax) { float compute_global_encode_scaling_factor_FP4(const float global_amax, const bool use_fast_math) {
constexpr float fp8_max = 448.0f; // 448.0f; constexpr float fp8_max = 448.0f; // 448.0f;
constexpr float fp4_max = 6.0f; // 6.0f; constexpr float fp4_max = 6.0f; // 6.0f;
float global_encode_scale = fp8_max * fp4_max / global_amax; float global_encode_scale = fp8_max * fp4_max / global_amax;
// If scale is infinity, return max value of float32 // If scale is infinity, return the max normalized value
global_encode_scale = fminf(global_encode_scale, Numeric_Traits<float>::maxNorm); const float max_norm_clamp = use_fast_math
? Numeric_Traits<bf16>::maxNorm
: Numeric_Traits<float>::maxNorm;
global_encode_scale = fminf(global_encode_scale, max_norm_clamp);
// If global amax is 0 or infinity, return 1 // If global amax is 0 or infinity, return 1
if (global_amax == 0.0f || global_encode_scale == 0.0f) { if (global_amax == 0.0f || global_encode_scale == 0.0f) {
return 1.0f; return 1.0f;
...@@ -76,10 +80,11 @@ void quantize_nvfp4_1d(float (*OP)(const float), ...@@ -76,10 +80,11 @@ void quantize_nvfp4_1d(float (*OP)(const float),
const size_t rows, const size_t rows,
const size_t cols, const size_t cols,
const size_t scales_stride, const size_t scales_stride,
const float global_amax) { const float global_amax,
const bool use_fast_math) {
// Compute a global encoding/decoding scaling factor for all S_dec_b // Compute a global encoding/decoding scaling factor for all S_dec_b
const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax); const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax, use_fast_math);
constexpr size_t block_size_X = 16; constexpr size_t block_size_X = 16;
const size_t blocks_X = divide_round_up(cols, block_size_X); const size_t blocks_X = divide_round_up(cols, block_size_X);
...@@ -114,14 +119,20 @@ void quantize_nvfp4_1d(float (*OP)(const float), ...@@ -114,14 +119,20 @@ void quantize_nvfp4_1d(float (*OP)(const float),
const float S_dec_b = block_amax / 6.0f; const float S_dec_b = block_amax / 6.0f;
// Scale & Store per-block decoding scaling factor // Scale & Store per-block decoding scaling factor
const float S_dec_b_fp8 = S_dec_b * S_enc; const fp8e4m3 S_dec_b_fp8 = static_cast<fp8e4m3>(S_dec_b * S_enc);
const float S_dec_b_fp32 = static_cast<float>(S_dec_b_fp8);
// Compute "correct" per-block encoding scaling factor // Compute "correct" per-block encoding scaling factor
const float S_enc_b_fp8 = S_dec_b_fp8 == 0 ? 0.f : S_enc / S_dec_b_fp8; const float S_enc_b_fp8 = S_dec_b_fp32 == 0.f ? 0.f : S_enc / S_dec_b_fp32;
const size_t scale_idx = i * scales_stride + block_X; const size_t scale_idx = i * scales_stride + block_X;
scales[scale_idx] = static_cast<fp8e4m3>(S_dec_b_fp8); scales[scale_idx] = S_dec_b_fp8;
const float scale_reciprocal = S_enc_b_fp8;
float scale_reciprocal = S_enc_b_fp8;
if (use_fast_math) {
// Numerical truncation to match GPU implementation, if mixed precision FMA instruction is used
scale_reciprocal = static_cast<float>(static_cast<bf16>(scale_reciprocal));
}
for (size_t j = j_min; j < j_max; j += 2) { for (size_t j = j_min; j < j_max; j += 2) {
const int idx_pair = (i * cols + j) / 2; const int idx_pair = (i * cols + j) / 2;
...@@ -136,7 +147,7 @@ void quantize_nvfp4_1d(float (*OP)(const float), ...@@ -136,7 +147,7 @@ void quantize_nvfp4_1d(float (*OP)(const float),
fp4e2m1x2 casted_to_e2m1_pair(scaled_elt_pair); fp4e2m1x2 casted_to_e2m1_pair(scaled_elt_pair);
output[idx_pair] = casted_to_e2m1_pair; output[idx_pair] = casted_to_e2m1_pair;
// const double2 truncated_pair = cvt_fp4x2_to_double2(casted_to_e2m1_pair); const double2 truncated_pair = cvt_fp4x2_to_double2(casted_to_e2m1_pair);
} }
} }
} }
...@@ -149,9 +160,10 @@ void compute_2d_mathematical_scales(float (*OP)(const float), ...@@ -149,9 +160,10 @@ void compute_2d_mathematical_scales(float (*OP)(const float),
const size_t rows, const size_t rows,
const size_t cols, const size_t cols,
const float global_amax, const float global_amax,
std::vector<std::vector<fp8e4m3>>& math_scales) { std::vector<std::vector<fp8e4m3>>& math_scales,
const bool use_fast_math) {
const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax); const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax, use_fast_math);
constexpr size_t block_size_Y = 16; constexpr size_t block_size_Y = 16;
constexpr size_t block_size_X = 16; constexpr size_t block_size_X = 16;
const size_t blocks_Y = divide_round_up(rows, block_size_Y); const size_t blocks_Y = divide_round_up(rows, block_size_Y);
...@@ -195,13 +207,14 @@ void quantize_nvfp4_2d(float (*OP)(const float), ...@@ -195,13 +207,14 @@ void quantize_nvfp4_2d(float (*OP)(const float),
const size_t rows, const size_t rows,
const size_t cols, const size_t cols,
const size_t scales_stride, const size_t scales_stride,
const float global_amax) { const float global_amax,
const bool use_fast_math) {
// Step 1: Compute mathematical 8x8 scaling factors // Step 1: Compute mathematical 8x8 scaling factors
std::vector<std::vector<fp8e4m3>> math_scales; std::vector<std::vector<fp8e4m3>> math_scales;
compute_2d_mathematical_scales(OP, input, rows, cols, global_amax, math_scales); compute_2d_mathematical_scales(OP, input, rows, cols, global_amax, math_scales, use_fast_math);
const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax); const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax, use_fast_math);
constexpr size_t block_size_Y = 16; constexpr size_t block_size_Y = 16;
constexpr size_t block_size_X = 16; constexpr size_t block_size_X = 16;
const size_t blocks_Y = divide_round_up(rows, block_size_Y); const size_t blocks_Y = divide_round_up(rows, block_size_Y);
...@@ -282,11 +295,12 @@ void quantize_nvfp4(float (*OP)(const float), ...@@ -282,11 +295,12 @@ void quantize_nvfp4(float (*OP)(const float),
const size_t cols, const size_t cols,
const size_t scales_stride, const size_t scales_stride,
const float global_amax, const float global_amax,
const bool use_fast_math,
const bool use_2d_quantization = false) { const bool use_2d_quantization = false) {
if (use_2d_quantization) { if (use_2d_quantization) {
quantize_nvfp4_2d(OP, input, output, scales, rows, cols, scales_stride, global_amax); quantize_nvfp4_2d(OP, input, output, scales, rows, cols, scales_stride, global_amax, use_fast_math);
} else { } else {
quantize_nvfp4_1d(OP, input, output, scales, rows, cols, scales_stride, global_amax); quantize_nvfp4_1d(OP, input, output, scales, rows, cols, scales_stride, global_amax, use_fast_math);
} }
} }
...@@ -302,6 +316,7 @@ void compute_ref(float (*OP)(const float), ...@@ -302,6 +316,7 @@ void compute_ref(float (*OP)(const float),
const size_t cols, const size_t cols,
const size_t scales_stride, const size_t scales_stride,
const size_t scales_stride_t, const size_t scales_stride_t,
const bool use_fast_math,
const bool use_2d_quantization = false) const bool use_2d_quantization = false)
{ {
std::vector<InputType> input_t = create_transpose(input, rows, cols); std::vector<InputType> input_t = create_transpose(input, rows, cols);
...@@ -309,7 +324,7 @@ void compute_ref(float (*OP)(const float), ...@@ -309,7 +324,7 @@ void compute_ref(float (*OP)(const float),
if (use_2d_quantization) { if (use_2d_quantization) {
// Step 1: Compute mathematical 8×8 scaling factors // Step 1: Compute mathematical 8×8 scaling factors
std::vector<std::vector<fp8e4m3>> math_scales; std::vector<std::vector<fp8e4m3>> math_scales;
compute_2d_mathematical_scales(OP, input, rows, cols, global_amax, math_scales); compute_2d_mathematical_scales(OP, input, rows, cols, global_amax, math_scales, use_fast_math);
constexpr size_t block_size_Y = 16; constexpr size_t block_size_Y = 16;
constexpr size_t block_size_X = 16; constexpr size_t block_size_X = 16;
...@@ -336,12 +351,16 @@ void compute_ref(float (*OP)(const float), ...@@ -336,12 +351,16 @@ void compute_ref(float (*OP)(const float),
// Step 4: Process quantized outputs using the same algorithm as quantize_nvfp4_2d // Step 4: Process quantized outputs using the same algorithm as quantize_nvfp4_2d
// (This part processes the actual FP4 data using the mathematical scaling factors) // (This part processes the actual FP4 data using the mathematical scaling factors)
quantize_nvfp4_2d(OP, input, output, nullptr, rows, cols, scales_stride, global_amax); // scales already filled quantize_nvfp4_2d(OP, input, output, nullptr, rows, cols, scales_stride, global_amax,
quantize_nvfp4_2d(OP, input_t.data(), output_t, nullptr, cols, rows, scales_stride_t, global_amax); // scales_t already filled use_fast_math); // scales already filled
quantize_nvfp4_2d(OP, input_t.data(), output_t, nullptr, cols, rows, scales_stride_t, global_amax,
use_fast_math); // scales_t already filled
} else { } else {
quantize_nvfp4(OP, input, output, scales, rows, cols, scales_stride, global_amax, use_2d_quantization); quantize_nvfp4(OP, input, output, scales, rows, cols, scales_stride, global_amax,
quantize_nvfp4(OP, input_t.data(), output_t, scales_t, cols, rows, scales_stride_t, global_amax, use_2d_quantization); use_fast_math, use_2d_quantization);
quantize_nvfp4(OP, input_t.data(), output_t, scales_t, cols, rows, scales_stride_t, global_amax,
use_fast_math, use_2d_quantization);
} }
} }
...@@ -349,6 +368,8 @@ void compare_nvfp4_tensors(const std::string& name, ...@@ -349,6 +368,8 @@ void compare_nvfp4_tensors(const std::string& name,
const fp4e2m1 *test_data, const fp4e2m1 *ref_data, const fp4e2m1 *test_data, const fp4e2m1 *ref_data,
const int rows, const int cols, const int rows, const int cols,
double atol = 1e-5, double rtol = 1e-8) { double atol = 1e-5, double rtol = 1e-8) {
constexpr int max_mismatches_to_print = 3;
std::vector<std::string> mismatch_messages; std::vector<std::string> mismatch_messages;
size_t total_mismatches = 0; size_t total_mismatches = 0;
...@@ -362,29 +383,16 @@ void compare_nvfp4_tensors(const std::string& name, ...@@ -362,29 +383,16 @@ void compare_nvfp4_tensors(const std::string& name,
const double t = (k == 0 ? test_data_pair.x : test_data_pair.y); const double t = (k == 0 ? test_data_pair.x : test_data_pair.y);
const double r = (k == 0 ? ref_data_pair.x : ref_data_pair.y); const double r = (k == 0 ? ref_data_pair.x : ref_data_pair.y);
bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol); const bool mismatch = fabs(t - r) > (atol + fabs(r) * rtol);
/* For Float32 the floating point comparison is enough to error out */ if (mismatch) {
bool assertion = false;
if (mismatch && !assertion) {
/* Check if it is just a failure of round to nearest choosing different
side of the real value */
const double mean = (t + r) / 2;
const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6);
const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6);
const double cast_mean_p = static_cast<double>(static_cast<fp4e2m1>(mean_p));
const double cast_mean_m = static_cast<double>(static_cast<fp4e2m1>(mean_m));
assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r));
}
if (assertion) {
total_mismatches++; total_mismatches++;
// Optional: limit number of detailed messages to avoid overwhelming output
if (total_mismatches <= max_mismatches_to_print) {
std::string msg = "Mismatch at place (" + std::to_string(idx + k) + "): " + std::string msg = "Mismatch at place (" + std::to_string(idx + k) + "): " +
std::to_string(t) + " vs " + std::to_string(r) + std::to_string(t) + " vs " + std::to_string(r) +
" (abs_diff: " + std::to_string(fabs(t - r)) + " (abs_diff: " + std::to_string(fabs(t - r)) +
", rel_diff: " + std::to_string(r == 0 ? 0.0 : fabs((t - r) / r)) + ")"; ", rel_diff: " + std::to_string(r == 0 ? 0.0 : fabs((t - r) / r)) + ")";
mismatch_messages.push_back(msg); mismatch_messages.push_back(msg);
// Optional: limit number of detailed messages to avoid overwhelming output
if (mismatch_messages.size() <= 100) {
std::cout << "Error in tensor " << name << ": " << msg << std::endl; std::cout << "Error in tensor " << name << ": " << msg << std::endl;
} }
} }
...@@ -400,8 +408,9 @@ void compare_nvfp4_tensors(const std::string& name, ...@@ -400,8 +408,9 @@ void compare_nvfp4_tensors(const std::string& name,
std::cout << "STATUS: FAILED for output" << std::endl; std::cout << "STATUS: FAILED for output" << std::endl;
std::cout << "Total mismatches found: " << total_mismatches << std::endl; std::cout << "Total mismatches found: " << total_mismatches << std::endl;
std::cout << "Mismatch rate: " << (100.0 * total_mismatches) / (rows * cols) << "%" << std::endl; std::cout << "Mismatch rate: " << (100.0 * total_mismatches) / (rows * cols) << "%" << std::endl;
if (mismatch_messages.size() > 100) { if (mismatch_messages.size() > max_mismatches_to_print) {
std::cout << "... and " << (mismatch_messages.size() - 100) << " more mismatches (showing first 100)" << std::endl; std::cout << "... and " << (mismatch_messages.size() - max_mismatches_to_print)
<< " more mismatches (showing first " << max_mismatches_to_print << ")" << std::endl;
} }
std::cout << "============================" << std::endl; std::cout << "============================" << std::endl;
...@@ -519,7 +528,8 @@ void compareResults_nvfp4(const Tensor &test, ...@@ -519,7 +528,8 @@ void compareResults_nvfp4(const Tensor &test,
template <typename InputType> template <typename InputType>
void performTest(float (*OP)(const float), void performTest(float (*OP)(const float),
const std::vector<size_t>& shape) { const std::vector<size_t>& shape,
const bool use_fast_math) {
using namespace test; using namespace test;
DType itype = TypeInfo<InputType>::dtype; DType itype = TypeInfo<InputType>::dtype;
...@@ -580,15 +590,16 @@ void performTest(float (*OP)(const float), ...@@ -580,15 +590,16 @@ void performTest(float (*OP)(const float),
cols, cols,
scales_stride, scales_stride,
scales_stride_t, scales_stride_t,
use_fast_math,
use_2d_quantization); use_2d_quantization);
QuantizationConfigWrapper quant_config;
// Initialize stochastic rounding // Initialize stochastic rounding
Tensor rng_state("rng_state", std::vector<size_t>{2}, DType::kInt64); Tensor rng_state("rng_state", std::vector<size_t>{2}, DType::kInt64);
rng_state.rowwise_cpu_dptr<int64_t>()[0] = 123; // rng_seed rng_state.rowwise_cpu_dptr<int64_t>()[0] = 123; // rng_seed
rng_state.rowwise_cpu_dptr<int64_t>()[1] = 321; // rng_sequence rng_state.rowwise_cpu_dptr<int64_t>()[1] = 321; // rng_sequence
rng_state.from_cpu(); rng_state.from_cpu();
QuantizationConfigWrapper quant_config;
quant_config.set_use_fast_math(use_fast_math);
quant_config.set_stochastic_rounding(false); quant_config.set_stochastic_rounding(false);
quant_config.set_rng_state(rng_state.data()); quant_config.set_rng_state(rng_state.data());
...@@ -619,8 +630,8 @@ void performTest(float (*OP)(const float), ...@@ -619,8 +630,8 @@ void performTest(float (*OP)(const float),
} }
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
const double atol = 0.05; const double atol = 1.0E-6;
const double rtol = 0.1; const double rtol = 1.0E-6;
// Set dump_data=true to enable dumping tensor data to files for analysis // Set dump_data=true to enable dumping tensor data to files for analysis
compareResults_nvfp4(output, ref_output.get(), ref_output_t.get(), rows, cols, atol, rtol, true, false); compareResults_nvfp4(output, ref_output.get(), ref_output_t.get(), rows, cols, atol, rtol, true, false);
...@@ -666,12 +677,18 @@ std::vector<ActivationType> Activation_types = { ...@@ -666,12 +677,18 @@ std::vector<ActivationType> Activation_types = {
ActivationType::Identity ActivationType::Identity
}; };
std::vector<bool> use_fast_nvfp4_scaling_vec = {
false,
true
};
} // namespace } // namespace
class FusedCastTransposeNVFP4TestSuite : public ::testing::TestWithParam class FusedCastTransposeNVFP4TestSuite : public ::testing::TestWithParam
<std::tuple<ActivationType, <std::tuple<ActivationType,
std::vector<size_t>, std::vector<size_t>,
transformer_engine::DType>> {}; transformer_engine::DType,
bool>> {};
TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) { TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) {
// Skip tests for pre-Blackwell architectures // Skip tests for pre-Blackwell architectures
...@@ -685,6 +702,7 @@ TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) { ...@@ -685,6 +702,7 @@ TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) {
const ActivationType Act_type = std::get<0>(GetParam()); const ActivationType Act_type = std::get<0>(GetParam());
const auto tensor_dims = std::get<1>(GetParam()); const auto tensor_dims = std::get<1>(GetParam());
const DType input_type = std::get<2>(GetParam()); const DType input_type = std::get<2>(GetParam());
const bool use_fast_math = std::get<3>(GetParam());
// Skip tests if the input tensor is 1D // Skip tests if the input tensor is 1D
if (tensor_dims.size() < 2) { if (tensor_dims.size() < 2) {
...@@ -702,7 +720,7 @@ TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) { ...@@ -702,7 +720,7 @@ TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) {
} }
TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType, TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType,
performTest<InputType>(OP, tensor_dims); performTest<InputType>(OP, tensor_dims, use_fast_math);
); );
} }
...@@ -724,7 +742,8 @@ INSTANTIATE_TEST_SUITE_P( ...@@ -724,7 +742,8 @@ INSTANTIATE_TEST_SUITE_P(
::testing::Combine( ::testing::Combine(
::testing::ValuesIn(Activation_types), ::testing::ValuesIn(Activation_types),
::testing::ValuesIn(tensor_dims), ::testing::ValuesIn(tensor_dims),
::testing::Values(DType::kBFloat16)), ::testing::Values(DType::kBFloat16),
::testing::ValuesIn(use_fast_nvfp4_scaling_vec)),
[](const testing::TestParamInfo<FusedCastTransposeNVFP4TestSuite::ParamType>& info) { [](const testing::TestParamInfo<FusedCastTransposeNVFP4TestSuite::ParamType>& info) {
std::string name = to_string(std::get<0>(info.param)); std::string name = to_string(std::get<0>(info.param));
const auto& shape = std::get<1>(info.param); const auto& shape = std::get<1>(info.param);
...@@ -732,5 +751,8 @@ INSTANTIATE_TEST_SUITE_P( ...@@ -732,5 +751,8 @@ INSTANTIATE_TEST_SUITE_P(
name += "X" + std::to_string(s); name += "X" + std::to_string(s);
} }
name += "X" + test::typeName(std::get<2>(info.param)); name += "X" + test::typeName(std::get<2>(info.param));
if (std::get<3>(info.param)) {
name += "X_FAST_SCALING";
}
return name; return name;
}); });
...@@ -35,6 +35,12 @@ inline bool dimensions_supported_by_TMA(const Tensor *const t) { ...@@ -35,6 +35,12 @@ inline bool dimensions_supported_by_TMA(const Tensor *const t) {
return cols % alignment_requirement == 0; return cols % alignment_requirement == 0;
} }
__device__ __forceinline__ unsigned char *align_smem_ptr_per_TMA_requirements(unsigned char *p) {
size_t addr = reinterpret_cast<size_t>(p);
addr = (addr + TMA_SHMEM_ALIGNMENT - 1) & ~(TMA_SHMEM_ALIGNMENT - 1);
return reinterpret_cast<unsigned char *>(addr);
}
namespace kernel { namespace kernel {
constexpr size_t THREADS_PER_BLOCK = 256; constexpr size_t THREADS_PER_BLOCK = 256;
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "../../util/ptx.cuh" #include "../../util/ptx.cuh"
#include "../../utils.cuh" #include "../../utils.cuh"
#include "core_nvfp4.cuh" #include "core_nvfp4.cuh"
#include "specialized/quantize_transpose_nvfp4_tuned_1D.cuh"
namespace transformer_engine { namespace transformer_engine {
namespace dispatch { namespace dispatch {
...@@ -1159,6 +1160,7 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, ...@@ -1159,6 +1160,7 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output,
#if FP4_TYPE_SUPPORTED #if FP4_TYPE_SUPPORTED
using namespace quantize_transpose_kernel; using namespace quantize_transpose_kernel;
using namespace ptx; using namespace ptx;
bool use_stochastic_rounding = quant_config ? quant_config->stochastic_rounding : false; bool use_stochastic_rounding = quant_config ? quant_config->stochastic_rounding : false;
// If transposed output is allocated, return the transposed data. Otherwise, it's not necesary to // If transposed output is allocated, return the transposed data. Otherwise, it's not necesary to
...@@ -1166,6 +1168,11 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, ...@@ -1166,6 +1168,11 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output,
// TODO(Frank): Is there a better way to do this? // TODO(Frank): Is there a better way to do this?
bool return_transpose = output->has_columnwise_data(); bool return_transpose = output->has_columnwise_data();
if (!use_2d_quantization && (input.dtype() == DType::kBFloat16)) {
quantize_transpose_tuned_1D(input, noop, output, quant_config, stream);
return;
}
constexpr bool COMPUTE_ACTIVATIONS = false; constexpr bool COMPUTE_ACTIVATIONS = false;
using ParamOP = Empty; using ParamOP = Empty;
constexpr float (*OP)(float, const ParamOP &) = nullptr; constexpr float (*OP)(float, const ParamOP &) = nullptr;
......
...@@ -164,6 +164,18 @@ __device__ __forceinline__ void mbarrier_arrive_expect_tx(uint64_t *mbar, const ...@@ -164,6 +164,18 @@ __device__ __forceinline__ void mbarrier_arrive_expect_tx(uint64_t *mbar, const
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
} }
__device__ __forceinline__ void mbarrier_arrive_expect_tx_cta_relaxed_shared_cta(
uint64_t *mbar, const uint32_t tx_count) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
uint32_t mbar_ptr = __cvta_generic_to_shared(mbar);
asm volatile("mbarrier.arrive.expect_tx.relaxed.cta.shared::cta.b64 _, [%0], %1;" ::"r"(mbar_ptr),
"r"(tx_count));
#else
NVTE_DEVICE_ERROR(
"mbarrier_arrive_expect_tx_cta_relaxed_shared_cta is only supported on SM 10.0+.");
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
__device__ __forceinline__ void fence_mbarrier_init_release_cluster() { __device__ __forceinline__ void fence_mbarrier_init_release_cluster() {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
asm volatile("fence.mbarrier_init.release.cluster;"); asm volatile("fence.mbarrier_init.release.cluster;");
...@@ -243,6 +255,75 @@ __device__ __forceinline__ void mbarrier_wait_parity(uint64_t *mbar, const uint3 ...@@ -243,6 +255,75 @@ __device__ __forceinline__ void mbarrier_wait_parity(uint64_t *mbar, const uint3
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
} }
__device__ __forceinline__ void mbarrier_wait_parity_acquire_cta_shared_cta(uint64_t *mbar,
uint32_t phase_parity) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
uint32_t mbar_ptr = __cvta_generic_to_shared(mbar);
asm volatile(
"{\n\t"
".reg .b64 r1; \n\t"
".reg .pred waitComplete; \n\t" // predicate representing if barrier condition is met
"WAIT: \n\t" // loop around barrier wait
"mbarrier.try_wait.parity.acquire.cta.shared::cta.b64 waitComplete, [%0], %1; \n\t"
"@waitComplete bra DONE; \n\t" // mbarrier conditions are met
"bra WAIT; \n\t" // just a time-out, try again
"DONE: \n\t"
"}\n\t"
:
: "r"(mbar_ptr), "r"(phase_parity)
: "memory");
#else
NVTE_DEVICE_ERROR("mbarrier_wait_parity_acquire_cta_shared_cta is only supported on SM 10.0+.");
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
__device__ __forceinline__ void try_cancel_cta(uint64_t *mbar, __uint128_t *response_data_ptr) {
constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY;
if constexpr (is_blackwell) {
uint32_t mbar_ptr = __cvta_generic_to_shared(mbar);
uint32_t workID_response = __cvta_generic_to_shared(response_data_ptr);
asm volatile(
"clusterlaunchcontrol.try_cancel.async.mbarrier::complete_tx::bytes.multicast::cluster::"
"all.b128 "
"[%0], [%1];" ::"r"(workID_response),
"r"(mbar_ptr));
} else {
NVTE_DEVICE_ERROR(
"Cluster Launch Control PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX.");
}
}
__device__ __forceinline__ void get_cancelled_cta_id_2D(__uint128_t *response_data_ptr,
int32_t &ctaid_X, int32_t &ctaid_Y) {
constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY;
if constexpr (is_blackwell) {
uint32_t workID_response = __cvta_generic_to_shared(response_data_ptr);
asm volatile(
"{\n\t"
".reg .s32 x_ctaid; \n\t"
".reg .s32 y_ctaid; \n\t"
"mov .s32 x_ctaid, -1; \n\t"
"mov .s32 y_ctaid, -1; \n\t"
".reg.b128 try_cancel_response; \n\t"
"ld.shared.b128 try_cancel_response, [%2]; \n\t"
".reg .pred P1; \n\t"
"clusterlaunchcontrol.query_cancel.is_canceled.pred.b128 P1, try_cancel_response; \n\t"
"@P1 clusterlaunchcontrol.query_cancel.get_first_ctaid.v4.b32.b128 {x_ctaid, y_ctaid, _, "
"_}, try_cancel_response; \n\t"
"mov .s32 %0, x_ctaid; \n\t"
"mov .s32 %1, y_ctaid; \n\t"
"}\n\t"
: "=r"(ctaid_X), "=r"(ctaid_Y)
: "r"(workID_response)
: "memory");
} else {
NVTE_DEVICE_ERROR(
"Cluster Launch Control PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX.");
}
}
constexpr uint32_t FP32_MANTISSA_BITS = 23; constexpr uint32_t FP32_MANTISSA_BITS = 23;
constexpr uint32_t FP32_EXPONENT_BIAS = 127; constexpr uint32_t FP32_EXPONENT_BIAS = 127;
...@@ -657,6 +738,179 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x(const float2 in01, c ...@@ -657,6 +738,179 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x(const float2 in01, c
return mul_cvt_fp32_to_fp4_4x_with_rn(in01, in23, scale, rbits); return mul_cvt_fp32_to_fp4_4x_with_rn(in01, in23, scale, rbits);
} }
} }
template <typename SCALING_COEFFICIENT_TYPE>
__device__ __forceinline__ uint32_t mul_cvt_bf16_to_fp4_8x_round_to_nearest(
const uint64_t in03, const uint64_t in47, const SCALING_COEFFICIENT_TYPE scaling_coefficient) {
uint32_t out_8x = 0;
constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY;
if constexpr (is_blackwell) {
if constexpr (std::is_same<SCALING_COEFFICIENT_TYPE, bf16>::value) {
asm volatile(
"{\n"
".reg.f32 zero; \n\t"
"mov.b32 zero, 0; \n\t"
".reg.b16 scaling_coeff; \n\t"
"mov.b16 scaling_coeff, %3; \n\t"
".reg.b16 v0_h, v1_h, v2_h, v3_h, v4_h, v5_h, v6_h, v7_h; \n\t"
"mov.b64 {v0_h, v1_h, v2_h, v3_h}, %1; \n\t"
"mov.b64 {v4_h, v5_h, v6_h, v7_h}, %2; \n\t"
".reg.f32 v0, v1, v2, v3, v4, v5, v6, v7; \n\t"
"fma.rn.f32.bf16 v0, v0_h, scaling_coeff, zero; \n\t"
"fma.rn.f32.bf16 v1, v1_h, scaling_coeff, zero; \n\t"
"fma.rn.f32.bf16 v2, v2_h, scaling_coeff, zero; \n\t"
"fma.rn.f32.bf16 v3, v3_h, scaling_coeff, zero; \n\t"
"fma.rn.f32.bf16 v4, v4_h, scaling_coeff, zero; \n\t"
"fma.rn.f32.bf16 v5, v5_h, scaling_coeff, zero; \n\t"
"fma.rn.f32.bf16 v6, v6_h, scaling_coeff, zero; \n\t"
"fma.rn.f32.bf16 v7, v7_h, scaling_coeff, zero; \n\t"
".reg.b8 f0, f1, f2, f3; \n\t"
// Elements reordered to match e2m1x4 packing order (v1,v0)
"cvt.rn.satfinite.e2m1x2.f32 f0, v1, v0;\n\t"
"cvt.rn.satfinite.e2m1x2.f32 f1, v3, v2;\n\t"
"cvt.rn.satfinite.e2m1x2.f32 f2, v5, v4;\n\t"
"cvt.rn.satfinite.e2m1x2.f32 f3, v7, v6;\n\t"
"mov.b32 %0, {f0, f1, f2, f3};\n"
"}"
: "=r"(out_8x)
: "l"(in03), "l"(in47), "h"(reinterpret_cast<const uint16_t &>(scaling_coefficient)));
} else if constexpr (std::is_same<SCALING_COEFFICIENT_TYPE, float>::value) {
asm volatile(
"{\n"
".reg.b64 scaling_coeff_2x; \n\t"
"mov.b64 scaling_coeff_2x, {%3, %3}; \n\t"
".reg.b16 v0_bf16, v1_bf16, v2_bf16, v3_bf16, v4_bf16, v5_bf16, v6_bf16, v7_bf16; \n\t"
"mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16}, %1; \n\t"
"mov.b64 {v4_bf16, v5_bf16, v6_bf16, v7_bf16}, %2; \n\t"
".reg.b32 v0, v1, v2, v3, v4, v5, v6, v7; \n\t"
"cvt.f32.bf16 v0, v0_bf16; \n\t"
"cvt.f32.bf16 v1, v1_bf16; \n\t"
"cvt.f32.bf16 v2, v2_bf16; \n\t"
"cvt.f32.bf16 v3, v3_bf16; \n\t"
"cvt.f32.bf16 v4, v4_bf16; \n\t"
"cvt.f32.bf16 v5, v5_bf16; \n\t"
"cvt.f32.bf16 v6, v6_bf16; \n\t"
"cvt.f32.bf16 v7, v7_bf16; \n\t"
".reg.b64 v01, v23, v45, v67; \n\t"
"mov.b64 v01, {v0, v1}; \n\t"
"mov.b64 v23, {v2, v3}; \n\t"
"mov.b64 v45, {v4, v5}; \n\t"
"mov.b64 v67, {v6, v7}; \n\t"
"mul.f32x2 v01, v01, scaling_coeff_2x; \n\t"
"mul.f32x2 v23, v23, scaling_coeff_2x; \n\t"
"mul.f32x2 v45, v45, scaling_coeff_2x; \n\t"
"mul.f32x2 v67, v67, scaling_coeff_2x; \n\t"
// Elements reordered to match the packing order (v1,v0)
"mov.b64 {v1, v0}, v01; \n\t"
"mov.b64 {v3, v2}, v23; \n\t"
"mov.b64 {v5, v4}, v45; \n\t"
"mov.b64 {v7, v6}, v67; \n\t"
".reg.b8 f0, f1, f2, f3; \n\t"
"cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;\n\t"
"cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;\n\t"
"cvt.rn.satfinite.e2m1x2.f32 f2, v4, v5;\n\t"
"cvt.rn.satfinite.e2m1x2.f32 f3, v6, v7;\n\t"
"mov.b32 %0, {f0, f1, f2, f3};\n\t"
"}"
: "=r"(out_8x)
: "l"(in03), "l"(in47), "f"(scaling_coefficient));
} else {
NVTE_DEVICE_ERROR("Not supported scaling coefficient type.");
}
} else {
NVTE_DEVICE_ERROR(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX.");
}
return out_8x;
}
template <typename SCALING_COEFFICIENT_TYPE>
__device__ __forceinline__ uint32_t mul_cvt_bf16_to_fp4_8x_stochastic_rounding(
const uint64_t in03, const uint64_t in47, const SCALING_COEFFICIENT_TYPE scaling_coefficient,
const uint32_t rbits03, const uint32_t rbits47) {
uint32_t out_8x = 0;
constexpr bool has_rs = ARCH_HAS_STOCHASTIC_ROUNDING;
if constexpr (has_rs) {
if constexpr (std::is_same<SCALING_COEFFICIENT_TYPE, bf16>::value) {
asm volatile(
"{\n"
".reg.f32 zero; \n\t"
"mov.b32 zero, 0; \n\t"
".reg.b16 scaling_coeff; \n\t"
"mov.b16 scaling_coeff, %3; \n\t"
".reg.b16 v0_h, v1_h, v2_h, v3_h, v4_h, v5_h, v6_h, v7_h; \n\t"
"mov.b64 {v0_h, v1_h, v2_h, v3_h}, %1; \n\t"
"mov.b64 {v4_h, v5_h, v6_h, v7_h}, %2; \n\t"
".reg.f32 v0, v1, v2, v3, v4, v5, v6, v7; \n\t"
"fma.rn.f32.bf16 v0, v0_h, scaling_coeff, zero; \n\t"
"fma.rn.f32.bf16 v1, v1_h, scaling_coeff, zero; \n\t"
"fma.rn.f32.bf16 v2, v2_h, scaling_coeff, zero; \n\t"
"fma.rn.f32.bf16 v3, v3_h, scaling_coeff, zero; \n\t"
"fma.rn.f32.bf16 v4, v4_h, scaling_coeff, zero; \n\t"
"fma.rn.f32.bf16 v5, v5_h, scaling_coeff, zero; \n\t"
"fma.rn.f32.bf16 v6, v6_h, scaling_coeff, zero; \n\t"
"fma.rn.f32.bf16 v7, v7_h, scaling_coeff, zero; \n\t"
".reg.b16 b03, b47; \n\t"
// Elements reordered to match e2m1x4 packing order (v3,v2,v1,v0)
"cvt.rs.satfinite.e2m1x4.f32 b03, {v3, v2, v1, v0}, %4; \n\t"
"cvt.rs.satfinite.e2m1x4.f32 b47, {v7, v6, v5, v4}, %5; \n\t"
"mov.b32 %0, {b03, b47};\n"
"}"
: "=r"(out_8x)
: "l"(in03), "l"(in47), "h"(reinterpret_cast<const uint16_t &>(scaling_coefficient)),
"r"(rbits03), "r"(rbits47));
} else if constexpr (std::is_same<SCALING_COEFFICIENT_TYPE, float>::value) {
asm volatile(
"{\n"
".reg.b16 v0_bf16, v1_bf16, v2_bf16, v3_bf16, v4_bf16, v5_bf16, v6_bf16, v7_bf16; \n\t"
"mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16}, %1; \n\t"
"mov.b64 {v4_bf16, v5_bf16, v6_bf16, v7_bf16}, %2; \n\t"
".reg.b32 v0, v1, v2, v3, v4, v5, v6, v7; \n\t"
"cvt.f32.bf16 v0, v0_bf16; \n\t"
"cvt.f32.bf16 v1, v1_bf16; \n\t"
"cvt.f32.bf16 v2, v2_bf16; \n\t"
"cvt.f32.bf16 v3, v3_bf16; \n\t"
"cvt.f32.bf16 v4, v4_bf16; \n\t"
"cvt.f32.bf16 v5, v5_bf16; \n\t"
"cvt.f32.bf16 v6, v6_bf16; \n\t"
"cvt.f32.bf16 v7, v7_bf16; \n\t"
"mul.f32 v0, v0, %3; \n\t"
"mul.f32 v1, v1, %3; \n\t"
"mul.f32 v2, v2, %3; \n\t"
"mul.f32 v3, v3, %3; \n\t"
"mul.f32 v4, v4, %3; \n\t"
"mul.f32 v5, v5, %3; \n\t"
"mul.f32 v6, v6, %3; \n\t"
"mul.f32 v7, v7, %3; \n\t"
".reg.b16 b03, b47; \n\t"
// Elements reordered to match e2m1x4 packing order (v3,v2,v1,v0)
"cvt.rs.satfinite.e2m1x4.f32 b03, {v3, v2, v1, v0}, %4; \n\t"
"cvt.rs.satfinite.e2m1x4.f32 b47, {v7, v6, v5, v4}, %5; \n\t"
"mov.b32 %0, {b03, b47};\n"
"}"
: "=r"(out_8x)
: "l"(in03), "l"(in47), "f"(scaling_coefficient), "r"(rbits03), "r"(rbits47));
} else {
NVTE_DEVICE_ERROR("Not supported scaling coefficient type.");
}
} else {
NVTE_DEVICE_ERROR(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX.");
}
return out_8x;
}
#endif // FP4_TYPE_SUPPORTED #endif // FP4_TYPE_SUPPORTED
// SIMD like "Fused" cast + multiplication (x2) // SIMD like "Fused" cast + multiplication (x2)
...@@ -1508,6 +1762,58 @@ __device__ __forceinline__ floatx4 up_cast(const bf16x4 &in) { ...@@ -1508,6 +1762,58 @@ __device__ __forceinline__ floatx4 up_cast(const bf16x4 &in) {
return out; return out;
} }
// Loads single BF16/FP16 element from shared memory state space
__device__ __forceinline__ bf16 ld_shared_b16(const bf16 *__restrict__ src_smem) {
const uint32_t src_smem_ptr = __cvta_generic_to_shared(src_smem);
bf16 dst;
asm volatile("ld.shared.b16 %0, [%1];"
: "=h"(reinterpret_cast<uint16_t &>(dst))
: "r"(src_smem_ptr));
return dst;
}
// Loads pair of BF16/FP16 values from shared memory state space
__device__ __forceinline__ bf16x2 ld_shared_b32(const bf16x2 *__restrict__ src_smem) {
const uint32_t src_smem_ptr = __cvta_generic_to_shared(src_smem);
bf16x2 dst;
asm volatile("ld.shared.b32 %0, [%1];"
: "=r"(reinterpret_cast<uint32_t &>(dst))
: "r"(src_smem_ptr));
return dst;
}
// Loads 8x BF16 values from shared memory state space
__device__ __forceinline__ __uint128_t ld_shared_b128(const bf16 *__restrict__ src_smem) {
uint64_t elts03, elts47;
const uint32_t src_smem_ptr = __cvta_generic_to_shared(src_smem);
asm volatile(
"{\n\t"
".reg.b128 xy; \n\t"
"ld.shared.b128 xy, [%2]; \n\t"
"mov.b128 {%0, %1}, xy; \n"
"}\n"
: "=l"(elts03), "=l"(elts47)
: "r"(src_smem_ptr));
return (static_cast<__uint128_t>(elts47) << 64) | static_cast<__uint128_t>(elts03);
}
#if FP4_TYPE_SUPPORTED
// Vectorized store of x8 FP4 elements into shared memory state space
__device__ __forceinline__ void st_shared_b32(fp4e2m1x2 *__restrict__ dst_smem,
uint32_t fp4_pack_x8) {
const uint32_t dst_smem_ptr = __cvta_generic_to_shared(dst_smem);
asm volatile("st.shared.b32 [%0], %1;" : : "r"(dst_smem_ptr), "r"(fp4_pack_x8));
}
#endif
// Vectorized store of x16 FP4 elements into shared memory state space
#if FP4_TYPE_SUPPORTED
__device__ __forceinline__ void st_shared_b64(fp4e2m1x2 *__restrict__ dst_smem,
uint64_t fp4_pack_x16) {
const uint32_t dst_smem_ptr = __cvta_generic_to_shared(dst_smem);
asm volatile("st.shared.b64 [%0], %1;" : : "r"(dst_smem_ptr), "l"(fp4_pack_x16));
}
#endif
} // namespace ptx } // namespace ptx
namespace { namespace {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment