"docs/vscode:/vscode.git/clone" did not exist on "08b2812ce4d615a494e1a1c29a889dd6d15b2b48"
Commit 98395085 authored by rocking's avatar rocking
Browse files

Add kMThreadPerBlock to template parameter

parent 03247367
......@@ -5,28 +5,30 @@
#include "layernorm2d_fwd_instance_common.hpp"
template <ck_tile::index_t NRepeat,
ck_tile::index_t NThread,
ck_tile::index_t kMThreadPerBlock,
ck_tile::index_t kNThreadPerBlock,
ck_tile::index_t VectorAccessSize,
bool kTwoPass>
using t = layernorm2d_fwd_traits_<ck_tile::bf16_t,
NRepeat,
NThread,
kMThreadPerBlock,
kNThreadPerBlock,
VectorAccessSize,
false,
false,
kTwoPass>;
// Disable all vector 8fp16 read/write instances as it has performance issue regarding compiler
// template float layernorm2d_fwd_<t<1, 16, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<1, 32, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<1, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<2, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<4, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<4, 64, 8, true>>(const S&, A);
// template float layernorm2d_fwd_<t<1, 4, 16, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<1, 4, 32, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<1, 4, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<2, 4, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<4, 4, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<4, 4, 64, 8, true>>(const S&, A);
template float layernorm2d_fwd_<t<1, 32, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<1, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<2, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<4, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<8, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<8, 64, 4, true>>(const S&, A);
template float layernorm2d_fwd_<t<1, 4, 32, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<1, 4, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<2, 4, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<4, 4, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<8, 4, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<8, 4, 64, 4, true>>(const S&, A);
......@@ -5,12 +5,14 @@
#include "layernorm2d_fwd_instance_common.hpp"
template <ck_tile::index_t NRepeat,
ck_tile::index_t NThread,
ck_tile::index_t kMThreadPerBlock,
ck_tile::index_t kNThreadPerBlock,
ck_tile::index_t VectorAccessSize,
bool kTwoPass>
using t = layernorm2d_fwd_traits_<ck_tile::bf16_t,
NRepeat,
NThread,
kMThreadPerBlock,
kNThreadPerBlock,
VectorAccessSize,
true,
false,
......@@ -24,19 +26,19 @@ using t = layernorm2d_fwd_traits_<ck_tile::bf16_t,
// template float layernorm2d_fwd_<t<4, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<4, 64, 8, true>>(const S&, A);
template float layernorm2d_fwd_<t<1, 32, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<1, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<2, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<4, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<8, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<8, 64, 4, true>>(const S&, A);
template float layernorm2d_fwd_<t<1, 4, 32, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<1, 4, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<2, 4, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<4, 4, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<8, 4, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<8, 4, 64, 4, true>>(const S&, A);
template float layernorm2d_fwd_<t<1, 64, 2, false>>(const S&, A);
template float layernorm2d_fwd_<t<2, 64, 2, false>>(const S&, A);
template float layernorm2d_fwd_<t<4, 64, 2, false>>(const S&, A);
template float layernorm2d_fwd_<t<8, 64, 2, false>>(const S&, A);
template float layernorm2d_fwd_<t<16, 64, 2, false>>(const S&, A);
template float layernorm2d_fwd_<t<16, 64, 2, true>>(const S&, A);
template float layernorm2d_fwd_<t<1, 4, 64, 2, false>>(const S&, A);
template float layernorm2d_fwd_<t<2, 4, 64, 2, false>>(const S&, A);
template float layernorm2d_fwd_<t<4, 4, 64, 2, false>>(const S&, A);
template float layernorm2d_fwd_<t<8, 4, 64, 2, false>>(const S&, A);
template float layernorm2d_fwd_<t<16, 4, 64, 2, false>>(const S&, A);
template float layernorm2d_fwd_<t<16, 4, 64, 2, true>>(const S&, A);
template float layernorm2d_fwd_<t<32, 64, 1, false>>(const S&, A);
template float layernorm2d_fwd_<t<32, 64, 1, true>>(const S&, A);
template float layernorm2d_fwd_<t<32, 4, 64, 1, false>>(const S&, A);
template float layernorm2d_fwd_<t<32, 4, 64, 1, true>>(const S&, A);
......@@ -5,28 +5,30 @@
#include "layernorm2d_fwd_instance_common.hpp"
template <ck_tile::index_t NRepeat,
ck_tile::index_t NThread,
ck_tile::index_t kMThreadPerBlock,
ck_tile::index_t kNThreadPerBlock,
ck_tile::index_t VectorAccessSize,
bool kTwoPass>
using t = layernorm2d_fwd_traits_<ck_tile::fp16_t,
NRepeat,
NThread,
kMThreadPerBlock,
kNThreadPerBlock,
VectorAccessSize,
false,
false,
kTwoPass>;
// Disable all vector 8fp16 read/write instances as it has performance issue regarding compiler
// template float layernorm2d_fwd_<t<1, 16, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<1, 32, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<1, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<2, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<4, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<4, 64, 8, true>>(const S&, A);
// template float layernorm2d_fwd_<t<1, 4, 16, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<1, 4, 32, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<1, 4, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<2, 4, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<4, 4, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<4, 4, 64, 8, true>>(const S&, A);
template float layernorm2d_fwd_<t<1, 32, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<1, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<2, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<4, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<8, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<8, 64, 4, true>>(const S&, A);
template float layernorm2d_fwd_<t<1, 4, 32, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<1, 4, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<2, 4, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<4, 4, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<8, 4, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<8, 4, 64, 4, true>>(const S&, A);
......@@ -5,38 +5,40 @@
#include "layernorm2d_fwd_instance_common.hpp"
template <ck_tile::index_t NRepeat,
ck_tile::index_t NThread,
ck_tile::index_t kMThreadPerBlock,
ck_tile::index_t kNThreadPerBlock,
ck_tile::index_t VectorAccessSize,
bool kTwoPass>
using t = layernorm2d_fwd_traits_<ck_tile::fp16_t,
NRepeat,
NThread,
kMThreadPerBlock,
kNThreadPerBlock,
VectorAccessSize,
true,
false,
kTwoPass>;
// Disable all vector 8fp16 read/write instances as it has performance issue regarding compiler
// template float layernorm2d_fwd_<t<1, 16, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<1, 32, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<1, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<2, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<4, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<4, 64, 8, true>>(const S&, A);
// template float layernorm2d_fwd_<t<1, 4, 16, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<1, 4, 32, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<1, 4, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<2, 4, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<4, 4, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<4, 4, 64, 8, true>>(const S&, A);
template float layernorm2d_fwd_<t<1, 32, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<1, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<2, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<4, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<8, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<8, 64, 4, true>>(const S&, A);
template float layernorm2d_fwd_<t<1, 4, 32, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<1, 4, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<2, 4, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<4, 4, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<8, 4, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<8, 4, 64, 4, true>>(const S&, A);
template float layernorm2d_fwd_<t<1, 64, 2, false>>(const S&, A);
template float layernorm2d_fwd_<t<2, 64, 2, false>>(const S&, A);
template float layernorm2d_fwd_<t<4, 64, 2, false>>(const S&, A);
template float layernorm2d_fwd_<t<8, 64, 2, false>>(const S&, A);
template float layernorm2d_fwd_<t<16, 64, 2, false>>(const S&, A);
template float layernorm2d_fwd_<t<16, 64, 2, true>>(const S&, A);
template float layernorm2d_fwd_<t<1, 4, 64, 2, false>>(const S&, A);
template float layernorm2d_fwd_<t<2, 4, 64, 2, false>>(const S&, A);
template float layernorm2d_fwd_<t<4, 4, 64, 2, false>>(const S&, A);
template float layernorm2d_fwd_<t<8, 4, 64, 2, false>>(const S&, A);
template float layernorm2d_fwd_<t<16, 4, 64, 2, false>>(const S&, A);
template float layernorm2d_fwd_<t<16, 4, 64, 2, true>>(const S&, A);
template float layernorm2d_fwd_<t<32, 64, 1, false>>(const S&, A);
template float layernorm2d_fwd_<t<32, 64, 1, true>>(const S&, A);
template float layernorm2d_fwd_<t<32, 4, 64, 1, false>>(const S&, A);
template float layernorm2d_fwd_<t<32, 4, 64, 1, true>>(const S&, A);
......@@ -52,7 +52,8 @@ struct layernorm2d_fwd_args
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template <typename DataType_,
ck_tile::index_t NRepeat,
ck_tile::index_t NThread,
ck_tile::index_t kMThreadPerBlock,
ck_tile::index_t kNThreadPerBlock,
ck_tile::index_t VectorAccessSize,
bool kPadN_,
bool kSaveMeanInvStd_,
......@@ -62,14 +63,17 @@ struct layernorm2d_fwd_traits_
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr ck_tile::index_t MRepeat = 1;
static_assert(NThread <= 64, "We only support intra-wave reduction");
static constexpr ck_tile::index_t WaveNum = NThread / 16;
static_assert(kNThreadPerBlock <= 64, "We only support intra-wave reduction");
static constexpr ck_tile::index_t kNWarpPerBlock = 1;
static constexpr ck_tile::index_t kMWarpPerBlock =
kMThreadPerBlock * kNThreadPerBlock / warpSize;
// kNThreadPerBlock / 16;
using thread_tile = ck_tile::sequence<MRepeat, NRepeat, VectorAccessSize>;
using warp_tile =
ck_tile::sequence<MRepeat * 64 / NThread, NRepeat * NThread * VectorAccessSize>;
using block_tile =
ck_tile::sequence<MRepeat * WaveNum * 64 / NThread, NRepeat * NThread * VectorAccessSize>;
using warp_tile = ck_tile::sequence<MRepeat * warpSize / kNThreadPerBlock,
NRepeat * kNThreadPerBlock * VectorAccessSize>;
using block_tile = ck_tile::sequence<kMWarpPerBlock * MRepeat * warpSize / kNThreadPerBlock,
NRepeat * kNThreadPerBlock * VectorAccessSize>;
using Shape = ck_tile::TileLayernorm2dShape<thread_tile, warp_tile, block_tile>;
......
......@@ -6,12 +6,19 @@
template <typename DataType,
ck_tile::index_t NRepeat,
ck_tile::index_t NThread,
ck_tile::index_t kMThreadPerBlock,
ck_tile::index_t kNThreadPerBlock,
ck_tile::index_t VectorAccessSize,
bool kPadN,
bool kTwoPass = false>
using trait_ =
layernorm2d_fwd_traits_<DataType, NRepeat, NThread, VectorAccessSize, kPadN, false, kTwoPass>;
using trait_ = layernorm2d_fwd_traits_<DataType,
NRepeat,
kMThreadPerBlock,
kNThreadPerBlock,
VectorAccessSize,
kPadN,
false,
kTwoPass>;
float layernorm2d_fwd(layernorm2d_fwd_traits t,
layernorm2d_fwd_args a,
......@@ -24,70 +31,75 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
{
if(a.N <= 128)
{
return a.N == 128 ? layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 32, 4, false>>(s, a)
: layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 32, 4, true>>(s, a);
return a.N == 128
? layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 32, 4, false>>(s, a)
: layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 32, 4, true>>(s, a);
}
else if(a.N <= 256)
{
return a.N == 256 ? layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 64, 4, false>>(s, a)
: layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 64, 4, true>>(s, a);
return a.N == 256
? layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 64, 4, false>>(s, a)
: layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 64, 4, true>>(s, a);
}
else if(a.N <= 512)
{
return a.N == 512 ? layernorm2d_fwd_<trait_<ck_tile::fp16_t, 2, 64, 4, false>>(s, a)
: layernorm2d_fwd_<trait_<ck_tile::fp16_t, 2, 64, 4, true>>(s, a);
return a.N == 512
? layernorm2d_fwd_<trait_<ck_tile::fp16_t, 2, 4, 64, 4, false>>(s, a)
: layernorm2d_fwd_<trait_<ck_tile::fp16_t, 2, 4, 64, 4, true>>(s, a);
}
else if(a.N <= 1024)
{
return a.N == 1024
? layernorm2d_fwd_<trait_<ck_tile::fp16_t, 4, 64, 4, false>>(s, a)
: layernorm2d_fwd_<trait_<ck_tile::fp16_t, 4, 64, 4, true>>(s, a);
? layernorm2d_fwd_<trait_<ck_tile::fp16_t, 4, 4, 64, 4, false>>(s, a)
: layernorm2d_fwd_<trait_<ck_tile::fp16_t, 4, 4, 64, 4, true>>(s, a);
}
else if(a.N <= 2048)
{
return a.N == 2048
? layernorm2d_fwd_<trait_<ck_tile::fp16_t, 8, 64, 4, false>>(s, a)
: layernorm2d_fwd_<trait_<ck_tile::fp16_t, 8, 64, 4, true>>(s, a);
? layernorm2d_fwd_<trait_<ck_tile::fp16_t, 8, 4, 64, 4, false>>(s, a)
: layernorm2d_fwd_<trait_<ck_tile::fp16_t, 8, 4, 64, 4, true>>(s, a);
}
else
{
return a.N % 2048 == 0
? layernorm2d_fwd_<trait_<ck_tile::fp16_t, 8, 64, 4, false, true>>(s, a)
: layernorm2d_fwd_<trait_<ck_tile::fp16_t, 8, 64, 4, true, true>>(s, a);
? layernorm2d_fwd_<trait_<ck_tile::fp16_t, 8, 4, 64, 4, false, true>>(s,
a)
: layernorm2d_fwd_<trait_<ck_tile::fp16_t, 8, 4, 64, 4, true, true>>(s,
a);
}
}
else if(a.N % 2 == 0)
{
if(a.N <= 128)
{
return layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 64, 2, true>>(s, a);
return layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 64, 2, true>>(s, a);
}
else if(a.N <= 256)
{
return layernorm2d_fwd_<trait_<ck_tile::fp16_t, 2, 64, 2, true>>(s, a);
return layernorm2d_fwd_<trait_<ck_tile::fp16_t, 2, 4, 64, 2, true>>(s, a);
}
else if(a.N <= 512)
{
return layernorm2d_fwd_<trait_<ck_tile::fp16_t, 4, 64, 2, true>>(s, a);
return layernorm2d_fwd_<trait_<ck_tile::fp16_t, 4, 4, 64, 2, true>>(s, a);
}
else if(a.N <= 1024)
{
return layernorm2d_fwd_<trait_<ck_tile::fp16_t, 8, 64, 2, true>>(s, a);
return layernorm2d_fwd_<trait_<ck_tile::fp16_t, 8, 4, 64, 2, true>>(s, a);
}
else if(a.N <= 2048)
{
return layernorm2d_fwd_<trait_<ck_tile::fp16_t, 16, 64, 2, true>>(s, a);
return layernorm2d_fwd_<trait_<ck_tile::fp16_t, 16, 4, 64, 2, true>>(s, a);
}
else
{
return layernorm2d_fwd_<trait_<ck_tile::fp16_t, 16, 64, 2, true, true>>(s, a);
return layernorm2d_fwd_<trait_<ck_tile::fp16_t, 16, 4, 64, 2, true, true>>(s, a);
}
}
else
{
return a.N <= 2048
? layernorm2d_fwd_<trait_<ck_tile::fp16_t, 32, 64, 1, true, false>>(s, a)
: layernorm2d_fwd_<trait_<ck_tile::fp16_t, 32, 64, 1, true, true>>(s, a);
? layernorm2d_fwd_<trait_<ck_tile::fp16_t, 32, 4, 64, 1, true, false>>(s, a)
: layernorm2d_fwd_<trait_<ck_tile::fp16_t, 32, 4, 64, 1, true, true>>(s, a);
}
}
else if(t.data_type.compare("bf16") == 0)
......@@ -96,70 +108,75 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
{
if(a.N <= 128)
{
return a.N == 128 ? layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 32, 4, false>>(s, a)
: layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 32, 4, true>>(s, a);
return a.N == 128
? layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 32, 4, false>>(s, a)
: layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 32, 4, true>>(s, a);
}
else if(a.N <= 256)
{
return a.N == 256 ? layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 64, 4, false>>(s, a)
: layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 64, 4, true>>(s, a);
return a.N == 256
? layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 64, 4, false>>(s, a)
: layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 64, 4, true>>(s, a);
}
else if(a.N <= 512)
{
return a.N == 512 ? layernorm2d_fwd_<trait_<ck_tile::bf16_t, 2, 64, 4, false>>(s, a)
: layernorm2d_fwd_<trait_<ck_tile::bf16_t, 2, 64, 4, true>>(s, a);
return a.N == 512
? layernorm2d_fwd_<trait_<ck_tile::bf16_t, 2, 4, 64, 4, false>>(s, a)
: layernorm2d_fwd_<trait_<ck_tile::bf16_t, 2, 4, 64, 4, true>>(s, a);
}
else if(a.N <= 1024)
{
return a.N == 1024
? layernorm2d_fwd_<trait_<ck_tile::bf16_t, 4, 64, 4, false>>(s, a)
: layernorm2d_fwd_<trait_<ck_tile::bf16_t, 4, 64, 4, true>>(s, a);
? layernorm2d_fwd_<trait_<ck_tile::bf16_t, 4, 4, 64, 4, false>>(s, a)
: layernorm2d_fwd_<trait_<ck_tile::bf16_t, 4, 4, 64, 4, true>>(s, a);
}
else if(a.N <= 2048)
{
return a.N == 2048
? layernorm2d_fwd_<trait_<ck_tile::bf16_t, 8, 64, 4, false>>(s, a)
: layernorm2d_fwd_<trait_<ck_tile::bf16_t, 8, 64, 4, true>>(s, a);
? layernorm2d_fwd_<trait_<ck_tile::bf16_t, 8, 4, 64, 4, false>>(s, a)
: layernorm2d_fwd_<trait_<ck_tile::bf16_t, 8, 4, 64, 4, true>>(s, a);
}
else
{
return a.N % 2048 == 0
? layernorm2d_fwd_<trait_<ck_tile::bf16_t, 8, 64, 4, false, true>>(s, a)
: layernorm2d_fwd_<trait_<ck_tile::bf16_t, 8, 64, 4, true, true>>(s, a);
? layernorm2d_fwd_<trait_<ck_tile::bf16_t, 8, 4, 64, 4, false, true>>(s,
a)
: layernorm2d_fwd_<trait_<ck_tile::bf16_t, 8, 4, 64, 4, true, true>>(s,
a);
}
}
else if(a.N % 2 == 0)
{
if(a.N <= 128)
{
return layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 64, 2, true>>(s, a);
return layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 64, 2, true>>(s, a);
}
else if(a.N <= 256)
{
return layernorm2d_fwd_<trait_<ck_tile::bf16_t, 2, 64, 2, true>>(s, a);
return layernorm2d_fwd_<trait_<ck_tile::bf16_t, 2, 4, 64, 2, true>>(s, a);
}
else if(a.N <= 512)
{
return layernorm2d_fwd_<trait_<ck_tile::bf16_t, 4, 64, 2, true>>(s, a);
return layernorm2d_fwd_<trait_<ck_tile::bf16_t, 4, 4, 64, 2, true>>(s, a);
}
else if(a.N <= 1024)
{
return layernorm2d_fwd_<trait_<ck_tile::bf16_t, 8, 64, 2, true>>(s, a);
return layernorm2d_fwd_<trait_<ck_tile::bf16_t, 8, 4, 64, 2, true>>(s, a);
}
else if(a.N <= 2048)
{
return layernorm2d_fwd_<trait_<ck_tile::bf16_t, 16, 64, 2, true>>(s, a);
return layernorm2d_fwd_<trait_<ck_tile::bf16_t, 16, 4, 64, 2, true>>(s, a);
}
else
{
return layernorm2d_fwd_<trait_<ck_tile::bf16_t, 16, 64, 2, true, true>>(s, a);
return layernorm2d_fwd_<trait_<ck_tile::bf16_t, 16, 4, 64, 2, true, true>>(s, a);
}
}
else
{
return a.N <= 2048
? layernorm2d_fwd_<trait_<ck_tile::bf16_t, 32, 64, 1, true, false>>(s, a)
: layernorm2d_fwd_<trait_<ck_tile::bf16_t, 32, 64, 1, true, true>>(s, a);
? layernorm2d_fwd_<trait_<ck_tile::bf16_t, 32, 4, 64, 1, true, false>>(s, a)
: layernorm2d_fwd_<trait_<ck_tile::bf16_t, 32, 4, 64, 1, true, true>>(s, a);
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment