Commit 98395085 authored by rocking's avatar rocking
Browse files

Add kMThreadPerBlock to template parameter

parent 03247367
...@@ -5,28 +5,30 @@ ...@@ -5,28 +5,30 @@
#include "layernorm2d_fwd_instance_common.hpp" #include "layernorm2d_fwd_instance_common.hpp"
template <ck_tile::index_t NRepeat, template <ck_tile::index_t NRepeat,
ck_tile::index_t NThread, ck_tile::index_t kMThreadPerBlock,
ck_tile::index_t kNThreadPerBlock,
ck_tile::index_t VectorAccessSize, ck_tile::index_t VectorAccessSize,
bool kTwoPass> bool kTwoPass>
using t = layernorm2d_fwd_traits_<ck_tile::bf16_t, using t = layernorm2d_fwd_traits_<ck_tile::bf16_t,
NRepeat, NRepeat,
NThread, kMThreadPerBlock,
kNThreadPerBlock,
VectorAccessSize, VectorAccessSize,
false, false,
false, false,
kTwoPass>; kTwoPass>;
// Disable all vector 8fp16 read/write instances as it has performance issue regarding compiler // Disable all vector 8fp16 read/write instances as it has performance issue regarding compiler
// template float layernorm2d_fwd_<t<1, 16, 8, false>>(const S&, A); // template float layernorm2d_fwd_<t<1, 4, 16, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<1, 32, 8, false>>(const S&, A); // template float layernorm2d_fwd_<t<1, 4, 32, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<1, 64, 8, false>>(const S&, A); // template float layernorm2d_fwd_<t<1, 4, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<2, 64, 8, false>>(const S&, A); // template float layernorm2d_fwd_<t<2, 4, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<4, 64, 8, false>>(const S&, A); // template float layernorm2d_fwd_<t<4, 4, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<4, 64, 8, true>>(const S&, A); // template float layernorm2d_fwd_<t<4, 4, 64, 8, true>>(const S&, A);
template float layernorm2d_fwd_<t<1, 32, 4, false>>(const S&, A); template float layernorm2d_fwd_<t<1, 4, 32, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<1, 64, 4, false>>(const S&, A); template float layernorm2d_fwd_<t<1, 4, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<2, 64, 4, false>>(const S&, A); template float layernorm2d_fwd_<t<2, 4, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<4, 64, 4, false>>(const S&, A); template float layernorm2d_fwd_<t<4, 4, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<8, 64, 4, false>>(const S&, A); template float layernorm2d_fwd_<t<8, 4, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<8, 64, 4, true>>(const S&, A); template float layernorm2d_fwd_<t<8, 4, 64, 4, true>>(const S&, A);
...@@ -5,12 +5,14 @@ ...@@ -5,12 +5,14 @@
#include "layernorm2d_fwd_instance_common.hpp" #include "layernorm2d_fwd_instance_common.hpp"
template <ck_tile::index_t NRepeat, template <ck_tile::index_t NRepeat,
ck_tile::index_t NThread, ck_tile::index_t kMThreadPerBlock,
ck_tile::index_t kNThreadPerBlock,
ck_tile::index_t VectorAccessSize, ck_tile::index_t VectorAccessSize,
bool kTwoPass> bool kTwoPass>
using t = layernorm2d_fwd_traits_<ck_tile::bf16_t, using t = layernorm2d_fwd_traits_<ck_tile::bf16_t,
NRepeat, NRepeat,
NThread, kMThreadPerBlock,
kNThreadPerBlock,
VectorAccessSize, VectorAccessSize,
true, true,
false, false,
...@@ -24,19 +26,19 @@ using t = layernorm2d_fwd_traits_<ck_tile::bf16_t, ...@@ -24,19 +26,19 @@ using t = layernorm2d_fwd_traits_<ck_tile::bf16_t,
// template float layernorm2d_fwd_<t<4, 64, 8, false>>(const S&, A); // template float layernorm2d_fwd_<t<4, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<4, 64, 8, true>>(const S&, A); // template float layernorm2d_fwd_<t<4, 64, 8, true>>(const S&, A);
template float layernorm2d_fwd_<t<1, 32, 4, false>>(const S&, A); template float layernorm2d_fwd_<t<1, 4, 32, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<1, 64, 4, false>>(const S&, A); template float layernorm2d_fwd_<t<1, 4, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<2, 64, 4, false>>(const S&, A); template float layernorm2d_fwd_<t<2, 4, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<4, 64, 4, false>>(const S&, A); template float layernorm2d_fwd_<t<4, 4, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<8, 64, 4, false>>(const S&, A); template float layernorm2d_fwd_<t<8, 4, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<8, 64, 4, true>>(const S&, A); template float layernorm2d_fwd_<t<8, 4, 64, 4, true>>(const S&, A);
template float layernorm2d_fwd_<t<1, 64, 2, false>>(const S&, A); template float layernorm2d_fwd_<t<1, 4, 64, 2, false>>(const S&, A);
template float layernorm2d_fwd_<t<2, 64, 2, false>>(const S&, A); template float layernorm2d_fwd_<t<2, 4, 64, 2, false>>(const S&, A);
template float layernorm2d_fwd_<t<4, 64, 2, false>>(const S&, A); template float layernorm2d_fwd_<t<4, 4, 64, 2, false>>(const S&, A);
template float layernorm2d_fwd_<t<8, 64, 2, false>>(const S&, A); template float layernorm2d_fwd_<t<8, 4, 64, 2, false>>(const S&, A);
template float layernorm2d_fwd_<t<16, 64, 2, false>>(const S&, A); template float layernorm2d_fwd_<t<16, 4, 64, 2, false>>(const S&, A);
template float layernorm2d_fwd_<t<16, 64, 2, true>>(const S&, A); template float layernorm2d_fwd_<t<16, 4, 64, 2, true>>(const S&, A);
template float layernorm2d_fwd_<t<32, 64, 1, false>>(const S&, A); template float layernorm2d_fwd_<t<32, 4, 64, 1, false>>(const S&, A);
template float layernorm2d_fwd_<t<32, 64, 1, true>>(const S&, A); template float layernorm2d_fwd_<t<32, 4, 64, 1, true>>(const S&, A);
...@@ -5,28 +5,30 @@ ...@@ -5,28 +5,30 @@
#include "layernorm2d_fwd_instance_common.hpp" #include "layernorm2d_fwd_instance_common.hpp"
template <ck_tile::index_t NRepeat, template <ck_tile::index_t NRepeat,
ck_tile::index_t NThread, ck_tile::index_t kMThreadPerBlock,
ck_tile::index_t kNThreadPerBlock,
ck_tile::index_t VectorAccessSize, ck_tile::index_t VectorAccessSize,
bool kTwoPass> bool kTwoPass>
using t = layernorm2d_fwd_traits_<ck_tile::fp16_t, using t = layernorm2d_fwd_traits_<ck_tile::fp16_t,
NRepeat, NRepeat,
NThread, kMThreadPerBlock,
kNThreadPerBlock,
VectorAccessSize, VectorAccessSize,
false, false,
false, false,
kTwoPass>; kTwoPass>;
// Disable all vector 8fp16 read/write instances as it has performance issue regarding compiler // Disable all vector 8fp16 read/write instances as it has performance issue regarding compiler
// template float layernorm2d_fwd_<t<1, 16, 8, false>>(const S&, A); // template float layernorm2d_fwd_<t<1, 4, 16, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<1, 32, 8, false>>(const S&, A); // template float layernorm2d_fwd_<t<1, 4, 32, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<1, 64, 8, false>>(const S&, A); // template float layernorm2d_fwd_<t<1, 4, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<2, 64, 8, false>>(const S&, A); // template float layernorm2d_fwd_<t<2, 4, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<4, 64, 8, false>>(const S&, A); // template float layernorm2d_fwd_<t<4, 4, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<4, 64, 8, true>>(const S&, A); // template float layernorm2d_fwd_<t<4, 4, 64, 8, true>>(const S&, A);
template float layernorm2d_fwd_<t<1, 32, 4, false>>(const S&, A); template float layernorm2d_fwd_<t<1, 4, 32, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<1, 64, 4, false>>(const S&, A); template float layernorm2d_fwd_<t<1, 4, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<2, 64, 4, false>>(const S&, A); template float layernorm2d_fwd_<t<2, 4, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<4, 64, 4, false>>(const S&, A); template float layernorm2d_fwd_<t<4, 4, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<8, 64, 4, false>>(const S&, A); template float layernorm2d_fwd_<t<8, 4, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<8, 64, 4, true>>(const S&, A); template float layernorm2d_fwd_<t<8, 4, 64, 4, true>>(const S&, A);
...@@ -5,38 +5,40 @@ ...@@ -5,38 +5,40 @@
#include "layernorm2d_fwd_instance_common.hpp" #include "layernorm2d_fwd_instance_common.hpp"
template <ck_tile::index_t NRepeat, template <ck_tile::index_t NRepeat,
ck_tile::index_t NThread, ck_tile::index_t kMThreadPerBlock,
ck_tile::index_t kNThreadPerBlock,
ck_tile::index_t VectorAccessSize, ck_tile::index_t VectorAccessSize,
bool kTwoPass> bool kTwoPass>
using t = layernorm2d_fwd_traits_<ck_tile::fp16_t, using t = layernorm2d_fwd_traits_<ck_tile::fp16_t,
NRepeat, NRepeat,
NThread, kMThreadPerBlock,
kNThreadPerBlock,
VectorAccessSize, VectorAccessSize,
true, true,
false, false,
kTwoPass>; kTwoPass>;
// Disable all vector 8fp16 read/write instances as it has performance issue regarding compiler // Disable all vector 8fp16 read/write instances as it has performance issue regarding compiler
// template float layernorm2d_fwd_<t<1, 16, 8, false>>(const S&, A); // template float layernorm2d_fwd_<t<1, 4, 16, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<1, 32, 8, false>>(const S&, A); // template float layernorm2d_fwd_<t<1, 4, 32, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<1, 64, 8, false>>(const S&, A); // template float layernorm2d_fwd_<t<1, 4, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<2, 64, 8, false>>(const S&, A); // template float layernorm2d_fwd_<t<2, 4, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<4, 64, 8, false>>(const S&, A); // template float layernorm2d_fwd_<t<4, 4, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<4, 64, 8, true>>(const S&, A); // template float layernorm2d_fwd_<t<4, 4, 64, 8, true>>(const S&, A);
template float layernorm2d_fwd_<t<1, 32, 4, false>>(const S&, A); template float layernorm2d_fwd_<t<1, 4, 32, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<1, 64, 4, false>>(const S&, A); template float layernorm2d_fwd_<t<1, 4, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<2, 64, 4, false>>(const S&, A); template float layernorm2d_fwd_<t<2, 4, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<4, 64, 4, false>>(const S&, A); template float layernorm2d_fwd_<t<4, 4, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<8, 64, 4, false>>(const S&, A); template float layernorm2d_fwd_<t<8, 4, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<8, 64, 4, true>>(const S&, A); template float layernorm2d_fwd_<t<8, 4, 64, 4, true>>(const S&, A);
template float layernorm2d_fwd_<t<1, 64, 2, false>>(const S&, A); template float layernorm2d_fwd_<t<1, 4, 64, 2, false>>(const S&, A);
template float layernorm2d_fwd_<t<2, 64, 2, false>>(const S&, A); template float layernorm2d_fwd_<t<2, 4, 64, 2, false>>(const S&, A);
template float layernorm2d_fwd_<t<4, 64, 2, false>>(const S&, A); template float layernorm2d_fwd_<t<4, 4, 64, 2, false>>(const S&, A);
template float layernorm2d_fwd_<t<8, 64, 2, false>>(const S&, A); template float layernorm2d_fwd_<t<8, 4, 64, 2, false>>(const S&, A);
template float layernorm2d_fwd_<t<16, 64, 2, false>>(const S&, A); template float layernorm2d_fwd_<t<16, 4, 64, 2, false>>(const S&, A);
template float layernorm2d_fwd_<t<16, 64, 2, true>>(const S&, A); template float layernorm2d_fwd_<t<16, 4, 64, 2, true>>(const S&, A);
template float layernorm2d_fwd_<t<32, 64, 1, false>>(const S&, A); template float layernorm2d_fwd_<t<32, 4, 64, 1, false>>(const S&, A);
template float layernorm2d_fwd_<t<32, 64, 1, true>>(const S&, A); template float layernorm2d_fwd_<t<32, 4, 64, 1, true>>(const S&, A);
...@@ -52,7 +52,8 @@ struct layernorm2d_fwd_args ...@@ -52,7 +52,8 @@ struct layernorm2d_fwd_args
// this is used to pattern-match internl kernel implementation, not to instantiate kernel // this is used to pattern-match internl kernel implementation, not to instantiate kernel
template <typename DataType_, template <typename DataType_,
ck_tile::index_t NRepeat, ck_tile::index_t NRepeat,
ck_tile::index_t NThread, ck_tile::index_t kMThreadPerBlock,
ck_tile::index_t kNThreadPerBlock,
ck_tile::index_t VectorAccessSize, ck_tile::index_t VectorAccessSize,
bool kPadN_, bool kPadN_,
bool kSaveMeanInvStd_, bool kSaveMeanInvStd_,
...@@ -62,14 +63,17 @@ struct layernorm2d_fwd_traits_ ...@@ -62,14 +63,17 @@ struct layernorm2d_fwd_traits_
using DataType = ck_tile::remove_cvref_t<DataType_>; using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr ck_tile::index_t MRepeat = 1; static constexpr ck_tile::index_t MRepeat = 1;
static_assert(NThread <= 64, "We only support intra-wave reduction"); static_assert(kNThreadPerBlock <= 64, "We only support intra-wave reduction");
static constexpr ck_tile::index_t WaveNum = NThread / 16; static constexpr ck_tile::index_t kNWarpPerBlock = 1;
static constexpr ck_tile::index_t kMWarpPerBlock =
kMThreadPerBlock * kNThreadPerBlock / warpSize;
// kNThreadPerBlock / 16;
using thread_tile = ck_tile::sequence<MRepeat, NRepeat, VectorAccessSize>; using thread_tile = ck_tile::sequence<MRepeat, NRepeat, VectorAccessSize>;
using warp_tile = using warp_tile = ck_tile::sequence<MRepeat * warpSize / kNThreadPerBlock,
ck_tile::sequence<MRepeat * 64 / NThread, NRepeat * NThread * VectorAccessSize>; NRepeat * kNThreadPerBlock * VectorAccessSize>;
using block_tile = using block_tile = ck_tile::sequence<kMWarpPerBlock * MRepeat * warpSize / kNThreadPerBlock,
ck_tile::sequence<MRepeat * WaveNum * 64 / NThread, NRepeat * NThread * VectorAccessSize>; NRepeat * kNThreadPerBlock * VectorAccessSize>;
using Shape = ck_tile::TileLayernorm2dShape<thread_tile, warp_tile, block_tile>; using Shape = ck_tile::TileLayernorm2dShape<thread_tile, warp_tile, block_tile>;
......
...@@ -6,12 +6,19 @@ ...@@ -6,12 +6,19 @@
template <typename DataType, template <typename DataType,
ck_tile::index_t NRepeat, ck_tile::index_t NRepeat,
ck_tile::index_t NThread, ck_tile::index_t kMThreadPerBlock,
ck_tile::index_t kNThreadPerBlock,
ck_tile::index_t VectorAccessSize, ck_tile::index_t VectorAccessSize,
bool kPadN, bool kPadN,
bool kTwoPass = false> bool kTwoPass = false>
using trait_ = using trait_ = layernorm2d_fwd_traits_<DataType,
layernorm2d_fwd_traits_<DataType, NRepeat, NThread, VectorAccessSize, kPadN, false, kTwoPass>; NRepeat,
kMThreadPerBlock,
kNThreadPerBlock,
VectorAccessSize,
kPadN,
false,
kTwoPass>;
float layernorm2d_fwd(layernorm2d_fwd_traits t, float layernorm2d_fwd(layernorm2d_fwd_traits t,
layernorm2d_fwd_args a, layernorm2d_fwd_args a,
...@@ -24,70 +31,75 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, ...@@ -24,70 +31,75 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
{ {
if(a.N <= 128) if(a.N <= 128)
{ {
return a.N == 128 ? layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 32, 4, false>>(s, a) return a.N == 128
: layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 32, 4, true>>(s, a); ? layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 32, 4, false>>(s, a)
: layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 32, 4, true>>(s, a);
} }
else if(a.N <= 256) else if(a.N <= 256)
{ {
return a.N == 256 ? layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 64, 4, false>>(s, a) return a.N == 256
: layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 64, 4, true>>(s, a); ? layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 64, 4, false>>(s, a)
: layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 64, 4, true>>(s, a);
} }
else if(a.N <= 512) else if(a.N <= 512)
{ {
return a.N == 512 ? layernorm2d_fwd_<trait_<ck_tile::fp16_t, 2, 64, 4, false>>(s, a) return a.N == 512
: layernorm2d_fwd_<trait_<ck_tile::fp16_t, 2, 64, 4, true>>(s, a); ? layernorm2d_fwd_<trait_<ck_tile::fp16_t, 2, 4, 64, 4, false>>(s, a)
: layernorm2d_fwd_<trait_<ck_tile::fp16_t, 2, 4, 64, 4, true>>(s, a);
} }
else if(a.N <= 1024) else if(a.N <= 1024)
{ {
return a.N == 1024 return a.N == 1024
? layernorm2d_fwd_<trait_<ck_tile::fp16_t, 4, 64, 4, false>>(s, a) ? layernorm2d_fwd_<trait_<ck_tile::fp16_t, 4, 4, 64, 4, false>>(s, a)
: layernorm2d_fwd_<trait_<ck_tile::fp16_t, 4, 64, 4, true>>(s, a); : layernorm2d_fwd_<trait_<ck_tile::fp16_t, 4, 4, 64, 4, true>>(s, a);
} }
else if(a.N <= 2048) else if(a.N <= 2048)
{ {
return a.N == 2048 return a.N == 2048
? layernorm2d_fwd_<trait_<ck_tile::fp16_t, 8, 64, 4, false>>(s, a) ? layernorm2d_fwd_<trait_<ck_tile::fp16_t, 8, 4, 64, 4, false>>(s, a)
: layernorm2d_fwd_<trait_<ck_tile::fp16_t, 8, 64, 4, true>>(s, a); : layernorm2d_fwd_<trait_<ck_tile::fp16_t, 8, 4, 64, 4, true>>(s, a);
} }
else else
{ {
return a.N % 2048 == 0 return a.N % 2048 == 0
? layernorm2d_fwd_<trait_<ck_tile::fp16_t, 8, 64, 4, false, true>>(s, a) ? layernorm2d_fwd_<trait_<ck_tile::fp16_t, 8, 4, 64, 4, false, true>>(s,
: layernorm2d_fwd_<trait_<ck_tile::fp16_t, 8, 64, 4, true, true>>(s, a); a)
: layernorm2d_fwd_<trait_<ck_tile::fp16_t, 8, 4, 64, 4, true, true>>(s,
a);
} }
} }
else if(a.N % 2 == 0) else if(a.N % 2 == 0)
{ {
if(a.N <= 128) if(a.N <= 128)
{ {
return layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 64, 2, true>>(s, a); return layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 64, 2, true>>(s, a);
} }
else if(a.N <= 256) else if(a.N <= 256)
{ {
return layernorm2d_fwd_<trait_<ck_tile::fp16_t, 2, 64, 2, true>>(s, a); return layernorm2d_fwd_<trait_<ck_tile::fp16_t, 2, 4, 64, 2, true>>(s, a);
} }
else if(a.N <= 512) else if(a.N <= 512)
{ {
return layernorm2d_fwd_<trait_<ck_tile::fp16_t, 4, 64, 2, true>>(s, a); return layernorm2d_fwd_<trait_<ck_tile::fp16_t, 4, 4, 64, 2, true>>(s, a);
} }
else if(a.N <= 1024) else if(a.N <= 1024)
{ {
return layernorm2d_fwd_<trait_<ck_tile::fp16_t, 8, 64, 2, true>>(s, a); return layernorm2d_fwd_<trait_<ck_tile::fp16_t, 8, 4, 64, 2, true>>(s, a);
} }
else if(a.N <= 2048) else if(a.N <= 2048)
{ {
return layernorm2d_fwd_<trait_<ck_tile::fp16_t, 16, 64, 2, true>>(s, a); return layernorm2d_fwd_<trait_<ck_tile::fp16_t, 16, 4, 64, 2, true>>(s, a);
} }
else else
{ {
return layernorm2d_fwd_<trait_<ck_tile::fp16_t, 16, 64, 2, true, true>>(s, a); return layernorm2d_fwd_<trait_<ck_tile::fp16_t, 16, 4, 64, 2, true, true>>(s, a);
} }
} }
else else
{ {
return a.N <= 2048 return a.N <= 2048
? layernorm2d_fwd_<trait_<ck_tile::fp16_t, 32, 64, 1, true, false>>(s, a) ? layernorm2d_fwd_<trait_<ck_tile::fp16_t, 32, 4, 64, 1, true, false>>(s, a)
: layernorm2d_fwd_<trait_<ck_tile::fp16_t, 32, 64, 1, true, true>>(s, a); : layernorm2d_fwd_<trait_<ck_tile::fp16_t, 32, 4, 64, 1, true, true>>(s, a);
} }
} }
else if(t.data_type.compare("bf16") == 0) else if(t.data_type.compare("bf16") == 0)
...@@ -96,70 +108,75 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, ...@@ -96,70 +108,75 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
{ {
if(a.N <= 128) if(a.N <= 128)
{ {
return a.N == 128 ? layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 32, 4, false>>(s, a) return a.N == 128
: layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 32, 4, true>>(s, a); ? layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 32, 4, false>>(s, a)
: layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 32, 4, true>>(s, a);
} }
else if(a.N <= 256) else if(a.N <= 256)
{ {
return a.N == 256 ? layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 64, 4, false>>(s, a) return a.N == 256
: layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 64, 4, true>>(s, a); ? layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 64, 4, false>>(s, a)
: layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 64, 4, true>>(s, a);
} }
else if(a.N <= 512) else if(a.N <= 512)
{ {
return a.N == 512 ? layernorm2d_fwd_<trait_<ck_tile::bf16_t, 2, 64, 4, false>>(s, a) return a.N == 512
: layernorm2d_fwd_<trait_<ck_tile::bf16_t, 2, 64, 4, true>>(s, a); ? layernorm2d_fwd_<trait_<ck_tile::bf16_t, 2, 4, 64, 4, false>>(s, a)
: layernorm2d_fwd_<trait_<ck_tile::bf16_t, 2, 4, 64, 4, true>>(s, a);
} }
else if(a.N <= 1024) else if(a.N <= 1024)
{ {
return a.N == 1024 return a.N == 1024
? layernorm2d_fwd_<trait_<ck_tile::bf16_t, 4, 64, 4, false>>(s, a) ? layernorm2d_fwd_<trait_<ck_tile::bf16_t, 4, 4, 64, 4, false>>(s, a)
: layernorm2d_fwd_<trait_<ck_tile::bf16_t, 4, 64, 4, true>>(s, a); : layernorm2d_fwd_<trait_<ck_tile::bf16_t, 4, 4, 64, 4, true>>(s, a);
} }
else if(a.N <= 2048) else if(a.N <= 2048)
{ {
return a.N == 2048 return a.N == 2048
? layernorm2d_fwd_<trait_<ck_tile::bf16_t, 8, 64, 4, false>>(s, a) ? layernorm2d_fwd_<trait_<ck_tile::bf16_t, 8, 4, 64, 4, false>>(s, a)
: layernorm2d_fwd_<trait_<ck_tile::bf16_t, 8, 64, 4, true>>(s, a); : layernorm2d_fwd_<trait_<ck_tile::bf16_t, 8, 4, 64, 4, true>>(s, a);
} }
else else
{ {
return a.N % 2048 == 0 return a.N % 2048 == 0
? layernorm2d_fwd_<trait_<ck_tile::bf16_t, 8, 64, 4, false, true>>(s, a) ? layernorm2d_fwd_<trait_<ck_tile::bf16_t, 8, 4, 64, 4, false, true>>(s,
: layernorm2d_fwd_<trait_<ck_tile::bf16_t, 8, 64, 4, true, true>>(s, a); a)
: layernorm2d_fwd_<trait_<ck_tile::bf16_t, 8, 4, 64, 4, true, true>>(s,
a);
} }
} }
else if(a.N % 2 == 0) else if(a.N % 2 == 0)
{ {
if(a.N <= 128) if(a.N <= 128)
{ {
return layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 64, 2, true>>(s, a); return layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 64, 2, true>>(s, a);
} }
else if(a.N <= 256) else if(a.N <= 256)
{ {
return layernorm2d_fwd_<trait_<ck_tile::bf16_t, 2, 64, 2, true>>(s, a); return layernorm2d_fwd_<trait_<ck_tile::bf16_t, 2, 4, 64, 2, true>>(s, a);
} }
else if(a.N <= 512) else if(a.N <= 512)
{ {
return layernorm2d_fwd_<trait_<ck_tile::bf16_t, 4, 64, 2, true>>(s, a); return layernorm2d_fwd_<trait_<ck_tile::bf16_t, 4, 4, 64, 2, true>>(s, a);
} }
else if(a.N <= 1024) else if(a.N <= 1024)
{ {
return layernorm2d_fwd_<trait_<ck_tile::bf16_t, 8, 64, 2, true>>(s, a); return layernorm2d_fwd_<trait_<ck_tile::bf16_t, 8, 4, 64, 2, true>>(s, a);
} }
else if(a.N <= 2048) else if(a.N <= 2048)
{ {
return layernorm2d_fwd_<trait_<ck_tile::bf16_t, 16, 64, 2, true>>(s, a); return layernorm2d_fwd_<trait_<ck_tile::bf16_t, 16, 4, 64, 2, true>>(s, a);
} }
else else
{ {
return layernorm2d_fwd_<trait_<ck_tile::bf16_t, 16, 64, 2, true, true>>(s, a); return layernorm2d_fwd_<trait_<ck_tile::bf16_t, 16, 4, 64, 2, true, true>>(s, a);
} }
} }
else else
{ {
return a.N <= 2048 return a.N <= 2048
? layernorm2d_fwd_<trait_<ck_tile::bf16_t, 32, 64, 1, true, false>>(s, a) ? layernorm2d_fwd_<trait_<ck_tile::bf16_t, 32, 4, 64, 1, true, false>>(s, a)
: layernorm2d_fwd_<trait_<ck_tile::bf16_t, 32, 64, 1, true, true>>(s, a); : layernorm2d_fwd_<trait_<ck_tile::bf16_t, 32, 4, 64, 1, true, true>>(s, a);
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment