Commit 248fd588 authored by danyao12's avatar danyao12
Browse files

remove v3 spec

parent d61f4b83
...@@ -100,9 +100,6 @@ auto create_args(int argc, char* argv[]) ...@@ -100,9 +100,6 @@ auto create_args(int argc, char* argv[])
"v3_atomic_fp32", "v3_atomic_fp32",
"1", "1",
"if set to 0 will use atomic fp16/bf16(w/o convert_dq kernel) when bwd_v3 is set to 1") "if set to 0 will use atomic fp16/bf16(w/o convert_dq kernel) when bwd_v3 is set to 1")
.insert("v3_spec",
"0",
"if set to 1 will call the specialized v3 kernel when bwd_v3 is set to 1")
.insert("v3_bf16_cvt", .insert("v3_bf16_cvt",
"1", "1",
"float to bf16 convert type when bwd_v3 is set to 1, 0:RTNE; 1:RTNA; 2:RTZ"); "float to bf16 convert type when bwd_v3 is set to 1, 0:RTNE; 1:RTNA; 2:RTZ");
...@@ -211,7 +208,6 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -211,7 +208,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
bool deterministic = arg_parser.get_bool("deterministic"); bool deterministic = arg_parser.get_bool("deterministic");
bool bwd_v3 = arg_parser.get_bool("bwd_v3"); bool bwd_v3 = arg_parser.get_bool("bwd_v3");
bool v3_atomic_fp32 = arg_parser.get_bool("v3_atomic_fp32"); bool v3_atomic_fp32 = arg_parser.get_bool("v3_atomic_fp32");
bool v3_spec = arg_parser.get_bool("v3_spec");
int v3_bf16_cvt = arg_parser.get_int("v3_bf16_cvt"); int v3_bf16_cvt = arg_parser.get_int("v3_bf16_cvt");
ck_tile::stream_config stream_config{nullptr, ck_tile::stream_config stream_config{nullptr,
...@@ -454,7 +450,6 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -454,7 +450,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
deterministic, deterministic,
bwd_v3, bwd_v3,
v3_atomic_fp32, v3_atomic_fp32,
v3_spec,
v3_bf16_cvt}; v3_bf16_cvt};
auto fmha_args = [&]() { auto fmha_args = [&]() {
assert(nhead % nhead_k == 0); assert(nhead % nhead_k == 0);
......
...@@ -452,7 +452,6 @@ struct fmha_bwd_traits ...@@ -452,7 +452,6 @@ struct fmha_bwd_traits
bool is_deterministic; bool is_deterministic;
bool uses_bwd_v3; bool uses_bwd_v3;
bool is_v3_atomic_fp32; bool is_v3_atomic_fp32;
bool is_v3_spec;
int how_v3_bf16_cvt; int how_v3_bf16_cvt;
// TODO: padding check is inside this api // TODO: padding check is inside this api
}; };
......
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This diff is collapsed.
...@@ -15,14 +15,10 @@ extern unsigned char bwd_bf16_causal_a16_rtz[]; ...@@ -15,14 +15,10 @@ extern unsigned char bwd_bf16_causal_a16_rtz[];
extern unsigned char bwd_bf16_causal_a32_rtna[]; extern unsigned char bwd_bf16_causal_a32_rtna[];
extern unsigned char bwd_bf16_causal_a32_rtne[]; extern unsigned char bwd_bf16_causal_a32_rtne[];
extern unsigned char bwd_bf16_causal_a32_rtz[]; extern unsigned char bwd_bf16_causal_a32_rtz[];
extern unsigned char bwd_bf16_spec_a32[];
extern unsigned char bwd_bf16_spec_causal_a32[];
extern unsigned char bwd_fp16_a16[]; extern unsigned char bwd_fp16_a16[];
extern unsigned char bwd_fp16_a32[]; extern unsigned char bwd_fp16_a32[];
extern unsigned char bwd_fp16_causal_a16[]; extern unsigned char bwd_fp16_causal_a16[];
extern unsigned char bwd_fp16_causal_a32[]; extern unsigned char bwd_fp16_causal_a32[];
extern unsigned char bwd_fp16_spec_a32[];
extern unsigned char bwd_fp16_spec_causal_a32[];
extern unsigned char bwd_hd64_bf16_a16_rtna[]; extern unsigned char bwd_hd64_bf16_a16_rtna[];
extern unsigned char bwd_hd64_bf16_a16_rtne[]; extern unsigned char bwd_hd64_bf16_a16_rtne[];
extern unsigned char bwd_hd64_bf16_a16_rtz[]; extern unsigned char bwd_hd64_bf16_a16_rtz[];
......
...@@ -12,16 +12,14 @@ for prec in "fp16" "bf16" ; do ...@@ -12,16 +12,14 @@ for prec in "fp16" "bf16" ; do
for perm in 0 1 ; do for perm in 0 1 ; do
for hdim in 128 ; do for hdim in 128 ; do
for v3_atomic_fp32 in 0 1 ; do for v3_atomic_fp32 in 0 1 ; do
for v3_spec in 0 1 ; do
for mask in 0 1 ; do for mask in 0 1 ; do
$EXE -prec=$prec -b=4 -h=2 -d=$hdim -s=512 -iperm=$perm -operm=$perm -mask=$mask -bwd_v3=1 -v3_atomic_fp32=$v3_atomic_fp32 -v3_spec=$v3_spec -mode=0 -kname=$KNAME $COMMON_ARGS $EXE -prec=$prec -b=4 -h=2 -d=$hdim -s=512 -iperm=$perm -operm=$perm -mask=$mask -bwd_v3=1 -v3_atomic_fp32=$v3_atomic_fp32 -mode=0 -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=1 -h=3 -d=$hdim -s=768 -iperm=$perm -operm=$perm -mask=$mask -bwd_v3=1 -v3_atomic_fp32=$v3_atomic_fp32 -v3_spec=$v3_spec -mode=0 -kname=$KNAME $COMMON_ARGS $EXE -prec=$prec -b=1 -h=3 -d=$hdim -s=768 -iperm=$perm -operm=$perm -mask=$mask -bwd_v3=1 -v3_atomic_fp32=$v3_atomic_fp32 -mode=0 -kname=$KNAME $COMMON_ARGS
done done
done done
done done
done done
done done
done
set +x set +x
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment