"...composable_kernel_rocm.git" did not exist on "f4dfc060b79987580da9afc481dad746d5b3178d"
Commit 8ac3eb39 authored by danyao12's avatar danyao12
Browse files

asm code update

parent 67b160c5
...@@ -59,7 +59,7 @@ set(EXAMPLE_FMHA_BWD "tile_example_fmha_bwd") ...@@ -59,7 +59,7 @@ set(EXAMPLE_FMHA_BWD "tile_example_fmha_bwd")
# to be included in "make all/install/check" # to be included in "make all/install/check"
message("adding example ${EXAMPLE_FMHA_BWD}") message("adding example ${EXAMPLE_FMHA_BWD}")
add_executable(${EXAMPLE_FMHA_BWD} EXCLUDE_FROM_ALL hsaco/bwd_bf16_a32.cpp hsaco/bwd_bf16_causal_a32.cpp hsaco/bwd_bf16_nocoex_a32.cpp hsaco/bwd_bf16_nocoex_causal_a32.cpp hsaco/bwd_fp16_a16.cpp hsaco/bwd_fp16_a32.cpp hsaco/bwd_fp16_causal_a32.cpp hsaco/bwd_fp16_nocoex_a32.cpp hsaco/bwd_fp16_nocoex_causal_a32.cpp fmha_bwd.cpp) add_executable(${EXAMPLE_FMHA_BWD} EXCLUDE_FROM_ALL hsaco/bwd_bf16_a16.cpp hsaco/bwd_bf16_a32.cpp hsaco/bwd_bf16_causal_a16.cpp hsaco/bwd_bf16_causal_a32.cpp hsaco/bwd_bf16_nocoex_a32.cpp hsaco/bwd_bf16_nocoex_causal_a32.cpp hsaco/bwd_fp16_a16.cpp hsaco/bwd_fp16_a32.cpp hsaco/bwd_fp16_causal_a16.cpp hsaco/bwd_fp16_causal_a32.cpp hsaco/bwd_fp16_nocoex_a32.cpp hsaco/bwd_fp16_nocoex_causal_a32.cpp fmha_bwd.cpp)
target_include_directories(${EXAMPLE_FMHA_BWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) target_include_directories(${EXAMPLE_FMHA_BWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_sources(${EXAMPLE_FMHA_BWD} PRIVATE ${FMHA_BWD_GEN_BLOBS}) target_sources(${EXAMPLE_FMHA_BWD} PRIVATE ${FMHA_BWD_GEN_BLOBS})
......
...@@ -220,6 +220,8 @@ struct __attribute__((packed)) fmha_bwd_asm_args ...@@ -220,6 +220,8 @@ struct __attribute__((packed)) fmha_bwd_asm_args
p3 _p13; p3 _p13;
unsigned int BAs; unsigned int BAs;
p3 _p14; p3 _p14;
unsigned int Seqs;
p3 _p15;
}}; }};
struct fmha_bwd_ext_traits struct fmha_bwd_ext_traits
...@@ -294,7 +296,7 @@ float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a) ...@@ -294,7 +296,7 @@ float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a)
}} }}
template <typename dot_do_o_trait_> template <typename dot_do_o_trait_>
float fmha_ext_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned char bwd_ext_asm[], const std::string& bwd_ext_name) float fmha_ext_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned char bwd_ext_asm[], const std::string& bwd_ext_name, bool io_perm)
{{ {{
if(s.log_level_ > 0) if(s.log_level_ > 0)
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << bwd_ext_name << std::flush; std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << bwd_ext_name << std::flush;
...@@ -311,9 +313,21 @@ float fmha_ext_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned ...@@ -311,9 +313,21 @@ float fmha_ext_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned
args.scalar = a.scale; args.scalar = a.scale;
args.log2e = ck_tile::log2e_v<float>; args.log2e = ck_tile::log2e_v<float>;
args.seq_len = a.seqlen_q; args.seq_len = a.seqlen_q;
args.Ts = 128 * a.hdim_q * 2;
args.Hs = a.seqlen_q * a.hdim_q * 2; int stride_tg = 128 * a.hdim_q * 2;
args.BAs = a.nhead_q * a.seqlen_q * a.hdim_q * 2; int stride_head = a.seqlen_q * a.hdim_q * 2;
int stride_batch = a.nhead_q * a.seqlen_q * a.hdim_q * 2;
int stride_seqlen = a.hdim_q * 2;
if(io_perm == 0) //BSHD
{{
stride_seqlen = a.nhead_q * a.hdim_q * 2;
stride_tg = 128 * stride_seqlen;
stride_head = a.hdim_q * 2;
}}
args.Ts = stride_tg;
args.Hs = stride_head;
args.BAs = stride_batch;
args.Seqs = stride_seqlen;
auto traits = fmha_bwd_ext_traits{{a.batch, auto traits = fmha_bwd_ext_traits{{a.batch,
a.nhead_q, a.nhead_q,
a.seqlen_q, a.seqlen_q,
...@@ -330,7 +344,7 @@ float fmha_ext_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned ...@@ -330,7 +344,7 @@ float fmha_ext_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned
}} }}
template <typename dot_do_o_trait_, typename convert_dq_trait_> template <typename dot_do_o_trait_, typename convert_dq_trait_>
float fmha_ext_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned char bwd_ext_asm[], const std::string& bwd_ext_name) float fmha_ext_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned char bwd_ext_asm[], const std::string& bwd_ext_name, bool io_perm)
{{ {{
if(s.log_level_ > 0) if(s.log_level_ > 0)
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << bwd_ext_name << ", " << fmha_bwd_convert_dq_get_name_<convert_dq_trait_>() << std::flush; std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << bwd_ext_name << ", " << fmha_bwd_convert_dq_get_name_<convert_dq_trait_>() << std::flush;
...@@ -347,9 +361,21 @@ float fmha_ext_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned ...@@ -347,9 +361,21 @@ float fmha_ext_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned
args.scalar = a.scale; args.scalar = a.scale;
args.log2e = ck_tile::log2e_v<float>; args.log2e = ck_tile::log2e_v<float>;
args.seq_len = a.seqlen_q; args.seq_len = a.seqlen_q;
args.Ts = 128 * a.hdim_q * 2;
args.Hs = a.seqlen_q * a.hdim_q * 2; int stride_tg = 128 * a.hdim_q * 2;
args.BAs = a.nhead_q * a.seqlen_q * a.hdim_q * 2; int stride_head = a.seqlen_q * a.hdim_q * 2;
int stride_batch = a.nhead_q * a.seqlen_q * a.hdim_q * 2;
int stride_seqlen = a.hdim_q * 2;
if(io_perm == 0) //BSHD
{{
stride_seqlen = a.nhead_q * a.hdim_q * 2;
stride_tg = 128 * stride_seqlen;
stride_head = a.hdim_q * 2;
}}
args.Ts = stride_tg;
args.Hs = stride_head;
args.BAs = stride_batch;
args.Seqs = stride_seqlen;
auto traits = fmha_bwd_ext_traits{{a.batch, auto traits = fmha_bwd_ext_traits{{a.batch,
a.nhead_q, a.nhead_q,
a.seqlen_q, a.seqlen_q,
...@@ -371,45 +397,79 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -371,45 +397,79 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
if (t.uses_ext_asm == true){{ if (t.uses_ext_asm == true){{
if ((t.is_group_mode == false) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && if ((t.is_group_mode == false) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) &&
(a.seqlen_q == a.seqlen_k) && (a.seqlen_k % 128 == 0) && (a.hdim_q == 128) && (a.hdim_v == 128) && (t.is_deterministic == false)) {{ (a.seqlen_q == a.seqlen_k) && (a.seqlen_k % 128 == 0) && (a.hdim_q == 128) && (a.hdim_v == 128) && (t.is_deterministic == false) &&
(a.stride_q == a.stride_dq /*i_perm == o_perm*/) && (a.stride_k == a.stride_dk /*i_perm == o_perm*/) &&
(a.stride_v == a.stride_dv /*i_perm == o_perm*/) && (a.nhead_q == a.nhead_k /*MQA/GQA not supported yet*/)) {{
if(t.data_type.compare("fp16") == 0){{ if(t.data_type.compare("fp16") == 0){{
if(t.mask_type == mask_enum::no_mask){{ if(t.mask_type == mask_enum::no_mask){{
if(t.is_asm_atomic_fp32 == true){{ if((t.is_asm_atomic_fp32 == true) && (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>;
const std::string bwd_ext_name = "bwd_ext_fp16_a32"; const std::string bwd_ext_name = "bwd_ext_fp16_a32";
r = fmha_ext_bwd_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_fp16_a32, bwd_ext_name); bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_ext_bwd_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_fp16_a32, bwd_ext_name, io_perm);
return r; return r;
}} }}
else{{ else if(t.is_asm_atomic_fp32 == false){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
const std::string bwd_ext_name = "bwd_ext_fp16_a16"; const std::string bwd_ext_name = "bwd_ext_fp16_a16";
r = fmha_ext_bwd_<dot_do_o_trait_>(s, a, bwd_fp16_a16, bwd_ext_name); bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_ext_bwd_<dot_do_o_trait_>(s, a, bwd_fp16_a16, bwd_ext_name, io_perm);
return r; return r;
}} }}
}} }}
else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{ else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; if((t.is_asm_atomic_fp32 == true) && (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
const std::string bwd_ext_name = "bwd_ext_fp16_causal_a32"; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>;
r = fmha_ext_bwd_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_fp16_causal_a32, bwd_ext_name); const std::string bwd_ext_name = "bwd_ext_fp16_causal_a32";
return r; bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_ext_bwd_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_fp16_causal_a32, bwd_ext_name, io_perm);
return r;
}}
else if(t.is_asm_atomic_fp32 == false){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
const std::string bwd_ext_name = "bwd_ext_fp16_causal_a16";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_ext_bwd_<dot_do_o_trait_>(s, a, bwd_fp16_causal_a16, bwd_ext_name, io_perm);
return r;
}}
}} }}
}} }}
else if(t.data_type.compare("bf16") == 0){{ else if(t.data_type.compare("bf16") == 0){{
if(t.mask_type == mask_enum::no_mask){{ if(t.mask_type == mask_enum::no_mask){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; if((t.is_asm_atomic_fp32 == true) && (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
const std::string bwd_ext_name = "bwd_ext_bf16_a32"; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
r = fmha_ext_bwd_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_a32, bwd_ext_name); const std::string bwd_ext_name = "bwd_ext_bf16_a32";
return r; bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_ext_bwd_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_a32, bwd_ext_name, io_perm);
return r;
}}
else if(t.is_asm_atomic_fp32 == false){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
const std::string bwd_ext_name = "bwd_ext_bf16_a16";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_ext_bwd_<dot_do_o_trait_>(s, a, bwd_bf16_a16, bwd_ext_name, io_perm);
return r;
}}
}} }}
else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{ else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; if((t.is_asm_atomic_fp32 == true) && (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
const std::string bwd_ext_name = "bwd_ext_bf16_causal_a32"; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
r = fmha_ext_bwd_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_causal_a32, bwd_ext_name); const std::string bwd_ext_name = "bwd_ext_bf16_causal_a32";
return r; bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_ext_bwd_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_causal_a32, bwd_ext_name, io_perm);
return r;
}}
else if(t.is_asm_atomic_fp32 == false){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
const std::string bwd_ext_name = "bwd_ext_bf16_causal_a16";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_ext_bwd_<dot_do_o_trait_>(s, a, bwd_bf16_causal_a16, bwd_ext_name, io_perm);
return r;
}}
}} }}
}} }}
}} }}
...@@ -688,14 +748,14 @@ class FmhaBwdDQDKDVKernel: ...@@ -688,14 +748,14 @@ class FmhaBwdDQDKDVKernel:
def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str) -> Optional[dict]: def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str) -> Optional[dict]:
if dtype == 'fp16' or dtype == 'bf16': if dtype == 'fp16' or dtype == 'bf16':
return { return {
'32' : [FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1), # '32' : [FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
"kr_ktr_vr_iglp", "kr_ktr_vr"], # "kr_ktr_vr_iglp", "kr_ktr_vr"],
'64' : [FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), # '64' : [FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
"kr_ktr_vr_iglp", "kr_ktr_vr"], # "kr_ktr_vr_iglp", "kr_ktr_vr"],
'128' : [FmhaBwdDQDKDVTileSize( 16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), '128' : [FmhaBwdDQDKDVTileSize( 16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
"kr_ktr_vr_iglp", "kr_ktr_vr"], "kr_ktr_vr_iglp", "kr_ktr_vr"],
'256' : [FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), # '256' : [FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
"kr_ktr_vr_iglp", "kr_ktr_vr"] # "kr_ktr_vr_iglp", "kr_ktr_vr"]
} }
else: else:
return None return None
...@@ -738,7 +798,7 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> ...@@ -738,7 +798,7 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
continue continue
if receipt == 3: if receipt == 3:
cond = dtype in ['fp16', 'bf16'] cond = dtype in ['fp16', 'bf16']
cond &= bias in ['no', 'alibi'] cond &= bias in ['no']
cond &= dpad == dvpad cond &= dpad == dvpad
cond &= deterministic == "f" cond &= deterministic == "f"
if not cond: if not cond:
......
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This diff is collapsed.
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -3,13 +3,15 @@ ...@@ -3,13 +3,15 @@
#pragma once #pragma once
extern unsigned char bwd_arg[]; extern unsigned char bwd_bf16_a16[];
extern unsigned char bwd_bf16_a32[]; extern unsigned char bwd_bf16_a32[];
extern unsigned char bwd_bf16_causal_a16[];
extern unsigned char bwd_bf16_causal_a32[]; extern unsigned char bwd_bf16_causal_a32[];
extern unsigned char bwd_bf16_nocoex_a32[]; extern unsigned char bwd_bf16_nocoex_a32[];
extern unsigned char bwd_bf16_nocoex_causal_a32[]; extern unsigned char bwd_bf16_nocoex_causal_a32[];
extern unsigned char bwd_fp16_a16[]; extern unsigned char bwd_fp16_a16[];
extern unsigned char bwd_fp16_a32[]; extern unsigned char bwd_fp16_a32[];
extern unsigned char bwd_fp16_causal_a16[];
extern unsigned char bwd_fp16_causal_a32[]; extern unsigned char bwd_fp16_causal_a32[];
extern unsigned char bwd_fp16_nocoex_a32[]; extern unsigned char bwd_fp16_nocoex_a32[];
extern unsigned char bwd_fp16_nocoex_causal_a32[]; extern unsigned char bwd_fp16_nocoex_causal_a32[];
#!/bin/sh
# TODO: run this script from CK root or build directory
EXE="$(find . -name tile_example_fmha_bwd -type f | head -n 1)"
KNAME=1
export CK_WARMUP=0
export CK_REPEAT=1
COMMON_ARGS='-v=1'
set -x
for prec in "fp16" "bf16" ; do
for perm in 1 ; do
for hdim in 128 ; do
for asm_atomic_fp32 in 0 1 ; do
for mask in 0 1 ; do
$EXE -prec=$prec -b=4 -h=2 -d=$hdim -s=512 -iperm=$perm -operm=$perm -mask=$mask -ext_asm=1 -asm_atomic_fp32=$asm_atomic_fp32 -v=1 -mode=0 -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=1 -h=3 -d=$hdim -s=768 -iperm=$perm -operm=$perm -mask=$mask -ext_asm=1 -asm_atomic_fp32=$asm_atomic_fp32 -v=1 -mode=0 -kname=$KNAME $COMMON_ARGS
done
done
done
done
done
set +x
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment