Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
72c9f129
Commit
72c9f129
authored
Sep 20, 2024
by
Jun Liu
Browse files
Merge branch 'amd-develop' into amd-master
parents
241c261f
ded0d83d
Changes
235
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2356 additions
and
633 deletions
+2356
-633
example/ck_tile/01_fmha/fmha_bwd.cpp
example/ck_tile/01_fmha/fmha_bwd.cpp
+49
-18
example/ck_tile/01_fmha/fmha_bwd.hpp
example/ck_tile/01_fmha/fmha_bwd.hpp
+95
-11
example/ck_tile/01_fmha/fmha_fwd.cpp
example/ck_tile/01_fmha/fmha_fwd.cpp
+616
-173
example/ck_tile/01_fmha/fmha_fwd.hpp
example/ck_tile/01_fmha/fmha_fwd.hpp
+292
-33
example/ck_tile/01_fmha/generate.py
example/ck_tile/01_fmha/generate.py
+16
-11
example/ck_tile/01_fmha/rotary.hpp
example/ck_tile/01_fmha/rotary.hpp
+84
-0
example/ck_tile/01_fmha/script/benchmark_bwd.sh
example/ck_tile/01_fmha/script/benchmark_bwd.sh
+2
-3
example/ck_tile/01_fmha/script/benchmark_fwd.sh
example/ck_tile/01_fmha/script/benchmark_fwd.sh
+2
-3
example/ck_tile/01_fmha/script/smoke_test_bwd.sh
example/ck_tile/01_fmha/script/smoke_test_bwd.sh
+14
-13
example/ck_tile/01_fmha/script/smoke_test_fwd.sh
example/ck_tile/01_fmha/script/smoke_test_fwd.sh
+94
-41
example/ck_tile/01_fmha/utils.hpp
example/ck_tile/01_fmha/utils.hpp
+97
-13
include/ck/ck.hpp
include/ck/ck.hpp
+3
-3
include/ck/filesystem.hpp
include/ck/filesystem.hpp
+135
-0
include/ck/host_utility/device_prop.hpp
include/ck/host_utility/device_prop.hpp
+6
-0
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp
...ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp
+43
-0
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp
...pu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp
+282
-99
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp
...operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp
+142
-162
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp
...ion/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp
+6
-6
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp
...evice_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp
+5
-5
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
...device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
+373
-39
No files found.
example/ck_tile/01_fmha/fmha_bwd.cpp
View file @
72c9f129
...
...
@@ -87,7 +87,11 @@ auto create_args(int argc, char* argv[])
.
insert
(
"drop_offset"
,
"0"
,
"offset for random number generator"
)
.
insert
(
"timer"
,
"gpu"
,
"gpu:gpu timer, cpu:cpu timer"
)
.
insert
(
"warmup"
,
"5"
,
"number of iterations before benchmark the kernel"
)
.
insert
(
"repeat"
,
"20"
,
"number of iterations to benchmark the kernel"
);
.
insert
(
"repeat"
,
"20"
,
"number of iterations to benchmark the kernel"
)
.
insert
(
"deterministic"
,
"0"
,
"if set to 1 will use multi-buffer reduction strategy for dq, atomic opeartion "
"will not be used"
);
bool
result
=
arg_parser
.
parse
(
argc
,
argv
);
return
std
::
make_tuple
(
result
,
arg_parser
);
...
...
@@ -128,11 +132,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
index_t
hdim_v
=
arg_parser
.
get_int
(
"d_v"
);
if
(
hdim_v
<
0
)
hdim_v
=
hdim_q
;
if
(
hdim_q
%
2
!=
0
||
hdim_v
%
2
!=
0
)
{
std
::
cerr
<<
"FMHA Bwd kernel currently only supports even headdim"
<<
std
::
endl
;
return
false
;
}
bool
i_perm
=
arg_parser
.
get_bool
(
"iperm"
);
// if true, will be batch * nhead * seqlen * hdim
bool
o_perm
=
arg_parser
.
get_bool
(
"operm"
);
// if false, will be batch * seqlen * nhead * hdim
...
...
@@ -180,6 +179,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
int
stream_warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
stream_repeat
=
arg_parser
.
get_int
(
"repeat"
);
bool
kname
=
arg_parser
.
get_bool
(
"kname"
);
bool
deterministic
=
arg_parser
.
get_bool
(
"deterministic"
);
ck_tile
::
stream_config
stream_config
{
nullptr
,
true
,
...
...
@@ -265,6 +265,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
(
mode
==
mode_enum
::
batch
?
seqlen_q
:
seqstart_q_host
.
back
());
const
ck_tile
::
index_t
shape_seqlen_k
=
(
mode
==
mode_enum
::
batch
?
seqlen_k
:
seqstart_k_host
.
back
());
const
ck_tile
::
index_t
kN0
=
(
hdim_q
<=
128
)
?
128
:
64
;
const
ck_tile
::
index_t
nsplits
=
deterministic
?
ck_tile
::
integer_divide_ceil
(
max_seqlen_k
,
kN0
)
:
1
;
ck_tile
::
HostTensor
<
QDataType
>
q_host
(
get_lengths
(
i_perm
,
shape_batch
,
nhead
,
shape_seqlen_q
,
hdim_q
));
...
...
@@ -284,9 +287,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
HostTensor
<
ODataType
>
o_host
(
get_lengths
(
o_perm
,
shape_batch
,
nhead
,
shape_seqlen_q
,
hdim_v
));
ck_tile
::
HostTensor
<
LSEDataType
>
lse_host
(
std
::
array
<
ck_tile
::
index_t
,
3
>
{
batch
,
nhead
,
max
_seqlen_q
});
std
::
array
<
ck_tile
::
index_t
,
3
>
{
shape_
batch
,
nhead
,
shape
_seqlen_q
});
ck_tile
::
HostTensor
<
DDataType
>
d_host
(
std
::
array
<
ck_tile
::
index_t
,
3
>
{
batch
,
nhead
,
max
_seqlen_q
});
std
::
array
<
ck_tile
::
index_t
,
3
>
{
shape_
batch
,
nhead
,
shape
_seqlen_q
});
ck_tile
::
HostTensor
<
RandValOutputDataType
>
randval_host
(
p_drop
>
0
?
get_lengths
(
true
,
shape_batch
,
nhead
,
shape_seqlen_q
,
max_seqlen_k
)
:
std
::
array
<
ck_tile
::
index_t
,
4
>
{
1
,
1
,
1
,
1
});
...
...
@@ -302,6 +305,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
use_dbias
?
get_lengths
(
i_perm
,
shape_batch
,
nhead
,
shape_seqlen_q
,
max_seqlen_k
)
:
std
::
array
<
ck_tile
::
index_t
,
4
>
{
1
,
1
,
1
,
1
}
/* dummy shape for simplifying code */
);
ck_tile
::
HostTensor
<
AccDataType
>
dq_acc_host
(
i_perm
?
std
::
array
<
ck_tile
::
index_t
,
5
>
{
nsplits
,
shape_batch
,
nhead
,
shape_seqlen_q
,
hdim_q
}
:
std
::
array
<
ck_tile
::
index_t
,
5
>
{
nsplits
,
shape_batch
,
shape_seqlen_q
,
nhead
,
hdim_q
});
if
(
init_method
==
0
)
{
...
...
@@ -362,6 +369,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
DeviceMem
seqstart_q
(
seqstart_q_host
.
size
()
*
sizeof
(
int32_t
));
ck_tile
::
DeviceMem
seqstart_k
(
seqstart_k_host
.
size
()
*
sizeof
(
int32_t
));
ck_tile
::
DeviceMem
alibi_slope_buf
(
alibi_slope_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
dq_acc_buf
(
dq_acc_host
.
get_element_space_size_in_bytes
());
q_buf
.
ToDevice
(
q_host
.
data
());
k_buf
.
ToDevice
(
k_host
.
data
());
...
...
@@ -387,8 +395,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
std
::
cout
<<
"["
<<
prec
<<
"|"
<<
mode
<<
"|"
<<
io_layout
(
i_perm
,
o_perm
)
<<
"] b:"
<<
batch
<<
", h:"
<<
nhead
<<
"/"
<<
nhead_k
<<
", s:"
<<
seqlen_q
<<
"/"
<<
seqlen_k
<<
", d:"
<<
hdim_q
<<
"/"
<<
hdim_v
<<
", scale:"
<<
scale
<<
", bias:"
<<
bias
<<
", dbias:"
<<
use_dbias
<<
", p_drop:"
<<
p_drop
<<
", mask:"
<<
mask
<<
std
::
flush
;
<<
", dbias:"
<<
use_dbias
<<
", p_drop:"
<<
p_drop
<<
", s_randval:"
<<
s_randval
<<
", deterministic:"
<<
deterministic
<<
", mask:"
<<
mask
<<
std
::
flush
;
std
::
size_t
workspace_size
=
dq_acc_host
.
get_element_space_size_in_bytes
()
*
sizeof
(
AccDataType
)
/
(
1024
*
1024
);
if
(
deterministic
==
1
)
{
std
::
cout
<<
"
\n
Deterministic mode ON: "
<<
workspace_size
<<
" MByte memory workspace allocated"
<<
std
::
endl
;
}
auto
fmha_traits
=
fmha_bwd_traits
{
hdim_q
,
hdim_v
,
...
...
@@ -397,7 +414,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
mask
.
type
,
bias
.
type
,
use_dbias
,
p_drop
>
0.0
f
};
p_drop
>
0.0
f
,
s_randval
,
deterministic
};
auto
fmha_args
=
[
&
]()
{
assert
(
nhead
%
nhead_k
==
0
);
/// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q,
...
...
@@ -422,7 +441,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
const
ck_tile
::
index_t
nhead_stride_o
=
(
o_perm
?
shape_seqlen_q
*
hdim_v
:
hdim_v
);
const
ck_tile
::
index_t
nhead_stride_randval
=
(
shape_seqlen_q
*
max_seqlen_k
);
const
ck_tile
::
index_t
nhead_stride_do
=
(
o_perm
?
shape_seqlen_q
*
hdim_v
:
hdim_v
);
const
ck_tile
::
index_t
nhead_stride_lsed
=
max
_seqlen_q
;
const
ck_tile
::
index_t
nhead_stride_lsed
=
shape
_seqlen_q
;
const
ck_tile
::
index_t
nhead_stride_dbias
=
(
i_perm
?
shape_seqlen_q
*
max_seqlen_k
:
max_seqlen_k
);
// setup batch_stride_* arguments
...
...
@@ -433,10 +452,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
const
ck_tile
::
index_t
batch_stride_o
=
(
nhead
*
shape_seqlen_q
*
hdim_v
);
const
ck_tile
::
index_t
batch_stride_randval
=
(
nhead
*
shape_seqlen_q
*
max_seqlen_k
);
const
ck_tile
::
index_t
batch_stride_do
=
(
nhead
*
shape_seqlen_q
*
hdim_v
);
const
ck_tile
::
index_t
batch_stride_lsed
=
(
nhead
*
max
_seqlen_q
);
const
ck_tile
::
index_t
batch_stride_lsed
=
(
nhead
*
shape
_seqlen_q
);
const
ck_tile
::
index_t
batch_stride_dk
=
(
nhead
*
shape_seqlen_k
*
hdim_q
);
const
ck_tile
::
index_t
batch_stride_dv
=
(
nhead
*
shape_seqlen_k
*
hdim_v
);
const
ck_tile
::
index_t
batch_stride_dbias
=
(
nhead
*
shape_seqlen_q
*
max_seqlen_k
);
const
ck_tile
::
index_t
split_stride_dq_acc
=
(
shape_batch
*
nhead
*
shape_seqlen_q
*
hdim_q
);
return
fmha_bwd_args
{
q_buf
.
GetDeviceBuffer
(),
k_buf
.
GetDeviceBuffer
(),
...
...
@@ -452,6 +473,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
dk_buf
.
GetDeviceBuffer
(),
dv_buf
.
GetDeviceBuffer
(),
dbias_buf
.
GetDeviceBuffer
(),
dq_acc_buf
.
GetDeviceBuffer
(),
seqstart_q
.
GetDeviceBuffer
(),
seqstart_k
.
GetDeviceBuffer
(),
nullptr
,
...
...
@@ -473,6 +495,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
stride_o
,
stride_randval
,
stride_do
,
stride_q
,
// stride_dq_acc
stride_q
,
// stride_dq
stride_dk
,
stride_dv
,
stride_dbias
,
...
...
@@ -484,6 +508,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
nhead_stride_randval
,
nhead_stride_do
,
nhead_stride_lsed
,
nhead_stride_q
,
// nhead_stride_dq_acc
nhead_stride_q
,
// nhead_stride_dq
nhead_stride_k
,
// nhead_stride_dk
nhead_stride_v
,
// nhead_stride_dv
nhead_stride_dbias
,
batch_stride_q
,
batch_stride_k
,
...
...
@@ -493,15 +521,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
batch_stride_randval
,
batch_stride_do
,
batch_stride_lsed
,
batch_stride_q
,
// batch_stride_dq_acc
batch_stride_q
,
// batch_stride_dq
batch_stride_dk
,
batch_stride_dv
,
batch_stride_dbias
,
split_stride_dq_acc
,
mask
.
left
,
mask
.
right
,
static_cast
<
ck_tile
::
index_t
>
(
mask
.
type
),
p_drop
,
p_undrop
,
s_randval
,
{
drop_seed
,
drop_offset
}};
}();
...
...
@@ -719,7 +749,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
if
(
o_perm
)
o_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
o_host
(
b
,
idx
[
0
],
idx
[
1
]
+
query_offset
,
idx
[
2
])
=
self
(
idx
);
});
else
o_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
o_host
(
b
,
idx
[
1
]
+
query_offset
,
idx
[
0
],
idx
[
2
])
=
self
(
idx
);
});
lse_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
lse_host
(
w
b
,
idx
[
0
],
idx
[
1
])
=
self
(
idx
);
});
lse_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
lse_host
(
b
,
idx
[
0
],
idx
[
1
]
+
query_offset
)
=
self
(
idx
);
});
// clang-format on
q_host_refs
.
push_back
(
q_host_ref
);
...
...
@@ -738,6 +768,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
lse_buf
.
ToDevice
(
lse_host
.
data
());
dq_buf
.
SetZero
();
dbias_buf
.
SetZero
();
dq_acc_buf
.
SetZero
();
ck_tile
::
stream_config
stream_config_v
{
nullptr
,
true
,
0
,
0
,
1
,
arg_parser
.
get_str
(
"timer"
)
==
std
::
string
(
"gpu"
)};
...
...
example/ck_tile/01_fmha/fmha_bwd.hpp
View file @
72c9f129
...
...
@@ -77,6 +77,7 @@ struct fmha_bwd_args
void
*
dk_ptr
;
void
*
dv_ptr
;
void
*
dbias_ptr
;
void
*
dq_acc_ptr
;
const
void
*
seqstart_q_ptr
;
const
void
*
seqstart_k_ptr
;
const
void
*
seqlen_k_ptr
;
...
...
@@ -97,6 +98,8 @@ struct fmha_bwd_args
ck_tile
::
index_t
stride_o
;
ck_tile
::
index_t
stride_randval
;
ck_tile
::
index_t
stride_do
;
ck_tile
::
index_t
stride_dq_acc
;
ck_tile
::
index_t
stride_dq
;
ck_tile
::
index_t
stride_dk
;
ck_tile
::
index_t
stride_dv
;
ck_tile
::
index_t
stride_dbias
;
...
...
@@ -108,6 +111,10 @@ struct fmha_bwd_args
ck_tile
::
index_t
nhead_stride_randval
;
ck_tile
::
index_t
nhead_stride_do
;
ck_tile
::
index_t
nhead_stride_lsed
;
ck_tile
::
index_t
nhead_stride_dq_acc
;
ck_tile
::
index_t
nhead_stride_dq
;
ck_tile
::
index_t
nhead_stride_dk
;
ck_tile
::
index_t
nhead_stride_dv
;
ck_tile
::
index_t
nhead_stride_dbias
;
ck_tile
::
index_t
batch_stride_q
;
ck_tile
::
index_t
batch_stride_k
;
...
...
@@ -117,15 +124,17 @@ struct fmha_bwd_args
ck_tile
::
index_t
batch_stride_randval
;
ck_tile
::
index_t
batch_stride_do
;
ck_tile
::
index_t
batch_stride_lsed
;
ck_tile
::
index_t
batch_stride_dq_acc
;
ck_tile
::
index_t
batch_stride_dq
;
ck_tile
::
index_t
batch_stride_dk
;
ck_tile
::
index_t
batch_stride_dv
;
ck_tile
::
index_t
batch_stride_dbias
;
ck_tile
::
index_t
split_stride_dq_acc
;
ck_tile
::
index_t
window_size_left
;
ck_tile
::
index_t
window_size_right
;
ck_tile
::
index_t
mask_type
;
float
p_drop
;
float
p_undrop
;
bool
s_randval
;
std
::
tuple
<
uint64_t
,
uint64_t
>
drop_seed_offset
;
};
...
...
@@ -145,10 +154,10 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args
.
do_ptr
,
args
.
d_ptr
,
args
.
rand_val_ptr
,
args
.
dq_ptr
,
args
.
dk_ptr
,
args
.
dv_ptr
,
args
.
dbias_ptr
,
args
.
dq_acc_ptr
,
args
.
seqstart_q_ptr
,
args
.
seqstart_k_ptr
,
args
.
seqlen_k_ptr
,
...
...
@@ -163,6 +172,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args
.
stride_bias
,
args
.
stride_randval
,
args
.
stride_do
,
args
.
stride_dq_acc
,
args
.
stride_dk
,
args
.
stride_dv
,
args
.
stride_dbias
,
...
...
@@ -173,13 +183,15 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args
.
nhead_stride_randval
,
args
.
nhead_stride_do
,
args
.
nhead_stride_lsed
,
args
.
nhead_stride_dq_acc
,
args
.
nhead_stride_dk
,
args
.
nhead_stride_dv
,
args
.
nhead_stride_dbias
,
args
.
batch
_stride_
lsed
,
args
.
split
_stride_
dq_acc
,
args
.
window_size_left
,
args
.
window_size_right
,
args
.
mask_type
,
args
.
p_drop
,
args
.
s_randval
,
args
.
drop_seed_offset
);
}
else
...
...
@@ -192,10 +204,10 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args
.
do_ptr
,
args
.
d_ptr
,
args
.
rand_val_ptr
,
args
.
dq_ptr
,
args
.
dk_ptr
,
args
.
dv_ptr
,
args
.
dbias_ptr
,
args
.
dq_acc_ptr
,
args
.
seqlen_q
,
args
.
seqlen_k
,
args
.
hdim_q
,
...
...
@@ -209,6 +221,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args
.
stride_bias
,
args
.
stride_randval
,
args
.
stride_do
,
args
.
stride_dq_acc
,
args
.
stride_dk
,
args
.
stride_dv
,
args
.
stride_dbias
,
...
...
@@ -219,6 +232,9 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args
.
nhead_stride_randval
,
args
.
nhead_stride_do
,
args
.
nhead_stride_lsed
,
args
.
nhead_stride_dq_acc
,
args
.
nhead_stride_dk
,
args
.
nhead_stride_dv
,
args
.
nhead_stride_dbias
,
args
.
batch_stride_q
,
args
.
batch_stride_k
,
...
...
@@ -227,14 +243,15 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args
.
batch_stride_randval
,
args
.
batch_stride_do
,
args
.
batch_stride_lsed
,
args
.
batch_stride_dq_acc
,
args
.
batch_stride_dk
,
args
.
batch_stride_dv
,
args
.
batch_stride_dbias
,
args
.
split_stride_dq_acc
,
args
.
window_size_left
,
args
.
window_size_right
,
args
.
mask_type
,
args
.
p_drop
,
args
.
s_randval
,
args
.
drop_seed_offset
);
}
}();
...
...
@@ -260,8 +277,7 @@ auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args)
args
.
stride_o
,
args
.
nhead_stride_do
,
args
.
nhead_stride_o
,
args
.
nhead_stride_lsed
,
args
.
batch_stride_lsed
);
args
.
nhead_stride_lsed
);
}
else
{
// create batch mode kernel arguments
...
...
@@ -286,19 +302,59 @@ auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args)
return
ck_tile
::
make_tuple
(
kargs
,
grids
);
}
template
<
typename
FmhaBwdConvertQGradKernel
>
auto
fmha_bwd_convert_dq_create_kargs_and_grids
(
fmha_bwd_args
args
)
{
auto
kargs
=
[
&
]
{
// create group mode kernel arguments
if
constexpr
(
FmhaBwdConvertQGradKernel
::
kIsGroupMode
)
{
return
FmhaBwdConvertQGradKernel
::
MakeKargs
(
args
.
dq_acc_ptr
,
args
.
dq_ptr
,
args
.
seqstart_q_ptr
,
args
.
seqstart_k_ptr
,
args
.
hdim_q
,
args
.
stride_dq
,
args
.
stride_dq_acc
,
args
.
nhead_stride_dq
,
args
.
nhead_stride_dq_acc
,
args
.
split_stride_dq_acc
);
}
else
{
// create batch mode kernel arguments
return
FmhaBwdConvertQGradKernel
::
MakeKargs
(
args
.
dq_acc_ptr
,
args
.
dq_ptr
,
args
.
seqlen_q
,
args
.
seqlen_k
,
args
.
hdim_q
,
args
.
stride_dq
,
args
.
stride_dq_acc
,
args
.
nhead_stride_dq
,
args
.
nhead_stride_dq_acc
,
args
.
batch_stride_dq
,
args
.
batch_stride_dq_acc
,
args
.
split_stride_dq_acc
);
}
}();
dim3
grids
=
FmhaBwdConvertQGradKernel
::
GridSize
(
args
.
batch
,
args
.
nhead_q
,
args
.
max_seqlen_q
);
return
ck_tile
::
make_tuple
(
kargs
,
grids
);
}
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template
<
ck_tile
::
index_t
HDim_
,
typename
DataType_
,
bool
kIsGroupMode_
,
ck_tile
::
BlockFmhaBwdPipelineEnum
FmhaBwdPipelineEnum_
,
typename
FmhaMask_
,
typename
FmhaDropout_
,
ck_tile
::
BlockAttentionBiasEnum
BiasEnum_
,
bool
kHasBiasGrad_
,
bool
kHasDropout_
,
bool
kPadS_
,
bool
kPadSK_
,
bool
kPadD_
,
bool
kPadDv_
>
bool
kPadDv_
,
bool
kIsDeterministic_
>
struct
fmha_bwd_dq_dk_dv_traits_
{
static
constexpr
ck_tile
::
index_t
HDim
=
HDim_
;
...
...
@@ -306,13 +362,14 @@ struct fmha_bwd_dq_dk_dv_traits_
static
constexpr
bool
kIsGroupMode
=
kIsGroupMode_
;
static
constexpr
auto
FmhaBwdPipelineEnum
=
FmhaBwdPipelineEnum_
;
using
FmhaMask
=
ck_tile
::
remove_cvref_t
<
FmhaMask_
>
;
using
FmhaDropout
=
ck_tile
::
remove_cvref_t
<
FmhaDropout_
>
;
static
constexpr
auto
BiasEnum
=
BiasEnum_
;
static
constexpr
bool
kHasBiasGrad
=
kHasBiasGrad_
;
static
constexpr
bool
kHasDropout
=
kHasDropout_
;
static
constexpr
bool
kPadS
=
kPadS_
;
static
constexpr
bool
kPadSK
=
kPadSK_
;
static
constexpr
bool
kPadD
=
kPadD_
;
static
constexpr
bool
kPadDv
=
kPadDv_
;
static
constexpr
bool
kIsDeterministic
=
kIsDeterministic_
;
};
template
<
typename
Traits_
>
...
...
@@ -343,6 +400,31 @@ void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config&, fmha_bwd_args);
template
<
typename
Traits_
>
std
::
string
fmha_bwd_dot_do_o_get_name_
();
template
<
ck_tile
::
index_t
HDim_
,
typename
DataType_
,
bool
kIsGroupMode_
,
bool
kPadS_
,
bool
kPadD_
,
bool
kIsDeterministic_
>
struct
fmha_bwd_convert_dq_traits_
{
static
constexpr
ck_tile
::
index_t
HDim
=
HDim_
;
using
DataType
=
ck_tile
::
remove_cvref_t
<
DataType_
>
;
static
constexpr
bool
kIsGroupMode
=
kIsGroupMode_
;
static
constexpr
bool
kPadS
=
kPadS_
;
static
constexpr
bool
kPadD
=
kPadD_
;
static
constexpr
bool
kIsDeterministic
=
kIsDeterministic_
;
};
template
<
typename
Traits_
>
float
fmha_bwd_convert_dq_
(
const
ck_tile
::
stream_config
&
,
fmha_bwd_args
);
template
<
typename
Traits_
>
void
fmha_bwd_convert_dq_oneshot_
(
const
ck_tile
::
stream_config
&
,
fmha_bwd_args
);
template
<
typename
Traits_
>
std
::
string
fmha_bwd_convert_dq_get_name_
();
// This is the public API, will be generated by script
struct
fmha_bwd_traits
{
...
...
@@ -354,6 +436,8 @@ struct fmha_bwd_traits
bias_enum
bias_type
;
// 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
bool
has_dbias
;
bool
has_dropout
;
bool
is_store_randval
;
bool
is_deterministic
;
// TODO: padding check is inside this api
};
float
fmha_bwd
(
fmha_bwd_traits
,
fmha_bwd_args
,
const
ck_tile
::
stream_config
&
);
example/ck_tile/01_fmha/fmha_fwd.cpp
View file @
72c9f129
...
...
@@ -4,6 +4,7 @@
#include "fmha_fwd.hpp"
#include "ck_tile/host.hpp"
#include "mask.hpp"
#include "rotary.hpp"
#include "utils.hpp"
#include <array>
...
...
@@ -16,6 +17,10 @@
#include <utility>
#include <vector>
#if CK_TILE_FMHA_FWD_APPENDKV_API && !CK_TILE_FMHA_FWD_SPLITKV_API
#error "we should enable fmha_fwd_splitkv() api in order to cooperate with fmha_fwd_appendkv()"
#endif
template
<
typename
T
>
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
std
::
vector
<
T
>&
v
)
{
...
...
@@ -50,7 +55,11 @@ auto create_args(int argc, char* argv[])
"seqlen_q. if group-mode, means the average value of seqlen_q
\n
"
"total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary
\n
"
"also with
\"
-s=s0,s1,s2...
\"
comma seperated int to set per batch seqlen(group-mode)"
)
.
insert
(
"s_k"
,
"-1"
,
"seqlen_k, -1 means equal to s"
)
.
insert
(
"s_k"
,
"-1"
,
"seqlen_k (including new key/value), -1 means equal to s"
)
.
insert
(
"s_knew"
,
"0"
,
"seqlen_k for new key/value, 0 means not to use this at all; "
"-1 to choose s_knew in [1, s] randomly."
)
.
insert
(
"s_kpad"
,
"-1"
,
"seqlen_k stride between 2 tokens, currently used in group-mode only
\n
"
...
...
@@ -114,9 +123,14 @@ auto create_args(int argc, char* argv[])
.
insert
(
"drop_seed"
,
"1"
,
"seed for random number generator"
)
.
insert
(
"drop_offset"
,
"0"
,
"offset for random number generator"
)
.
insert
(
"timer"
,
"gpu"
,
"gpu:gpu timer, cpu:cpu timer"
)
.
insert
(
"rotary_dim"
,
"0"
,
"RoPE rotary dimension. rotary_dim <= 0 means not apply RoPE at all"
)
.
insert
(
"rotary_interleaved"
,
"1"
,
"whether to apply interleaved RoPE"
)
.
insert
(
"num_splits"
,
"1"
,
"# of splits for key/value. 0 to determine actual number by heuristic"
)
.
insert
(
"page_block_size"
,
"0"
,
"paged-kvcache block size. 0 means not use paged-kvcahe"
)
.
insert
(
"cache_batch_idx"
,
"0"
,
"whether to use index map to the kvcache"
)
.
insert
(
"warmup"
,
"5"
,
"number of iterations before benchmark the kernel"
)
.
insert
(
"repeat"
,
"20"
,
"number of iterations to benchmark the kernel"
);
...
...
@@ -244,20 +258,6 @@ int override_num_splits_if_necessary(
return
num_splits
;
}
float
fmha_fwd_dispatch
(
fmha_fwd_traits
traits
,
fmha_fwd_args
args
,
const
ck_tile
::
stream_config
&
config
)
{
if
(
1
<
args
.
num_splits
)
{
return
fmha_fwd_splitkv
(
traits
,
args
,
config
);
}
else
{
return
fmha_fwd
(
traits
,
args
,
config
);
}
}
template
<
typename
DataType
>
bool
run
(
const
ck_tile
::
ArgParser
&
arg_parser
)
{
...
...
@@ -276,11 +276,114 @@ bool run(const ck_tile::ArgParser& arg_parser)
return
false
;
}
auto
[
seqlen_qs
,
seqlen_ks
,
seqlen_kpads
]
=
decode_seqlen
(
mode
,
std
::
optional
<
uint32_t
>
seed
=
arg_parser
.
get_uint32
(
"seed"
);
if
(
*
seed
==
0
)
{
seed
.
reset
();
}
ck_tile
::
index_t
hdim_q
=
arg_parser
.
get_int
(
"d"
);
ck_tile
::
index_t
hdim_v
=
arg_parser
.
get_int
(
"d_v"
);
if
(
hdim_v
<
0
)
hdim_v
=
hdim_q
;
ck_tile
::
index_t
seqlen_knew
=
arg_parser
.
get_int
(
"s_knew"
);
#if !CK_TILE_FMHA_FWD_APPENDKV_API
if
(
seqlen_knew
!=
0
)
{
std
::
cerr
<<
"kvcache is not supported. ignoring the 's_knew' option"
<<
std
::
endl
;
seqlen_knew
=
0
;
}
#endif
if
(
seqlen_knew
<
0
)
{
seqlen_knew
=
randint
<
ck_tile
::
index_t
>
(
1
,
arg_parser
.
get_int
(
"s"
),
seed
);
}
ck_tile
::
index_t
rotary_dim
=
arg_parser
.
get_int
(
"rotary_dim"
);
if
constexpr
(
!
(
std
::
is_same_v
<
DataType
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
DataType
,
ck_tile
::
bf16_t
>
))
{
if
(
0
<
rotary_dim
)
{
std
::
cerr
<<
"rotary embedding is only available for data type=fp16|bf16"
<<
std
::
endl
;
return
false
;
}
}
#if !CK_TILE_FMHA_FWD_APPENDKV_API
else
if
(
0
<
rotary_dim
)
{
std
::
cerr
<<
"rotary embedding is not supported. ignoring the 'rotary_dim' option"
<<
std
::
endl
;
rotary_dim
=
0
;
}
#endif
if
(
!
(
rotary_dim
<=
hdim_q
))
{
std
::
cerr
<<
"rotary_dim should be less than or equal to head dim for q"
<<
std
::
endl
;
return
false
;
}
else
if
(
!
(
rotary_dim
%
16
==
0
))
{
std
::
cerr
<<
"only rotary dimensions divisible by 16 are currently supported"
<<
std
::
endl
;
return
false
;
}
ck_tile
::
index_t
page_block_size
=
arg_parser
.
get_int
(
"page_block_size"
);
#if !CK_TILE_FMHA_FWD_APPENDKV_API && !CK_TILE_FMHA_FWD_SPLITKV_API
if
(
0
<
page_block_size
)
{
std
::
cerr
<<
"paged-kvcache is not supported. ignoring the 'page_block_size' option"
<<
std
::
endl
;
page_block_size
=
0
;
}
#endif
if
(
!
(
page_block_size
%
128
==
0
))
{
std
::
cerr
<<
"only paged-kvcache block size divisible by 128 are currently supported"
<<
std
::
endl
;
return
false
;
}
bool
use_cache_batch_idx
=
arg_parser
.
get_bool
(
"cache_batch_idx"
);
#if !CK_TILE_FMHA_FWD_APPENDKV_API && !CK_TILE_FMHA_FWD_SPLITKV_API
if
(
use_cache_batch_idx
)
{
std
::
cerr
<<
"split-kv is not supported. ignoring the 'cache_batch_idx' option"
<<
std
::
endl
;
use_cache_batch_idx
=
false
;
}
#endif
if
(
0
<
page_block_size
&&
use_cache_batch_idx
)
{
std
::
cerr
<<
"paged-kvcache does not support cache_batch_idx. ignoring the "
"'cache_batch_idx' option"
<<
std
::
endl
;
use_cache_batch_idx
=
false
;
}
// the input tensor layout for kvcache is same as batch mode
const
bool
need_append_kvcache
=
(
0
<
seqlen_knew
||
0
<
rotary_dim
);
const
bool
use_kvcache
=
(
need_append_kvcache
||
use_cache_batch_idx
||
0
<
page_block_size
);
if
(
use_kvcache
&&
mode
!=
mode_enum
::
batch
)
{
std
::
cerr
<<
"kvcache enabled. ignoring the 'mode' option"
<<
std
::
endl
;
mode
=
mode_enum
::
batch
;
}
auto
[
seqlen_qs
,
seqlen_ks
,
seqlen_kpads
]
=
decode_seqlen
(
mode
,
batch
,
arg_parser
.
get_str
(
"s"
),
arg_parser
.
get_str
(
"s_k"
),
arg_parser
.
get_str
(
"s_kpad"
));
arg_parser
.
get_str
(
"s_kpad"
),
/*seqlen_k_min=*/
0
<
seqlen_knew
?
seqlen_knew
:
0
,
use_kvcache
);
// compute kvcache seqlen_k (before appending knew/vnew)
auto
cache_seqlen_ks
=
seqlen_ks
;
std
::
transform
(
cache_seqlen_ks
.
begin
(),
cache_seqlen_ks
.
end
(),
cache_seqlen_ks
.
begin
(),
[
&
](
auto
seqlen_k
)
{
return
seqlen_k
-
seqlen_knew
;
});
#if 0
// clang-format off
...
...
@@ -290,11 +393,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
// clang-format on
#endif
ck_tile
::
index_t
hdim_q
=
arg_parser
.
get_int
(
"d"
);
ck_tile
::
index_t
hdim_v
=
arg_parser
.
get_int
(
"d_v"
);
if
(
hdim_v
<
0
)
hdim_v
=
hdim_q
;
bool
i_perm
=
arg_parser
.
get_bool
(
"iperm"
);
// if true, will be batch * nhead * seqlen * hdim
bool
o_perm
=
arg_parser
.
get_bool
(
"operm"
);
// if false, will be batch * seqlen * nhead * hdim
...
...
@@ -357,13 +455,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
std
::
string
init_method
=
arg_parser
.
get_str
(
"init"
);
std
::
optional
<
uint32_t
>
seed
=
arg_parser
.
get_uint32
(
"seed"
);
if
(
*
seed
==
0
)
const
bool
is_rotary_interleaved
=
arg_parser
.
get_bool
(
"rotary_interleaved"
);
ck_tile
::
index_t
num_splits
=
arg_parser
.
get_int
(
"num_splits"
);
#if !CK_TILE_FMHA_FWD_SPLITKV_API
if
(
num_splits
!=
1
)
{
seed
.
reset
();
std
::
cerr
<<
"split-kv is not supported. ignoring the 'num_splits' option"
<<
std
::
endl
;
num_splits
=
1
;
}
int
num_splits
=
arg_parser
.
get_int
(
"num_splits"
);
#endif
int
stream_warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
stream_repeat
=
arg_parser
.
get_int
(
"repeat"
);
...
...
@@ -425,6 +527,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
}
const
ck_tile
::
index_t
max_num_page_blocks
=
(
0
<
page_block_size
?
batch
*
std
::
max
(
1
,
ck_tile
::
integer_divide_ceil
(
max_seqlen_k
,
page_block_size
))
:
0
);
// legalize num_splits according to other options
if
(
num_splits
<
1
)
{
...
...
@@ -436,6 +543,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
std
::
cerr
<<
"num_splits greater than 128 is not supported"
<<
std
::
endl
;
return
false
;
}
#if CK_TILE_FMHA_FWD_SPLITKV_API
if
(
0
<
p_drop
&&
(
1
<
num_splits
||
use_kvcache
))
{
std
::
cerr
<<
"dropout is not supoprted by split-kv kernels. ignoring the 'p_drop' option"
<<
std
::
endl
;
p_drop
=
0.0
f
;
}
#endif
auto
get_lengths
=
[
&
](
bool
permute
,
ck_tile
::
index_t
b
/*batch*/
,
...
...
@@ -462,11 +577,26 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
HostTensor
<
QDataType
>
q_host
(
get_lengths
(
i_perm
,
shape_batch
,
nhead
,
shape_seqlen_q
,
hdim_q
));
ck_tile
::
HostTensor
<
KDataType
>
k_host
(
get_lengths
(
i_perm
,
shape_batch
,
nhead_k
,
shape_seqlen_k
,
hdim_q
));
0
<
page_block_size
?
get_lengths
(
i_perm
,
max_num_page_blocks
,
nhead_k
,
page_block_size
,
hdim_q
)
:
get_lengths
(
i_perm
,
shape_batch
,
nhead_k
,
shape_seqlen_k
,
hdim_q
));
/// NOTICE: always use same shape for knew_host & vnew_host in batch/group mode
ck_tile
::
HostTensor
<
KDataType
>
knew_host
(
0
<
seqlen_knew
?
get_lengths
(
i_perm
,
batch
,
nhead_k
,
seqlen_knew
,
hdim_q
)
:
std
::
array
<
ck_tile
::
index_t
,
4
>
{
1
,
1
,
1
,
1
}
/* dummy shape for simplifying code */
);
ck_tile
::
HostTensor
<
VDataType
>
v_host
(
is_v_rowmajor
?
get_lengths
(
i_perm
,
shape_batch
,
nhead_k
,
shape_seqlen_k
,
hdim_v
)
:
get_lengths
(
i_perm
,
shape_batch
,
nhead_k
,
hdim_v
,
shape_seqlen_k
));
0
<
page_block_size
?
(
is_v_rowmajor
?
get_lengths
(
i_perm
,
max_num_page_blocks
,
nhead_k
,
page_block_size
,
hdim_v
)
:
get_lengths
(
i_perm
,
max_num_page_blocks
,
nhead_k
,
hdim_v
,
page_block_size
))
:
(
is_v_rowmajor
?
get_lengths
(
i_perm
,
shape_batch
,
nhead_k
,
shape_seqlen_k
,
hdim_v
)
:
get_lengths
(
i_perm
,
shape_batch
,
nhead_k
,
hdim_v
,
shape_seqlen_k
)));
ck_tile
::
HostTensor
<
VDataType
>
vnew_host
(
0
<
seqlen_knew
?
(
is_v_rowmajor
?
get_lengths
(
i_perm
,
batch
,
nhead_k
,
seqlen_knew
,
hdim_v
)
:
get_lengths
(
i_perm
,
batch
,
nhead_k
,
hdim_v
,
seqlen_knew
))
:
std
::
array
<
ck_tile
::
index_t
,
4
>
{
1
,
1
,
1
,
1
}
/* dummy shape for simplifying code */
);
ck_tile
::
HostTensor
<
BiasDataType
>
bias_host
(
bias
.
type
==
bias_enum
::
elementwise_bias
?
get_lengths
(
i_perm
,
1
,
1
,
shape_seqlen_q
,
shape_seqlen_k
)
...
...
@@ -478,17 +608,22 @@ bool run(const ck_tile::ArgParser& arg_parser)
:
std
::
array
<
ck_tile
::
index_t
,
2
>
{
batch
,
nhead
})
:
std
::
array
<
ck_tile
::
index_t
,
2
>
{
1
,
1
});
auto
[
rotary_cos_host
,
rotary_sin_host
]
=
generate_rotary_cos_sin
<
KDataType
>
(
std
::
max
(
shape_seqlen_q
,
shape_seqlen_k
),
rotary_dim
,
seed
);
ck_tile
::
HostTensor
<
LSEDataType
>
lse_acc_host
(
1
<
num_splits
?
std
::
array
<
ck_tile
::
index_t
,
4
>
{
num_splits
,
batch
,
nhead
,
max_seqlen_q
}
1
<
num_splits
||
use_kvcache
?
std
::
array
<
ck_tile
::
index_t
,
4
>
{
num_splits
,
shape_batch
,
nhead
,
shape_seqlen_q
}
:
std
::
array
<
ck_tile
::
index_t
,
4
>
{
1
,
1
,
1
,
1
});
ck_tile
::
HostTensor
<
OaccDataType
>
o_acc_host
(
1
<
num_splits
1
<
num_splits
||
use_kvcache
?
std
::
array
<
ck_tile
::
index_t
,
5
>
{
num_splits
,
batch
,
nhead
,
max_seqlen_q
,
hdim_v
}
:
std
::
array
<
ck_tile
::
index_t
,
5
>
{
1
,
1
,
1
,
1
,
1
});
// self define lse data layout as [batch, nhead, max_seqlen_q]
// batch mode of lse data layout is [batch, nhead, seqlen_q]
// group mode of lse data layout is [nhead, total_seqlen_q]
ck_tile
::
HostTensor
<
LSEDataType
>
lse_host
(
lse
?
std
::
array
<
ck_tile
::
index_t
,
3
>
{
batch
,
nhead
,
max
_seqlen_q
}
lse
?
std
::
array
<
ck_tile
::
index_t
,
3
>
{
shape_
batch
,
nhead
,
shape
_seqlen_q
}
:
std
::
array
<
ck_tile
::
index_t
,
3
>
{
1
,
1
,
1
}
/* dummy shape for simplifying code */
);
ck_tile
::
HostTensor
<
ODataType
>
o_host
(
...
...
@@ -498,39 +633,57 @@ bool run(const ck_tile::ArgParser& arg_parser)
p_drop
>
0
?
get_lengths
(
true
,
shape_batch
,
nhead
,
shape_seqlen_q
,
max_seqlen_k
)
:
std
::
array
<
ck_tile
::
index_t
,
4
>
{
1
,
1
,
1
,
1
});
ck_tile
::
HostTensor
<
int32_t
>
block_table_host
(
0
<
page_block_size
?
std
::
array
<
ck_tile
::
index_t
,
2
>
{
batch
,
max_num_page_blocks
/
batch
}
:
std
::
array
<
ck_tile
::
index_t
,
2
>
{
1
,
1
});
ck_tile
::
HostTensor
<
int32_t
>
cache_batch_idx_host
(
use_cache_batch_idx
?
std
::
array
<
ck_tile
::
index_t
,
1
>
{
batch
}
:
std
::
array
<
ck_tile
::
index_t
,
1
>
{
1
});
if
(
init_method
==
"ui"
||
init_method
==
"0"
)
{
ck_tile
::
FillUniformDistributionIntegerValue
<
QDataType
>
{
-
3.
f
,
3.
f
,
seed
}(
q_host
);
ck_tile
::
FillUniformDistributionIntegerValue
<
KDataType
>
{
-
3.
f
,
3.
f
,
seed
}(
k_host
);
ck_tile
::
FillUniformDistributionIntegerValue
<
KDataType
>
{
-
3.
f
,
3.
f
,
seed
}(
knew_host
);
ck_tile
::
FillUniformDistributionIntegerValue
<
VDataType
>
{
-
3.
f
,
3.
f
,
seed
}(
v_host
);
ck_tile
::
FillUniformDistributionIntegerValue
<
VDataType
>
{
-
3.
f
,
3.
f
,
seed
}(
vnew_host
);
ck_tile
::
FillUniformDistributionIntegerValue
<
BiasDataType
>
{
-
3.
f
,
3.
f
,
seed
}(
bias_host
);
}
else
if
(
init_method
==
"ni"
)
{
ck_tile
::
FillNormalDistributionIntegerValue
<
QDataType
>
{
-
3.
f
,
3.
f
,
seed
}(
q_host
);
ck_tile
::
FillNormalDistributionIntegerValue
<
KDataType
>
{
-
3.
f
,
3.
f
,
seed
}(
k_host
);
ck_tile
::
FillNormalDistributionIntegerValue
<
KDataType
>
{
-
3.
f
,
3.
f
,
seed
}(
knew_host
);
ck_tile
::
FillNormalDistributionIntegerValue
<
VDataType
>
{
-
3.
f
,
3.
f
,
seed
}(
v_host
);
ck_tile
::
FillNormalDistributionIntegerValue
<
VDataType
>
{
-
3.
f
,
3.
f
,
seed
}(
vnew_host
);
ck_tile
::
FillNormalDistributionIntegerValue
<
BiasDataType
>
{
-
3.
f
,
3.
f
,
seed
}(
bias_host
);
}
else
if
(
init_method
==
"uf"
||
init_method
==
"1"
)
{
ck_tile
::
FillUniformDistribution
<
QDataType
>
{
0.
f
,
1.
f
,
seed
}(
q_host
);
ck_tile
::
FillUniformDistribution
<
KDataType
>
{
0.
f
,
1.
f
,
seed
}(
k_host
);
ck_tile
::
FillUniformDistribution
<
KDataType
>
{
0.
f
,
1.
f
,
seed
}(
knew_host
);
ck_tile
::
FillUniformDistribution
<
VDataType
>
{
0.
f
,
1.
f
,
seed
}(
v_host
);
ck_tile
::
FillUniformDistribution
<
VDataType
>
{
0.
f
,
1.
f
,
seed
}(
vnew_host
);
ck_tile
::
FillUniformDistribution
<
BiasDataType
>
{
0.
f
,
1.
f
,
seed
}(
bias_host
);
}
else
if
(
init_method
==
"nf"
)
{
ck_tile
::
FillNormalDistribution
<
QDataType
>
{
0.
f
,
3.
f
,
seed
}(
q_host
);
ck_tile
::
FillNormalDistribution
<
KDataType
>
{
0.
f
,
3.
f
,
seed
}(
k_host
);
ck_tile
::
FillNormalDistribution
<
KDataType
>
{
0.
f
,
3.
f
,
seed
}(
knew_host
);
ck_tile
::
FillNormalDistribution
<
VDataType
>
{
0.
f
,
3.
f
,
seed
}(
v_host
);
ck_tile
::
FillNormalDistribution
<
VDataType
>
{
0.
f
,
3.
f
,
seed
}(
vnew_host
);
ck_tile
::
FillNormalDistribution
<
BiasDataType
>
{
0.
f
,
3.
f
,
seed
}(
bias_host
);
}
else
if
(
init_method
==
"tf"
||
init_method
==
"2"
)
{
ck_tile
::
FillTrigValue
<
QDataType
>
{}(
q_host
);
ck_tile
::
FillTrigValue
<
KDataType
>
{}(
k_host
);
ck_tile
::
FillTrigValue
<
KDataType
>
{}(
knew_host
);
ck_tile
::
FillTrigValue
<
VDataType
>
{}(
v_host
);
ck_tile
::
FillTrigValue
<
VDataType
>
{}(
vnew_host
);
ck_tile
::
FillTrigValue
<
BiasDataType
>
{}(
bias_host
);
}
else
if
(
init_method
==
"ufq"
||
init_method
==
"uf:q"
||
...
...
@@ -538,7 +691,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
{
ck_tile
::
FillUniformDistribution
<
QDataType
>
{
-
dtype_max
,
dtype_max
,
seed
}(
q_host
);
ck_tile
::
FillUniformDistribution
<
KDataType
>
{
-
dtype_max
,
dtype_max
,
seed
}(
k_host
);
ck_tile
::
FillUniformDistribution
<
KDataType
>
{
-
dtype_max
,
dtype_max
,
seed
}(
knew_host
);
ck_tile
::
FillUniformDistribution
<
VDataType
>
{
-
dtype_max
,
dtype_max
,
seed
}(
v_host
);
ck_tile
::
FillUniformDistribution
<
VDataType
>
{
-
dtype_max
,
dtype_max
,
seed
}(
vnew_host
);
// bias_fp8 = qscale_bias * bias_fp32
float
qscale_bias
=
(
dtype_max
/
range_q
)
*
(
dtype_max
/
range_k
);
...
...
@@ -548,7 +703,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
if
(
bias
.
type
==
bias_enum
::
alibi
)
{
auto
slopes
=
ck_tile
::
get_alibi_slopes
<
SaccDataType
>
(
nhead
);
assert
(
slopes
.
size
()
==
nhead
);
assert
(
slopes
.
size
()
==
static_cast
<
std
::
size_t
>
(
nhead
)
)
;
if
(
bias
.
rank_info
==
0
)
{
// alibi in 1*h
...
...
@@ -563,10 +718,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
}
}
iota_shuffle
(
block_table_host
.
begin
(),
block_table_host
.
end
(),
0
);
iota_shuffle
(
cache_batch_idx_host
.
begin
(),
cache_batch_idx_host
.
end
(),
0
);
ck_tile
::
DeviceMem
q_buf
(
q_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
k_buf
(
k_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
knew_buf
(
knew_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
v_buf
(
v_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
vnew_buf
(
vnew_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
bias_buf
(
bias_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
lse_acc_buf
(
lse_acc_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
o_acc_buf
(
o_acc_host
.
get_element_space_size_in_bytes
());
...
...
@@ -574,27 +733,41 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
DeviceMem
o_buf
(
o_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
seqstart_q
(
seqstart_q_host
.
size
()
*
sizeof
(
int32_t
));
ck_tile
::
DeviceMem
seqstart_k
(
seqstart_k_host
.
size
()
*
sizeof
(
int32_t
));
ck_tile
::
DeviceMem
seqlen_k_buf
(
seqlen_kpads
[
0
]
<
0
?
0
:
seqlen_ks
.
size
()
*
sizeof
(
int32_t
));
ck_tile
::
DeviceMem
seqlen_k_buf
(
use_kvcache
||
0
<=
seqlen_kpads
[
0
]
?
seqlen_ks
.
size
()
*
sizeof
(
int32_t
)
:
0
);
ck_tile
::
DeviceMem
cache_seqlen_k_buf
(
need_append_kvcache
?
cache_seqlen_ks
.
size
()
*
sizeof
(
int32_t
)
:
0
);
ck_tile
::
DeviceMem
rotary_cos_buf
(
rotary_cos_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
rotary_sin_buf
(
rotary_sin_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
randval_buf
(
randval_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
alibi_slope_buf
(
alibi_slope_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
block_table_buf
(
block_table_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
cache_batch_idx_buf
(
cache_batch_idx_host
.
get_element_space_size_in_bytes
());
q_buf
.
ToDevice
(
q_host
.
data
());
k_buf
.
ToDevice
(
k_host
.
data
());
knew_buf
.
ToDevice
(
knew_host
.
data
());
v_buf
.
ToDevice
(
v_host
.
data
());
vnew_buf
.
ToDevice
(
vnew_host
.
data
());
bias_buf
.
ToDevice
(
bias_host
.
data
());
seqstart_q
.
ToDevice
(
seqstart_q_host
.
data
());
seqstart_k
.
ToDevice
(
seqlen_kpads
[
0
]
<
0
?
seqstart_k_host
.
data
()
:
seqstart_k_with_padding_host
.
data
());
seqlen_k_buf
.
ToDevice
(
seqlen_kpads
[
0
]
<
0
?
nullptr
:
seqlen_ks
.
data
());
seqlen_k_buf
.
ToDevice
(
use_kvcache
||
0
<=
seqlen_kpads
[
0
]
?
seqlen_ks
.
data
()
:
nullptr
);
cache_seqlen_k_buf
.
ToDevice
(
need_append_kvcache
?
cache_seqlen_ks
.
data
()
:
nullptr
);
rotary_cos_buf
.
ToDevice
(
rotary_cos_host
.
data
());
rotary_sin_buf
.
ToDevice
(
rotary_sin_host
.
data
());
alibi_slope_buf
.
ToDevice
(
alibi_slope_host
.
data
());
block_table_buf
.
ToDevice
(
block_table_host
.
data
());
cache_batch_idx_buf
.
ToDevice
(
cache_batch_idx_host
.
data
());
// clang-format off
auto
layout_str
=
[
&
](
bool
permute
){
if
(
permute
)
return
std
::
string
(
"bhsd"
);
if
(
permute
)
return
std
::
string
(
"bhsd"
);
else
return
std
::
string
(
"bshd"
);
};
auto
io_layout
=
[
&
](
bool
iperm_
,
bool
operm_
)
{
if
(
iperm_
==
operm_
)
return
layout_str
(
iperm_
);
if
(
iperm_
==
operm_
)
return
layout_str
(
iperm_
);
else
return
layout_str
(
iperm_
)
+
std
::
string
(
"-"
)
+
layout_str
(
operm_
);
};
// clang-format on
...
...
@@ -607,39 +780,57 @@ bool run(const ck_tile::ArgParser& arg_parser)
<<
", d:"
<<
hdim_q
<<
"/"
<<
hdim_v
<<
", scale_s:"
<<
scale_s
<<
", bias:"
<<
bias
<<
", p_drop:"
<<
p_drop
<<
", lse:"
<<
lse
<<
", squant:"
<<
squant
<<
", mask:"
<<
mask
<<
", v:"
<<
vlayout
;
#if CK_TILE_FMHA_FWD_APPENDKV_API
if
(
0
<
rotary_dim
)
{
std
::
cout
<<
", rotary_dim:"
<<
rotary_dim
<<
"("
<<
(
is_rotary_interleaved
?
"inter"
:
"half"
)
<<
")"
;
}
#endif
#if CK_TILE_FMHA_FWD_SPLITKV_API
if
(
1
<
num_splits
)
{
std
::
cout
<<
", num_splits:"
<<
num_splits
;
}
if
(
0
<
page_block_size
)
{
std
::
cout
<<
", page_block_size:"
<<
page_block_size
;
}
if
(
use_cache_batch_idx
)
{
std
::
cout
<<
", cache_batch_idx:"
<<
use_cache_batch_idx
;
}
#endif
std
::
cout
<<
std
::
flush
;
auto
fmha_traits
=
fmha_fwd_traits
{
hdim_q
,
hdim_v
,
data_type
,
mode
==
mode_enum
::
group
,
is_v_rowmajor
,
mask
.
type
,
bias
.
type
,
lse
,
p_drop
>
0.0
f
,
squant
};
const
auto
init_traits
=
[
&
](
auto
&
traits
)
{
traits
.
hdim_q
=
hdim_q
;
traits
.
hdim_v
=
hdim_v
;
traits
.
data_type
=
data_type
;
traits
.
is_v_rowmajor
=
is_v_rowmajor
;
auto
p_compute_element_func
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck_tile
::
fp8_t
>
)
return
ck_tile
::
scales
{
scale_p
};
else
return
ck_tile
::
identity
{};
}();
if
constexpr
(
std
::
is_same_v
<
fmha_fwd_appendkv_traits
,
std
::
decay_t
<
decltype
(
traits
)
>>
)
{
traits
.
rope_type
=
(
0
<
rotary_dim
?
(
is_rotary_interleaved
?
rope_enum
::
interleaved
:
rope_enum
::
half_rotated
)
:
rope_enum
::
none
);
}
else
// fmha_fwd_traits or fmha_splitkv_traits
{
traits
.
is_group_mode
=
(
mode
==
mode_enum
::
group
);
traits
.
mask_type
=
mask
.
type
;
traits
.
bias_type
=
bias
.
type
;
traits
.
has_lse
=
lse
;
traits
.
do_fp8_static_quant
=
squant
;
auto
oacc_element_func
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck_tile
::
fp8_t
>
)
return
ck_tile
::
composes
(
ck_tile
::
saturates
<
ck_tile
::
fp8_t
>
{},
ck_tile
::
scales
{
scale_o
});
else
return
ck_tile
::
identity
{};
}();
if
constexpr
(
std
::
is_same_v
<
fmha_fwd_traits
,
std
::
decay_t
<
decltype
(
traits
)
>>
)
{
traits
.
has_dropout
=
(
p_drop
>
0.0
f
);
}
}
};
auto
fmha
_args
=
[
&
,
k_paddings_
=
seqlen_kpads
]()
{
const
auto
init
_args
=
[
&
,
k_paddings_
=
seqlen_kpads
](
auto
&
args
)
{
assert
(
nhead
%
nhead_k
==
0
);
/// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q,
/// seqlen_k] in this example, hence both the 'batch_stride_bias' &
...
...
@@ -647,11 +838,19 @@ bool run(const ck_tile::ArgParser& arg_parser)
// setup stride_* arguments
const
ck_tile
::
index_t
stride_q
=
(
i_perm
?
hdim_q
:
nhead
*
hdim_q
);
const
ck_tile
::
index_t
stride_k
=
(
i_perm
?
hdim_q
:
nhead_k
*
hdim_q
);
const
ck_tile
::
index_t
stride_knew
=
(
i_perm
?
hdim_q
:
nhead_k
*
hdim_q
);
const
ck_tile
::
index_t
stride_v
=
[
&
]()
{
if
(
is_v_rowmajor
)
return
i_perm
?
hdim_v
:
nhead_k
*
hdim_v
;
else
return
i_perm
?
shape_seqlen_k
:
nhead_k
*
shape_seqlen_k
;
return
0
<
page_block_size
?
(
i_perm
?
page_block_size
:
nhead_k
*
page_block_size
)
:
(
i_perm
?
shape_seqlen_k
:
nhead_k
*
shape_seqlen_k
);
}();
const
ck_tile
::
index_t
stride_vnew
=
[
&
]()
{
if
(
is_v_rowmajor
)
return
i_perm
?
hdim_v
:
nhead_k
*
hdim_v
;
else
return
i_perm
?
seqlen_knew
:
nhead_k
*
seqlen_knew
;
}();
const
ck_tile
::
index_t
stride_bias
=
(
i_perm
?
shape_seqlen_k
:
1
*
shape_seqlen_k
);
const
ck_tile
::
index_t
stride_randval
=
(
max_seqlen_k
);
...
...
@@ -659,103 +858,220 @@ bool run(const ck_tile::ArgParser& arg_parser)
const
ck_tile
::
index_t
stride_o
=
(
o_perm
?
hdim_v
:
nhead
*
hdim_v
);
// setup nhead_stride_* arguments
const
ck_tile
::
index_t
nhead_stride_q
=
(
i_perm
?
shape_seqlen_q
*
hdim_q
:
hdim_q
);
const
ck_tile
::
index_t
nhead_stride_k
=
(
i_perm
?
shape_seqlen_k
*
hdim_q
:
hdim_q
);
const
ck_tile
::
index_t
nhead_stride_k
=
(
0
<
page_block_size
?
(
i_perm
?
page_block_size
*
hdim_q
:
hdim_q
)
:
(
i_perm
?
shape_seqlen_k
*
hdim_q
:
hdim_q
));
const
ck_tile
::
index_t
nhead_stride_knew
=
(
i_perm
?
seqlen_knew
*
hdim_q
:
hdim_q
);
const
ck_tile
::
index_t
nhead_stride_v
=
[
&
]()
{
if
(
is_v_rowmajor
)
return
i_perm
?
shape_seqlen_k
*
hdim_v
:
hdim_v
;
return
0
<
page_block_size
?
(
i_perm
?
page_block_size
*
hdim_v
:
hdim_v
)
:
(
i_perm
?
shape_seqlen_k
*
hdim_v
:
hdim_v
);
else
return
i_perm
?
hdim_v
*
shape_seqlen_k
:
shape_seqlen_k
;
return
0
<
page_block_size
?
(
i_perm
?
hdim_v
*
page_block_size
:
page_block_size
)
:
(
i_perm
?
hdim_v
*
shape_seqlen_k
:
shape_seqlen_k
);
}();
const
ck_tile
::
index_t
nhead_stride_vnew
=
[
&
]()
{
if
(
is_v_rowmajor
)
return
i_perm
?
seqlen_knew
*
hdim_v
:
hdim_v
;
else
return
i_perm
?
hdim_v
*
seqlen_knew
:
seqlen_knew
;
}();
const
ck_tile
::
index_t
nhead_stride_bias
=
(
i_perm
?
0
*
shape_seqlen_q
*
shape_seqlen_k
:
0
*
shape_seqlen_k
);
const
ck_tile
::
index_t
nhead_stride_randval
=
(
shape_seqlen_q
*
max_seqlen_k
);
const
ck_tile
::
index_t
nhead_stride_lse
=
max
_seqlen_q
;
const
ck_tile
::
index_t
nhead_stride_lse_acc
=
max
_seqlen_q
;
const
ck_tile
::
index_t
nhead_stride_lse
=
shape
_seqlen_q
;
const
ck_tile
::
index_t
nhead_stride_lse_acc
=
shape
_seqlen_q
;
const
ck_tile
::
index_t
nhead_stride_o_acc
=
(
max_seqlen_q
*
hdim_v
);
const
ck_tile
::
index_t
nhead_stride_o
=
(
o_perm
?
shape_seqlen_q
*
hdim_v
:
hdim_v
);
// setup batch_stride_* arguments
const
ck_tile
::
index_t
batch_stride_q
=
(
nhead
*
shape_seqlen_q
*
hdim_q
);
const
ck_tile
::
index_t
batch_stride_k
=
(
nhead_k
*
shape_seqlen_k
*
hdim_q
);
const
ck_tile
::
index_t
batch_stride_v
=
(
nhead_k
*
hdim_v
*
shape_seqlen_k
);
const
ck_tile
::
index_t
batch_stride_k
=
(
0
<
page_block_size
?
(
nhead_k
*
page_block_size
*
hdim_q
)
:
(
nhead_k
*
shape_seqlen_k
*
hdim_q
));
const
ck_tile
::
index_t
batch_stride_knew
=
(
nhead_k
*
seqlen_knew
*
hdim_q
);
const
ck_tile
::
index_t
batch_stride_v
=
(
0
<
page_block_size
?
(
nhead_k
*
hdim_v
*
page_block_size
)
:
(
nhead_k
*
hdim_v
*
shape_seqlen_k
));
const
ck_tile
::
index_t
batch_stride_vnew
=
(
nhead_k
*
hdim_v
*
seqlen_knew
);
const
ck_tile
::
index_t
batch_stride_bias
=
(
0
*
nhead
*
shape_seqlen_q
*
shape_seqlen_k
);
const
ck_tile
::
index_t
batch_stride_randval
=
(
nhead
*
shape_seqlen_q
*
max_seqlen_k
);
const
ck_tile
::
index_t
batch_stride_lse
=
(
nhead
*
max
_seqlen_q
);
const
ck_tile
::
index_t
batch_stride_lse_acc
=
(
nhead
*
max
_seqlen_q
);
const
ck_tile
::
index_t
batch_stride_lse
=
(
nhead
*
shape
_seqlen_q
);
const
ck_tile
::
index_t
batch_stride_lse_acc
=
(
nhead
*
shape
_seqlen_q
);
const
ck_tile
::
index_t
batch_stride_o_acc
=
(
nhead
*
max_seqlen_q
*
hdim_v
);
const
ck_tile
::
index_t
batch_stride_o
=
(
nhead
*
shape_seqlen_q
*
hdim_v
);
const
ck_tile
::
index_t
batch_stride_block_table
=
(
max_num_page_blocks
/
batch
);
// setup split_stride_* arguments (only used in split-kv kernel)
const
ck_tile
::
index_t
split_stride_lse_acc
=
(
batch
*
nhead
*
max
_seqlen_q
);
const
ck_tile
::
index_t
split_stride_lse_acc
=
(
shape_
batch
*
nhead
*
shape
_seqlen_q
);
const
ck_tile
::
index_t
split_stride_o_acc
=
(
batch
*
nhead
*
max_seqlen_q
*
hdim_v
);
return
fmha_fwd_args
{
q_buf
.
GetDeviceBuffer
(),
k_buf
.
GetDeviceBuffer
(),
v_buf
.
GetDeviceBuffer
(),
bias
.
type
==
bias_enum
::
alibi
?
alibi_slope_buf
.
GetDeviceBuffer
()
:
bias_buf
.
GetDeviceBuffer
(),
randval_buf
.
GetDeviceBuffer
(),
lse_acc_buf
.
GetDeviceBuffer
(),
o_acc_buf
.
GetDeviceBuffer
(),
lse_buf
.
GetDeviceBuffer
(),
o_buf
.
GetDeviceBuffer
(),
seqstart_q
.
GetDeviceBuffer
(),
seqstart_k
.
GetDeviceBuffer
(),
k_paddings_
[
0
]
<
0
?
nullptr
:
seqlen_k_buf
.
GetDeviceBuffer
(),
shape_seqlen_q
,
shape_seqlen_k
,
batch
,
max_seqlen_q
,
hdim_q
,
hdim_v
,
nhead
,
nhead_k
,
num_splits
,
scale_s
,
scale_p
,
scale_o
,
stride_q
,
stride_k
,
stride_v
,
bias
.
type
==
bias_enum
::
alibi
?
(
bias
.
rank_info
==
0
?
0
:
nhead
)
:
stride_bias
,
stride_randval
,
stride_o_acc
,
stride_o
,
nhead_stride_q
,
nhead_stride_k
,
nhead_stride_v
,
nhead_stride_bias
,
nhead_stride_randval
,
nhead_stride_lse
,
nhead_stride_lse_acc
,
nhead_stride_o_acc
,
nhead_stride_o
,
batch_stride_q
,
batch_stride_k
,
batch_stride_v
,
batch_stride_bias
,
batch_stride_randval
,
batch_stride_lse
,
batch_stride_lse_acc
,
batch_stride_o_acc
,
batch_stride_o
,
split_stride_lse_acc
,
split_stride_o_acc
,
mask
.
left
,
mask
.
right
,
static_cast
<
ck_tile
::
index_t
>
(
mask
.
type
),
p_drop
,
s_randval
,
{
drop_seed
,
drop_offset
}};
args
.
q_ptr
=
q_buf
.
GetDeviceBuffer
();
args
.
k_ptr
=
k_buf
.
GetDeviceBuffer
();
args
.
v_ptr
=
v_buf
.
GetDeviceBuffer
();
args
.
batch
=
batch
;
args
.
seqlen_q
=
shape_seqlen_q
;
// unused in group mode
args
.
hdim_q
=
hdim_q
;
args
.
hdim_v
=
hdim_v
;
args
.
nhead_q
=
nhead
;
args
.
nhead_k
=
nhead_k
;
args
.
stride_q
=
stride_q
;
args
.
stride_k
=
stride_k
;
args
.
stride_v
=
stride_v
;
args
.
nhead_stride_q
=
nhead_stride_q
;
args
.
nhead_stride_k
=
nhead_stride_k
;
args
.
nhead_stride_v
=
nhead_stride_v
;
args
.
batch_stride_q
=
batch_stride_q
;
args
.
batch_stride_k
=
batch_stride_k
;
args
.
batch_stride_v
=
batch_stride_v
;
if
constexpr
(
std
::
is_same_v
<
fmha_fwd_appendkv_args
,
std
::
decay_t
<
decltype
(
args
)
>>
)
{
args
.
knew_ptr
=
knew_buf
.
GetDeviceBuffer
();
args
.
vnew_ptr
=
vnew_buf
.
GetDeviceBuffer
();
args
.
seqlen_knew
=
seqlen_knew
;
args
.
seqlen_k_ptr
=
cache_seqlen_k_buf
.
GetDeviceBuffer
();
args
.
rotary_cos_ptr
=
(
0
<
rotary_dim
?
rotary_cos_buf
.
GetDeviceBuffer
()
:
nullptr
);
args
.
rotary_sin_ptr
=
(
0
<
rotary_dim
?
rotary_sin_buf
.
GetDeviceBuffer
()
:
nullptr
);
args
.
rotary_dim
=
rotary_dim
;
args
.
has_mask
=
(
mask
.
type
!=
mask_enum
::
no_mask
);
args
.
block_table_ptr
=
(
0
<
page_block_size
?
block_table_buf
.
GetDeviceBuffer
()
:
nullptr
);
args
.
batch_stride_block_table
=
batch_stride_block_table
;
args
.
page_block_size
=
page_block_size
;
args
.
cache_batch_idx
=
(
use_cache_batch_idx
?
cache_batch_idx_buf
.
GetDeviceBuffer
()
:
nullptr
);
args
.
stride_knew
=
stride_knew
;
args
.
stride_vnew
=
stride_vnew
;
args
.
nhead_stride_knew
=
nhead_stride_knew
;
args
.
nhead_stride_vnew
=
nhead_stride_vnew
;
args
.
batch_stride_knew
=
batch_stride_knew
;
args
.
batch_stride_vnew
=
batch_stride_vnew
;
}
else
// fmha_fwd_args or fmha_fwd_splitkv_args
{
args
.
bias_ptr
=
bias
.
type
==
bias_enum
::
alibi
?
alibi_slope_buf
.
GetDeviceBuffer
()
:
bias_buf
.
GetDeviceBuffer
();
args
.
lse_ptr
=
lse_buf
.
GetDeviceBuffer
();
args
.
o_ptr
=
o_buf
.
GetDeviceBuffer
();
args
.
seqstart_q_ptr
=
(
mode
==
mode_enum
::
group
?
seqstart_q
.
GetDeviceBuffer
()
:
nullptr
);
args
.
seqstart_k_ptr
=
(
mode
==
mode_enum
::
group
?
seqstart_k
.
GetDeviceBuffer
()
:
nullptr
);
args
.
seqlen_k_ptr
=
(
use_kvcache
||
0
<=
k_paddings_
[
0
]
?
seqlen_k_buf
.
GetDeviceBuffer
()
:
nullptr
);
args
.
seqlen_k
=
shape_seqlen_k
;
// unused in group mode (or kvcache enabled)
args
.
max_seqlen_q
=
max_seqlen_q
;
args
.
scale_s
=
scale_s
;
args
.
scale_p
=
scale_p
;
args
.
scale_o
=
scale_o
;
args
.
stride_bias
=
(
bias
.
type
==
bias_enum
::
alibi
?
(
bias
.
rank_info
==
0
?
0
:
nhead
)
:
stride_bias
);
args
.
stride_o
=
stride_o
;
args
.
nhead_stride_bias
=
nhead_stride_bias
;
args
.
nhead_stride_lse
=
nhead_stride_lse
;
args
.
nhead_stride_o
=
nhead_stride_o
;
args
.
batch_stride_bias
=
batch_stride_bias
;
args
.
batch_stride_lse
=
batch_stride_lse
;
args
.
batch_stride_o
=
batch_stride_o
;
args
.
window_size_left
=
mask
.
left
;
args
.
window_size_right
=
mask
.
right
;
args
.
mask_type
=
static_cast
<
ck_tile
::
index_t
>
(
mask
.
type
);
if
constexpr
(
std
::
is_same_v
<
fmha_fwd_args
,
std
::
decay_t
<
decltype
(
args
)
>>
)
{
args
.
rand_val_ptr
=
randval_buf
.
GetDeviceBuffer
();
args
.
stride_randval
=
stride_randval
;
args
.
nhead_stride_randval
=
nhead_stride_randval
;
args
.
batch_stride_randval
=
batch_stride_randval
;
args
.
p_drop
=
p_drop
;
args
.
s_randval
=
s_randval
;
args
.
drop_seed_offset
=
std
::
tie
(
drop_seed
,
drop_offset
);
}
else
if
constexpr
(
std
::
is_same_v
<
fmha_fwd_splitkv_args
,
std
::
decay_t
<
decltype
(
args
)
>>
)
{
args
.
lse_acc_ptr
=
lse_acc_buf
.
GetDeviceBuffer
();
args
.
o_acc_ptr
=
o_acc_buf
.
GetDeviceBuffer
();
args
.
block_table_ptr
=
(
0
<
page_block_size
?
block_table_buf
.
GetDeviceBuffer
()
:
nullptr
);
args
.
batch_stride_block_table
=
batch_stride_block_table
;
args
.
page_block_size
=
page_block_size
;
args
.
cache_batch_idx
=
(
use_cache_batch_idx
?
cache_batch_idx_buf
.
GetDeviceBuffer
()
:
nullptr
);
args
.
num_splits
=
num_splits
;
args
.
stride_o_acc
=
stride_o_acc
;
args
.
nhead_stride_lse_acc
=
nhead_stride_lse_acc
;
args
.
nhead_stride_o_acc
=
nhead_stride_o_acc
;
args
.
batch_stride_lse_acc
=
batch_stride_lse_acc
;
args
.
batch_stride_o_acc
=
batch_stride_o_acc
;
args
.
split_stride_lse_acc
=
split_stride_lse_acc
;
args
.
split_stride_o_acc
=
split_stride_o_acc
;
}
}
};
const
float
appendkv_ave_time
=
[
&
]
{
#if CK_TILE_FMHA_FWD_APPENDKV_API
if
(
need_append_kvcache
)
{
fmha_fwd_appendkv_traits
fwd_appendkv_traits
;
init_traits
(
fwd_appendkv_traits
);
fmha_fwd_appendkv_args
fwd_appendkv_args
;
init_args
(
fwd_appendkv_args
);
return
fmha_fwd_appendkv
(
fwd_appendkv_traits
,
fwd_appendkv_args
,
stream_config
);
}
#endif
return
0.0
f
;
}();
float
ave_time
=
fmha_fwd_dispatch
(
fmha_traits
,
fmha_args
,
stream_config
);
const
float
fwd_ave_time
=
[
&
]
{
#if CK_TILE_FMHA_FWD_SPLITKV_API
if
(
1
<
num_splits
||
use_kvcache
)
{
fmha_fwd_splitkv_traits
fmha_splitkv_traits
;
init_traits
(
fmha_splitkv_traits
);
fmha_fwd_splitkv_args
fmha_splitkv_args
;
init_args
(
fmha_splitkv_args
);
return
fmha_fwd_splitkv
(
fmha_splitkv_traits
,
fmha_splitkv_args
,
stream_config
);
}
#endif
fmha_fwd_traits
fmha_traits
;
init_traits
(
fmha_traits
);
fmha_fwd_args
fmha_args
;
init_args
(
fmha_args
);
return
fmha_fwd
(
fmha_traits
,
fmha_args
,
stream_config
);
}();
if
(
ave_time
<
0
)
if
(
appendkv_
ave_time
<
0
.0
f
||
fwd_ave_time
<
0.0
f
)
{
std
::
cout
<<
", not supported yet"
<<
std
::
flush
<<
std
::
endl
;
return
false
;
}
const
float
ave_time
=
(
appendkv_ave_time
+
fwd_ave_time
);
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_byte
/
1.E6
/
ave_time
;
...
...
@@ -773,36 +1089,46 @@ bool run(const ck_tile::ArgParser& arg_parser)
o_buf
.
FromDevice
(
o_host
.
data
());
lse_buf
.
FromDevice
(
lse_host
.
data
());
randval_buf
.
FromDevice
(
randval_host
.
data
());
auto
p_compute_element_func
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck_tile
::
fp8_t
>
)
return
ck_tile
::
scales
{
scale_p
};
else
return
ck_tile
::
identity
{};
}();
auto
oacc_element_func
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck_tile
::
fp8_t
>
)
return
ck_tile
::
composes
(
ck_tile
::
saturates
<
ck_tile
::
fp8_t
>
{},
ck_tile
::
scales
{
scale_o
});
else
return
ck_tile
::
identity
{};
}();
float
p_undrop
=
1.0
-
p_drop
;
uint8_t
p_undrop_in_uint8_t
=
uint8_t
(
std
::
floor
(
p_undrop
*
std
::
numeric_limits
<
uint8_t
>::
max
()));
float
rp_undrop
=
1.0
/
p_undrop
;
bool
pass
=
true
;
for
(
ck_tile
::
index_t
wb
=
0
;
wb
<
batch
;
++
wb
)
{
const
ck_tile
::
index_t
real_seqlen_q
=
seqstart_q_host
[
wb
+
1
]
-
seqstart_q_host
[
wb
];
const
ck_tile
::
index_t
real_seqlen_k
=
seqstart_k_host
[
wb
+
1
]
-
seqstart_k_host
[
wb
];
// adjust matrix index according to the mode
const
ck_tile
::
index_t
b
=
(
mode
==
mode_enum
::
batch
?
wb
:
0
);
const
ck_tile
::
index_t
b_idx
=
(
mode
==
mode_enum
::
batch
?
wb
:
0
);
const
ck_tile
::
index_t
cache_b_idx
=
(
use_cache_batch_idx
?
cache_batch_idx_host
(
b_idx
)
:
b_idx
);
const
ck_tile
::
index_t
query_offset
=
(
mode
==
mode_enum
::
batch
?
0
:
seqstart_q_host
[
wb
]);
const
ck_tile
::
index_t
key_offset
=
(
mode
==
mode_enum
::
batch
?
0
:
(
seqlen_kpads
[
0
]
<
0
?
seqstart_k_host
[
wb
]
:
seqstart_k_with_padding_host
[
wb
]));
const
auto
v_host_ref_lengths
=
std
::
array
<
ck_tile
::
index_t
,
3
>
{
nhead
,
hdim_v
,
real_seqlen_k
};
const
auto
v_host_ref_strides
=
is_v_rowmajor
?
std
::
array
<
ck_tile
::
index_t
,
3
>
{
hdim_v
*
real_seqlen_k
,
1
,
hdim_v
}
:
std
::
array
<
ck_tile
::
index_t
,
3
>
{
hdim_v
*
real_seqlen_k
,
real_seqlen_k
,
1
};
ck_tile
::
HostTensor
<
QDataType
>
q_host_ref
({
nhead
,
real_seqlen_q
,
hdim_q
});
ck_tile
::
HostTensor
<
KDataType
>
k_host_ref
({
nhead
,
real_seqlen_k
,
hdim_q
});
ck_tile
::
HostTensor
<
VDataType
>
v_host_ref
(
v_host_ref_lengths
,
v_host_ref_strides
);
ck_tile
::
HostTensor
<
VDataType
>
v_host_ref
(
{
nhead
,
hdim_v
,
real_seqlen_k
}
);
ck_tile
::
HostTensor
<
ODataType
>
o_host_ref
({
nhead
,
real_seqlen_q
,
hdim_v
});
ck_tile
::
HostTensor
<
SMPLComputeDataType
>
s_host_ref
({
nhead
,
real_seqlen_q
,
real_seqlen_k
});
...
...
@@ -813,22 +1139,138 @@ bool run(const ck_tile::ArgParser& arg_parser)
// clang-format off
// permute
if
(
i_perm
)
q_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
q_host
(
b
,
i
[
0
],
i
[
1
]
+
query_offset
,
i
[
2
]);
});
else
q_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
q_host
(
b
,
i
[
1
]
+
query_offset
,
i
[
0
],
i
[
2
]);
});
if
(
i_perm
)
q_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
q_host
(
b_idx
,
i
[
0
],
i
[
1
]
+
query_offset
,
i
[
2
]);
});
else
q_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
q_host
(
b_idx
,
i
[
1
]
+
query_offset
,
i
[
0
],
i
[
2
]);
});
#if CK_TILE_FMHA_FWD_APPENDKV_API
// optionally apply RoPE to the q_host_ref
if
(
0
<
rotary_dim
)
{
decltype
(
q_host_ref
)
q_host_ref_ro
(
q_host_ref
.
get_lengths
());
auto
[
rotary_cos_slice
,
rotary_sin_slice
]
=
slice_rotary_cos_sin
(
rotary_cos_host
,
rotary_sin_host
,
cache_seqlen_ks
[
wb
],
real_seqlen_q
);
ck_tile
::
reference_batched_rotary_position_embedding
(
q_host_ref
,
rotary_cos_slice
,
rotary_sin_slice
,
is_rotary_interleaved
,
q_host_ref_ro
,
/*use_1_row_sin_cos=*/
mask
.
type
==
mask_enum
::
no_mask
);
q_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
q_host_ref_ro
(
i
);
});
}
#endif
#if CK_TILE_FMHA_FWD_SPLITKV_API
if
(
0
<
page_block_size
)
{
if
(
i_perm
)
{
k_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
k_host
(
block_table_host
(
wb
,
i
[
1
]
/
page_block_size
),
i
[
0
]
/
nr
,
i
[
1
]
%
page_block_size
,
i
[
2
]);
});
}
else
{
k_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
k_host
(
block_table_host
(
wb
,
i
[
1
]
/
page_block_size
),
i
[
1
]
%
page_block_size
,
i
[
0
]
/
nr
,
i
[
2
]);
});
}
}
else
#endif
{
if
(
i_perm
)
k_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
k_host
(
cache_b_idx
,
i
[
0
]
/
nr
,
i
[
1
]
+
key_offset
,
i
[
2
]);
});
else
k_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
k_host
(
cache_b_idx
,
i
[
1
]
+
key_offset
,
i
[
0
]
/
nr
,
i
[
2
]);
});
}
#if CK_TILE_FMHA_FWD_APPENDKV_API
// copy Knew to the end of K
if
(
0
<
seqlen_knew
)
{
ck_tile
::
HostTensor
<
KDataType
>
knew_host_ref
({
nhead
,
seqlen_knew
,
hdim_q
});
if
(
i_perm
)
knew_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
knew_host
(
wb
,
i
[
0
]
/
nr
,
i
[
1
],
i
[
2
]);
});
else
knew_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
knew_host
(
wb
,
i
[
1
],
i
[
0
]
/
nr
,
i
[
2
]);
});
// optionally apply RoPE to the knew_host_ref
auto
*
real_knew_host_ref
=
&
knew_host_ref
;
std
::
optional
<
decltype
(
knew_host_ref
)
>
knew_host_ref_ro
;
if
(
0
<
rotary_dim
)
{
knew_host_ref_ro
.
emplace
(
knew_host_ref
.
get_lengths
());
if
(
i_perm
)
k_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
k_host
(
b
,
i
[
0
]
/
nr
,
i
[
1
]
+
key_offset
,
i
[
2
]);
});
else
k_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
k_host
(
b
,
i
[
1
]
+
key_offset
,
i
[
0
]
/
nr
,
i
[
2
]);
}
);
auto
[
rotary_cos_slice
,
rotary_sin_slice
]
=
slice_rotary_cos_sin
(
rotary_cos_host
,
rotary_sin_host
,
cache_seqlen_ks
[
wb
],
seqlen_knew
);
if
(
is_v_rowmajor
)
{
ck_tile
::
reference_batched_rotary_position_embedding
(
knew_host_ref
,
rotary_cos_slice
,
rotary_sin_slice
,
is_rotary_interleaved
,
knew_host_ref_ro
.
value
());
real_knew_host_ref
=
&
knew_host_ref_ro
.
value
();
}
(
*
real_knew_host_ref
).
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
k_host_ref
(
i
[
0
],
i
[
1
]
+
cache_seqlen_ks
[
wb
],
i
[
2
])
=
self
(
i
);
});
}
#endif
#if CK_TILE_FMHA_FWD_SPLITKV_API
if
(
0
<
page_block_size
)
{
if
(
is_v_rowmajor
)
{
if
(
i_perm
)
{
v_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
v_host
(
block_table_host
(
wb
,
i
[
2
]
/
page_block_size
),
i
[
0
]
/
nr
,
i
[
2
]
%
page_block_size
,
i
[
1
]);
});
}
else
{
v_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
v_host
(
block_table_host
(
wb
,
i
[
2
]
/
page_block_size
),
i
[
2
]
%
page_block_size
,
i
[
0
]
/
nr
,
i
[
1
]);
});
}
}
else
{
if
(
i_perm
)
{
v_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
v_host
(
block_table_host
(
wb
,
i
[
2
]
/
page_block_size
),
i
[
0
]
/
nr
,
i
[
1
],
i
[
2
]
%
page_block_size
);
});
}
else
{
v_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
v_host
(
block_table_host
(
wb
,
i
[
2
]
/
page_block_size
),
i
[
1
],
i
[
0
]
/
nr
,
i
[
2
]
%
page_block_size
);
});
}
}
}
else
#endif
{
if
(
is_v_rowmajor
)
{
// v_host_ref: [nhead, hdim, seq], v_host: [b, h_k, s, d]
if
(
i_perm
)
v_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
v_host
(
b
,
i
[
0
]
/
nr
,
i
[
2
]
+
key_offset
,
i
[
1
]);
});
if
(
i_perm
)
v_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
v_host
(
cache_b_idx
,
i
[
0
]
/
nr
,
i
[
2
]
+
key_offset
,
i
[
1
]);
});
// v_host_ref: [nhead, hdim, seq], v_host: [b, s, h_k, d]
else
v_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
v_host
(
b
,
i
[
2
]
+
key_offset
,
i
[
0
]
/
nr
,
i
[
1
]);
});
else
v_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
v_host
(
cache_b_idx
,
i
[
2
]
+
key_offset
,
i
[
0
]
/
nr
,
i
[
1
]);
});
}
else
{
if
(
i_perm
)
v_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
v_host
(
cache_b_idx
,
i
[
0
]
/
nr
,
i
[
1
],
i
[
2
]
+
key_offset
);
});
else
v_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
v_host
(
cache_b_idx
,
i
[
1
],
i
[
0
]
/
nr
,
i
[
2
]
+
key_offset
);
});
}
else
{
if
(
i_perm
)
v_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
v_host
(
b
,
i
[
0
]
/
nr
,
i
[
1
],
i
[
2
]
+
key_offset
);
});
else
v_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
v_host
(
b
,
i
[
1
],
i
[
0
]
/
nr
,
i
[
2
]
+
key_offset
);
});
}
#if CK_TILE_FMHA_FWD_APPENDKV_API
// copy Vnew to the end of V
if
(
0
<
seqlen_knew
)
{
ck_tile
::
HostTensor
<
VDataType
>
vnew_host_ref
({
nhead
,
hdim_v
,
seqlen_knew
});
if
(
is_v_rowmajor
)
{
if
(
i_perm
)
vnew_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
vnew_host
(
wb
,
i
[
0
]
/
nr
,
i
[
2
],
i
[
1
]);
});
else
vnew_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
vnew_host
(
wb
,
i
[
2
],
i
[
0
]
/
nr
,
i
[
1
]);
});
}
else
{
if
(
i_perm
)
vnew_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
vnew_host
(
wb
,
i
[
0
]
/
nr
,
i
[
1
],
i
[
2
]);
});
else
vnew_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
vnew_host
(
wb
,
i
[
1
],
i
[
0
]
/
nr
,
i
[
2
]);
});
}
vnew_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
v_host_ref
(
i
[
0
],
i
[
1
],
i
[
2
]
+
cache_seqlen_ks
[
wb
])
=
self
(
i
);
});
}
#endif
// clang-format on
// reference
...
...
@@ -957,7 +1399,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
HostTensor
<
RandValOutputDataType
>
randval_host_ref
(
{
nhead
,
real_seqlen_q
,
real_seqlen_k
});
randval_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
randval_host
(
b
,
idx
[
0
],
idx
[
1
]
+
query_offset
,
idx
[
2
]);
self
(
idx
)
=
randval_host
(
b
_idx
,
idx
[
0
],
idx
[
1
]
+
query_offset
,
idx
[
2
]);
});
ck_tile
::
reference_batched_dropout
(
p_host_ref
,
randval_host_ref
,
p_undrop_in_uint8_t
,
rp_undrop
);
...
...
@@ -974,8 +1416,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
HostTensor
<
ODataType
>
o_host_result
({
nhead
,
real_seqlen_q
,
hdim_v
});
// clang-format off
// permute
if
(
o_perm
)
o_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
o_host
(
b
,
idx
[
0
],
idx
[
1
]
+
query_offset
,
idx
[
2
]);
});
else
o_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
o_host
(
b
,
idx
[
1
]
+
query_offset
,
idx
[
0
],
idx
[
2
]);
});
if
(
o_perm
)
o_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
o_host
(
b
_idx
,
idx
[
0
],
idx
[
1
]
+
query_offset
,
idx
[
2
]);
});
else
o_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
o_host
(
b
_idx
,
idx
[
1
]
+
query_offset
,
idx
[
0
],
idx
[
2
]);
});
// clang-format on
auto
[
rtol
,
atol
]
=
get_elimit
<
DataType
>
(
init_method
);
...
...
@@ -996,8 +1438,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
if
(
lse
)
{
ck_tile
::
HostTensor
<
SMPLComputeDataType
>
lse_host_result
({
nhead
,
real_seqlen_q
});
lse_host_result
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
lse_host
(
wb
,
idx
[
0
],
idx
[
1
]);
});
lse_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
lse_host
(
b_idx
,
idx
[
0
],
idx
[
1
]
+
query_offset
);
});
cur_pass
=
ck_tile
::
check_err
(
lse_host_result
,
lse_host_ref
,
...
...
example/ck_tile/01_fmha/fmha_fwd.hpp
View file @
72c9f129
...
...
@@ -5,10 +5,13 @@
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/fmha.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "mask.hpp"
#include "ck_tile/ops/fmha.hpp"
#include "bias.hpp"
#include "mask.hpp"
#include "rotary.hpp"
#include <type_traits>
template
<
typename
DataType
>
...
...
@@ -93,13 +96,86 @@ struct fmha_fwd_args
const
void
*
v_ptr
;
const
void
*
bias_ptr
;
// bias or alibi_slope pointer
void
*
rand_val_ptr
;
void
*
lse_ptr
;
void
*
o_ptr
;
const
void
*
seqstart_q_ptr
;
const
void
*
seqstart_k_ptr
;
const
void
*
seqlen_k_ptr
;
// only used if both 'seqstart_q_ptr' & 'seqstart_k_ptr' are not nullptr
ck_tile
::
index_t
seqlen_q
;
ck_tile
::
index_t
seqlen_k
;
ck_tile
::
index_t
batch
;
ck_tile
::
index_t
max_seqlen_q
;
ck_tile
::
index_t
hdim_q
;
ck_tile
::
index_t
hdim_v
;
ck_tile
::
index_t
nhead_q
;
ck_tile
::
index_t
nhead_k
;
float
scale_s
;
float
scale_p
;
float
scale_o
;
ck_tile
::
index_t
stride_q
;
ck_tile
::
index_t
stride_k
;
ck_tile
::
index_t
stride_v
;
ck_tile
::
index_t
stride_bias
;
// if alibi, b*h need set this to h, 1*h need set this to 0
ck_tile
::
index_t
stride_randval
;
ck_tile
::
index_t
stride_o
;
ck_tile
::
index_t
nhead_stride_q
;
ck_tile
::
index_t
nhead_stride_k
;
ck_tile
::
index_t
nhead_stride_v
;
ck_tile
::
index_t
nhead_stride_bias
;
ck_tile
::
index_t
nhead_stride_randval
;
ck_tile
::
index_t
nhead_stride_lse
;
ck_tile
::
index_t
nhead_stride_o
;
ck_tile
::
index_t
batch_stride_q
;
ck_tile
::
index_t
batch_stride_k
;
ck_tile
::
index_t
batch_stride_v
;
ck_tile
::
index_t
batch_stride_bias
;
ck_tile
::
index_t
batch_stride_randval
;
ck_tile
::
index_t
batch_stride_lse
;
ck_tile
::
index_t
batch_stride_o
;
ck_tile
::
index_t
window_size_left
;
ck_tile
::
index_t
window_size_right
;
ck_tile
::
index_t
mask_type
;
float
p_drop
;
bool
s_randval
;
std
::
tuple
<
uint64_t
,
uint64_t
>
drop_seed_offset
;
};
struct
fmha_fwd_splitkv_args
{
const
void
*
q_ptr
;
const
void
*
k_ptr
;
const
void
*
v_ptr
;
const
void
*
bias_ptr
;
// bias or alibi_slope pointer
void
*
lse_acc_ptr
;
void
*
o_acc_ptr
;
void
*
lse_ptr
;
void
*
o_ptr
;
void
*
block_table_ptr
;
ck_tile
::
index_t
batch_stride_block_table
;
// only used if 'block_table_ptr' is not nullptr
ck_tile
::
index_t
page_block_size
;
// only used if 'block_table_ptr' is not nullptr
const
void
*
cache_batch_idx
;
// the real seqlen_q & seqlen_k are decided by following:
// batch mode: seqlen_q = kargs.seqlen_q
// seqlen_k = kargs.seqlen_k
// group mode: seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b]
// seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b]
// kvcache mode (use same kernel as batch mode):
// seqlen_q = kargs.seqlen_q
// seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b]
const
void
*
seqstart_q_ptr
;
const
void
*
seqstart_k_ptr
;
const
void
*
seqlen_k_ptr
;
ck_tile
::
index_t
seqlen_q
;
ck_tile
::
index_t
seqlen_k
;
ck_tile
::
index_t
batch
;
...
...
@@ -109,21 +185,21 @@ struct fmha_fwd_args
ck_tile
::
index_t
nhead_q
;
ck_tile
::
index_t
nhead_k
;
ck_tile
::
index_t
num_splits
;
float
scale_s
;
float
scale_p
;
float
scale_o
;
ck_tile
::
index_t
stride_q
;
ck_tile
::
index_t
stride_k
;
ck_tile
::
index_t
stride_v
;
ck_tile
::
index_t
stride_bias
;
// if alibi, b*h need set this to h, 1*h need set this to 0
ck_tile
::
index_t
stride_randval
;
ck_tile
::
index_t
stride_o_acc
;
ck_tile
::
index_t
stride_o
;
ck_tile
::
index_t
nhead_stride_q
;
ck_tile
::
index_t
nhead_stride_k
;
ck_tile
::
index_t
nhead_stride_v
;
ck_tile
::
index_t
nhead_stride_bias
;
ck_tile
::
index_t
nhead_stride_randval
;
ck_tile
::
index_t
nhead_stride_lse
;
ck_tile
::
index_t
nhead_stride_lse_acc
;
ck_tile
::
index_t
nhead_stride_o_acc
;
...
...
@@ -132,19 +208,62 @@ struct fmha_fwd_args
ck_tile
::
index_t
batch_stride_k
;
ck_tile
::
index_t
batch_stride_v
;
ck_tile
::
index_t
batch_stride_bias
;
ck_tile
::
index_t
batch_stride_randval
;
ck_tile
::
index_t
batch_stride_lse
;
ck_tile
::
index_t
batch_stride_lse_acc
;
ck_tile
::
index_t
batch_stride_o_acc
;
ck_tile
::
index_t
batch_stride_o
;
ck_tile
::
index_t
split_stride_lse_acc
;
ck_tile
::
index_t
split_stride_o_acc
;
ck_tile
::
index_t
window_size_left
;
ck_tile
::
index_t
window_size_right
;
ck_tile
::
index_t
mask_type
;
float
p_drop
;
bool
s_randval
;
std
::
tuple
<
uint64_t
,
uint64_t
>
drop_seed_offset
;
};
struct
fmha_fwd_appendkv_args
{
void
*
q_ptr
;
void
*
k_ptr
;
const
void
*
knew_ptr
;
void
*
v_ptr
;
const
void
*
vnew_ptr
;
const
void
*
seqlen_k_ptr
;
ck_tile
::
index_t
seqlen_q
;
ck_tile
::
index_t
seqlen_knew
;
ck_tile
::
index_t
batch
;
ck_tile
::
index_t
hdim_q
;
ck_tile
::
index_t
hdim_v
;
ck_tile
::
index_t
nhead_q
;
ck_tile
::
index_t
nhead_k
;
const
void
*
rotary_cos_ptr
;
// only used if 'rotary_dim' > 0
const
void
*
rotary_sin_ptr
;
// only used if 'rotary_dim' > 0
ck_tile
::
index_t
rotary_dim
;
bool
has_mask
;
void
*
block_table_ptr
;
ck_tile
::
index_t
batch_stride_block_table
;
// only used if 'block_table_ptr' is not nullptr
ck_tile
::
index_t
page_block_size
;
// only used if 'block_table_ptr' is not nullptr
const
void
*
cache_batch_idx
;
ck_tile
::
index_t
stride_q
;
ck_tile
::
index_t
stride_k
;
ck_tile
::
index_t
stride_knew
;
ck_tile
::
index_t
stride_v
;
ck_tile
::
index_t
stride_vnew
;
ck_tile
::
index_t
nhead_stride_q
;
ck_tile
::
index_t
nhead_stride_k
;
ck_tile
::
index_t
nhead_stride_knew
;
ck_tile
::
index_t
nhead_stride_v
;
ck_tile
::
index_t
nhead_stride_vnew
;
ck_tile
::
index_t
batch_stride_q
;
ck_tile
::
index_t
batch_stride_k
;
ck_tile
::
index_t
batch_stride_knew
;
ck_tile
::
index_t
batch_stride_v
;
ck_tile
::
index_t
batch_stride_vnew
;
};
template
<
typename
FmhaKernel
>
...
...
@@ -185,7 +304,6 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args
.
nhead_stride_randval
,
args
.
nhead_stride_lse
,
args
.
nhead_stride_o
,
args
.
batch_stride_lse
,
args
.
window_size_left
,
args
.
window_size_right
,
args
.
mask_type
,
...
...
@@ -245,7 +363,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
}
template
<
typename
Kernel
>
auto
fmha_fwd_splitkv_create_kargs_and_grids
(
fmha_fwd_args
args
)
auto
fmha_fwd_splitkv_create_kargs_and_grids
(
fmha_fwd_
splitkv_
args
args
)
{
assert
(
args
.
nhead_q
%
args
.
nhead_k
==
0
);
auto
kargs
=
[
&
]
{
...
...
@@ -256,11 +374,9 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
args
.
k_ptr
,
args
.
v_ptr
,
args
.
bias_ptr
,
args
.
rand_val_ptr
,
args
.
lse_acc_ptr
,
args
.
o_acc_ptr
,
args
.
batch
,
args
.
max_seqlen_q
,
args
.
seqstart_q_ptr
,
args
.
seqstart_k_ptr
,
args
.
seqlen_k_ptr
,
...
...
@@ -275,25 +391,22 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
args
.
stride_k
,
args
.
stride_v
,
args
.
stride_bias
,
args
.
stride_randval
,
args
.
stride_o_acc
,
args
.
nhead_stride_q
,
args
.
nhead_stride_k
,
args
.
nhead_stride_v
,
args
.
nhead_stride_bias
,
args
.
nhead_stride_randval
,
args
.
nhead_stride_lse_acc
,
args
.
nhead_stride_o_acc
,
args
.
batch_stride_k
,
args
.
batch_stride_v
,
args
.
batch_stride_lse_acc
,
args
.
batch_stride_o_acc
,
args
.
split_stride_lse_acc
,
args
.
split_stride_o_acc
,
args
.
window_size_left
,
args
.
window_size_right
,
args
.
mask_type
,
args
.
p_drop
,
args
.
s_randval
,
args
.
drop_seed_offset
);
args
.
mask_type
);
}
else
{
// create batch mode kernel arguments
...
...
@@ -301,48 +414,45 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
args
.
k_ptr
,
args
.
v_ptr
,
args
.
bias_ptr
,
args
.
rand_val_ptr
,
args
.
lse_acc_ptr
,
args
.
o_acc_ptr
,
args
.
batch
,
args
.
max_seqlen_q
,
args
.
seqlen_q
,
args
.
seqlen_k
,
args
.
seqlen_k_ptr
,
args
.
hdim_q
,
args
.
hdim_v
,
args
.
nhead_q
,
args
.
nhead_q
/
args
.
nhead_k
,
args
.
num_splits
,
args
.
block_table_ptr
,
args
.
batch_stride_block_table
,
args
.
page_block_size
,
args
.
cache_batch_idx
,
args
.
scale_s
,
args
.
scale_p
,
args
.
stride_q
,
args
.
stride_k
,
args
.
stride_v
,
args
.
stride_bias
,
args
.
stride_randval
,
args
.
stride_o_acc
,
args
.
nhead_stride_q
,
args
.
nhead_stride_k
,
args
.
nhead_stride_v
,
args
.
nhead_stride_bias
,
args
.
nhead_stride_randval
,
args
.
nhead_stride_lse_acc
,
args
.
nhead_stride_o_acc
,
args
.
batch_stride_q
,
args
.
batch_stride_k
,
args
.
batch_stride_v
,
args
.
batch_stride_bias
,
args
.
batch_stride_randval
,
args
.
batch_stride_lse_acc
,
args
.
batch_stride_o_acc
,
args
.
split_stride_lse_acc
,
args
.
split_stride_o_acc
,
args
.
window_size_left
,
args
.
window_size_right
,
args
.
mask_type
,
args
.
p_drop
,
args
.
s_randval
,
args
.
drop_seed_offset
);
args
.
mask_type
);
}
}();
...
...
@@ -353,7 +463,7 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
}
template
<
typename
Kernel
>
auto
fmha_fwd_splitkv_combine_create_kargs_and_grids
(
fmha_fwd_args
args
)
auto
fmha_fwd_splitkv_combine_create_kargs_and_grids
(
fmha_fwd_
splitkv_
args
args
)
{
assert
(
args
.
nhead_q
%
args
.
nhead_k
==
0
);
auto
kargs
=
[
&
]
{
...
...
@@ -376,9 +486,7 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_args args)
args
.
nhead_stride_o_acc
,
args
.
nhead_stride_lse
,
args
.
nhead_stride_o
,
args
.
batch_stride_lse_acc
,
args
.
batch_stride_o_acc
,
args
.
batch_stride_lse
,
args
.
split_stride_lse_acc
,
args
.
split_stride_o_acc
);
}
...
...
@@ -414,6 +522,51 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_args args)
return
ck_tile
::
make_tuple
(
kargs
,
grids
);
}
template
<
typename
Kernel
>
auto
fmha_fwd_appendkv_create_kargs_and_grids
(
fmha_fwd_appendkv_args
args
)
{
assert
(
args
.
nhead_q
%
args
.
nhead_k
==
0
);
auto
kargs
=
Kernel
::
MakeKargs
(
args
.
q_ptr
,
args
.
k_ptr
,
args
.
knew_ptr
,
args
.
v_ptr
,
args
.
vnew_ptr
,
args
.
seqlen_q
,
args
.
seqlen_k_ptr
,
args
.
seqlen_knew
,
args
.
hdim_q
,
args
.
hdim_v
,
args
.
nhead_q
,
args
.
nhead_q
/
args
.
nhead_k
,
args
.
rotary_cos_ptr
,
args
.
rotary_sin_ptr
,
args
.
rotary_dim
,
args
.
has_mask
,
args
.
block_table_ptr
,
args
.
batch_stride_block_table
,
args
.
page_block_size
,
args
.
cache_batch_idx
,
args
.
stride_q
,
args
.
stride_k
,
args
.
stride_knew
,
args
.
stride_v
,
args
.
stride_vnew
,
args
.
nhead_stride_q
,
args
.
nhead_stride_k
,
args
.
nhead_stride_knew
,
args
.
nhead_stride_v
,
args
.
nhead_stride_vnew
,
args
.
batch_stride_q
,
args
.
batch_stride_k
,
args
.
batch_stride_knew
,
args
.
batch_stride_v
,
args
.
batch_stride_vnew
);
dim3
grids
=
Kernel
::
GridSize
(
args
.
batch
,
args
.
nhead_q
,
args
.
seqlen_q
,
args
.
seqlen_knew
);
return
ck_tile
::
make_tuple
(
kargs
,
grids
);
}
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template
<
ck_tile
::
index_t
HDim_
,
typename
DataType_
,
...
...
@@ -462,8 +615,52 @@ struct fmha_fwd_traits_
template
<
typename
Traits_
>
float
fmha_fwd_
(
const
ck_tile
::
stream_config
&
,
fmha_fwd_args
);
template
<
ck_tile
::
index_t
HDim_
,
typename
DataType_
,
bool
kIsGroupMode_
,
ck_tile
::
index_t
kM0_
,
ck_tile
::
index_t
kN0_
,
ck_tile
::
index_t
kK0_
,
ck_tile
::
index_t
kN1_
,
ck_tile
::
index_t
kK1_
,
ck_tile
::
index_t
kK0BlockLength_
,
bool
kIsVLayoutRowMajor_
,
ck_tile
::
BlockFmhaPipelineEnum
FmhaPipelineEnum_
,
typename
FmhaMask_
,
ck_tile
::
BlockAttentionBiasEnum
BiasEnum_
,
bool
kStoreLse_
,
bool
kDoFp8StaticQuant_
,
bool
kIsPagedKV_
,
bool
kPadS_
,
bool
kPadSK_
,
bool
kPadD_
,
bool
kPadDv_
>
struct
fmha_fwd_splitkv_traits_
{
static
constexpr
ck_tile
::
index_t
HDim
=
HDim_
;
using
DataType
=
ck_tile
::
remove_cvref_t
<
DataType_
>
;
static
constexpr
bool
kIsGroupMode
=
kIsGroupMode_
;
static
constexpr
ck_tile
::
index_t
kM0
=
kM0_
;
static
constexpr
ck_tile
::
index_t
kN0
=
kN0_
;
static
constexpr
ck_tile
::
index_t
kK0
=
kK0_
;
static
constexpr
ck_tile
::
index_t
kN1
=
kN1_
;
static
constexpr
ck_tile
::
index_t
kK1
=
kK1_
;
static
constexpr
ck_tile
::
index_t
kK0BlockLength
=
kK0BlockLength_
;
static
constexpr
bool
kIsVLayoutRowMajor
=
kIsVLayoutRowMajor_
;
static
constexpr
auto
FmhaPipelineEnum
=
FmhaPipelineEnum_
;
using
FmhaMask
=
ck_tile
::
remove_cvref_t
<
FmhaMask_
>
;
static
constexpr
auto
BiasEnum
=
BiasEnum_
;
static
constexpr
bool
kStoreLse
=
kStoreLse_
;
static
constexpr
bool
kDoFp8StaticQuant
=
kDoFp8StaticQuant_
;
static
constexpr
bool
kPadS
=
kPadS_
;
static
constexpr
bool
kPadSK
=
kPadSK_
;
static
constexpr
bool
kPadD
=
kPadD_
;
static
constexpr
bool
kPadDv
=
kPadDv_
;
static
constexpr
bool
kIsPagedKV
=
kIsPagedKV_
;
};
template
<
typename
Traits_
>
void
fmha_fwd_splitkv_oneshot_
(
const
ck_tile
::
stream_config
&
,
fmha_fwd_args
);
void
fmha_fwd_splitkv_oneshot_
(
const
ck_tile
::
stream_config
&
,
fmha_fwd_
splitkv_
args
);
template
<
typename
Traits_
>
std
::
string
fmha_fwd_splitkv_get_name_
();
...
...
@@ -491,11 +688,45 @@ struct fmha_fwd_splitkv_combine_traits_
};
template
<
typename
Traits_
>
void
fmha_fwd_splitkv_combine_oneshot_
(
const
ck_tile
::
stream_config
&
,
fmha_fwd_args
);
void
fmha_fwd_splitkv_combine_oneshot_
(
const
ck_tile
::
stream_config
&
,
fmha_fwd_
splitkv_
args
);
template
<
typename
Traits_
>
std
::
string
fmha_fwd_splitkv_combine_get_name_
();
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template
<
ck_tile
::
index_t
HDim_
,
typename
DataType_
,
ck_tile
::
index_t
kTileSizeS_
,
ck_tile
::
index_t
kTileSizeSk_
,
ck_tile
::
index_t
kTileSizeD_
,
ck_tile
::
index_t
kTileSizeDv_
,
bool
kIsVLayoutRowMajor_
,
bool
kPadS_
,
bool
kPadSk_
,
bool
kPadD_
,
bool
kPadDv_
,
ck_tile
::
RotaryEmbeddingEnum
RotaryEnum_
,
bool
kIsPagedKV_
>
struct
fmha_fwd_appendkv_traits_
{
static
constexpr
ck_tile
::
index_t
HDim
=
HDim_
;
using
DataType
=
ck_tile
::
remove_cvref_t
<
DataType_
>
;
static
constexpr
ck_tile
::
index_t
kTileSizeS
=
kTileSizeS_
;
static
constexpr
ck_tile
::
index_t
kTileSizeSk
=
kTileSizeSk_
;
static
constexpr
ck_tile
::
index_t
kTileSizeD
=
kTileSizeD_
;
static
constexpr
ck_tile
::
index_t
kTileSizeDv
=
kTileSizeDv_
;
static
constexpr
bool
kIsVLayoutRowMajor
=
kIsVLayoutRowMajor_
;
static
constexpr
bool
kPadS
=
kPadS_
;
static
constexpr
bool
kPadSk
=
kPadSk_
;
static
constexpr
bool
kPadD
=
kPadD_
;
static
constexpr
bool
kPadDv
=
kPadDv_
;
static
constexpr
auto
RotaryEnum
=
RotaryEnum_
;
static
constexpr
bool
kIsPagedKV
=
kIsPagedKV_
;
};
template
<
typename
Traits_
>
float
fmha_fwd_appendkv_
(
const
ck_tile
::
stream_config
&
,
fmha_fwd_appendkv_args
);
// This is the public API, will be generated by script
struct
fmha_fwd_traits
{
...
...
@@ -512,4 +743,32 @@ struct fmha_fwd_traits
// TODO: padding check is inside this api
};
float
fmha_fwd
(
fmha_fwd_traits
,
fmha_fwd_args
,
const
ck_tile
::
stream_config
&
);
float
fmha_fwd_splitkv
(
fmha_fwd_traits
,
fmha_fwd_args
,
const
ck_tile
::
stream_config
&
);
struct
fmha_fwd_splitkv_traits
{
int
hdim_q
;
int
hdim_v
;
std
::
string
data_type
;
bool
is_group_mode
;
bool
is_v_rowmajor
;
mask_enum
mask_type
;
bias_enum
bias_type
;
// 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
bool
has_lse
;
bool
do_fp8_static_quant
;
// TODO: padding check is inside this api
};
float
fmha_fwd_splitkv
(
fmha_fwd_splitkv_traits
,
fmha_fwd_splitkv_args
,
const
ck_tile
::
stream_config
&
);
struct
fmha_fwd_appendkv_traits
{
int
hdim_q
;
int
hdim_v
;
std
::
string
data_type
;
bool
is_v_rowmajor
;
rope_enum
rope_type
;
};
float
fmha_fwd_appendkv
(
fmha_fwd_appendkv_traits
,
fmha_fwd_appendkv_args
,
const
ck_tile
::
stream_config
&
);
example/ck_tile/01_fmha/generate.py
View file @
72c9f129
...
...
@@ -5,25 +5,30 @@
import
argparse
from
enum
import
IntEnum
from
pathlib
import
Path
import
pkgutil
import
sys
from
typing
import
List
,
Optional
import
codegen.ops
from
codegen.cmake_config
import
*
from
codegen.ops
import
(
fmha_fwd
,
fmha_fwd_splitkv
,
fmha_bwd
)
class
HandlerId
(
IntEnum
):
LIST_BLOBS
=
0
WRITE_BLOBS
=
1
handlers
=
{
'fwd'
:
(
fmha_fwd
.
list_blobs
,
fmha_fwd
.
write_blobs
),
'fwd_splitkv'
:
(
fmha_fwd_splitkv
.
list_blobs
,
fmha_fwd_splitkv
.
write_blobs
),
'bwd'
:
(
fmha_bwd
.
list_blobs
,
fmha_bwd
.
write_blobs
),
}
# inspect all modules under 'codegen.ops' and register API handlers
ops
=
[]
for
importer
,
module_name
,
_
in
pkgutil
.
iter_modules
(
codegen
.
ops
.
__path__
):
full_module_name
=
'%s.%s'
%
(
codegen
.
ops
.
__name__
,
module_name
)
if
full_module_name
not
in
sys
.
modules
:
ops
.
append
(
importer
.
find_spec
(
module_name
).
loader
.
load_module
(
module_name
))
unwanted_prefix
=
'fmha_'
handlers
=
dict
(
[(
op
.
__name__
[
len
(
unwanted_prefix
):]
if
op
.
__name__
.
startswith
(
unwanted_prefix
)
else
op
.
__name__
,
(
op
.
list_blobs
,
op
.
write_blobs
))
for
op
in
ops
]
)
assert
0
<
len
(
handlers
)
def
write_blobs
(
output_dir
:
Optional
[
str
],
api_list
:
List
[
str
],
kernel_filter
:
Optional
[
str
],
receipt
,
mask_impl
)
->
None
:
if
output_dir
is
None
:
...
...
example/ck_tile/01_fmha/rotary.hpp
0 → 100644
View file @
72c9f129
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include <cassert>
#include <cmath>
#include <functional>
#include <iterator>
#include <optional>
#include <random>
#include <tuple>
// keep sync with RotaryEmbeddingEnum
enum
class
rope_enum
{
none
=
0
,
interleaved
=
1
,
half_rotated
=
2
,
};
template
<
typename
DataType
>
std
::
tuple
<
ck_tile
::
HostTensor
<
DataType
>
,
ck_tile
::
HostTensor
<
DataType
>>
generate_rotary_cos_sin
(
ck_tile
::
index_t
seqlen
,
ck_tile
::
index_t
rotary_dim
,
std
::
optional
<
unsigned
>
seed
=
std
::
nullopt
)
{
// return dummy tensors if we won't apply RoPE at all
if
(
rotary_dim
<=
0
)
{
ck_tile
::
HostTensor
<
DataType
>
dummy
({
1
,
1
});
return
std
::
make_tuple
(
dummy
,
dummy
);
}
std
::
mt19937
random_engine
(
seed
.
has_value
()
?
*
seed
:
std
::
random_device
{}());
std
::
uniform_real_distribution
<
float
>
generator
(
0.0
f
,
1.0
f
);
const
ck_tile
::
index_t
num_rows
=
seqlen
*
2
;
const
ck_tile
::
index_t
num_cols
=
rotary_dim
/
2
;
using
std
::
begin
,
std
::
end
;
ck_tile
::
HostTensor
<
float
>
angle
({
num_rows
,
num_cols
});
std
::
generate
(
begin
(
angle
),
end
(
angle
),
[
&
]
{
return
generator
(
random_engine
)
*
2
*
M_PI
;
});
ck_tile
::
HostTensor
<
DataType
>
cos
({
num_rows
,
num_cols
});
std
::
transform
(
begin
(
angle
),
end
(
angle
),
begin
(
cos
),
[](
float
origin_value
)
{
return
ck_tile
::
type_convert
<
DataType
>
(
std
::
cos
(
origin_value
));
});
ck_tile
::
HostTensor
<
DataType
>
sin
({
num_rows
,
num_cols
});
std
::
transform
(
begin
(
angle
),
end
(
angle
),
begin
(
sin
),
[](
float
origin_value
)
{
return
ck_tile
::
type_convert
<
DataType
>
(
std
::
sin
(
origin_value
));
});
return
std
::
make_tuple
(
cos
,
sin
);
}
template
<
typename
DataType
>
std
::
tuple
<
ck_tile
::
HostTensor
<
DataType
>
,
ck_tile
::
HostTensor
<
DataType
>>
slice_rotary_cos_sin
(
const
ck_tile
::
HostTensor
<
DataType
>&
cos
,
const
ck_tile
::
HostTensor
<
DataType
>&
sin
,
ck_tile
::
index_t
seqlen_offset
,
ck_tile
::
index_t
seqlen
)
{
assert
(
cos
.
get_num_of_dimension
()
==
2
&&
sin
.
get_num_of_dimension
()
==
2
);
assert
(
cos
.
get_length
(
0
)
==
sin
.
get_length
(
0
)
&&
cos
.
get_length
(
1
)
==
sin
.
get_length
(
1
));
assert
(
static_cast
<
std
::
size_t
>
(
seqlen_offset
+
seqlen
)
<=
cos
.
get_length
(
0
));
const
ck_tile
::
index_t
num_rows
=
seqlen
;
const
ck_tile
::
index_t
num_cols
=
cos
.
get_length
(
1
);
ck_tile
::
HostTensor
<
DataType
>
cos_pt
({
num_rows
,
num_cols
});
cos_pt
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
cos
(
i
[
0
]
+
seqlen_offset
,
i
[
1
]);
});
ck_tile
::
HostTensor
<
DataType
>
sin_pt
({
num_rows
,
num_cols
});
sin_pt
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
sin
(
i
[
0
]
+
seqlen_offset
,
i
[
1
]);
});
return
std
::
make_tuple
(
cos_pt
,
sin_pt
);
}
example/ck_tile/01_fmha/script/benchmark_bwd.sh
View file @
72c9f129
#!/bin/sh
# TODO: run this script from CK root
BUILD
=
build
EXE
=
$BUILD
/bin/tile_example_fmha_bwd
# TODO: run this script from CK root or build directory
EXE
=
"
$(
find
.
-name
tile_example_fmha_bwd
-type
f |
head
-n
1
)
"
VALID
=
0
for
prec
in
"fp16"
"bf16"
;
do
...
...
example/ck_tile/01_fmha/script/benchmark_fwd.sh
View file @
72c9f129
#!/bin/sh
# TODO: run this script from CK root
BUILD
=
build
EXE
=
$BUILD
/bin/tile_example_fmha_fwd
# TODO: run this script from CK root or build directory
EXE
=
"
$(
find
.
-name
tile_example_fmha_fwd
-type
f |
head
-n
1
)
"
VALID
=
0
for
prec
in
"fp16"
"bf16"
;
do
...
...
example/ck_tile/01_fmha/script/smoke_test_bwd.sh
View file @
72c9f129
#!/bin/sh
# TODO: run this script from CK root
BUILD
=
build
EXE
=
$BUILD
/bin/tile_example_fmha_bwd
# TODO: run this script from CK root or build directory
EXE
=
"
$(
find
.
-name
tile_example_fmha_bwd
-type
f |
head
-n
1
)
"
KNAME
=
1
export
CK_WARMUP
=
0
...
...
@@ -11,18 +10,19 @@ COMMON_ARGS='-v=1'
set
-x
for
prec
in
"fp16"
"bf16"
;
do
for
perm
in
0 1
;
do
for
hdim
in
32 64 128
;
do
for
hdim
in
32 64 128
256
;
do
for
mode
in
0 1
;
do
for
bias
in
"n"
"e"
"a"
;
do
for
dbias
in
0 1
;
do
for
p_drop
in
0.0 0.2
;
do
for
bias
in
"n"
"a"
;
do
for
dbias
in
0
;
do
for
p_drop
in
0.0 0.2
;
do
for
deterministic
in
0
;
do
$EXE
-prec
=
$prec
-b
=
1
-h
=
4
-h_k
=
2
-d
=
$hdim
-s
=
259
-bias
=
$bias
-dbias
=
$dbias
-p_drop
=
$p_drop
-iperm
=
$perm
-operm
=
$perm
-v
=
1
-mode
=
$mode
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-b
=
2
-h
=
2
-d
=
$hdim
-s
=
516
-s_k
=
253
-bias
=
$bias
-dbias
=
$dbias
-p_drop
=
$p_drop
-iperm
=
$perm
-operm
=
$perm
-v
=
1
-mode
=
$mode
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-b
=
1
-h
=
4
-h_k
=
1
-d
=
$hdim
-s
=
500
-s_k
=
251
-bias
=
$bias
-dbias
=
$dbias
-p_drop
=
$p_drop
-iperm
=
$perm
-operm
=
$perm
-mask
=
1
-v
=
1
-mode
=
$mode
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-b
=
1
-h
=
2
-d
=
$hdim
-s
=
900
-s_k
=
258
-bias
=
$bias
-dbias
=
$dbias
-p_drop
=
$p_drop
-iperm
=
$perm
-operm
=
$perm
-mask
=
2
-v
=
1
-mode
=
$mode
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-b
=
2
-h
=
1
-d
=
$hdim
-s
=
987
-s_k
=
219
-bias
=
$bias
-dbias
=
$dbias
-p_drop
=
$p_drop
-iperm
=
$perm
-operm
=
$perm
-mask
=
t:128,30
-v
=
1
-mode
=
$mode
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-b
=
2
-h
=
3
-h_k
=
1
-d
=
$hdim
-s
=
244
-s_k
=
499
-bias
=
$bias
-dbias
=
$dbias
-p_drop
=
$p_drop
-iperm
=
$perm
-operm
=
$perm
-mask
=
b:4,35
-v
=
1
-mode
=
$mode
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-b
=
1
-h
=
4
-h_k
=
2
-d
=
$hdim
-s
=
259
-bias
=
$bias
-dbias
=
$dbias
-p_drop
=
$p_drop
-iperm
=
$perm
-operm
=
$perm
-deterministic
=
$deterministic
-v
=
1
-mode
=
$mode
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-b
=
2
-h
=
2
-d
=
$hdim
-s
=
516
-s_k
=
253
-bias
=
$bias
-dbias
=
$dbias
-p_drop
=
$p_drop
-iperm
=
$perm
-operm
=
$perm
-deterministic
=
$deterministic
-v
=
1
-mode
=
$mode
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-b
=
1
-h
=
4
-h_k
=
1
-d
=
$hdim
-s
=
500
-s_k
=
251
-bias
=
$bias
-dbias
=
$dbias
-p_drop
=
$p_drop
-iperm
=
$perm
-operm
=
$perm
-mask
=
1
-deterministic
=
$deterministic
-v
=
1
-mode
=
$mode
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-b
=
1
-h
=
2
-d
=
$hdim
-s
=
900
-s_k
=
258
-bias
=
$bias
-dbias
=
$dbias
-p_drop
=
$p_drop
-iperm
=
$perm
-operm
=
$perm
-mask
=
2
-v
=
1
-deterministic
=
$deterministic
-mode
=
$mode
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-b
=
2
-h
=
1
-d
=
$hdim
-s
=
987
-s_k
=
219
-bias
=
$bias
-dbias
=
$dbias
-p_drop
=
$p_drop
-iperm
=
$perm
-operm
=
$perm
-mask
=
t:128,30
-deterministic
=
$deterministic
-v
=
1
-mode
=
$mode
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-b
=
2
-h
=
3
-h_k
=
1
-d
=
$hdim
-s
=
244
-s_k
=
499
-bias
=
$bias
-dbias
=
$dbias
-p_drop
=
$p_drop
-iperm
=
$perm
-operm
=
$perm
-mask
=
b:4,35
-deterministic
=
$deterministic
-v
=
1
-mode
=
$mode
-kname
=
$KNAME
$COMMON_ARGS
done
done
...
...
@@ -31,4 +31,5 @@ done
done
done
done
done
set
+x
example/ck_tile/01_fmha/script/smoke_test_fwd.sh
View file @
72c9f129
#!/bin/sh
# TODO: run this script from CK root
BUILD
=
build
EXE
=
$BUILD
/bin/tile_example_fmha_fwd
#!/bin/bash
# TODO: run this script from CK root or build directory
EXE
=
"
$(
find
.
-name
tile_example_fmha_fwd
-type
f |
head
-n
1
)
"
KNAME
=
1
export
CK_WARMUP
=
0
...
...
@@ -10,44 +9,98 @@ export CK_REPEAT=1
COMMON_ARGS
=
'-v=1 -warmup=0 -repeat=1'
# mode=0
# export HIP_VISIBLE_DEVICES=4
set
-x
for
prec
in
"fp16"
"bf16"
;
do
for
mode
in
1 0
;
do
for
perm
in
0 1
;
do
for
vlayout
in
"r"
"c"
;
do
for
hdim
in
32 64 128 256
;
do
for
lse
in
0 1
;
do
for
bias
in
"n"
"e"
"a"
;
do
for
p_drop
in
0.0 0.2
;
do
# $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
2
-h
=
2
-h_k
=
1
-d
=
16,
-d_v
=
$hdim
-s
=
55
-s_k
=
256
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
1
-h
=
3
-d
=
$hdim
-s
=
100
-s_k
=
51
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
2
-h
=
1
-d
=
16
-d_v
=
$hdim
-s
=
99
-s_k
=
256
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
1
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
1
-h
=
2
-h_k
=
1
-d
=
$hdim
-s
=
1024
-s_k
=
256
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
2
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
2
-h
=
1
-d
=
$hdim
-d_v
=
24
-s
=
3
-s_k
=
99
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
2
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
3
-h
=
2
-h_k
=
1
-d
=
$hdim
-s
=
200
-s_k
=
520
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
t:128,30
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
2
-h
=
1
-d
=
$hdim
-s
=
99
-s_k
=
32
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
b:4,35
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
1
-h
=
2
-h_k
=
1
-d
=
$hdim
-s
=
33
-s_k
=
0
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
2
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
1
-h
=
2
-h_k
=
1
-d
=
$hdim
-s
=
1
-s_k
=
10
-s_kpad
=
32
-bias
=
$bias
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
2
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
done
done
done
done
done
done
done
TEST_SPLITKV
=
0
TEST_APPENDKV
=
0
# options:
# -s: run splitkv tests
# -a: run appendkv tests
while
getopts
":sa"
opt
;
do
case
"
${
opt
}
"
in
s
)
TEST_SPLITKV
=
1
;;
a
)
TEST_APPENDKV
=
1
;;
*
)
;;
esac
done
run_fp16_bf16_tests
()
{
local
NUM_SPLITS
=(
1
)
local
PAGE_BLOCK_SIZE
=(
0
)
local
CACHE_BATCH_IDX
=(
0
)
if
[
$TEST_SPLITKV
-eq
1
]
;
then
NUM_SPLITS+
=(
2 3
)
PAGE_BLOCK_SIZE+
=(
128
)
CACHE_BATCH_IDX+
=(
1
)
fi
for
prec
in
"fp16"
"bf16"
;
do
for
mode
in
1 0
;
do
for
perm
in
0 1
;
do
for
vlayout
in
"r"
"c"
;
do
for
hdim
in
32 64 128 256
;
do
for
lse
in
0 1
;
do
for
bias
in
"n"
"e"
"a"
;
do
for
p_drop
in
0.0 0.2
;
do
for
num_splits
in
"
${
NUM_SPLITS
[@]
}
"
;
do
for
page_block_size
in
"
${
PAGE_BLOCK_SIZE
[@]
}
"
;
do
for
cache_batch_idx
in
"
${
CACHE_BATCH_IDX
[@]
}
"
;
do
# $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
2
-h
=
2
-h_k
=
1
-d
=
16,
-d_v
=
$hdim
-s
=
55
-s_k
=
256
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-vlayout
=
$vlayout
-num_splits
=
$num_splits
-page_block_size
=
$page_block_size
-cache_batch_idx
=
$cache_batch_idx
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
1
-h
=
3
-d
=
$hdim
-s
=
100
-s_k
=
51
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-vlayout
=
$vlayout
-num_splits
=
$num_splits
-page_block_size
=
$page_block_size
-cache_batch_idx
=
$cache_batch_idx
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
2
-h
=
1
-d
=
16
-d_v
=
$hdim
-s
=
99
-s_k
=
256
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
1
-vlayout
=
$vlayout
-num_splits
=
$num_splits
-page_block_size
=
$page_block_size
-cache_batch_idx
=
$cache_batch_idx
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
1
-h
=
2
-h_k
=
1
-d
=
$hdim
-s
=
1024
-s_k
=
256
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
2
-vlayout
=
$vlayout
-num_splits
=
$num_splits
-page_block_size
=
$page_block_size
-cache_batch_idx
=
$cache_batch_idx
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
2
-h
=
1
-d
=
$hdim
-d_v
=
24
-s
=
3
-s_k
=
99
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
2
-vlayout
=
$vlayout
-num_splits
=
$num_splits
-page_block_size
=
$page_block_size
-cache_batch_idx
=
$cache_batch_idx
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
3
-h
=
2
-h_k
=
1
-d
=
$hdim
-s
=
200
-s_k
=
520
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
t:128,30
-vlayout
=
$vlayout
-num_splits
=
$num_splits
-page_block_size
=
$page_block_size
-cache_batch_idx
=
$cache_batch_idx
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
2
-h
=
1
-d
=
$hdim
-s
=
99
-s_k
=
32
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
b:4,35
-vlayout
=
$vlayout
-num_splits
=
$num_splits
-page_block_size
=
$page_block_size
-cache_batch_idx
=
$cache_batch_idx
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
1
-h
=
2
-h_k
=
1
-d
=
$hdim
-s
=
33
-s_k
=
0
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
2
-vlayout
=
$vlayout
-num_splits
=
$num_splits
-page_block_size
=
$page_block_size
-cache_batch_idx
=
$cache_batch_idx
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
1
-h
=
2
-h_k
=
1
-d
=
$hdim
-s
=
1
-s_k
=
10
-s_kpad
=
32
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
2
-vlayout
=
$vlayout
-num_splits
=
$num_splits
-page_block_size
=
$page_block_size
-cache_batch_idx
=
$cache_batch_idx
-kname
=
$KNAME
$COMMON_ARGS
done
;
done
;
done
;
done
;
done
done
;
done
;
done
;
done
;
done
done
;
}
run_fp8_tests
()
{
for
perm
in
0 1
;
do
for
bias
in
"n"
"e"
"a"
;
do
for
b
in
1 2
;
do
for
hdim
in
64 128 256
;
do
$EXE
-prec
=
fp8
-init
=
3
-b
=
$b
-h
=
1
-d
=
128
-s
=
128
-bias
=
$bias
-iperm
=
$perm
-operm
=
$perm
-vlayout
=
c
-squant
=
1
-kname
=
$KNAME
$COMMON_ARGS
done
;
done
;
done
;
done
}
run_fp16_appendkv_tests
()
{
for
s
in
$(
seq
63 1 65
)
;
do
for
s_k
in
65 129
;
do
for
s_knew
in
0 64
$s_k
;
do
for
hdim
in
32 64 128 256
;
do
for
ri
in
0 1
;
do
for
rdim
in
0 16 32
$hdim
;
do
for
page_block_size
in
0 128
;
do
for
cache_batch_idx
in
0 1
;
do
$EXE
-prec
=
fp16
-b
=
3
-h
=
3
-d
=
$hdim
-s
=
$s
-s_k
=
$s_k
-s_knew
=
$s_knew
-rotary_dim
=
$rdim
-rotary_interleaved
=
$ri
-page_block_size
=
$page_block_size
-cache_batch_idx
=
$cache_batch_idx
-iperm
=
1
-operm
=
1
-kname
=
1
$COMMON_ARGS
done
;
done
;
done
;
done
;
done
done
;
done
;
done
}
set
-x
run_fp16_bf16_tests
run_fp8_tests
if
[
$TEST_APPENDKV
-eq
1
]
;
then
run_fp16_appendkv_tests
fi
for
perm
in
0 1
;
do
for
bias
in
"n"
"e"
"a"
;
do
for
b
in
1 2
;
do
for
hdim
in
64 128 256
;
do
$EXE
-prec
=
fp8
-init
=
3
-b
=
$b
-h
=
1
-d
=
128
-s
=
128
-bias
=
$bias
-iperm
=
$perm
-operm
=
$perm
-vlayout
=
c
-squant
=
1
-kname
=
$KNAME
$COMMON_ARGS
done
done
done
done
set
+x
\ No newline at end of file
example/ck_tile/01_fmha/utils.hpp
View file @
72c9f129
...
...
@@ -3,15 +3,17 @@
#pragma once
#include <algorithm>
#include <cstdint>
#include <cstdlib>
#include <functional>
#include <optional>
#include <ostream>
#include <sstream>
#include <string>
#include <tuple>
#include <utility>
#include <vector>
#include <functional>
#include <string>
#include "ck_tile/core/container/span.hpp"
...
...
@@ -40,13 +42,17 @@ std::vector<int32_t> to_seqstarts(ck_tile::span<const int32_t> seqlens)
std
::
vector
<
int32_t
>
generate_seqlens
(
mode_enum
mode
,
unsigned
count
,
int32_t
seqlen_avg
,
int32_t
seqlen_min
=
-
1
,
// if not negative, clamp min
int32_t
seqlen_max
=
-
1
,
// if not negative, clamp max
std
::
optional
<
unsigned
>
seed
=
std
::
nullopt
)
{
assert
(
0
<
count
);
std
::
vector
<
int32_t
>
seqlens
(
count
,
seqlen_max
>
0
?
(
seqlen_avg
<
seqlen_max
?
seqlen_avg
:
seqlen_max
)
:
seqlen_avg
);
seqlen_min
=
(
0
<
seqlen_min
?
seqlen_min
:
1
);
seqlen_max
=
(
0
<
seqlen_max
?
seqlen_max
:
std
::
numeric_limits
<
int32_t
>::
max
());
assert
(
seqlen_min
<=
seqlen_max
);
std
::
vector
<
int32_t
>
seqlens
(
count
,
std
::
clamp
(
seqlen_avg
,
seqlen_min
,
seqlen_max
));
if
(
mode
==
mode_enum
::
group
&&
1
<
count
)
{
...
...
@@ -62,15 +68,15 @@ std::vector<int32_t> generate_seqlens(mode_enum mode,
for
(
unsigned
repeat
=
seqlen_avg
*
(
count
/
2
);
0
<
repeat
;
--
repeat
)
{
const
size_type
to_decrease
=
next_idx
();
// make sure each elements of seqlens is
always greater than 0
if
(
seqlens
[
to_decrease
]
==
1
)
// make sure each elements of seqlens is
in range [seqlen_min, seqlen_max]
if
(
seqlens
[
to_decrease
]
==
seqlen_min
)
{
continue
;
}
const
size_type
to_increase
=
(
to_decrease
+
next_step
())
%
count
;
if
(
seqlen_max
>
0
&&
seqlens
[
to_increase
]
>=
seqlen_max
)
if
(
seqlens
[
to_increase
]
>=
seqlen_max
)
{
continue
;
}
...
...
@@ -86,10 +92,36 @@ std::vector<int32_t> generate_seqlens(mode_enum mode,
std
::
vector
<
int32_t
>
generate_seqstarts
(
mode_enum
mode
,
unsigned
count
,
int32_t
seqlen_avg
,
int32_t
seqlen_min
=
-
1
,
int32_t
seqlen_max
=
-
1
,
std
::
optional
<
unsigned
>
seed
=
std
::
nullopt
)
{
return
to_seqstarts
(
generate_seqlens
(
mode
,
count
,
seqlen_avg
,
seqlen_max
,
seed
));
return
to_seqstarts
(
generate_seqlens
(
mode
,
count
,
seqlen_avg
,
seqlen_min
,
seqlen_max
,
seed
));
}
// return random integer generated uniformly in range [low, high]
template
<
typename
Int
=
int
>
auto
randint
(
Int
low
,
Int
high
,
std
::
optional
<
unsigned
>
seed
=
std
::
nullopt
)
->
std
::
enable_if_t
<
std
::
is_integral_v
<
Int
>
,
Int
>
{
std
::
mt19937
engine
(
seed
.
has_value
()
?
*
seed
:
std
::
random_device
{}());
std
::
uniform_int_distribution
<
Int
>
dist
(
low
,
high
);
return
dist
(
engine
);
}
// return random integers generated uniformly in range [low, high]
template
<
typename
Int
,
typename
ForwardIterator
>
auto
randints
(
ForwardIterator
first
,
ForwardIterator
last
,
Int
low
,
Int
high
,
std
::
optional
<
unsigned
>
seed
=
std
::
nullopt
)
->
std
::
enable_if_t
<
std
::
is_integral_v
<
Int
>>
{
std
::
mt19937
engine
(
seed
.
has_value
()
?
*
seed
:
std
::
random_device
{}());
std
::
uniform_int_distribution
<
Int
>
dist
(
low
,
high
);
std
::
generate
(
first
,
last
,
[
&
]
{
return
dist
(
engine
);
});
}
/*
...
...
@@ -112,6 +144,8 @@ decode_seqlen(mode_enum mode,
std
::
string
q_val
,
std
::
string
k_val
,
std
::
string
k_pad_val
,
ck_tile
::
index_t
seqlen_k_min
=
0
,
bool
use_kvcache
=
false
,
std
::
optional
<
unsigned
>
seed
=
std
::
nullopt
)
{
#define _S2I_(str_) static_cast<ck_tile::index_t>(std::atoi((str_).c_str()))
...
...
@@ -119,9 +153,36 @@ decode_seqlen(mode_enum mode,
{
ck_tile
::
index_t
q
=
_S2I_
(
q_val
);
ck_tile
::
index_t
k
=
_S2I_
(
k_val
);
auto
s_q
=
std
::
vector
<
ck_tile
::
index_t
>
(
batch
,
q
);
auto
s_k
=
std
::
vector
<
ck_tile
::
index_t
>
(
batch
,
k
<
0
?
q
:
k
);
auto
s_k
=
[
&
]
{
const
ck_tile
::
index_t
seqlen_k_max
=
(
k
<
0
?
q
:
k
);
std
::
vector
<
ck_tile
::
index_t
>
seqlen_ks
(
batch
,
seqlen_k_max
);
if
(
1
<
batch
&&
use_kvcache
)
{
// to keep the original s_k value, we always use seqlen_k_max in first batch
randints
(
std
::
next
(
seqlen_ks
.
begin
()),
seqlen_ks
.
end
(),
seqlen_k_min
,
seqlen_k_max
,
seed
);
return
seqlen_ks
;
}
return
seqlen_ks
;
}();
auto
s_kpad
=
std
::
vector
<
ck_tile
::
index_t
>
(
batch
,
-
1
);
// TODO: batch not support k_padding
// s_k should be greater than or equal to seqlen_k_min if provided
if
(
s_k
.
back
()
<
seqlen_k_min
)
{
std
::
ostringstream
msg
;
msg
<<
__FILE__
<<
":"
<<
__LINE__
<<
": seqlen_k (="
<<
s_k
.
back
()
<<
") is less than minimum seqlen_k (="
<<
seqlen_k_min
<<
")"
;
throw
std
::
runtime_error
(
msg
.
str
());
}
return
std
::
make_tuple
(
s_q
,
s_k
,
s_kpad
);
}
else
...
...
@@ -149,6 +210,16 @@ decode_seqlen(mode_enum mode,
s_q
.
push_back
(
q
);
s_k
.
push_back
(
k
<
0
?
q
:
k
);
s_kpad
.
push_back
(
kp
);
// s_k should be greater than or equal to seqlen_k_min
if
(
s_k
.
back
()
<
seqlen_k_min
)
{
std
::
ostringstream
msg
;
msg
<<
__FILE__
<<
":"
<<
__LINE__
<<
": seqlen_k (="
<<
s_k
.
back
()
<<
") is less than minimum seqlen_k (="
<<
seqlen_k_min
<<
")"
;
throw
std
::
runtime_error
(
msg
.
str
());
}
idx
++
;
if
(
found_q
==
std
::
string
::
npos
||
idx
>=
batch
)
{
...
...
@@ -160,8 +231,9 @@ decode_seqlen(mode_enum mode,
}
if
(
idx
<
batch
)
{
auto
rem_q
=
generate_seqlens
(
mode
,
batch
-
idx
,
s_q
.
back
(),
s_kpad
.
back
(),
seed
);
auto
rem_k
=
generate_seqlens
(
mode
,
batch
-
idx
,
s_k
.
back
(),
s_kpad
.
back
(),
seed
);
auto
rem_q
=
generate_seqlens
(
mode
,
batch
-
idx
,
s_q
.
back
(),
1
,
s_kpad
.
back
(),
seed
);
auto
rem_k
=
generate_seqlens
(
mode
,
batch
-
idx
,
s_k
.
back
(),
seqlen_k_min
,
s_kpad
.
back
(),
seed
);
s_q
.
insert
(
s_q
.
end
(),
rem_q
.
begin
(),
rem_q
.
end
());
s_k
.
insert
(
s_k
.
end
(),
rem_k
.
begin
(),
rem_k
.
end
());
...
...
@@ -180,3 +252,15 @@ int env_get_int(const char* var_name, int default_int)
r
=
std
::
atoi
(
v
);
return
r
;
}
template
<
typename
RandomAccessIterator
,
typename
Int
>
std
::
enable_if_t
<
std
::
is_integral_v
<
Int
>>
iota_shuffle
(
RandomAccessIterator
first
,
RandomAccessIterator
last
,
Int
value
,
std
::
optional
<
unsigned
>
seed
=
std
::
nullopt
)
{
std
::
iota
(
first
,
last
,
value
);
std
::
mt19937
engine
(
seed
.
has_value
()
?
*
seed
:
std
::
random_device
{}());
std
::
shuffle
(
first
,
last
,
engine
);
}
include/ck/ck.hpp
View file @
72c9f129
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -153,8 +153,8 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
// LDS direct loads using inline assembly
#define CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM 0
// set
stochastic rounding
as default for f8 conversions
#define CK_USE_SR_F8_CONVERSION
1
// set
rounding to nearest even
as default for f8 conversions
#define CK_USE_SR_F8_CONVERSION
0
// block synchronization only s_wait lgkmcnt(0), not vmcnt(0)
#define CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1
...
...
include/ck/filesystem.hpp
0 → 100644
View file @
72c9f129
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#ifndef GUARD_CK_FILESYSTEM_HPP_
#define GUARD_CK_FILESYSTEM_HPP_
#include <string>
#include <string_view>
// clang-format off
#if defined(CPPCHECK)
#define CK_HAS_FILESYSTEM 1
#define CK_HAS_FILESYSTEM_TS 1
#elif defined(_WIN32)
#if _MSC_VER >= 1920
#define CK_HAS_FILESYSTEM 1
#define CK_HAS_FILESYSTEM_TS 0
#elif _MSC_VER >= 1900
#define CK_HAS_FILESYSTEM 0
#define CK_HAS_FILESYSTEM_TS 1
#else
#define CK_HAS_FILESYSTEM 0
#define CK_HAS_FILESYSTEM_TS 0
#endif
#elif defined(__has_include)
#if __has_include(<filesystem>) && __cplusplus >= 201703L
#define CK_HAS_FILESYSTEM 1
#else
#define CK_HAS_FILESYSTEM 0
#endif
#if __has_include(<experimental/filesystem>) && __cplusplus >= 201103L
#define CK_HAS_FILESYSTEM_TS 1
#else
#define CK_HAS_FILESYSTEM_TS 0
#endif
#else
#define CK_HAS_FILESYSTEM 0
#define CK_HAS_FILESYSTEM_TS 0
#endif
// clang-format on
#if CK_HAS_FILESYSTEM
#include <filesystem>
#elif CK_HAS_FILESYSTEM_TS
#include <experimental/filesystem>
#else
#error "No filesystem include available"
#endif
namespace
CK
{
#if CK_HAS_FILESYSTEM
namespace
fs
=
::
std
::
filesystem
;
#elif CK_HAS_FILESYSTEM_TS
namespace
fs
=
::
std
::
experimental
::
filesystem
;
#endif
}
// namespace CK
inline
std
::
string
operator
+
(
const
std
::
string_view
s
,
const
CK
::
fs
::
path
&
path
)
{
return
path
.
string
().
insert
(
0
,
s
);
}
inline
std
::
string
operator
+
(
const
CK
::
fs
::
path
&
path
,
const
std
::
string_view
s
)
{
return
path
.
string
().
append
(
s
);
}
#define FS_ENUM_PERMS_ALL fs::perms::all
#if CK_HAS_FILESYSTEM_TS
#ifdef __linux__
#include <linux/limits.h>
namespace
CK
{
inline
fs
::
path
weakly_canonical
(
const
fs
::
path
&
path
)
{
std
::
string
result
(
PATH_MAX
,
'\0'
);
std
::
string
p
{
path
.
is_relative
()
?
(
fs
::
current_path
()
/
path
).
string
()
:
path
.
string
()};
char
*
retval
=
realpath
(
p
.
c_str
(),
&
result
[
0
]);
return
(
retval
==
nullptr
)
?
path
:
fs
::
path
{
result
};
}
}
// namespace CK
#else
#error "Not implmeneted!"
#endif
#else
namespace
CK
{
inline
fs
::
path
weakly_canonical
(
const
fs
::
path
&
path
)
{
return
fs
::
weakly_canonical
(
path
);
}
}
// namespace CK
#endif
namespace
CK
{
#ifdef _WIN32
constexpr
std
::
string_view
executable_postfix
{
".exe"
};
constexpr
std
::
string_view
library_prefix
{
""
};
constexpr
std
::
string_view
dynamic_library_postfix
{
".dll"
};
constexpr
std
::
string_view
static_library_postfix
{
".lib"
};
constexpr
std
::
string_view
object_file_postfix
{
".obj"
};
#else
constexpr
std
::
string_view
executable_postfix
{
""
};
constexpr
std
::
string_view
library_prefix
{
"lib"
};
constexpr
std
::
string_view
dynamic_library_postfix
{
".so"
};
constexpr
std
::
string_view
static_library_postfix
{
".a"
};
constexpr
std
::
string_view
object_file_postfix
{
".o"
};
#endif
inline
fs
::
path
make_executable_name
(
const
fs
::
path
&
path
)
{
return
path
.
parent_path
()
/
(
path
.
filename
()
+
executable_postfix
);
}
inline
fs
::
path
make_dynamic_library_name
(
const
fs
::
path
&
path
)
{
return
path
.
parent_path
()
/
(
library_prefix
+
path
.
filename
()
+
dynamic_library_postfix
);
}
inline
fs
::
path
make_object_file_name
(
const
fs
::
path
&
path
)
{
return
path
.
parent_path
()
/
(
path
.
filename
()
+
object_file_postfix
);
}
inline
fs
::
path
make_static_library_name
(
const
fs
::
path
&
path
)
{
return
path
.
parent_path
()
/
(
library_prefix
+
path
.
filename
()
+
static_library_postfix
);
}
struct
FsPathHash
{
std
::
size_t
operator
()(
const
fs
::
path
&
path
)
const
{
return
fs
::
hash_value
(
path
);
}
};
}
// namespace CK
#endif // GUARD_CK_FILESYSTEM_HPP_
include/ck/host_utility/device_prop.hpp
View file @
72c9f129
...
...
@@ -65,6 +65,12 @@ inline bool is_lds_direct_load_supported()
ck
::
get_device_name
()
==
"gfx941"
||
ck
::
get_device_name
()
==
"gfx942"
;
}
inline
bool
is_bf16_atomic_supported
()
{
return
ck
::
get_device_name
()
==
"gfx940"
||
ck
::
get_device_name
()
==
"gfx941"
||
ck
::
get_device_name
()
==
"gfx942"
;
}
inline
bool
is_gfx101_supported
()
{
return
ck
::
get_device_name
()
==
"gfx1010"
||
ck
::
get_device_name
()
==
"gfx1011"
||
...
...
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp
View file @
72c9f129
...
...
@@ -53,6 +53,49 @@ struct DeviceGemmMultipleD : public BaseOperator
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
// GEMM:
// input : A[M, K], B[K, N],
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// D0, D1, ... and E have the same layout
template
<
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
typename
ADataType
,
typename
BDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
>
struct
DeviceGemmMultipleDSplitK
:
public
BaseOperator
{
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
void
*
p_e
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
std
::
array
<
ck
::
index_t
,
NumDTensor
>
StrideDs
,
ck
::
index_t
StrideE
,
ck
::
index_t
KBatch
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp
View file @
72c9f129
...
...
@@ -69,7 +69,7 @@ template <typename ALayout,
typename
ComputeTypeB
=
ComputeTypeA
,
typename
LDSTypeA
=
ComputeTypeA
,
typename
LDSTypeB
=
ComputeTypeB
>
struct
DeviceGemmMultiD_Xdl_CShuffle_V3
:
public
DeviceGemmMultipleD
<
ALayout
,
struct
DeviceGemmMultiD_Xdl_CShuffle_V3
:
public
DeviceGemmMultipleD
SplitK
<
ALayout
,
BLayout
,
DsLayout
,
CLayout
,
...
...
@@ -192,15 +192,11 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
// rotating mem
rotating_mem
.
Next
();
// clear c mem
if
constexpr
(
!
is_same
<
remove_cvref_t
<
CDataType
>
,
bhalf_t
>::
value
)
{
if
(
arg_
.
KBatch
>
1
)
hipGetErrorString
(
hipMemsetAsync
(
arg_
.
p_c_grid
,
hipGetErrorString
(
hipMemsetAsync
(
arg_
.
p_c_grid
,
0
,
arg_
.
M
*
arg_
.
N
*
sizeof
(
CDataType
),
stream_config
.
stream_id_
));
}
};
ave_time
=
ck
::
utility
::
launch_and_time_kernel_with_preprocess
<
false
>
(
...
...
@@ -234,9 +230,19 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v1
||
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v3
)
{
if
(
arg
.
KBatch
>
1
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
kernel_gemm_xdl_cshuffle_v3
_multi_d
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
>
;
...
...
@@ -246,11 +252,124 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
// Tail number could be One to Seven
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v2
)
{
if
(
arg
.
KBatch
>
1
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
One
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
One
>
;
Run
(
kernel
);
}
else
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Full
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Full
>
;
Run
(
kernel
);
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
2
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Two
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Two
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
3
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Three
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Three
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
4
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Four
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Four
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
5
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Five
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Five
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
6
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Six
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Six
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
7
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Seven
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Seven
>
;
Run
(
kernel
);
}
}
}
else
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
One
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
kernel_gemm_xdl_cshuffle_v3
_multi_d
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
...
...
@@ -261,7 +380,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
TailNumber
::
Full
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
kernel_gemm_xdl_cshuffle_v3
_multi_d
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
...
...
@@ -273,8 +392,8 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Two
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
...
...
@@ -288,8 +407,8 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Three
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
...
...
@@ -303,8 +422,8 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Four
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
...
...
@@ -318,8 +437,8 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Five
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
...
...
@@ -332,8 +451,8 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Six
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
...
...
@@ -347,8 +466,8 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Seven
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
...
...
@@ -361,12 +480,35 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
// Tail number could be Odd or Even
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v4
)
{
if
(
arg
.
KBatch
>
1
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_2lds
<
GridwiseGemm
,
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d_2lds
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Odd
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d_2lds
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Even
>
;
Run
(
kernel
);
}
}
else
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d_2lds
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
...
...
@@ -375,8 +517,8 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
}
else
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_2lds
<
GridwiseGemm
,
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d_2lds
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
...
...
@@ -387,11 +529,35 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
}
else
{
if
(
arg
.
KBatch
>
1
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Odd
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Even
>
;
Run
(
kernel
);
}
}
else
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
kernel_gemm_xdl_cshuffle_v3
_multi_d
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
...
...
@@ -401,7 +567,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
else
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
kernel_gemm_xdl_cshuffle_v3
_multi_d
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
...
...
@@ -416,9 +582,19 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
// Tail number always 1
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v1
)
{
if
(
arg
.
KBatch
>
1
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d
<
GridwiseGemm
,
false
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
kernel_gemm_xdl_cshuffle_v3
_multi_d
<
GridwiseGemm
,
false
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
>
;
...
...
@@ -451,6 +627,11 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
return
false
;
}
if
(
!
is_bf16_atomic_supported
()
&&
std
::
is_same_v
<
CDataType
,
ck
::
bhalf_t
>
&&
arg
.
KBatch
>
1
)
{
return
false
;
}
if
((
arg
.
K
%
AK1
!=
0
||
arg
.
K
%
BK1
!=
0
)
&&
!
(
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
||
...
...
@@ -479,6 +660,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
index_t
StrideB
,
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
,
index_t
StrideC
,
index_t
KBatch
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
...
...
@@ -494,7 +676,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
StrideB
,
StrideDs
,
StrideC
,
1
,
KBatch
,
a_element_op
,
b_element_op
,
c_element_op
};
...
...
@@ -514,6 +696,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
index_t
StrideB
,
std
::
array
<
ck
::
index_t
,
NumDTensor
>
StrideDs
,
index_t
StrideC
,
index_t
KBatch
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
override
...
...
@@ -529,7 +712,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
StrideB
,
StrideDs
,
StrideC
,
1
,
KBatch
,
a_element_op
,
b_element_op
,
c_element_op
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp
View file @
72c9f129
...
...
@@ -168,15 +168,11 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
// rotating mem
rotating_mem
.
Next
();
// clear c mem
if
constexpr
(
!
is_same
<
remove_cvref_t
<
CDataType
>
,
bhalf_t
>::
value
)
{
if
(
arg_
.
KBatch
>
1
)
hipGetErrorString
(
hipMemsetAsync
(
arg_
.
p_c_grid
,
hipGetErrorString
(
hipMemsetAsync
(
arg_
.
p_c_grid
,
0
,
arg_
.
M
*
arg_
.
N
*
sizeof
(
CDataType
),
stream_config
.
stream_id_
));
}
};
ave_time
=
ck
::
utility
::
launch_and_time_kernel_with_preprocess
<
false
>
(
...
...
@@ -189,15 +185,12 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
arg_
);
}
else
{
if
constexpr
(
!
is_same
<
remove_cvref_t
<
CDataType
>
,
bhalf_t
>::
value
)
{
if
(
arg
.
KBatch
>
1
)
hipGetErrorString
(
hipMemsetAsync
(
arg
.
p_c_grid
,
0
,
arg
.
M
*
arg
.
N
*
sizeof
(
CDataType
),
stream_config
.
stream_id_
));
}
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
arg
);
...
...
@@ -214,8 +207,6 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v3
)
{
if
(
arg
.
KBatch
>
1
)
{
if
constexpr
(
!
is_same
<
remove_cvref_t
<
CDataType
>
,
bhalf_t
>::
value
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
...
...
@@ -224,7 +215,6 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
minimum_occupancy
>
;
Run
(
kernel
);
}
}
else
{
const
auto
kernel
=
...
...
@@ -239,13 +229,11 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v2
)
{
if
(
arg
.
KBatch
>
1
)
{
if
constexpr
(
!
is_same
<
remove_cvref_t
<
CDataType
>
,
bhalf_t
>::
value
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
One
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
...
...
@@ -255,8 +243,8 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
else
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Full
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
...
...
@@ -266,8 +254,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
2
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Two
)
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Two
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
...
...
@@ -326,8 +313,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
6
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Six
)
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Six
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
...
...
@@ -354,7 +340,6 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
}
}
}
}
else
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
One
)
...
...
@@ -472,8 +457,6 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v4
)
{
if
(
arg
.
KBatch
>
1
)
{
if
constexpr
(
!
is_same
<
remove_cvref_t
<
CDataType
>
,
bhalf_t
>::
value
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
{
...
...
@@ -496,7 +479,6 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
Run
(
kernel
);
}
}
}
else
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
...
...
@@ -524,13 +506,11 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
else
{
if
(
arg
.
KBatch
>
1
)
{
if
constexpr
(
!
is_same
<
remove_cvref_t
<
CDataType
>
,
bhalf_t
>::
value
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
...
...
@@ -539,8 +519,8 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
}
else
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
...
...
@@ -548,7 +528,6 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
Run
(
kernel
);
}
}
}
else
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
...
...
@@ -579,10 +558,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
// Tail number always 1
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v1
)
{
if
(
arg
.
KBatch
>
1
)
{
if
constexpr
(
!
is_same
<
remove_cvref_t
<
CDataType
>
,
bhalf_t
>::
value
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
...
...
@@ -591,7 +567,6 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
minimum_occupancy
>
;
Run
(
kernel
);
}
}
else
{
const
auto
kernel
=
...
...
@@ -628,6 +603,11 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
return
false
;
}
if
(
!
is_bf16_atomic_supported
()
&&
std
::
is_same_v
<
CDataType
,
ck
::
bhalf_t
>
&&
arg
.
KBatch
>
1
)
{
return
false
;
}
if
((
arg
.
K
%
AK1
!=
0
||
arg
.
K
%
BK1
!=
0
)
&&
!
(
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
||
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp
View file @
72c9f129
...
...
@@ -1039,14 +1039,14 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
return
false
;
if
constexpr
(
!
((
NDimSpatial
==
1
&&
(
is_NWG
K
_GKXC_NWG
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_GNW
K
_GKXC_GNW
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()))
||
(
is_NWG
C
_GKXC_NWG
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_GNW
C
_GKXC_GNW
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()))
||
(
NDimSpatial
==
2
&&
(
is_NHWG
K
_GKYXC_NHWG
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_GNHW
K
_GKYXC_GNHW
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()))
||
(
is_NHWG
C
_GKYXC_NHWG
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_GNHW
C
_GKYXC_GNHW
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()))
||
(
NDimSpatial
==
3
&&
(
is_NDHWG
K
_GKZYXC_NDHWG
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_GNDHW
K
_GKZYXC_GNDHW
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()))))
(
is_NDHWG
C
_GKZYXC_NDHWG
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_GNDHW
C
_GKZYXC_GNDHW
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()))))
{
return
false
;
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp
View file @
72c9f129
...
...
@@ -864,23 +864,23 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
}
if
constexpr
(
NDimSpatial
==
1
)
{
if
constexpr
(
!
is_GNW
K
_GKXC_GNW
C
<
InLayout
,
WeiLayout
,
OutLayout
>
())
if
constexpr
(
!
is_GNW
C
_GKXC_GNW
K
<
InLayout
,
WeiLayout
,
OutLayout
>
())
{
return
false
;
}
}
else
if
constexpr
(
NDimSpatial
==
2
)
{
if
constexpr
(
!
(
is_NHWG
K
_GKYXC_NHWG
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_GNHW
K
_GKYXC_GNHW
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()))
if
constexpr
(
!
(
is_NHWG
C
_GKYXC_NHWG
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_GNHW
C
_GKYXC_GNHW
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()))
{
return
false
;
}
}
else
if
constexpr
(
NDimSpatial
==
3
)
{
if
constexpr
(
!
(
is_NDHWG
K
_GKZYXC_NDHWG
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_GNDHW
K
_GKZYXC_GNDHW
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()))
if
constexpr
(
!
(
is_NDHWG
C
_GKZYXC_NDHWG
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_GNDHW
C
_GKZYXC_GNDHW
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()))
{
return
false
;
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
View file @
72c9f129
...
...
@@ -22,6 +22,7 @@
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
...
...
@@ -191,7 +192,9 @@ template <ck::index_t NDimSpatial,
BlockGemmPipelineVersion
BlkGemmPipelineVer
=
BlockGemmPipelineVersion
::
v1
,
index_t
NumGroupsToMerge
=
1
,
typename
ComputeTypeA
=
InDataType
,
typename
ComputeTypeB
=
ComputeTypeA
>
typename
ComputeTypeB
=
ComputeTypeA
,
index_t
TransposeTransferSrcScalarPerVector
=
1
,
index_t
TransposeTransferDstScalarPerVector
=
1
>
struct
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
:
public
DeviceGroupedConvBwdWeight
<
NDimSpatial
,
InLayout
,
...
...
@@ -216,6 +219,11 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
using
BDataType
=
InDataType
;
using
EDataType
=
WeiDataType
;
// If NGCHW then ADataType must be equal to BDataType
static_assert
(
!
(
is_NGCHW_GKYXC_NGKHW
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_NGCDHW_GKZYXC_NGKDHW
<
InLayout
,
WeiLayout
,
OutLayout
>
())
||
is_same_v
<
ADataType
,
BDataType
>
);
using
AElementwiseOperation
=
OutElementwiseOperation
;
using
BElementwiseOperation
=
InElementwiseOperation
;
using
CDEElementwiseOperation
=
WeiElementwiseOperation
;
...
...
@@ -351,6 +359,142 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
batch
)[
I2
];
}
static
constexpr
index_t
ClusterLengthMPerBlock
=
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
::
At
(
1
);
static
constexpr
index_t
ClusterLengthNPerBlock
=
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
::
At
(
3
);
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
static
auto
MakeInputTransposeDesc
(
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_strides
)
{
const
index_t
&
G
=
g_n_c_wis_lengths
[
0
];
const
index_t
&
N
=
g_n_c_wis_lengths
[
1
];
const
index_t
&
C
=
g_n_c_wis_lengths
[
2
];
const
index_t
&
Hi
=
g_n_c_wis_lengths
[
3
];
const
index_t
&
Wi
=
g_n_c_wis_lengths
[
4
];
const
index_t
&
GStride
=
g_n_c_wis_strides
[
0
];
const
index_t
&
NStride
=
g_n_c_wis_strides
[
1
];
const
index_t
&
CStride
=
g_n_c_wis_strides
[
2
];
const
index_t
&
HiStride
=
g_n_c_wis_strides
[
3
];
const
index_t
&
WiStride
=
g_n_c_wis_strides
[
4
];
const
auto
desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
G
,
C
,
Hi
,
Wi
),
make_tuple
(
NStride
,
GStride
,
CStride
,
HiStride
,
WiStride
));
const
auto
merged_desc
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
G
,
C
)),
make_merge_transform
(
make_tuple
(
Hi
,
Wi
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
PadTensorDescriptor
(
merged_desc
,
make_tuple
(
MPerBlock
/
ClusterLengthMPerBlock
,
NPerBlock
/
ClusterLengthNPerBlock
),
Sequence
<
true
,
true
>
{});
}
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
static
auto
MakeOutputTransposeDesc
(
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_strides
)
{
const
index_t
&
G
=
g_n_c_wis_lengths
[
0
];
const
index_t
&
N
=
g_n_c_wis_lengths
[
1
];
const
index_t
&
C
=
g_n_c_wis_lengths
[
2
];
const
index_t
&
Hi
=
g_n_c_wis_lengths
[
3
];
const
index_t
&
Wi
=
g_n_c_wis_lengths
[
4
];
const
index_t
&
NStride
=
g_n_c_wis_strides
[
1
];
const
index_t
HiStride
=
Wi
*
G
*
C
;
const
index_t
WiStride
=
G
*
C
;
const
index_t
GStride
=
C
;
const
index_t
CStride
=
1
;
const
auto
desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
G
,
C
,
Hi
,
Wi
),
make_tuple
(
NStride
,
GStride
,
CStride
,
HiStride
,
WiStride
));
const
auto
merged_desc
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
G
,
C
)),
make_merge_transform
(
make_tuple
(
Hi
,
Wi
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
PadTensorDescriptor
(
merged_desc
,
make_tuple
(
MPerBlock
/
ClusterLengthMPerBlock
,
NPerBlock
/
ClusterLengthNPerBlock
),
Sequence
<
true
,
true
>
{});
}
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
static
auto
MakeInputTransposeDesc
(
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_strides
)
{
const
index_t
&
G
=
g_n_c_wis_lengths
[
0
];
const
index_t
&
N
=
g_n_c_wis_lengths
[
1
];
const
index_t
&
C
=
g_n_c_wis_lengths
[
2
];
const
index_t
&
Di
=
g_n_c_wis_lengths
[
3
];
const
index_t
&
Hi
=
g_n_c_wis_lengths
[
4
];
const
index_t
&
Wi
=
g_n_c_wis_lengths
[
5
];
const
index_t
&
GStride
=
g_n_c_wis_strides
[
0
];
const
index_t
&
NStride
=
g_n_c_wis_strides
[
1
];
const
index_t
&
CStride
=
g_n_c_wis_strides
[
2
];
const
index_t
&
DiStride
=
g_n_c_wis_strides
[
3
];
const
index_t
&
HiStride
=
g_n_c_wis_strides
[
4
];
const
index_t
&
WiStride
=
g_n_c_wis_strides
[
5
];
const
auto
desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
G
,
C
,
Di
,
Hi
,
Wi
),
make_tuple
(
NStride
,
GStride
,
CStride
,
DiStride
,
HiStride
,
WiStride
));
const
auto
merged_desc
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
G
,
C
)),
make_merge_transform
(
make_tuple
(
Di
,
Hi
,
Wi
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
PadTensorDescriptor
(
merged_desc
,
make_tuple
(
MPerBlock
/
ClusterLengthMPerBlock
,
NPerBlock
/
ClusterLengthNPerBlock
),
Sequence
<
true
,
true
>
{});
}
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
static
auto
MakeOutputTransposeDesc
(
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_strides
)
{
const
index_t
&
G
=
g_n_c_wis_lengths
[
0
];
const
index_t
&
N
=
g_n_c_wis_lengths
[
1
];
const
index_t
&
C
=
g_n_c_wis_lengths
[
2
];
const
index_t
&
Di
=
g_n_c_wis_lengths
[
3
];
const
index_t
&
Hi
=
g_n_c_wis_lengths
[
4
];
const
index_t
&
Wi
=
g_n_c_wis_lengths
[
5
];
const
index_t
&
NStride
=
g_n_c_wis_strides
[
1
];
const
index_t
DiStride
=
Hi
*
Wi
*
G
*
C
;
const
index_t
HiStride
=
Wi
*
G
*
C
;
const
index_t
WiStride
=
G
*
C
;
const
index_t
GStride
=
C
;
const
index_t
CStride
=
1
;
const
auto
desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
G
,
C
,
Di
,
Hi
,
Wi
),
make_tuple
(
NStride
,
GStride
,
CStride
,
DiStride
,
HiStride
,
WiStride
));
const
auto
merged_desc
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
G
,
C
)),
make_merge_transform
(
make_tuple
(
Di
,
Hi
,
Wi
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
PadTensorDescriptor
(
merged_desc
,
make_tuple
(
MPerBlock
/
ClusterLengthMPerBlock
,
NPerBlock
/
ClusterLengthNPerBlock
),
Sequence
<
true
,
true
>
{});
}
using
InputTransposeDescType
=
remove_cvref_t
<
decltype
(
MakeInputTransposeDesc
<
NDimSpatial
>
({},
{}))
>
;
using
OutputTransposeDescType
=
remove_cvref_t
<
decltype
(
MakeOutputTransposeDesc
<
NDimSpatial
>
({},
{}))
>
;
using
ABCGridDescs
=
decltype
(
GetABCGridDesc
<
NDimSpatial
>
());
using
AGridDesc_K0_M_K1
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I0
])
>
;
...
...
@@ -407,13 +551,9 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
ComputeTypeA
,
ComputeTypeB
>
;
static
constexpr
index_t
ClusterLengthMPerBlock
=
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
::
At
(
1
);
static
constexpr
index_t
ClusterLengthNPerBlock
=
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
::
At
(
3
);
using
Block2TileMapElementwise
=
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
>
;
using
GridwiseElementwise
=
using
GridwiseElementwise
Cast
=
GridwiseElementwise
<
Tuple
<
CElementwiseGridDesc_M_N
>
,
Tuple
<
CElementwiseGridDesc_M_N
>
,
Tuple
<
const
AccDataType
*>
,
...
...
@@ -431,6 +571,24 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
I1
,
I1
>
;
using
GridwiseElementwiseTranspose
=
GridwiseElementwise
<
Tuple
<
InputTransposeDescType
>
,
Tuple
<
OutputTransposeDescType
>
,
Tuple
<
const
ADataType
*>
,
Tuple
<
ADataType
*>
,
Block2TileMapElementwise
,
element_wise
::
PassThrough
,
BlockSize
,
MPerBlock
,
NPerBlock
,
MPerBlock
/
ClusterLengthMPerBlock
,
NPerBlock
/
ClusterLengthNPerBlock
,
Sequence
<
1
,
0
>
,
Sequence
<
TransposeTransferSrcScalarPerVector
>
,
Sequence
<
TransposeTransferDstScalarPerVector
>
,
I1
,
I0
>
;
// Argument
using
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
decltype
(
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
...
...
@@ -493,6 +651,45 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
end
(
a_g_n_k_wos_lengths
),
begin
(
output_spatial_lengths_
));
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_n_c_wis_strides_transposed
=
b_g_n_c_wis_strides
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_k_wos_strides_transposed
=
a_g_n_k_wos_strides
;
// NGKHW - transpose needed
if
constexpr
(
is_NGCHW_GKYXC_NGKHW
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_NGCDHW_GKZYXC_NGKDHW
<
InLayout
,
WeiLayout
,
OutLayout
>
())
{
b_g_n_c_wis_strides_transposed
[
I0
]
=
Conv_C_
;
b_g_n_c_wis_strides_transposed
[
I2
]
=
I1
;
a_g_n_k_wos_strides_transposed
[
I0
]
=
Conv_K_
;
a_g_n_k_wos_strides_transposed
[
I2
]
=
I1
;
if
constexpr
(
NDimSpatial
==
2
)
{
b_g_n_c_wis_strides_transposed
[
I3
]
=
input_spatial_lengths_
[
I1
]
*
Conv_G_
*
Conv_C_
;
b_g_n_c_wis_strides_transposed
[
I4
]
=
Conv_G_
*
Conv_C_
;
a_g_n_k_wos_strides_transposed
[
I3
]
=
output_spatial_lengths_
[
I1
]
*
Conv_G_
*
Conv_K_
;
a_g_n_k_wos_strides_transposed
[
I4
]
=
Conv_G_
*
Conv_K_
;
}
else
if
constexpr
(
NDimSpatial
==
3
)
{
b_g_n_c_wis_strides_transposed
[
I3
]
=
input_spatial_lengths_
[
I1
]
*
input_spatial_lengths_
[
I2
]
*
Conv_G_
*
Conv_C_
;
b_g_n_c_wis_strides_transposed
[
I4
]
=
input_spatial_lengths_
[
I2
]
*
Conv_G_
*
Conv_C_
;
b_g_n_c_wis_strides_transposed
[
I5
]
=
Conv_G_
*
Conv_C_
;
a_g_n_k_wos_strides_transposed
[
I3
]
=
output_spatial_lengths_
[
I1
]
*
input_spatial_lengths_
[
I2
]
*
Conv_G_
*
Conv_K_
;
a_g_n_k_wos_strides_transposed
[
I4
]
=
input_spatial_lengths_
[
I2
]
*
Conv_G_
*
Conv_K_
;
a_g_n_k_wos_strides_transposed
[
I5
]
=
Conv_G_
*
Conv_K_
;
}
}
const
auto
descs
=
conv_to_gemm_transformer_v2
.
template
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
<
NDimSpatial
>(
...
...
@@ -502,9 +699,9 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
input_spatial_lengths_
,
filter_spatial_lengths_
,
output_spatial_lengths_
,
b_g_n_c_wis_strides
,
b_g_n_c_wis_strides
_transposed
,
e_g_k_c_xs_strides
,
a_g_n_k_wos_strides
,
a_g_n_k_wos_strides
_transposed
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
...
...
@@ -540,8 +737,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
const
index_t
GemmN
=
b_grid_desc_k0_n_k1_
.
GetLength
(
I1
);
// A/B/C Batch Stride
compute_ptr_offset_of_batch_
.
BatchStrideA_
=
a_g_n_k_wos_strides
[
0
];
compute_ptr_offset_of_batch_
.
BatchStrideB_
=
b_g_n_c_wis_strides
[
0
];
compute_ptr_offset_of_batch_
.
BatchStrideA_
=
a_g_n_k_wos_strides
_transposed
[
0
];
compute_ptr_offset_of_batch_
.
BatchStrideB_
=
b_g_n_c_wis_strides
_transposed
[
0
];
compute_ptr_offset_of_batch_
.
BatchStrideC_
=
Conv_K_
*
Conv_C_
*
std
::
accumulate
(
begin
(
filter_spatial_lengths_
),
...
...
@@ -553,13 +750,58 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
ce_grid_desc_m_n_
,
GridwiseGemm
::
CalculateMBlock
(
GemmM
),
GridwiseGemm
::
CalculateNBlock
(
GemmN
));
if
constexpr
(
is_NGCHW_GKYXC_NGKHW
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_NGCDHW_GKZYXC_NGKDHW
<
InLayout
,
WeiLayout
,
OutLayout
>
())
{
a_in_transpose_desc_
=
MakeInputTransposeDesc
<
NDimSpatial
>
(
a_g_n_k_wos_lengths
,
a_g_n_k_wos_strides
);
a_out_transpose_desc_
=
MakeOutputTransposeDesc
<
NDimSpatial
>
(
a_g_n_k_wos_lengths
,
a_g_n_k_wos_strides
);
b_in_transpose_desc_
=
MakeInputTransposeDesc
<
NDimSpatial
>
(
b_g_n_c_wis_lengths
,
b_g_n_c_wis_strides
);
b_out_transpose_desc_
=
MakeOutputTransposeDesc
<
NDimSpatial
>
(
b_g_n_c_wis_lengths
,
b_g_n_c_wis_strides
);
elementwise_block_2_ctile_map_transpose_a_
=
Block2TileMapElementwise
{
a_in_transpose_desc_
.
GetLength
(
I0
),
a_in_transpose_desc_
.
GetLength
(
I1
)};
elementwise_block_2_ctile_map_transpose_b_
=
Block2TileMapElementwise
{
b_in_transpose_desc_
.
GetLength
(
I0
),
b_in_transpose_desc_
.
GetLength
(
I1
)};
}
}
std
::
size_t
GetWorkspaceSizeBytes
()
const
std
::
size_t
GetWorkspaceATensorSizeBytes
()
const
{
return
sizeof
(
ADataType
)
*
a_in_transpose_desc_
.
GetElementSpaceSize
();
}
std
::
size_t
GetWorkspaceBTensorSizeBytes
()
const
{
return
sizeof
(
BDataType
)
*
b_in_transpose_desc_
.
GetElementSpaceSize
();
}
std
::
size_t
GetWorkspaceETensorSizeBytes
()
const
{
return
sizeof
(
AccDataType
)
*
ce_grid_desc_m_n_
.
GetElementSpaceSize
()
*
Conv_G_
;
}
std
::
size_t
GetWorkspaceSizeBytes
()
const
{
// Transpose require workspace for A and B
if
constexpr
(
is_NGCHW_GKYXC_NGKHW
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_NGCDHW_GKZYXC_NGKDHW
<
InLayout
,
WeiLayout
,
OutLayout
>
())
{
return
GetWorkspaceATensorSizeBytes
()
+
GetWorkspaceBTensorSizeBytes
()
+
GetWorkspaceETensorSizeBytes
();
}
else
{
return
GetWorkspaceETensorSizeBytes
();
}
}
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
EDataType
*
p_e_grid_
;
...
...
@@ -571,6 +813,11 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_
;
Block2TileMapElementwise
elementwise_block_2_ctile_map_
;
Block2TileMapElementwise
elementwise_block_2_ctile_map_transpose_a_
,
elementwise_block_2_ctile_map_transpose_b_
;
InputTransposeDescType
a_in_transpose_desc_
,
b_in_transpose_desc_
;
OutputTransposeDescType
a_out_transpose_desc_
,
b_out_transpose_desc_
;
// for computing batch offset
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
compute_ptr_offset_of_batch_
;
...
...
@@ -624,17 +871,23 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
AccDataType
*
p_c_grid
=
type_convert
<
AccDataType
*>
(
arg
.
p_workspace_
);
const
ADataType
*
p_a_grid
=
arg
.
p_a_grid_
;
const
BDataType
*
p_b_grid
=
arg
.
p_b_grid_
;
if
constexpr
(
is_NGCHW_GKYXC_NGKHW
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_NGCDHW_GKZYXC_NGKDHW
<
InLayout
,
WeiLayout
,
OutLayout
>
())
{
p_a_grid
=
type_convert
<
const
ADataType
*>
(
arg
.
p_workspace_
)
+
arg
.
GetWorkspaceETensorSizeBytes
()
/
sizeof
(
BDataType
);
p_b_grid
=
type_convert
<
const
BDataType
*>
(
arg
.
p_workspace_
)
+
(
arg
.
GetWorkspaceETensorSizeBytes
()
+
arg
.
GetWorkspaceATensorSizeBytes
())
/
sizeof
(
BDataType
);
}
// nullptr for output, will be set after workspace set
typename
GridwiseGemm
::
Argument
gemm_arg
{
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
p_c_grid
,
GemmM
,
GemmN
,
GemmK
,
I0
,
I0
,
I0
,
arg
.
k_batch_
};
typename
GridwiseGemm
::
Argument
gemm_arg
{
p_a_grid
,
p_b_grid
,
p_c_grid
,
GemmM
,
GemmN
,
GemmK
,
I0
,
I0
,
I0
,
arg
.
k_batch_
};
index_t
gdx
,
gdy
,
gdz
;
std
::
tie
(
gdx
,
gdy
,
gdz
)
=
GridwiseGemm
::
CalculateGridSize
(
...
...
@@ -651,8 +904,10 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
Number
<
0
>
{})
/
gemm_arg
.
KBatch
;
const
auto
clear_workspace
=
[
&
]()
{
hip_check_error
(
hipMemsetAsync
(
gemm_arg
.
p_c_grid
,
0
,
arg
.
GetWorkspaceSizeBytes
(),
stream_config
.
stream_id_
));
hip_check_error
(
hipMemsetAsync
(
gemm_arg
.
p_c_grid
,
0
,
arg
.
GetWorkspaceETensorSizeBytes
(),
stream_config
.
stream_id_
));
};
const
auto
Run
=
[
&
](
const
auto
&
kernel
)
{
...
...
@@ -1261,6 +1516,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
float
avg_time
=
0.
f
;
auto
launch_elementwise_kernel
=
[
&
]()
{
const
AccDataType
*
p_c_grid
=
type_convert
<
const
AccDataType
*>
(
arg
.
p_workspace_
);
const
index_t
grid_size
=
arg
.
elementwise_block_2_ctile_map_
.
CalculateGridSize
(
...
...
@@ -1270,7 +1526,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
std
::
array
<
index_t
,
I1
>
in_out_batch_strides
=
{
static_cast
<
index_t
>
(
arg
.
compute_ptr_offset_of_batch_
.
BatchStrideC_
)};
const
auto
kernel
=
kernel_batched_elementwise
<
GridwiseElementwise
,
const
auto
kernel
=
kernel_batched_elementwise
<
GridwiseElementwise
Cast
,
ck
::
Tuple
<
CElementwiseGridDesc_M_N
>
,
ck
::
Tuple
<
CElementwiseGridDesc_M_N
>
,
ck
::
Tuple
<
const
AccDataType
*>
,
...
...
@@ -1296,7 +1552,54 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
in_out_batch_strides
);
};
float
avg_time
=
RunGemmV3
(
arg
,
stream_config
);
if
constexpr
(
is_NGCHW_GKYXC_NGKHW
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_NGCDHW_GKZYXC_NGKDHW
<
InLayout
,
WeiLayout
,
OutLayout
>
())
{
const
index_t
grid_size_a
=
arg
.
elementwise_block_2_ctile_map_transpose_a_
.
CalculateGridSize
(
arg
.
a_in_transpose_desc_
);
const
index_t
grid_size_b
=
arg
.
elementwise_block_2_ctile_map_transpose_b_
.
CalculateGridSize
(
arg
.
b_in_transpose_desc_
);
ADataType
*
p_a_out_grid
=
type_convert
<
ADataType
*>
(
arg
.
p_workspace_
)
+
arg
.
GetWorkspaceETensorSizeBytes
()
/
sizeof
(
BDataType
);
BDataType
*
p_b_out_grid
=
type_convert
<
BDataType
*>
(
arg
.
p_workspace_
)
+
(
arg
.
GetWorkspaceETensorSizeBytes
()
+
arg
.
GetWorkspaceATensorSizeBytes
())
/
sizeof
(
BDataType
);
auto
kernel_transpose
=
kernel_elementwise_dual
<
GridwiseElementwiseTranspose
,
ck
::
Tuple
<
InputTransposeDescType
>
,
ck
::
Tuple
<
InputTransposeDescType
>
,
ck
::
Tuple
<
OutputTransposeDescType
>
,
ck
::
Tuple
<
OutputTransposeDescType
>
,
ck
::
Tuple
<
const
ADataType
*>
,
ck
::
Tuple
<
BDataType
*>
,
Block2TileMapElementwise
,
Block2TileMapElementwise
,
element_wise
::
PassThrough
>
;
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kernel_transpose
,
dim3
(
grid_size_a
+
grid_size_b
),
dim3
(
BlockSize
),
0
,
make_tuple
(
arg
.
a_in_transpose_desc_
),
make_tuple
(
arg
.
b_in_transpose_desc_
),
make_tuple
(
arg
.
a_out_transpose_desc_
),
make_tuple
(
arg
.
b_out_transpose_desc_
),
make_tuple
(
arg
.
p_a_grid_
),
make_tuple
(
arg
.
p_b_grid_
),
make_tuple
(
p_a_out_grid
),
make_tuple
(
p_b_out_grid
),
arg
.
elementwise_block_2_ctile_map_transpose_a_
,
arg
.
elementwise_block_2_ctile_map_transpose_b_
,
element_wise
::
PassThrough
{},
grid_size_a
);
}
avg_time
+=
RunGemmV3
(
arg
,
stream_config
);
avg_time
+=
launch_elementwise_kernel
();
return
avg_time
;
}
...
...
@@ -1347,25 +1650,18 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
{
return
false
;
}
if
constexpr
(
NDimSpatial
==
1
)
if
constexpr
(
NDimSpatial
==
2
)
{
if
constexpr
(
!
is_GNWK_GKXC_GNWC
<
InLayout
,
WeiLayout
,
OutLayout
>
())
{
return
false
;
}
}
else
if
constexpr
(
NDimSpatial
==
2
)
{
if
constexpr
(
!
(
is_NHWGK_GKYXC_NHWGC
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_GNHWK_GKYXC_GNHWC
<
InLayout
,
WeiLayout
,
OutLayout
>
()))
if
constexpr
(
!
(
is_NHWGC_GKYXC_NHWGK
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_NGCHW_GKYXC_NGKHW
<
InLayout
,
WeiLayout
,
OutLayout
>
()))
{
return
false
;
}
}
else
if
constexpr
(
NDimSpatial
==
3
)
{
if
constexpr
(
!
(
is_NDHWG
K
_GKZYXC_NDHWG
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_
G
NDHW
K
_GKZYXC_
G
NDHW
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()))
if
constexpr
(
!
(
is_NDHWG
C
_GKZYXC_NDHWG
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_N
GC
DHW_GKZYXC_N
GK
DHW
<
InLayout
,
WeiLayout
,
OutLayout
>
()))
{
return
false
;
}
...
...
@@ -1431,6 +1727,35 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
return
false
;
}
if
constexpr
(
is_NGCHW_GKYXC_NGKHW
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_NGCDHW_GKZYXC_NGKDHW
<
InLayout
,
WeiLayout
,
OutLayout
>
())
{
if
((
arg
.
Conv_G_
*
arg
.
Conv_C_
)
%
TransposeTransferDstScalarPerVector
!=
0
)
{
return
false
;
}
if
((
arg
.
Conv_G_
*
arg
.
Conv_K_
)
%
TransposeTransferDstScalarPerVector
!=
0
)
{
return
false
;
}
const
index_t
input_spatial_acum
=
ck
::
accumulate_n
<
index_t
>
(
arg
.
input_spatial_lengths_
.
begin
(),
NDimSpatial
,
1
,
std
::
multiplies
<>
());
const
index_t
output_spatial_acum
=
ck
::
accumulate_n
<
index_t
>
(
arg
.
output_spatial_lengths_
.
begin
(),
NDimSpatial
,
1
,
std
::
multiplies
<>
());
if
(
input_spatial_acum
%
TransposeTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
if
(
output_spatial_acum
%
TransposeTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
return
true
;
}
...
...
@@ -1563,8 +1888,17 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<<
BlkGemmPipelineSchedulerToString
[
BlkGemmPipeSched
]
<<
", "
<<
"BlkGemmPipelineVersion: "
<<
BlkGemmPipelineVersionToString
[
BlkGemmPipelineVer
]
<<
", "
<<
NumGroupsToMerge
<<
">"
;
<<
NumGroupsToMerge
;
if
constexpr
(
is_NGCHW_GKYXC_NGKHW
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_NGCDHW_GKZYXC_NGKDHW
<
InLayout
,
WeiLayout
,
OutLayout
>
())
{
str
<<
", TransposeTransferSrcScalarPerVector: "
<<
TransposeTransferSrcScalarPerVector
<<
", "
<<
"TransposeTransferDstScalarPerVector: "
<<
TransposeTransferDstScalarPerVector
;
}
str
<<
">"
;
// clang-format on
return
str
.
str
();
...
...
Prev
1
2
3
4
5
6
7
…
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment