Unverified Commit 9533a172 authored by Illia Silin's avatar Illia Silin Committed by GitHub
Browse files

Merge branch 'develop' into codegen-enable-hiprtc

parents c2cf0733 50ee4267
......@@ -165,6 +165,8 @@ struct fmha_fwd_splitkv_args
void* block_table_ptr;
ck_tile::index_t batch_stride_block_table; // only used if 'block_table_ptr' is not nullptr
ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr
bool is_gappy; // differentiate seqstart_k_ptr usage. only used if 'block_table_ptr' is not
// nullptr.
const void* cache_batch_idx;
......@@ -173,9 +175,21 @@ struct fmha_fwd_splitkv_args
// seqlen_k = kargs.seqlen_k
// group mode: seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b]
// seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b]
// kvcache mode (use same kernel as batch mode):
// or kargs.seqlen_k_ptr[b]
//
// batch mode (kvcache):
// seqlen_q = kargs.seqlen_q
// seqlen_k = kargs.seqlen_k_ptr[b]
// group mode (kvcache):
// seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b]
//
// when is_gappy=true:
// seqlen_k = kargs.seqlen_k_ptr[b]
// seqstart_k_ptr[b] now store local offset of each batch
//
// when is_gappy=false:
// seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b]
// or kargs.seqlen_k_ptr[b]
const void* seqstart_q_ptr;
const void* seqstart_k_ptr;
const void* seqlen_k_ptr;
......@@ -251,7 +265,7 @@ struct fmha_fwd_appendkv_args
ck_tile::index_t batch_stride_block_table; // only used if 'block_table_ptr' is not nullptr
ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr
const void* cache_batch_idx;
const void* cache_batch_idx; // only used if block_table_ptr is nullptr -> batch mode (kvcache)
ck_tile::index_t stride_q;
ck_tile::index_t stride_k;
......@@ -278,7 +292,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
// create group mode kernel arguments
if constexpr(FmhaKernel::kIsGroupMode)
{
return FmhaKernel::MakeKargs(args.q_ptr,
return FmhaKernel::MakeKargsImpl(args.q_ptr,
args.k_ptr,
args.v_ptr,
args.bias_ptr,
......@@ -317,7 +331,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
}
else
{ // create batch mode kernel arguments
return FmhaKernel::MakeKargs(args.q_ptr,
return FmhaKernel::MakeKargsImpl(args.q_ptr,
args.k_ptr,
args.v_ptr,
args.bias_ptr,
......@@ -389,6 +403,10 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
args.nhead_q,
args.nhead_q / args.nhead_k,
args.num_splits,
args.block_table_ptr,
args.batch_stride_block_table,
args.page_block_size,
args.is_gappy,
args.scale_s,
args.scale_p,
args.stride_q,
......
......@@ -47,6 +47,9 @@ def list_blobs(output_file : Optional[str], api_list : List[str], kernel_filter
assert output_file is not None
file_path = Path(output_file)
# create an empty file / drop its contents if it exists
open(file_path, "w").close()
for api in api_list:
handler = handlers[api][HandlerId.LIST_BLOBS]
handler(file_path, kernel_filter, receipt, mask_impl)
......
......@@ -29,14 +29,14 @@ while getopts ":sa" opt; do
done
run_fp16_bf16_tests() {
local NUM_SPLITS=(1)
local PAGE_BLOCK_SIZE=(0)
local CACHE_BATCH_IDX=(0)
local NUM_SPLITS="1"
local PAGE_BLOCK_SIZE="0"
local CACHE_BATCH_IDX="0"
if [ $TEST_SPLITKV -eq 1 ] ; then
NUM_SPLITS+=(2 3)
PAGE_BLOCK_SIZE+=(128)
CACHE_BATCH_IDX+=(1)
NUM_SPLITS="$NUM_SPLITS 2 3"
PAGE_BLOCK_SIZE="$PAGE_BLOCK_SIZE 128"
CACHE_BATCH_IDX="$CACHE_BATCH_IDX 1"
fi
for prec in "fp16" "bf16" ; do
......@@ -47,9 +47,9 @@ run_fp16_bf16_tests() {
for lse in 0 1 ; do
for bias in "n" "e" "a" ; do
for p_drop in 0.0 0.2 ; do
for num_splits in "${NUM_SPLITS[@]}" ; do
for page_block_size in "${PAGE_BLOCK_SIZE[@]}" ; do
for cache_batch_idx in "${CACHE_BATCH_IDX[@]}" ; do
for num_splits in $NUM_SPLITS ; do
for page_block_size in $PAGE_BLOCK_SIZE ; do
for cache_batch_idx in $CACHE_BATCH_IDX ; do
# $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
......
......@@ -145,7 +145,7 @@ decode_seqlen(mode_enum mode,
std::string k_val,
std::string k_pad_val,
ck_tile::index_t seqlen_k_min = 0,
bool use_kvcache = false,
bool need_append_kvcache = false,
std::optional<unsigned> seed = std::nullopt)
{
#define _S2I_(str_) static_cast<ck_tile::index_t>(std::atoi((str_).c_str()))
......@@ -159,7 +159,7 @@ decode_seqlen(mode_enum mode,
const ck_tile::index_t seqlen_k_max = (k < 0 ? q : k);
std::vector<ck_tile::index_t> seqlen_ks(batch, seqlen_k_max);
if(1 < batch && use_kvcache)
if(1 < batch && need_append_kvcache)
{
// to keep the original s_k value, we always use seqlen_k_max in first batch
randints(std::next(seqlen_ks.begin()),
......
......@@ -69,7 +69,7 @@ args:
```
## limitations
Note that `fquant=2`, `fadd=2`, `prec_sx/prec_sy` other than `fp32` are not by default generated. though our kernel template suppor this. (TBD: add some flag in generate.py) to generate those instance on demand. Beside, N>8192 case will by default using two-pass pipeline, and `-fquant=1/2` are not supported yet.
Note that `fquant=2`, `fadd=2`, `prec_sx/prec_sy` other than `fp32` are not by default generated. Though our kernel template suppor this. (TBD: add some flag in generate.py) to generate those instance on demand. Beside, `N>8192` case will by default using two-pass pipeline, and `-fquant=1/2` are not supported yet. If need suport `N>8192` and `fused+residual+store`, you can use this example together with `12_smoothquant`, to construct layernorm+residual, and smoothquant, 2 kernels for this purpose.
```
# some case
......
......@@ -57,6 +57,7 @@ template <typename XDataType_,
ck_tile::index_t Vector_N_, // vector size along N
bool kPadN_,
bool kSaveMeanInvStd_,
bool kFastFDiv_,
bool kTwoPass_,
ck_tile::index_t kFusedAdd_ = 0,
ck_tile::index_t kFusedQuant_ = 0>
......@@ -118,6 +119,7 @@ struct layernorm2d_fwd_traits_
static constexpr bool kPadN = kPadN_;
static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_;
static constexpr bool kFastFDiv = kFastFDiv_;
static constexpr bool kTwoPass = kTwoPass_;
static constexpr ck_tile::index_t kFusedAdd = kFusedAdd_;
static constexpr ck_tile::index_t kFusedQuant = kFusedQuant_;
......@@ -134,6 +136,7 @@ template <typename XDataType_,
ck_tile::index_t Vector_N_, // vector size along N
bool kPadN_,
bool kSaveMeanInvStd_,
bool kFastFDiv_,
bool kTwoPass_,
int kFusedAdd_,
int kFusedQuant_>
......@@ -148,6 +151,7 @@ using traits_ = layernorm2d_fwd_traits_<XDataType_,
Vector_N_,
kPadN_,
kSaveMeanInvStd_,
kFastFDiv_,
kTwoPass_,
kFusedAdd_,
kFusedQuant_>;
......@@ -179,6 +183,7 @@ float layernorm2d_fwd_(const S& s, A a)
using PipelineTraits = ck_tile::Layernorm2dFwdTraits<Traits_::kPadN,
Traits_::kSaveMeanInvStd,
Traits_::kFastFDiv,
Traits_::kTwoPass,
static_cast<ck_tile::Layernorm2dFusedAddEnum>(Traits_::kFusedAdd),
static_cast<ck_tile::Layernorm2dFusedQuantEnum>(Traits_::kFusedQuant)>;
......@@ -202,8 +207,9 @@ float layernorm2d_fwd_(const S& s, A a)
using Default2DEpilogueProblem = ck_tile::Default2DEpilogueProblem<ComputeDataType, YDataType, false, Traits_::kPadN, false>;
using Default2DEpilogue = ck_tile::Default2DEpilogue<Default2DEpilogueProblem>;
using DynamicQuantEpilogueProblem = ck_tile::DynamicQuantEpilogueProblem<ComputeDataType, YScaleDataType, YDataType, typename Traits_::Shape,
ck_tile::DynamicQuantEpilogueTraits<false, Traits_::kPadN, false, true/*max3*/>>;
static constexpr bool UseSmoothInputScale = Traits_::kFusedQuant == 1;
using DynamicQuantEpilogueProblem = ck_tile::DynamicQuantEpilogueProblem<ComputeDataType, XScaleDataType, YScaleDataType, YDataType, typename Traits_::Shape,
ck_tile::DynamicQuantEpilogueTraits<false, Traits_::kPadN, UseSmoothInputScale, false, true/*max3*/>>;
using DynamicQuantEpilogue = ck_tile::DynamicQuantEpilogue<DynamicQuantEpilogueProblem>;
......@@ -268,7 +274,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
#include "layernorm2d_fwd_api_common.hpp"
// clang-format off
// prec_i prec_o prec_sy rm rn tm tn vn pd mv 2p add sweep
// prec_i prec_o prec_sy rm rn tm tn vn pd mv rpcf 2p add sweep
{F_instance_def}
// clang-format on
......@@ -355,6 +361,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
F_Vector_N : int
F_kPadN : bool
F_kSaveMeanInvStd_ : bool
F_kFastFDiv_ : bool
F_kTwoPass_ : bool
F_kFusedAdd : int
F_kFusedQuant : int
......@@ -362,7 +369,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
@property
def trait_name(self) ->str:
t_ = f'{DATA_TYPE_MAP[self.F_XDataType]}, {DATA_TYPE_MAP[self.F_YDataType]}, {DATA_TYPE_MAP[self.F_XScaleDataType]}, {DATA_TYPE_MAP[self.F_YScaleDataType]}, {self.F_Repeat_M:2}, {self.F_Repeat_N:2}, {self.F_ThreadPerBlock_M:2}, {self.F_ThreadPerBlock_N:4}'
t_ += f', {self.F_Vector_N:2}, {BOOL_MAP(self.F_kPadN):5}, {BOOL_MAP(self.F_kSaveMeanInvStd_):5}'
t_ += f', {self.F_Vector_N:2}, {BOOL_MAP(self.F_kPadN):5}, {BOOL_MAP(self.F_kSaveMeanInvStd_):5}, {BOOL_MAP(self.F_kFastFDiv_):5}'
t_ += f', {BOOL_MAP(self.F_kTwoPass_):5}, {self.F_kFusedAdd:4}, {self.F_kFusedQuant:4}'
return t_
......@@ -482,52 +489,55 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
fused_add_list = [0, 1]
fused_sweep_list = [0, 1] # NOTE: only single pass can use fused dynamic quant
# rm rn tm tn vn pd mv 2p add sweep
h_trait_dict = {'64' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 1, True, False, False, 0, 0)],
'128' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 2, True, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 1, True, False, False, 0, 0)],
'256' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 4, True, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 2, True, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 1, True, False, False, 0, 0)],
'512' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 8, True, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 4, True, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 2, True, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 8, 4, 64, 1, True, False, False, 0, 0)],
'768' : [ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 4, True, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 6, 4, 64, 2, True, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 12, 4, 64, 1, True, False, False, 0, 0)],
'1024' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 2, 128, 8, True, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 2, 128, 4, True, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 2, 128, 2, True, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 1, True, False, False, 0, 0)],
'1536' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 8, True, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 2, 128, 4, True, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 2, True, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 1, True, False, False, 0, 0)],
'2048' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 1, 256, 8, True, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 4, True, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 2, True, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 8, 1, 256, 1, True, False, False, 0, 0)],
'3072' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 128, 8, True, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 4, True, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 2, True, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 1, True, False, False, 0, 0)],
'4096' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, False, 0, 0)],
'6144' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 8, True, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 512, 4, True, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 2, True, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 6, 1,1024, 1, True, False, False, 0, 0)],
'8192' :[ h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 8, True, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 512, 4, True, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 2, True, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 8, 1,1024, 1, True, False, False, 0, 0)],
'big' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, 0, 0)]}
# rm rn tm tn vn pd mv fdiv 2p add sweep
h_trait_dict = {'64' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 8, 8, 8, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 4, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 1, True, False, True, False, 0, 0)],
'128' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 8, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 2, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 1, True, False, True, False, 0, 0)],
'256' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 4, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 2, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 1, True, False, True, False, 0, 0)],
'512' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 8, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 4, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 2, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 8, 4, 64, 1, True, False, True, False, 0, 0)],
'768' : [ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 4, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 6, 4, 64, 2, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 12, 4, 64, 1, True, False, True, False, 0, 0)],
'1024' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 2, 128, 8, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 2, 128, 4, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 2, 128, 2, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 1, True, False, True, False, 0, 0)],
'1536' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 8, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 2, 128, 4, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 2, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 1, True, False, True, False, 0, 0)],
'2048' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 1, 256, 8, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 4, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 2, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 8, 1, 256, 1, True, False, True, False, 0, 0)],
'3072' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 128, 8, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 4, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 2, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 1, True, False, True, False, 0, 0)],
'4096' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, False, 0, 0)],
'6144' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 8, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 512, 4, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 2, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 6, 1,1024, 1, True, False, True, False, 0, 0)],
'8192' :[ h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 8, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 512, 4, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 2, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 8, 1,1024, 1, True, False, True, False, 0, 0)],
'big' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, True, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, True, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, True, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, True, 0, 0)]}
total_blob = list()
for hs_key in h_trait_dict:
hs = h_trait_dict[hs_key]
......@@ -558,7 +568,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
w_p = Path(self.working_path)
list_p = w_p / 'layernorm2d_fwd_blobs.txt'
blobs = self.get_blobs()
with list_p.open('a') as list_f:
with list_p.open('w') as list_f:
# api related file
list_f.write(str(w_p / (self.name_api + ".cpp")) + "\n")
list_f.write(str(w_p / (self.name_common_header + ".hpp")) + "\n")
......
......@@ -25,7 +25,10 @@ auto create_args(int argc, char* argv[])
ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "3328", "m dimension")
.insert("n", "4096", "n dimension")
.insert("stride", "-1", "stride per row, if -1 then equal to n")
.insert("x_stride", "-1", "x row_stride, if -1 then equal to n")
.insert("xr_stride", "-1", "x residule row_stride, if -1 then equal to n")
.insert("y_stride", "-1", "y row_stride, if -1 then equal to n")
.insert("yr_stride", "-1", "y residule row_stride, if -1 then equal to n")
.insert("e", "1e-5", "epsilon")
.insert("save_mv", "0", "save mean/variance(invstd) or not. set to 1 in training case")
.insert("v", "1", "cpu validation or not")
......@@ -56,9 +59,18 @@ bool run(const ck_tile::ArgParser& arg_parser)
{
ck_tile::index_t m = arg_parser.get_int("m");
ck_tile::index_t n = arg_parser.get_int("n");
ck_tile::index_t stride = arg_parser.get_int("stride");
if(stride < 0)
stride = n;
ck_tile::index_t x_stride = arg_parser.get_int("x_stride");
if(x_stride < 0)
x_stride = n;
ck_tile::index_t xr_stride = arg_parser.get_int("xr_stride");
if(xr_stride < 0)
xr_stride = n;
ck_tile::index_t y_stride = arg_parser.get_int("y_stride");
if(y_stride < 0)
y_stride = n;
ck_tile::index_t yr_stride = arg_parser.get_int("yr_stride");
if(yr_stride < 0)
yr_stride = n;
float epsilon = arg_parser.get_float("e");
std::string prec_i = arg_parser.get_str("prec_i");
std::string prec_o = arg_parser.get_str("prec_o");
......@@ -89,7 +101,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
return false;
}
assert(stride >= n);
assert(x_stride >= n);
using TypeConfig = LayerNormTypeConfig<InDataType, OutDataType, XScaleDataType, YScaleDataType>;
......@@ -108,15 +120,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
using ComputeDataType = typename TypeConfig::ComputeDataType;
// host verify
ck_tile::HostTensor<XDataType> x_host({m, n}, {stride, 1});
ck_tile::HostTensor<XDataType> x_host({m, n}, {x_stride, 1});
ck_tile::HostTensor<GammaDataType> gamma_host({n});
ck_tile::HostTensor<BetaDataType> beta_host({n});
ck_tile::HostTensor<XResidualDataType> x_residual_host({m, n}, {stride, 1});
ck_tile::HostTensor<YResidualDataType> y_residual_host({m, n}, {stride, 1});
ck_tile::HostTensor<XResidualDataType> x_residual_host({m, n}, {xr_stride, 1});
ck_tile::HostTensor<YResidualDataType> y_residual_host({m, n}, {yr_stride, 1});
ck_tile::HostTensor<YDataType> y_host_ref({m, n}, {stride, 1});
ck_tile::HostTensor<YDataType> y_host_dev({m, n}, {stride, 1});
ck_tile::HostTensor<YDataType> y_host_ref({m, n}, {y_stride, 1});
ck_tile::HostTensor<YDataType> y_host_dev({m, n}, {y_stride, 1});
ck_tile::HostTensor<MeanDataType> mean_host_ref({m});
ck_tile::HostTensor<InvStdDataType> invStd_host_ref({m});
......@@ -127,9 +139,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::HostTensor<XScaleDataType> x_scale_host_dev({n});
ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host);
ck_tile::FillUniformDistribution<XResidualDataType>{-.5f, .5f}(x_residual_host);
ck_tile::FillUniformDistribution<XScaleDataType>{-1.f, 1.f}(x_scale_host);
ck_tile::FillUniformDistribution<GammaDataType>{-.5f, .5f}(gamma_host);
ck_tile::FillUniformDistribution<BetaDataType>{-.5f, .5f}(beta_host);
ck_tile::FillUniformDistribution<XScaleDataType>{-1.f, 1.f}(x_scale_host);
ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes());
......@@ -161,7 +174,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
}();
std::cout << "[" << prec_str << "]"
<< " m:" << m << ", n:" << n << ", stride:" << stride << std::flush;
<< " m:" << m << ", n:" << n << ", x_stride:" << x_stride
<< ", xr_stride:" << xr_stride << ", y_stride:" << y_stride
<< ", yr_stride:" << yr_stride << std::flush;
layernorm2d_fwd_traits traits{
prec_i, prec_o, prec_sx, prec_sy, SaveMeanVar, fused_add, fused_quant};
......@@ -181,7 +196,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
epsilon,
m,
n,
stride};
x_stride, // x row_stride
xr_stride, // x residule row stride
y_stride, // y row stride
yr_stride}; // y residule row stride
float ave_time = layernorm2d_fwd(
traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});
......@@ -212,7 +230,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
x_host.mData.cend(),
x_residual_host.mData.cbegin(),
x_host.mData.begin(),
std::plus<XDataType>{});
[](auto x_, auto r_) {
auto o_ = ck_tile::type_convert<ComputeDataType>(x_) +
ck_tile::type_convert<ComputeDataType>(r_);
return ck_tile::type_convert<XDataType>(o_);
});
}
ck_tile::reference_layernorm2d_fwd<XDataType,
GammaDataType,
......@@ -280,32 +302,35 @@ bool run(const ck_tile::ArgParser& arg_parser)
y_buf.FromDevice(y_host_dev.data());
ck_tile::HostTensor<YResidualDataType> sy_host_dev({m, n}, {stride, 1});
ck_tile::HostTensor<YResidualDataType> y_residual_host_dev({m, n}, {yr_stride, 1});
if(fused_add == 1)
{
y_residual_buf.FromDevice(sy_host_dev.data());
y_residual_buf.FromDevice(y_residual_host_dev.data());
}
auto [rtol, atol] = get_elimit<InDataType>();
if(stride == n)
if(x_stride == n)
{
pass = ck_tile::check_err(
y_host_dev, y_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol);
if(fused_add == 1)
{
pass &= ck_tile::check_err(
sy_host_dev, x_host, std::string("ADD Error: Incorrect results!"), rtol, atol);
pass &= ck_tile::check_err(y_residual_host_dev,
x_host,
std::string("ADD Error: Incorrect results!"),
rtol,
atol);
}
}
else
{
for(int i_r = 0; i_r < m; i_r++)
{
std::vector<YDataType> y_host_dev_row(y_host_dev.begin() + i_r * stride,
y_host_dev.begin() + i_r * stride + n);
std::vector<YDataType> y_host_ref_row(y_host_ref.begin() + i_r * stride,
y_host_ref.begin() + i_r * stride + n);
std::vector<YDataType> y_host_dev_row(y_host_dev.begin() + i_r * y_stride,
y_host_dev.begin() + i_r * y_stride + n);
std::vector<YDataType> y_host_ref_row(y_host_ref.begin() + i_r * y_stride,
y_host_ref.begin() + i_r * y_stride + n);
pass &= ck_tile::check_err(y_host_dev_row,
y_host_ref_row,
std::string("OUT[") + std::to_string(i_r) +
......@@ -314,12 +339,13 @@ bool run(const ck_tile::ArgParser& arg_parser)
atol);
if(fused_add == 1)
{
std::vector<YResidualDataType> sy_host_dev_row(
sy_host_dev.begin() + i_r * stride, sy_host_dev.begin() + i_r * stride + n);
std::vector<YResidualDataType> sy_host_ref_row(
x_host.begin() + i_r * stride, x_host.begin() + i_r * stride + n);
pass &= ck_tile::check_err(sy_host_dev_row,
sy_host_ref_row,
std::vector<YResidualDataType> y_residual_host_dev_row(
y_residual_host_dev.begin() + i_r * yr_stride,
y_residual_host_dev.begin() + i_r * yr_stride + n);
std::vector<YResidualDataType> y_residual_host_ref_row(
x_host.begin() + i_r * yr_stride, x_host.begin() + i_r * yr_stride + n);
pass &= ck_tile::check_err(y_residual_host_dev_row,
y_residual_host_ref_row,
std::string("ADD[") + std::to_string(i_r) +
std::string("] Error: Incorrect results!"),
rtol,
......
# run from top of ck folder
EXE=build/bin/tile_example_layernorm2d_fwd
#!/bin/sh
EXE="$(find . -name tile_example_layernorm2d_fwd -type f | head -n 1)"
$EXE -m=1 -n=1 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
......
#!/bin/sh
# call from top of CK folder
EXE=./build/bin/tile_example_layernorm2d_fwd
EXE="$(find . -name tile_example_layernorm2d_fwd -type f | head -n 1)"
for fquant in "" "-fquant=1 -prec_o=int8"; do
for pr_i in "fp16" "bf16" ; do
......
add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp)
add_executable(tile_example_gemm_mem_pipeline EXCLUDE_FROM_ALL gemm_mem_pipeline.cpp)
add_executable(tile_example_universal_gemm EXCLUDE_FROM_ALL universal_gemm.cpp)
......@@ -8,7 +8,10 @@ This folder contains example for GEMM using ck_tile tile-programming implementat
mkdir build && cd build
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
sh ../script/cmake-ck-dev.sh ../ <arch>
# The basic pipeline method on the gemm calculation
make tile_example_gemm_basic -j
# The memory bound pipeline on the gemm calculation
make tile_example_gemm_mem_pipeline -j
```
This will result in an executable `build/bin/tile_example_gemm_basic`
......
......@@ -17,10 +17,11 @@
template <typename ALayout, typename BLayout, typename CLayout>
float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
{
// The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part.
constexpr bool kPadA = true;
constexpr bool kPadB = true;
constexpr bool kPadC = true;
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
constexpr bool kPadM = false;
constexpr bool kPadN = false;
constexpr bool kPadK = false;
constexpr bool kTilePermute = false;
// The rank and permutation will also be generate out by the CodeGen part.
constexpr ck_tile::index_t kOutputRank = 2;
......@@ -56,8 +57,8 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
CShuffleEpilogue,
ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<AccDataType,
CDataType,
kPadA,
kPadB,
kPadM,
kPadN,
kTilePermute,
kOutputRank,
1,
......@@ -65,13 +66,13 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
TilePartitioner::kM,
TilePartitioner::kN>>,
ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadA, kPadB>>>;
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>>;
using CodegenGemmTraits =
ck_tile::TileGemmTraits<kPadA, kPadB, kPadC, ALayout, BLayout, CLayout>;
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
using CodegenPipelineProblem = ck_tile::
GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>;
using CodegenGemmPolicy = ck_tile::UniversalGemmPipelineAgBgCrPolicy<ALayout, BLayout, CLayout>;
using CodegenGemmPolicy = ck_tile::UniversalGemmPipelineAgBgCrPolicy;
using CodegenGemmPipeline =
ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem, CodegenGemmPolicy>;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
......
......@@ -31,15 +31,13 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
float ave_time = gemm_calc<ALayout, BLayout, CLayout>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
std::string op_name{"Gemm{MemBoundPipeline}"};
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_byte =
sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "Run " << op_name << "kernel with M =" << M << " N =" << N << " K =" << K
std::cout << "Run Gemm kernel with M =" << M << " N =" << N << " K =" << K
<< " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C
<< " : " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< std::endl;
......@@ -114,7 +112,6 @@ int run_gemm_example_with_layouts(int argc,
f_host_tensor_descriptor(M, N, stride_C, CLayout{}));
// TODO: add different init types
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n);
......@@ -202,14 +199,16 @@ int run_gemm_example(int argc, char* argv[])
{
return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{});
}
else if(a_layout == "C" && b_layout == "C")
{
return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{});
}
else if(a_layout == "C" && b_layout == "R")
{
return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{});
}
// TODO: Fixme: with latest changes to GemmPipelineAGmemBGmemCRegV1DefaultPolicy below do not
// work.
// else if(a_layout == "C" && b_layout == "C")
// {
// return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{});
// }
// else if(a_layout == "C" && b_layout == "R")
// {
// return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{});
// }
else
{
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
......
......@@ -14,12 +14,34 @@
#include "ck_tile/host.hpp"
#include "gemm_basic.hpp"
#define CK_TILE_PIPELINE_COMPUTE 1
#define CK_TILE_PIPELINE_MEMORY 2
#ifndef CK_TILE_PIPELINE_DEFAULT
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE
#endif
template <typename ALayout, typename BLayout, typename CLayout>
float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
{
// ToDo: This will be modified by the codegen code later.
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
// Memory friendly for Interwave scheduler
constexpr ck_tile::index_t M_Tile = 128;
constexpr ck_tile::index_t N_Tile = 128;
constexpr ck_tile::index_t N_Tile = 32;
constexpr ck_tile::index_t K_Tile = 64;
constexpr ck_tile::index_t M_Warp = 4;
constexpr ck_tile::index_t N_Warp = 1;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 8;
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
// Compute friendly for Intrawave scheduler
constexpr ck_tile::index_t M_Tile = 256;
constexpr ck_tile::index_t N_Tile = 256;
constexpr ck_tile::index_t K_Tile = 32;
constexpr ck_tile::index_t M_Warp = 2;
......@@ -28,12 +50,12 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 8;
constexpr ck_tile::index_t K_Warp_Tile = 16;
#endif
// The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part.
constexpr bool kPadA = true;
constexpr bool kPadB = true;
constexpr bool kPadC = true;
constexpr bool kPadM = false;
constexpr bool kPadN = false;
constexpr bool kPadK = false;
constexpr int kBlockPerCu = 1;
......@@ -46,11 +68,14 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
using TilePartitioner = ck_tile::GemmTilePartitioner<GemmShape>;
using GemmEpilogue = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, false, kPadC>>;
using Traits = ck_tile::TileGemmTraits<kPadA, kPadB, kPadC, ALayout, BLayout, CLayout>;
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>;
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem<
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<
#endif
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(args.K);
......@@ -63,13 +88,21 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<
#endif
ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
Traits,
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
ck_tile::GemmPipelineScheduler::Interwave,
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
ck_tile::GemmPipelineScheduler::Intrawave,
#endif
has_hot_loop_v,
tail_number_v>>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
......@@ -174,8 +207,8 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
{
std::ostringstream err;
err << "When there's no hot loop, this tail number \"" << tail_num
<< "\" is not supported! " << __FILE__ << ":" << __LINE__
<< ", in function: " << __func__;
<< "\" is not supported! PrefetchStages: " << BaseGemmPipeline::PrefetchStages
<< "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
}
......
......@@ -40,7 +40,7 @@ float matrix_core_swizzle(matrix_core_swizzle_traits t,
else if(t.permute.compare("0,1,3,4,2,5") == 0)
{
constexpr matrix_core_permute_style pstyle =
matrix_core_permute_style::permute_b_nr_kr_kw_nw_kv;
matrix_core_permute_style::b_nr_kr_kw_nw_kv;
using Kernel =
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
......@@ -83,7 +83,7 @@ float matrix_core_swizzle(matrix_core_swizzle_traits t,
else if(t.permute.compare("0,1,3,4,2,5") == 0)
{
constexpr matrix_core_permute_style pstyle =
matrix_core_permute_style::permute_b_nr_kr_kw_nw_kv;
matrix_core_permute_style::b_nr_kr_kw_nw_kv;
using Kernel =
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
......
......@@ -42,8 +42,8 @@ enum class matrix_core_permute_style
{
permute_b_n0_k0_n1_k1_n2_k2 = 0, // 0,1,4,2,5,3,6
permute_b_n0_n1_k0_k1_n2_k2 = 1, // 0,1,2,4,5,3,6
permute_b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5
permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv,
b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5
b_nr_kr_waveflatten = b_nr_kr_kw_nw_kv,
};
// assume this is B matrix, originally we have batch*n*k
......@@ -203,7 +203,7 @@ struct matrix_core_swizzle_kernel
else
{
// clang-format off
// permute_b_nr_kr_kw_nw_kv or permute_b_nr_kr_waveflatten
// b_nr_kr_kw_nw_kv or b_nr_kr_waveflatten
constexpr index_t Kv = Alignment;
constexpr index_t Nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t Kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
......@@ -332,7 +332,7 @@ struct matrix_core_swizzle_kernel
make_tuple(sequence<0>{}, sequence<1>{}));
return tmp_1;
#else
// permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv,
// b_nr_kr_waveflatten = b_nr_kr_kw_nw_kv,
constexpr index_t kv = Alignment;
constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
......@@ -376,13 +376,13 @@ struct matrix_core_swizzle_kernel
else
{
#if MERGE_2D_013425
// permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv
// b_nr_kr_waveflatten = b_nr_kr_kw_nw_kv
return make_tile_window(dst_view,
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
{i_n * NPerBlock, i_k * KPerBlock},
get_dst_dist());
#else
// permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv
// b_nr_kr_waveflatten = b_nr_kr_kw_nw_kv
constexpr index_t kv = Alignment;
constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
......
......@@ -264,7 +264,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
{
if(arg_parser.get_str("perm") == std::string("0,1,3,4,2,5"))
{
// permute_b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5
// b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5
matrix_core_swizzle_traits t;
t.data_type = data_type;
t.permute = arg_parser.get_str("perm");
......
......@@ -69,7 +69,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
using WarpTile = ck_tile::sequence<1, 64>;
using Vector = ck_tile::sequence<1, 1>;
using Shape = ck_tile::Rmsnorm2dShape<BlockTile, BlockWarps, WarpTile, Vector>;
using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>;
using Problem = ck_tile::Rmsnorm2dFwdPipelineProblem<XDataType,
GammaDataType,
ComputeDataType,
......
......@@ -28,7 +28,6 @@ float rmsnorm2d_fwd_b16_(rmsnorm2d_fwd_traits /*t*/,
rmsnorm2d_fwd_args a,
const ck_tile::stream_config& s)
{
#if 1
float r = -1;
// clang-format off
// rm rn tm tn vn pd rms 2p
......@@ -128,16 +127,12 @@ float rmsnorm2d_fwd_b16_(rmsnorm2d_fwd_traits /*t*/,
r = rmsnorm2d_fwd_<trait_<data_type, 1, 4, 1, 1024, 1, true, false, true>>(s, a);
}
return r;
#else
return rmsnorm2d_fwd_<trait_<data_type, 1, 1, 1, 256, 4, true, false, false>>(s, a);
#endif
// clang-format on
}
float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t, rmsnorm2d_fwd_args a, const ck_tile::stream_config& s)
{
float r = -1;
if(t.data_type.compare("fp16") == 0)
{
return rmsnorm2d_fwd_b16_<ck_tile::fp16_t>(t, a, s);
......@@ -146,8 +141,6 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t, rmsnorm2d_fwd_args a, const ck_tile:
{
return rmsnorm2d_fwd_b16_<ck_tile::bf16_t>(t, a, s);
}
if(r < 0)
else
throw std::runtime_error("Without supported instances!");
return r;
}
......@@ -97,7 +97,7 @@ struct rmsnorm2d_fwd_traits_
using WarpTile = ck_tile::sequence<Warp_M, Warp_N>;
using Vector = ck_tile::sequence<1, Vector_N_>;
using Shape = ck_tile::Rmsnorm2dShape<BlockTile, BlockWarps, WarpTile, Vector>;
using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>;
static constexpr bool kPadN = kPadN_;
static constexpr bool kSaveInvRms = kSaveInvRms_;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment