Unverified Commit 0e80c847 authored by Oleg Goncharov's avatar Oleg Goncharov Committed by GitHub
Browse files

[Common] Split cast/gated kernels by scaling mode (#2248)



* Separated gated and dequantize kernels
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Separated quantize, dequantize and gated functions
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fixed lint issues
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fixed persistent lint issues
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Added missing compute capability 10.0 check for Quantize FP8 TMA kernels
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fixed the issue which was added again by autofix
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Changed files description. Completely removed non-identity activations from the NVFP4 transpose test suite
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Removed unsupported template arguments in NVFP4 quantize
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fixed undefined symbol error
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Fixed condition
Signed-off-by: default avatarOleg Goncharov <64355998+Oleg-Goncharov@users.noreply.github.com>

* Fixed CUDA version check
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Changed arch conditions order
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Clean up
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Small fix
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Small fix
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Fixes per the PR review
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Fix
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Split quantize helper into two (FWD and BWD) functions
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Moved activation functions from cast.cu. Removed cast.cu from the fast-math compilation list
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Enabled fast math for activations by default
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Disabled fast math for activations by default
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

---------
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: default avatarOleg Goncharov <64355998+Oleg-Goncharov@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 490a5f41
This diff is collapsed.
......@@ -36,6 +36,8 @@ __device__ inline OType sigmoid(const IType val, const Empty&) {
return 1.f / (1.f + expf(-cval));
}
__device__ inline float sigmoidf(const float x) { return __frcp_rn(1.0f + __expf(-x)); }
template <typename OType, typename IType>
__device__ inline OType dsigmoid(const IType val, const Empty& e) {
const float cval = val;
......
......@@ -449,13 +449,12 @@ static_assert(sizeof(fp16x2) == 4);
static_assert(sizeof(fp8e4m3x2) == 2);
static_assert(sizeof(fp8e5m2x2) == 2);
#if CUDA_VERSION >= 12080
#if FP4_TYPE_SUPPORTED
using fp4e2m1 = __nv_fp4_e2m1;
using fp4e2m1x2 = __nv_fp4x2_e2m1;
using fp4e2m1x4 = __nv_fp4x4_e2m1;
static_assert(sizeof(fp4e2m1x2) == 1);
static_assert(sizeof(fp4e2m1x4) == 2);
#endif // CUDA_VERSION >= 12080
// When converting to .e2m1x2 data formats, the destination operand d has .b8 type.
// When converting two .f32 inputs to .e2m1x2, each input is converted to the specified format,
......@@ -464,7 +463,6 @@ static_assert(sizeof(fp4e2m1x4) == 2);
// from input b is stored in the lower 4 bits of d.
// SIMD like "Fused" cast + multiplication (x4)
#if CUDA_VERSION >= 12080
template <typename Tx2>
__device__ __forceinline__ void mul_cvt_4x(fp4e2m1x4 &out, const Tx2 &in01, const Tx2 &in23,
const float scale) {
......@@ -474,7 +472,192 @@ __device__ __forceinline__ void mul_cvt_4x(fp4e2m1x4 &out, const Tx2 &in01, cons
const float x3 = static_cast<float>(in23.y) * scale;
out = fp4e2m1x4(make_float4(x0, x1, x2, x3));
}
#endif // CUDA_VERSION >= 12080
__device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x_with_stochastic_rounding(
const uint64_t in_4x, const float2 scale, const uint32_t rbits) {
uint16_t out_4x = 0;
constexpr bool has_rs = ARCH_HAS_STOCHASTIC_ROUNDING;
if constexpr (has_rs) {
asm volatile(
"{\n"
".reg.b64 v01; \n\t"
".reg.b64 v23; \n\t"
".reg.b16 v0_bf16; \n\t"
".reg.b16 v1_bf16; \n\t"
".reg.b16 v2_bf16; \n\t"
".reg.b16 v3_bf16; \n\t"
".reg.b32 v0; \n\t"
".reg.b32 v1; \n\t"
".reg.b32 v2; \n\t"
".reg.b32 v3; \n\t"
"mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16} , %1; \n\t"
"cvt.f32.bf16 v0, v0_bf16; \n\t"
"cvt.f32.bf16 v1, v1_bf16; \n\t"
"cvt.f32.bf16 v2, v2_bf16; \n\t"
"cvt.f32.bf16 v3, v3_bf16; \n\t"
"mov.b64 v01, {v0, v1}; \n\t"
"mov.b64 v23, {v2, v3}; \n\t"
"mul.f32x2 v01, v01, %2; \n\t" // mind the shuffled elements order
"mul.f32x2 v23, v23, %2; \n\t" // mind the shuffled elements order
"mov.b64 {v1, v0}, v01; \n\t"
"mov.b64 {v3, v2}, v23; \n\t"
"cvt.rs.satfinite.e2m1x4.f32 %0, {v2, v3, v0, v1}, %3; \n\t" // mind the shuffled elements order
"}"
: "=h"(out_4x)
: "l"(in_4x), "l"(reinterpret_cast<const uint64_t &>(scale)), "r"(rbits));
} else {
NVTE_DEVICE_ERROR(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX.");
}
return *reinterpret_cast<fp4e2m1x4 *>(&out_4x);
}
__device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x_with_rn(const uint64_t in_4x,
const float2 scale,
const uint32_t rbits) {
constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY;
uint32_t out_4x = 0; // Only need 16 bit. Using 32 bit container for packing.
if constexpr (is_blackwell) {
// NOTE: rbits unused for rn.
asm volatile(
"{\n"
".reg.b64 v01; \n\t"
".reg.b64 v23; \n\t"
".reg.b16 v0_bf16; \n\t"
".reg.b16 v1_bf16; \n\t"
".reg.b16 v2_bf16; \n\t"
".reg.b16 v3_bf16; \n\t"
".reg.b32 v0; \n\t"
".reg.b32 v1; \n\t"
".reg.b32 v2; \n\t"
".reg.b32 v3; \n\t"
".reg.b8 f0; \n\t"
".reg.b8 f1; \n\t"
"mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16} , %1; \n\t"
"cvt.f32.bf16 v0, v0_bf16; \n\t"
"cvt.f32.bf16 v1, v1_bf16; \n\t"
"cvt.f32.bf16 v2, v2_bf16; \n\t"
"cvt.f32.bf16 v3, v3_bf16; \n\t"
"mov.b64 v01, {v0, v1}; \n\t"
"mov.b64 v23, {v2, v3}; \n\t"
"mul.f32x2 v01, v01, %2; \n\t" // mind the shuffled elements order
"mul.f32x2 v23, v23, %2; \n\t" // mind the shuffled elements order
"mov.b64 {v1, v0}, v01; \n\t"
"mov.b64 {v3, v2}, v23; \n\t"
"cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;\n\t"
"cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;\n\t"
"mov.b32 %0, {f0, f1, f0, f1};\n\t"
"}"
: "=r"(out_4x)
: "l"(in_4x), "l"(reinterpret_cast<const uint64_t &>(scale)));
} else {
NVTE_DEVICE_ERROR(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX.");
}
return reinterpret_cast<fp4e2m1x4 *>(&out_4x)[0];
}
template <bool USE_STOCHASTIC_ROUNDING>
__device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x(const uint64_t in_4x,
const float2 scale,
const uint32_t rbits) {
if constexpr (USE_STOCHASTIC_ROUNDING) {
return mul_cvt_bf16_to_fp4_4x_with_stochastic_rounding(in_4x, scale, rbits);
} else {
return mul_cvt_bf16_to_fp4_4x_with_rn(in_4x, scale, rbits);
}
}
__device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x_with_stochastic_rounding(
const float2 in01, const float2 in23, const float2 scale, const uint32_t rbits) {
uint16_t out_4x = 0;
constexpr bool has_rs = ARCH_HAS_STOCHASTIC_ROUNDING;
if constexpr (has_rs) {
asm volatile(
"{\n"
".reg.b64 v01; \n\t"
".reg.b64 v23; \n\t"
".reg.b32 v0; \n\t"
".reg.b32 v1; \n\t"
".reg.b32 v2; \n\t"
".reg.b32 v3; \n\t"
"mov.b64 {v0, v1} , %1; \n\t"
"mov.b64 {v2, v3} , %2; \n\t"
"mov.b64 v01, {v0, v1}; \n\t"
"mov.b64 v23, {v2, v3}; \n\t"
"mul.f32x2 v01, v01, %3; \n\t" // mind the shuffled elements order
"mul.f32x2 v23, v23, %3; \n\t" // mind the shuffled elements order
"mov.b64 {v1, v0}, v01; \n\t"
"mov.b64 {v3, v2}, v23; \n\t"
"cvt.rs.satfinite.e2m1x4.f32 %0, {v2, v3, v0, v1}, %4; \n\t" // mind the shuffled elements order
"}"
: "=h"(out_4x)
: "l"(reinterpret_cast<const uint64_t &>(in01)),
"l"(reinterpret_cast<const uint64_t &>(in23)),
"l"(reinterpret_cast<const uint64_t &>(scale)), "r"(rbits));
} else {
NVTE_DEVICE_ERROR(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX.");
}
return *reinterpret_cast<fp4e2m1x4 *>(&out_4x);
}
__device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x_with_rn(const float2 in01,
const float2 in23,
const float2 scale,
const uint32_t rbits) {
constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY;
uint32_t out_4x = 0; // Only need 16 bit. Using 32 bit container for packing.
if constexpr (is_blackwell) {
// NOTE: rbits unused for rn.
asm volatile(
"{\n"
".reg.b64 v01; \n\t"
".reg.b64 v23; \n\t"
".reg.b32 v0; \n\t"
".reg.b32 v1; \n\t"
".reg.b32 v2; \n\t"
".reg.b32 v3; \n\t"
".reg.b8 f0; \n\t"
".reg.b8 f1; \n\t"
"mov.b64 {v0, v1} , %1; \n\t"
"mov.b64 {v2, v3} , %2; \n\t"
"mov.b64 v01, {v0, v1}; \n\t"
"mov.b64 v23, {v2, v3}; \n\t"
"mul.f32x2 v01, v01, %3; \n\t" // mind the shuffled elements order
"mul.f32x2 v23, v23, %3; \n\t" // mind the shuffled elements order
"mov.b64 {v1, v0}, v01; \n\t"
"mov.b64 {v3, v2}, v23; \n\t"
"cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;\n\t"
"cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;\n\t"
"mov.b32 %0, {f0, f1, f0, f1};\n\t"
"}"
: "=r"(out_4x)
: "l"(reinterpret_cast<const uint64_t &>(in01)),
"l"(reinterpret_cast<const uint64_t &>(in23)),
"l"(reinterpret_cast<const uint64_t &>(scale)));
} else {
NVTE_DEVICE_ERROR(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX.");
}
return reinterpret_cast<fp4e2m1x4 *>(&out_4x)[0];
}
template <bool USE_STOCHASTIC_ROUNDING>
__device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x(const float2 in01, const float2 in23,
const float2 scale,
const uint32_t rbits) {
if constexpr (USE_STOCHASTIC_ROUNDING) {
return mul_cvt_fp32_to_fp4_4x_with_stochastic_rounding(in01, in23, scale, rbits);
} else {
return mul_cvt_fp32_to_fp4_4x_with_rn(in01, in23, scale, rbits);
}
}
#endif // FP4_TYPE_SUPPORTED
// SIMD like "Fused" cast + multiplication (x2)
__device__ __forceinline__ void mul_cvt_2x(fp8e4m3x2 &out, const floatx2 &in,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment