Unverified Commit 77f9d73c authored by yjk21's avatar yjk21 Committed by GitHub
Browse files

[FastLayerNorm] Support hidden dim of 14336 (#1368)

parent f9305e75
...@@ -166,6 +166,12 @@ REGISTER_BWD_LAUNCHER(12800, fp16, fp32, fp16, fp32, 5, 1, 4, 16, 4); ...@@ -166,6 +166,12 @@ REGISTER_BWD_LAUNCHER(12800, fp16, fp32, fp16, fp32, 5, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(12800, bf16, bf16, bf16, fp32, 5, 1, 4, 8, 4); REGISTER_BWD_LAUNCHER(12800, bf16, bf16, bf16, fp32, 5, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER(12800, bf16, fp32, bf16, fp32, 5, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(12800, bf16, fp32, bf16, fp32, 5, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(14336, fp32, fp32, fp32, fp32, 4, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER(14336, fp16, fp16, fp16, fp32, 4, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER(14336, fp16, fp32, fp16, fp32, 4, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER(14336, bf16, bf16, bf16, fp32, 4, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER(14336, bf16, fp32, bf16, fp32, 4, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER(15360, fp32, fp32, fp32, fp32, 4, 1, 4, 8, 4); REGISTER_BWD_LAUNCHER(15360, fp32, fp32, fp32, fp32, 4, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER(15360, fp16, fp16, fp16, fp32, 4, 1, 4, 4, 4); REGISTER_BWD_LAUNCHER(15360, fp16, fp16, fp16, fp32, 4, 1, 4, 4, 4);
REGISTER_BWD_LAUNCHER(15360, fp16, fp32, fp16, fp32, 4, 1, 4, 8, 4); REGISTER_BWD_LAUNCHER(15360, fp16, fp32, fp16, fp32, 4, 1, 4, 8, 4);
......
...@@ -154,6 +154,12 @@ REGISTER_FWD_LAUNCHER(12800, fp16, fp32, fp16, fp32, 2, 1, 4, 4); ...@@ -154,6 +154,12 @@ REGISTER_FWD_LAUNCHER(12800, fp16, fp32, fp16, fp32, 2, 1, 4, 4);
REGISTER_FWD_LAUNCHER(12800, bf16, bf16, bf16, fp32, 2, 1, 4, 4); REGISTER_FWD_LAUNCHER(12800, bf16, bf16, bf16, fp32, 2, 1, 4, 4);
REGISTER_FWD_LAUNCHER(12800, bf16, fp32, bf16, fp32, 2, 1, 4, 4); REGISTER_FWD_LAUNCHER(12800, bf16, fp32, bf16, fp32, 2, 1, 4, 4);
REGISTER_FWD_LAUNCHER(14336, fp32, fp32, fp32, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(14336, fp16, fp16, fp16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(14336, fp16, fp32, fp16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(14336, bf16, bf16, bf16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(14336, bf16, fp32, bf16, fp32, 2, 1, 4, 8);
REGISTER_FWD_LAUNCHER(15360, fp32, fp32, fp32, fp32, 2, 1, 4, 8); REGISTER_FWD_LAUNCHER(15360, fp32, fp32, fp32, fp32, 2, 1, 4, 8);
REGISTER_FWD_LAUNCHER(15360, fp16, fp16, fp16, fp32, 2, 1, 4, 8); REGISTER_FWD_LAUNCHER(15360, fp16, fp16, fp16, fp32, 2, 1, 4, 8);
REGISTER_FWD_LAUNCHER(15360, fp16, fp32, fp16, fp32, 2, 1, 4, 8); REGISTER_FWD_LAUNCHER(15360, fp16, fp32, fp16, fp32, 2, 1, 4, 8);
......
...@@ -216,6 +216,7 @@ class TestFastLayerNorm(unittest.TestCase): ...@@ -216,6 +216,7 @@ class TestFastLayerNorm(unittest.TestCase):
10240, 10240,
12288, 12288,
12800, 12800,
14336,
15360, 15360,
16384, 16384,
18432, 18432,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment