Commit a30bd5da authored by carlushuang's avatar carlushuang
Browse files

update instance for 2048, 1024 case

parent 24bebf15
...@@ -50,11 +50,11 @@ float layernorm2d_fwd_b16_(layernorm2d_fwd_traits /*t*/, ...@@ -50,11 +50,11 @@ float layernorm2d_fwd_b16_(layernorm2d_fwd_traits /*t*/,
} }
else if(a.n <= 1024) { else if(a.n <= 1024) {
if (a.n % 8 == 0) if (a.n % 8 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 1, 1, 128, 8, true, false, false>>(s, a); r = layernorm2d_fwd_<trait_<data_type, 1, 1, 2, 128, 8, true, false, false>>(s, a);
else if (a.n % 4 == 0) else if (a.n % 4 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 2, 1, 128, 4, true, false, false>>(s, a); r = layernorm2d_fwd_<trait_<data_type, 1, 2, 2, 128, 4, true, false, false>>(s, a);
else if (a.n % 2 == 0) else if (a.n % 2 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 4, 1, 128, 2, true, false, false>>(s, a); r = layernorm2d_fwd_<trait_<data_type, 1, 4, 2, 128, 2, true, false, false>>(s, a);
else else
r = layernorm2d_fwd_<trait_<data_type, 1, 4, 1, 256, 1, true, false, false>>(s, a); r = layernorm2d_fwd_<trait_<data_type, 1, 4, 1, 256, 1, true, false, false>>(s, a);
} }
...@@ -62,7 +62,7 @@ float layernorm2d_fwd_b16_(layernorm2d_fwd_traits /*t*/, ...@@ -62,7 +62,7 @@ float layernorm2d_fwd_b16_(layernorm2d_fwd_traits /*t*/,
if (a.n % 8 == 0) if (a.n % 8 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 1, 1, 256, 8, true, false, false>>(s, a); r = layernorm2d_fwd_<trait_<data_type, 1, 1, 1, 256, 8, true, false, false>>(s, a);
else if (a.n % 4 == 0) else if (a.n % 4 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 2, 1, 128, 8, true, false, false>>(s, a); r = layernorm2d_fwd_<trait_<data_type, 1, 2, 1, 256, 4, true, false, false>>(s, a);
else if (a.n % 2 == 0) else if (a.n % 2 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 4, 1, 256, 2, true, false, false>>(s, a); r = layernorm2d_fwd_<trait_<data_type, 1, 4, 1, 256, 2, true, false, false>>(s, a);
else else
......
...@@ -15,8 +15,8 @@ template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 16, 4, 64, 1, true ...@@ -15,8 +15,8 @@ template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 16, 4, 64, 1, true
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 1, 1, 256, 4, true , false, false>>(const S&, A); template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 1, 1, 256, 4, true , false, false>>(const S&, A);
#endif #endif
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 1, 1, 128, 8, true, false, false>>(const S&, A); template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 1, 2, 128, 8, true, false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 1, 128, 4, true, false, false>>(const S&, A); template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 2, 128, 4, true, false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 1, 128, 2, true, false, false>>(const S&, A); template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 2, 128, 2, true, false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 1, 256, 1, true, false, false>>(const S&, A); template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 1, 256, 1, true, false, false>>(const S&, A);
// clang-format on // clang-format on
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
// clang-format off // clang-format off
// rm rn tm tn vn pd mv 2p // rm rn tm tn vn pd mv 2p
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 1, 1, 256, 8, true, false, false>>(const S&, A); template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 1, 1, 256, 8, true, false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 1, 128, 8, true, false, false>>(const S&, A); template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 1, 256, 4, true, false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 1, 256, 2, true, false, false>>(const S&, A); template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 1, 256, 2, true, false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 8, 1, 256, 1, true, false, false>>(const S&, A); template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 8, 1, 256, 1, true, false, false>>(const S&, A);
......
...@@ -15,8 +15,8 @@ template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 16, 4, 64, 1, true ...@@ -15,8 +15,8 @@ template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 16, 4, 64, 1, true
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 1, 1, 256, 4, true , false, false>>(const S&, A); template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 1, 1, 256, 4, true , false, false>>(const S&, A);
#endif #endif
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 1, 1, 128, 8, true, false, false>>(const S&, A); template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 1, 2, 128, 8, true, false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 1, 128, 4, true, false, false>>(const S&, A); template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 2, 128, 4, true, false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 1, 128, 2, true, false, false>>(const S&, A); template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 2, 128, 2, true, false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 1, 256, 1, true, false, false>>(const S&, A); template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 1, 256, 1, true, false, false>>(const S&, A);
// clang-format on // clang-format on
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
// clang-format off // clang-format off
// rm rn tm tn vn pd mv 2p // rm rn tm tn vn pd mv 2p
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 1, 1, 256, 8, true, false, false>>(const S&, A); template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 1, 1, 256, 8, true, false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 1, 128, 8, true, false, false>>(const S&, A); template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 1, 256, 4, true, false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 1, 256, 2, true, false, false>>(const S&, A); template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 1, 256, 2, true, false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 8, 1, 256, 1, true, false, false>>(const S&, A); template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 8, 1, 256, 1, true, false, false>>(const S&, A);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment