Unverified Commit 27cb9950 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

Compile tuned RMSNorm kernels for hidden size 8192 (#747)


Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent ce328fac
...@@ -201,6 +201,10 @@ REGISTER_BWD_TUNED_LAUNCHER(4096, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); ...@@ -201,6 +201,10 @@ REGISTER_BWD_TUNED_LAUNCHER(4096, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_TUNED_LAUNCHER(4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_TUNED_LAUNCHER(4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(8192, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(8192, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(8192, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
// Create rmsnorm general launch function and register. Macro signature: // Create rmsnorm general launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, ... // HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, ...
// WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL // WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
......
...@@ -170,6 +170,13 @@ REGISTER_FWD_TUNED_LAUNCHER(4096, fp32, fp32, fp32, fp32, 1, 1, 4, 16); ...@@ -170,6 +170,13 @@ REGISTER_FWD_TUNED_LAUNCHER(4096, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(8192, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(8192, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(8192, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(8192, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(8192, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(8192, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
// Create rmsnorm general launch function and register. Macro signature: // Create rmsnorm general launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, WARPS_M, WARPS_N, BYTES_PER_LDG // HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, WARPS_M, WARPS_N, BYTES_PER_LDG
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment