Unverified Commit cb504cda authored by Oleg Goncharov's avatar Oleg Goncharov Committed by GitHub
Browse files

[Common] Improved performance of mxfp8 cast kernels (#1628)



* Fixed conflicts
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Minor code refactoring to avoid unnecessary checks
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fixed typo
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Fixed dBias accumulation error due to initialization. Minor code refactoring
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Test case to reproduce the init error
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fixed rowwise dbias error
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Changed ptx API
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Added a struct for two packed FP8 values
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Rolled back to scalar code for columnwise scaling due to its better performance
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Minor corrections
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Rebased on main
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fixes per code review
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Removed constexpr in C++ test suite to build faster
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Computed activations are now numerically truncated to InputType before scaling. Improved test suite.
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Minor refactoring
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Minor refactoring
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Modified mismatches checks of MXFP8 to address FP8 numerics
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Implemented Jeremy's fixes to JAX test suite with an intermediate downcast
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Reduced the dims of the test tensors to improve CI runtime
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Fixed memory alignment issue. Compute dbias without downcast.
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fixed misaligned memory issue also in gated kernels. Reduced size of MXFP8 gated tests
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 315b47db
This diff is collapsed.
...@@ -523,10 +523,13 @@ std::vector<size_t> unravel(const size_t i, const NVTEShape &shape) { ...@@ -523,10 +523,13 @@ std::vector<size_t> unravel(const size_t i, const NVTEShape &shape) {
void compareResults_sequential(const std::string &name, const Tensor &test, void compareResults_sequential(const std::string &name, const Tensor &test,
const void *ref, const bool rowwise, const void *ref, const bool rowwise,
double atol, double rtol, bool if_on_gpus) { double atol, double rtol, bool if_on_gpus,
const size_t tolerable_mismatches_limit) {
if (if_on_gpus) test.to_cpu(); if (if_on_gpus) test.to_cpu();
const auto& shape = rowwise ? test.rowwise_shape() : test.columnwise_shape(); const auto& shape = rowwise ? test.rowwise_shape() : test.columnwise_shape();
const size_t N = product(shape); const size_t N = product(shape);
size_t mismatches_num = 0;
int first_mismatch_idx = -1;
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(test.dtype(), T, TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(test.dtype(), T,
const T *test_data = rowwise ? test.rowwise_cpu_dptr<T>() : test.columnwise_cpu_dptr<T>(); const T *test_data = rowwise ? test.rowwise_cpu_dptr<T>() : test.columnwise_cpu_dptr<T>();
const T *ref_data = reinterpret_cast<const T*>(ref); const T *ref_data = reinterpret_cast<const T*>(ref);
...@@ -547,80 +550,102 @@ void compareResults_sequential(const std::string &name, const Tensor &test, ...@@ -547,80 +550,102 @@ void compareResults_sequential(const std::string &name, const Tensor &test,
assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r)); assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r));
} }
std::string direction = rowwise ? "rowwise" : "columnwise"; std::string direction = rowwise ? "rowwise" : "columnwise";
ASSERT_FALSE(assertion) << "Error in tensor " << name << " in " if (assertion) {
<< direction << " direction." << std::endl mismatches_num++;
<< "Mismatch at place " << to_string(unravel(i, shape)) if (first_mismatch_idx == -1) {
<< " (" << std::to_string(i) << "): " << t << " vs " << r; first_mismatch_idx = i;
}
}
if (mismatches_num > tolerable_mismatches_limit) {
const double first_mismatch_t = static_cast<double>(test_data[first_mismatch_idx]);
const double first_mismatch_r = static_cast<double>(ref_data[first_mismatch_idx]);
GTEST_FAIL() << mismatches_num << " mismatche(s) which is more than tolerable mismatch limit of "
<< tolerable_mismatches_limit << "." << std::endl
<< "Error in tensor " << name << " in "
<< direction << " direction." << std::endl
<< "First mismatch at place " << to_string(unravel(first_mismatch_idx, shape))
<< " (" << std::to_string(first_mismatch_idx) << "): "
<< first_mismatch_t << " vs " << first_mismatch_r;
}
} }
); );
} }
template <typename T> template <typename T>
static size_t getFirstMismatchIdx(const DType data_type, const T* test_data, const T* ref_data, static size_t getFirstMismatchIdx(const DType data_type, const T* test_data, const T* ref_data,
const size_t N, const double atol, const double rtol) { const size_t N, const double atol, const double rtol,
size_t& mismatches) {
int first_mismatch_idx = N; int first_mismatch_idx = N;
bool is_mismatch_found = false; #pragma omp parallel reduction(min: first_mismatch_idx) reduction(+: mismatches) proc_bind(spread)
#pragma omp parallel for schedule(static) firstprivate(is_mismatch_found) \ {
reduction(min: first_mismatch_idx) proc_bind(spread) size_t thread_mismatches = 0;
for (size_t i = 0; i < N; ++i) { #pragma omp for schedule(static)
if (is_mismatch_found) { // early escape of the omp thread for (size_t i = 0; i < N; ++i) {
continue; double t = static_cast<double>(test_data[i]);
} double r = static_cast<double>(ref_data[i]);
double t = static_cast<double>(test_data[i]);
double r = static_cast<double>(ref_data[i]);
bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol); bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol);
/* For Float32 the floating point comparison is enough to error out */ /* For Float32 the floating point comparison is enough to error out */
bool assertion = mismatch && (data_type == DType::kFloat32); bool assertion = mismatch && (data_type == DType::kFloat32);
if (mismatch && !assertion) { if (mismatch && !assertion) {
/* Check if it is just a failure of round to nearest choosing different /* Check if it is just a failure of round to nearest choosing different
side of the real value */ side of the real value */
const double mean = (t + r) / 2; const double mean = (t + r) / 2;
const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6); const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6);
const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6); const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6);
const double cast_mean_p = static_cast<double>(static_cast<T>(mean_p)); const double cast_mean_p = static_cast<double>(static_cast<T>(mean_p));
const double cast_mean_m = static_cast<double>(static_cast<T>(mean_m)); const double cast_mean_m = static_cast<double>(static_cast<T>(mean_m));
assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r)); assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r));
} }
if (assertion && i < first_mismatch_idx) { if (assertion) {
first_mismatch_idx = i; if (i < first_mismatch_idx) {
is_mismatch_found = true; first_mismatch_idx = i;
}
thread_mismatches++;
}
} }
mismatches += thread_mismatches;
} }
return first_mismatch_idx; return first_mismatch_idx;
} }
void compareResults_parallel(const std::string &name, const Tensor &test, const void *ref, void compareResults_parallel(const std::string &name, const Tensor &test, const void *ref,
const bool rowwise, double atol, double rtol, bool if_on_gpus) { const bool rowwise, double atol, double rtol, bool if_on_gpus,
const size_t tolerable_mismatches_limit) {
if (if_on_gpus) test.to_cpu(); if (if_on_gpus) test.to_cpu();
const auto& shape = rowwise ? test.rowwise_shape() : test.columnwise_shape(); const auto& shape = rowwise ? test.rowwise_shape() : test.columnwise_shape();
const size_t N = product(shape); const size_t N = product(shape);
size_t mismatches = 0;
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(test.dtype(), T, TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(test.dtype(), T,
const T *test_data = rowwise ? test.rowwise_cpu_dptr<T>() : test.columnwise_cpu_dptr<T>(); const T *test_data = rowwise ? test.rowwise_cpu_dptr<T>() : test.columnwise_cpu_dptr<T>();
const T *ref_data = reinterpret_cast<const T*>(ref); const T *ref_data = reinterpret_cast<const T*>(ref);
const size_t i = getFirstMismatchIdx<T>(test.dtype(), test_data, ref_data, N, atol, rtol); const size_t i = getFirstMismatchIdx<T>(test.dtype(), test_data, ref_data, N, atol, rtol, mismatches);
if (i != N) { if ((i != N) && (mismatches > tolerable_mismatches_limit)) {
const double t = static_cast<double>(test_data[i]); const double t = static_cast<double>(test_data[i]);
const double r = static_cast<double>(ref_data[i]); const double r = static_cast<double>(ref_data[i]);
std::string direction = rowwise ? "rowwise" : "columnwise"; std::string direction = rowwise ? "rowwise" : "columnwise";
ASSERT_FALSE(true) << "Error in tensor " << name << " in "
<< direction << " direction." << std::endl GTEST_FAIL() << mismatches << " mismatche(s) which is more than tolerable mismatch limit of "
<< "Mismatch at place " << to_string(unravel(i, shape)) << tolerable_mismatches_limit << "." << std::endl
<< " (" << std::to_string(i) << "): " << t << " vs " << r; << "Error in tensor " << name << " in "
<< direction << " direction." << std::endl
<< "Mismatch at place " << to_string(unravel(i, shape))
<< " (" << std::to_string(i) << "): " << t << " vs " << r;
} }
); );
} }
void compareResults(const std::string &name, const Tensor &test, const void *ref, void compareResults(const std::string &name, const Tensor &test, const void *ref,
const bool rowwise, double atol, double rtol, bool if_on_gpus) { const bool rowwise, double atol, double rtol, bool if_on_gpus,
const size_t tolerable_mismatches_limit) {
constexpr bool sequential = false; constexpr bool sequential = false;
if constexpr (sequential) { if constexpr (sequential) {
compareResults_sequential(name, test, ref, rowwise, atol, rtol, if_on_gpus); compareResults_sequential(name, test, ref, rowwise, atol, rtol, if_on_gpus, tolerable_mismatches_limit);
} else { } else {
compareResults_parallel(name, test, ref, rowwise, atol, rtol, if_on_gpus); compareResults_parallel(name, test, ref, rowwise, atol, rtol, if_on_gpus, tolerable_mismatches_limit);
} }
} }
...@@ -657,25 +682,39 @@ void compareResults(const std::string &name, const uint8_t *test, const uint8_t ...@@ -657,25 +682,39 @@ void compareResults(const std::string &name, const uint8_t *test, const uint8_t
} }
void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref, void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref,
const size_t row_blocks, const size_t col_blocks, const size_t stride) const size_t row_blocks, const size_t col_blocks, const size_t stride,
size_t& mismatches_num, const size_t atol,
const double abs_tolerable_mismatches_limit,
const double rel_tolerable_mismatches_limit)
{ {
const size_t N = row_blocks * col_blocks;
const size_t tolerable_mismatches_limit = std::min(abs_tolerable_mismatches_limit,
std::floor(N * rel_tolerable_mismatches_limit));
mismatches_num = 0;
std::vector<int> mismatch_indices;
for (int i = 0; i < row_blocks; ++i) { for (int i = 0; i < row_blocks; ++i) {
for (int j = 0; j < col_blocks; ++j) { for (int j = 0; j < col_blocks; ++j) {
const int idx = i * stride + j; const int idx = i * stride + j;
ASSERT_FALSE(test[idx] != ref[idx]) << "Error in " << name << std::endl const int test_val = static_cast<int>(test[idx]);
<< "Mismatch: " << static_cast<int>(test[idx]) << " vs " const int ref_val = static_cast<int>(ref[idx]);
<< static_cast<int>(ref[idx]) << " at index " << idx; const int abs_delta = std::abs(test_val - ref_val);
}
}
}
void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref, if (abs_delta > atol) {
const size_t N) mismatches_num++;
{ mismatch_indices.push_back(idx);
for (int i = 0; i < N; i++) { }
ASSERT_FALSE(test[i] != ref[i]) << "Error in " << name << std::endl if (mismatches_num > tolerable_mismatches_limit) {
<< "Mismatch: " << static_cast<int>(test[i]) << " vs " std::cout << "Error in " << name << std::endl;
<< static_cast<int>(ref[i]) << " at index " << i; for (const int index : mismatch_indices) {
std::cout << "Mismatch at (" << index << "):"
<< static_cast<int>(test[index]) << " vs "
<< static_cast<int>(ref[index]) << std::endl;
}
GTEST_FAIL() << mismatches_num << " mismatche(s) which is more than tolerable mismatch limit of "
<< tolerable_mismatches_limit << ".";
}
}
} }
} }
......
...@@ -413,7 +413,12 @@ inline fp8e8m0 float_to_e8m0(float val) { ...@@ -413,7 +413,12 @@ inline fp8e8m0 float_to_e8m0(float val) {
} }
inline float exp2f_rcp(fp8e8m0 biased_exp) { inline float exp2f_rcp(fp8e8m0 biased_exp) {
return (biased_exp == 0) ? 1 : exp2f(FP32_EXPONENT_BIAS - static_cast<float>(biased_exp)); if (biased_exp == 0) {
return 1.0f;
}
int32_t int_val = (254 - biased_exp) << FP32_MANTISSA_BITS; // 127 - (biased_exp - 127)
float fp32_val = *reinterpret_cast<float*>(&int_val);
return fp32_val;
} }
inline float identity(const float x) { return x; } inline float identity(const float x) { return x; }
...@@ -445,15 +450,18 @@ size_t last_dimension(const std::vector<size_t> &shape); ...@@ -445,15 +450,18 @@ size_t last_dimension(const std::vector<size_t> &shape);
bool areShapesEqual(const NVTEShape &s1, const NVTEShape &s2); bool areShapesEqual(const NVTEShape &s1, const NVTEShape &s2);
void compareResults(const std::string &name, const Tensor &test, const void *ref, void compareResults(const std::string &name, const Tensor &test, const void *ref,
bool rowwise, double atol = 1e-5, double rtol = 1e-8, bool if_on_gpus = true); bool rowwise, double atol = 1e-5, double rtol = 1e-8, bool if_on_gpus = true,
const size_t tolerable_mismatches_limit = 0);
void compareResults(const std::string &name, const float test, const float ref, void compareResults(const std::string &name, const float test, const float ref,
double atol = 1e-5, double rtol = 1e-8); double atol = 1e-5, double rtol = 1e-8);
void compareResults(const std::string &name, const uint8_t *test, const uint8_t *ref, void compareResults(const std::string &name, const uint8_t *test, const uint8_t *ref,
size_t N, float mismatch_rate_tol = 0.); size_t N, float mismatch_rate_tol = 0.);
void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref, void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref,
const size_t row_blocks, const size_t col_blocks, const size_t stride); const size_t row_blocks, const size_t col_blocks, const size_t stride,
void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref, size_t& mismatches_num,
const size_t N); const size_t scale_diff_abs_tolerance = 0,
const double abs_tolerable_mismatches_limit = 0,
const double rel_tolerable_mismatches_limit = 0);
std::array<size_t, 4> get_scale_tensor_dims(const size_t rows, const size_t cols, std::array<size_t, 4> get_scale_tensor_dims(const size_t rows, const size_t cols,
const size_t block_size_rows, const size_t block_size_cols); const size_t block_size_rows, const size_t block_size_cols);
......
...@@ -78,8 +78,14 @@ def is_shape_supported_by_mxfp8(input_shape): ...@@ -78,8 +78,14 @@ def is_shape_supported_by_mxfp8(input_shape):
return False return False
def assert_bitwise_scaled_tensors(a: ScaledTensor, b: ScaledTensor): def assert_bitwise_scaled_tensors(
a: ScaledTensor, b: ScaledTensor, precise_comparison: bool = True
):
if isinstance(a, ScaledTensor1x) and isinstance(b, ScaledTensor1x): if isinstance(a, ScaledTensor1x) and isinstance(b, ScaledTensor1x):
if not precise_comparison:
assert_allclose(a.dequantize(), b.dequantize(), dtype=a.data.dtype)
return
assert a.scaling_mode == b.scaling_mode assert a.scaling_mode == b.scaling_mode
assert a.scale_inv.dtype == b.scale_inv.dtype assert a.scale_inv.dtype == b.scale_inv.dtype
if a.scaling_mode.is_tensor_scaling(): if a.scaling_mode.is_tensor_scaling():
...@@ -94,8 +100,12 @@ def assert_bitwise_scaled_tensors(a: ScaledTensor, b: ScaledTensor): ...@@ -94,8 +100,12 @@ def assert_bitwise_scaled_tensors(a: ScaledTensor, b: ScaledTensor):
assert_allclose(a.data, b.data) assert_allclose(a.data, b.data)
elif isinstance(a, ScaledTensor2x) and isinstance(b, ScaledTensor2x): elif isinstance(a, ScaledTensor2x) and isinstance(b, ScaledTensor2x):
assert_bitwise_scaled_tensors(a.rowwise_tensor, b.rowwise_tensor) assert_bitwise_scaled_tensors(
assert_bitwise_scaled_tensors(a.colwise_tensor, b.colwise_tensor) a.rowwise_tensor, b.rowwise_tensor, precise_comparison=precise_comparison
)
assert_bitwise_scaled_tensors(
a.colwise_tensor, b.colwise_tensor, precise_comparison=precise_comparison
)
else: else:
pytest.fail("Unsupported input types") pytest.fail("Unsupported input types")
...@@ -481,24 +491,7 @@ class TestNorm: ...@@ -481,24 +491,7 @@ class TestNorm:
# if the input dtype is not float32 # if the input dtype is not float32
precise_comparison = False precise_comparison = False
if precise_comparison: assert_bitwise_scaled_tensors(output, ref_out, precise_comparison=precise_comparison)
assert_bitwise_scaled_tensors(output, ref_out)
else:
if isinstance(ref_out, ScaledTensor1x):
assert_allclose(output.dequantize(), ref_out.dequantize(), dtype=out_dtype)
elif isinstance(ref_out, ScaledTensor2x):
assert_allclose(
output.rowwise_tensor.dequantize(),
ref_out.rowwise_tensor.dequantize(),
dtype=out_dtype,
)
assert_allclose(
output.colwise_tensor.dequantize(),
ref_out.colwise_tensor.dequantize(),
dtype=out_dtype,
)
else:
pytest.fail("Unsupported output type")
assert_allclose(rsigma, ref_rsigma, dtype=inp_dtype) assert_allclose(rsigma, ref_rsigma, dtype=inp_dtype)
if norm_type == "layernorm": if norm_type == "layernorm":
...@@ -768,12 +761,24 @@ class TestFusedQuantize: ...@@ -768,12 +761,24 @@ class TestFusedQuantize:
)(dz, x) )(dz, x)
if is_casted_output: if is_casted_output:
assert_bitwise_scaled_tensors(te_output, jax_output) # TE kernels cast the intermediate results to the input dtype which reduces precision compared to the JAX implementation
precise_comparison = not (
in_dtype != jnp.float32 and scaling_mode.is_1d_block_scaling()
)
assert_bitwise_scaled_tensors(
te_output, jax_output, precise_comparison=precise_comparison
)
else: else:
assert_allclose(te_output, jax_output) assert_allclose(te_output, jax_output)
if is_dbias: if is_dbias:
assert_allclose(te_dbias, jax_dbias) # TE kernels cast the intermediate results to the input dtype which reduces precision compared to the JAX implementation, for dbias this typically only affects bfloat16.
precise_comparison = not (
in_dtype == jnp.bfloat16 and scaling_mode.is_1d_block_scaling()
)
assert_allclose(
te_dbias, jax_dbias, dtype=in_dtype if precise_comparison else out_dtype
)
@pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES) @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
@pytest_parametrize_wrapper("input_shape", ALL_ACTIVATION_SHAPES) @pytest_parametrize_wrapper("input_shape", ALL_ACTIVATION_SHAPES)
......
...@@ -192,6 +192,7 @@ if (NVTE_BUILD_ACTIVATION_WITH_FAST_MATH) ...@@ -192,6 +192,7 @@ if (NVTE_BUILD_ACTIVATION_WITH_FAST_MATH)
set_source_files_properties(activation/gelu.cu set_source_files_properties(activation/gelu.cu
activation/relu.cu activation/relu.cu
activation/swiglu.cu activation/swiglu.cu
util/cast.cu
PROPERTIES PROPERTIES
COMPILE_OPTIONS "--use_fast_math") COMPILE_OPTIONS "--use_fast_math")
endif() endif()
......
...@@ -162,10 +162,10 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor, ...@@ -162,10 +162,10 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor,
void *dataPtr = reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(tensor.dptr) + void *dataPtr = reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(tensor.dptr) +
(offset_elems * type_num_bits) / 8); (offset_elems * type_num_bits) / 8);
NVTE_CHECK(is_aligned_ptr(dataPtr, TMA_gmem_alignment), NVTE_CHECK(is_aligned_ptr(dataPtr, TMA_GMEM_ALIGNMENT),
"Tensor data pointer must be 16B aligned"); "Tensor data pointer must be 16B aligned");
const int TMA_needed_size = (TMA_gmem_alignment * 8) / type_num_bits; const int TMA_needed_size = (TMA_GMEM_ALIGNMENT * 8) / type_num_bits;
NVTE_CHECK(globalX % TMA_needed_size == 0, "Shape not supported. For ", type_num_bits, NVTE_CHECK(globalX % TMA_needed_size == 0, "Shape not supported. For ", type_num_bits,
"-bit data type, expected multiple of ", TMA_needed_size, ", got ", globalX); "-bit data type, expected multiple of ", TMA_needed_size, ", got ", globalX);
......
...@@ -668,7 +668,8 @@ constexpr size_t scale_tensor_alignment_X_colwise = 128; ...@@ -668,7 +668,8 @@ constexpr size_t scale_tensor_alignment_X_colwise = 128;
constexpr size_t scale_tensor_alignment_Y_colwise = 4; constexpr size_t scale_tensor_alignment_Y_colwise = 4;
// Alignment requirements for the Tensor Memory Accelerator (TMA) // Alignment requirements for the Tensor Memory Accelerator (TMA)
constexpr int TMA_gmem_alignment = 16; // global memory address alignment constexpr size_t TMA_GMEM_ALIGNMENT = 16; // global memory address alignment
constexpr size_t TMA_SHMEM_ALIGNMENT = 128; // shared memory address alignment
inline bool is_aligned_ptr(const void *ptr, size_t alignment) { inline bool is_aligned_ptr(const void *ptr, size_t alignment) {
return reinterpret_cast<uintptr_t>(ptr) % alignment == 0; return reinterpret_cast<uintptr_t>(ptr) % alignment == 0;
......
...@@ -84,8 +84,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -84,8 +84,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// const int thread_offset_X_colwise = tid_colwise_X; // const int thread_offset_X_colwise = tid_colwise_X;
// The destination shared memory buffer of a bulk tensor operation should be 128 e8m0_t aligned // The destination shared memory buffer of a bulk tensor operation should be 128 e8m0_t aligned
__shared__ alignas(128) IType in_sh[BUFFERS_NUM][SHMEM_DIM_Y][SHMEM_DIM_X]; __shared__ alignas(TMA_SHMEM_ALIGNMENT) IType in_sh[BUFFERS_NUM][SHMEM_DIM_Y][SHMEM_DIM_X];
__shared__ alignas(128) OType out_sh[BUFFERS_NUM][SHMEM_DIM_Y][SHMEM_DIM_X]; __shared__ alignas(TMA_SHMEM_ALIGNMENT) OType out_sh[BUFFERS_NUM][SHMEM_DIM_Y][SHMEM_DIM_X];
constexpr int shmem_buff_size = sizeof(in_sh) / BUFFERS_NUM; constexpr int shmem_buff_size = sizeof(in_sh) / BUFFERS_NUM;
constexpr int transaction_size = shmem_buff_size; constexpr int transaction_size = shmem_buff_size;
...@@ -166,7 +166,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -166,7 +166,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const int scale_idx = scale_offset_Y * scales_stride + scale_offset_X; const int scale_idx = scale_offset_Y * scales_stride + scale_offset_X;
const e8m0_t biased_exponent = scales_ptr[scale_idx]; const e8m0_t biased_exponent = scales_ptr[scale_idx];
const float block_scale = exp2f(static_cast<float>(biased_exponent) - FP32_EXPONENT_BIAS); const float block_scale = ptx::exp2f(biased_exponent);
if constexpr (USE_ROWWISE_SCALING) { if constexpr (USE_ROWWISE_SCALING) {
Vec<IType, ELEMS_PER_THREAD> in; Vec<IType, ELEMS_PER_THREAD> in;
......
...@@ -104,6 +104,53 @@ __device__ __forceinline__ void mbarrier_wait_parity(uint64_t *mbar, const uint3 ...@@ -104,6 +104,53 @@ __device__ __forceinline__ void mbarrier_wait_parity(uint64_t *mbar, const uint3
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
constexpr uint32_t FP32_MANTISSA_BITS = 23;
constexpr uint32_t FP32_EXPONENT_BIAS = 127;
__device__ __forceinline__ float exp2f_rcp(e8m0_t biased_exp) {
return (biased_exp == 0) ? 1
: __int_as_float((254 - biased_exp)
<< FP32_MANTISSA_BITS); // 127 - (biased_exp - 127)
}
__device__ __forceinline__ float exp2f(e8m0_t biased_exp) {
return __int_as_float(biased_exp << FP32_MANTISSA_BITS);
}
__device__ __forceinline__ e8m0_t float_to_e8m0(float val) {
#if ((__CUDA_ARCH_HAS_FEATURE__(SM100_ALL)) || (__CUDA_ARCH_HAS_FEATURE__(SM101_ALL)) || \
(__CUDA_ARCH_HAS_FEATURE__(SM120_ALL)))
uint16_t out;
asm volatile(
"{\n"
"cvt.rp.satfinite.ue8m0x2.f32 %0, 0.0, %1;\n"
"}"
: "=h"(out)
: "f"(val));
return *reinterpret_cast<e8m0_t *>(&out);
#else
// TODO: nan/inf needs to be set for any value
// of nan/inf in input not just amax.
if (isnan(val)) {
return 0xFF;
}
if (isinf(val)) {
return 0xFE;
}
if (val == 0.0f) {
return 0x00;
}
uint32_t val_u32 = *reinterpret_cast<uint32_t *>(&val);
e8m0_t exponent = (val_u32 >> FP32_MANTISSA_BITS);
uint32_t mantissa = val_u32 & 0x7FFFFF;
// Round up exponent and deal with satfinite.
if ((mantissa > 0 && exponent != 0xFE) && !(exponent == 0 && mantissa <= 0x400000)) {
++exponent;
}
return exponent;
#endif
}
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor
...@@ -169,6 +216,159 @@ __device__ __forceinline__ void fence_proxy_async_shared_cta() { ...@@ -169,6 +216,159 @@ __device__ __forceinline__ void fence_proxy_async_shared_cta() {
asm volatile("fence.proxy.async.shared::cta;"); asm volatile("fence.proxy.async.shared::cta;");
} }
template <typename T>
struct alignas(2 * sizeof(T)) FPx2 {
T x;
T y;
};
using floatx2 = FPx2<float>;
using bf16x2 = FPx2<bf16>;
using fp16x2 = FPx2<fp16>;
using fp8e4m3x2 = FPx2<fp8e4m3>;
using fp8e5m2x2 = FPx2<fp8e5m2>;
static_assert(sizeof(floatx2) == 8);
static_assert(sizeof(bf16x2) == 4);
static_assert(sizeof(fp16x2) == 4);
static_assert(sizeof(fp8e4m3x2) == 2);
static_assert(sizeof(fp8e5m2x2) == 2);
// SIMD like "Fused" cast + multiplication (x2)
__device__ __forceinline__ void mul_cvt_2x(fp8e4m3x2 &out, const floatx2 &in,
const floatx2 &scale) {
asm volatile(
"{\n"
".reg.b64 val_pair; \n\t"
".reg.b32 val1; \n\t"
".reg.b32 val2; \n\t"
"mul.f32x2 val_pair, %1, %2; \n\t"
"mov.b64 {val2,val1}, val_pair; \n\t"
"cvt.rn.satfinite.e4m3x2.f32 %0, val1, val2; \n\t"
"}"
: "=h"(reinterpret_cast<uint16_t &>(out))
: "l"(reinterpret_cast<const uint64_t &>(in)),
"l"(reinterpret_cast<const uint64_t &>(scale)));
}
__device__ __forceinline__ void mul_cvt_2x(fp8e5m2x2 &out, const floatx2 &in,
const floatx2 &scale) {
asm volatile(
"{\n"
".reg.b64 val_pair; \n\t"
".reg.b32 val1; \n\t"
".reg.b32 val2; \n\t"
"mul.f32x2 val_pair, %1, %2; \n\t"
"mov.b64 {val2,val1}, val_pair; \n\t"
"cvt.rn.satfinite.e5m2x2.f32 %0, val1, val2; \n\t"
"}"
: "=h"(reinterpret_cast<uint16_t &>(out))
: "l"(reinterpret_cast<const uint64_t &>(in)),
"l"(reinterpret_cast<const uint64_t &>(scale)));
}
__device__ __forceinline__ void mul_cvt_2x(fp8e4m3x2 &out, const bf16x2 &in, const floatx2 &scale) {
asm volatile(
"{\n"
".reg.b64 val_pair_before; \n\t"
".reg.b64 val_pair_after; \n\t"
".reg.b32 val1; \n\t"
".reg.b32 val2; \n\t"
".reg.b16 val1_bf16; \n\t"
".reg.b16 val2_bf16; \n\t"
"mov.b32 {val1_bf16, val2_bf16} , %1; \n\t"
"cvt.f32.bf16 val1, val1_bf16; \n\t"
"cvt.f32.bf16 val2, val2_bf16; \n\t"
"mov.b64 val_pair_before, {val1,val2}; \n\t"
"mul.f32x2 val_pair_after, val_pair_before, %2; \n\t"
"mov.b64 {val2,val1}, val_pair_after; \n\t"
"cvt.rn.satfinite.e4m3x2.f32 %0, val1, val2; \n\t"
"}"
: "=h"(reinterpret_cast<uint16_t &>(out))
: "r"(reinterpret_cast<const uint32_t &>(in)),
"l"(reinterpret_cast<const uint64_t &>(scale)));
}
__device__ __forceinline__ void mul_cvt_2x(fp8e5m2x2 &out, const bf16x2 &in, const floatx2 &scale) {
asm volatile(
"{\n"
".reg.b64 val_pair_before; \n\t"
".reg.b64 val_pair_after; \n\t"
".reg.b32 val1; \n\t"
".reg.b32 val2; \n\t"
".reg.b16 val1_bf16; \n\t"
".reg.b16 val2_bf16; \n\t"
"mov.b32 {val1_bf16, val2_bf16} , %1; \n\t"
"cvt.f32.bf16 val1, val1_bf16; \n\t"
"cvt.f32.bf16 val2, val2_bf16; \n\t"
"mov.b64 val_pair_before, {val1,val2}; \n\t"
"mul.f32x2 val_pair_after, val_pair_before, %2; \n\t"
"mov.b64 {val2,val1}, val_pair_after; \n\t"
"cvt.rn.satfinite.e5m2x2.f32 %0, val1, val2; \n\t"
"}"
: "=h"(reinterpret_cast<uint16_t &>(out))
: "r"(reinterpret_cast<const uint32_t &>(in)),
"l"(reinterpret_cast<const uint64_t &>(scale)));
}
__device__ __forceinline__ void mul_cvt_2x(fp8e4m3x2 &out, const fp16x2 &in, const floatx2 &scale) {
asm volatile(
"{\n"
".reg.b64 val_pair_before; \n\t"
".reg.b64 val_pair_after; \n\t"
".reg.b32 val1; \n\t"
".reg.b32 val2; \n\t"
".reg.b16 val1_fp16; \n\t"
".reg.b16 val2_fp16; \n\t"
"mov.b32 {val1_fp16, val2_fp16} , %1; \n\t"
"cvt.f32.f16 val1, val1_fp16; \n\t"
"cvt.f32.f16 val2, val2_fp16; \n\t"
"mov.b64 val_pair_before, {val1,val2}; \n\t"
"mul.f32x2 val_pair_after, val_pair_before, %2; \n\t"
"mov.b64 {val2,val1}, val_pair_after; \n\t"
"cvt.rn.satfinite.e4m3x2.f32 %0, val1, val2; \n\t"
"}"
: "=h"(reinterpret_cast<uint16_t &>(out))
: "r"(reinterpret_cast<const uint32_t &>(in)),
"l"(reinterpret_cast<const uint64_t &>(scale)));
}
__device__ __forceinline__ void mul_cvt_2x(fp8e5m2x2 &out, const fp16x2 &in, const floatx2 &scale) {
asm volatile(
"{\n"
".reg.b64 val_pair_before; \n\t"
".reg.b64 val_pair_after; \n\t"
".reg.b32 val1; \n\t"
".reg.b32 val2; \n\t"
".reg.b16 val1_fp16; \n\t"
".reg.b16 val2_fp16; \n\t"
"mov.b32 {val1_fp16, val2_fp16} , %1; \n\t"
"cvt.f32.f16 val1, val1_fp16; \n\t"
"cvt.f32.f16 val2, val2_fp16; \n\t"
"mov.b64 val_pair_before, {val1,val2}; \n\t"
"mul.f32x2 val_pair_after, val_pair_before, %2; \n\t"
"mov.b64 {val2,val1}, val_pair_after; \n\t"
"cvt.rn.satfinite.e5m2x2.f32 %0, val1, val2; \n\t"
"}"
: "=h"(reinterpret_cast<uint16_t &>(out))
: "r"(reinterpret_cast<const uint32_t &>(in)),
"l"(reinterpret_cast<const uint64_t &>(scale)));
}
__device__ __forceinline__ void abs_max_2x(bf16x2 &dst, const bf16x2 &p1, const bf16x2 &p2) {
asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;"
: "=r"(reinterpret_cast<uint32_t &>(dst))
: "r"(reinterpret_cast<const uint32_t &>(p1)),
"r"(reinterpret_cast<const uint32_t &>(p2)));
}
__device__ __forceinline__ void abs_max_2x(fp16x2 &dst, const fp16x2 &p1, const fp16x2 &p2) {
asm volatile("max.xorsign.abs.f16x2 %0, %1, %2;"
: "=r"(reinterpret_cast<uint32_t &>(dst))
: "r"(reinterpret_cast<const uint32_t &>(p1)),
"r"(reinterpret_cast<const uint32_t &>(p2)));
}
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
} // namespace ptx } // namespace ptx
......
...@@ -905,10 +905,7 @@ using fp8e4m3 = __nv_fp8_e4m3; ...@@ -905,10 +905,7 @@ using fp8e4m3 = __nv_fp8_e4m3;
using fp8e5m2 = __nv_fp8_e5m2; using fp8e5m2 = __nv_fp8_e5m2;
using e8m0_t = uint8_t; using e8m0_t = uint8_t;
constexpr uint32_t FP32_MANTISSA_BITS = 23; enum ScalingType { ROWWISE = 0, COLWISE = 1, BIDIMENSIONAL = 2 };
constexpr uint32_t FP32_EXPONENT_BIAS = 127;
enum ScalingType { ROWWISE = 0, COLWISE = 1, BIDIMENTIONAL = 2 };
template <typename T> template <typename T>
struct Numeric_Traits; struct Numeric_Traits;
...@@ -934,44 +931,6 @@ struct Quantized_Limits { ...@@ -934,44 +931,6 @@ struct Quantized_Limits {
static constexpr float emax_rcp = 1.0 / emax; static constexpr float emax_rcp = 1.0 / emax;
}; };
__device__ __forceinline__ e8m0_t float_to_e8m0(float val) {
// TODO: nan/inf needs to be set for any value
// of nan/inf in input not just amax.
if (isnan(val)) {
return 0xFF;
}
if (isinf(val)) {
return 0xFE;
}
#if ((__CUDA_ARCH_HAS_FEATURE__(SM100_ALL)) || (__CUDA_ARCH_HAS_FEATURE__(SM101_ALL)) || \
(__CUDA_ARCH_HAS_FEATURE__(SM120_ALL)))
uint16_t out;
asm volatile(
"{\n"
"cvt.rp.satfinite.ue8m0x2.f32 %0, 0.0, %1;\n"
"}"
: "=h"(out)
: "f"(val));
return *reinterpret_cast<e8m0_t *>(&out);
#else
if (val == 0.0f) {
return 0x00;
}
uint32_t val_u32 = *reinterpret_cast<uint32_t *>(&val);
e8m0_t exponent = (val_u32 >> FP32_MANTISSA_BITS);
uint32_t mantissa = val_u32 & 0x7FFFFF;
// Round up exponent and deal with satfinite.
if ((mantissa > 0 && exponent != 0xFE) && !(exponent == 0 && mantissa <= 0x400000)) {
++exponent;
}
return exponent;
#endif
}
__device__ __forceinline__ float exp2f_rcp(e8m0_t biased_exp) {
return (biased_exp == 0) ? 1 : exp2f(FP32_EXPONENT_BIAS - static_cast<float>(biased_exp));
}
} // namespace transformer_engine } // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_COMMON_UTILS_CUH_ #endif // TRANSFORMER_ENGINE_COMMON_UTILS_CUH_
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment