Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
171ed358
"vscode:/vscode.git/clone" did not exist on "6434d29db0e98dd428fdc87fcf86d4f72c320690"
Unverified
Commit
171ed358
authored
Sep 04, 2024
by
Illia Silin
Committed by
GitHub
Sep 04, 2024
Browse files
Merge pull request #148 from ROCm/merge_from_public
Merge from public
parents
829e0eb3
e536d321
Changes
76
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1448 additions
and
1102 deletions
+1448
-1102
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
+182
-177
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp
...le/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp
+277
-0
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp
...eline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp
+288
-0
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp
...mha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp
+98
-80
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp
...peline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp
+0
-770
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async_default_policy.hpp
...ha_fwd_splitkv_pipeline_qr_ks_vs_async_default_policy.hpp
+0
-19
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp
...ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp
+83
-31
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
...a/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
+13
-3
include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp
include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp
+37
-21
library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
..._operation_instance/device_operation_instance_factory.hpp
+8
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp
...device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp
+40
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp
...wd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp
+96
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp
...tion_instance/gpu/grouped_convolution_backward_weight.hpp
+30
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_xdl.inc
..._instance/gpu/grouped_convolution_backward_weight_xdl.inc
+46
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp
...or_operation_instance/gpu/grouped_convolution_forward.hpp
+13
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl_merged_groups.inc
...nce/gpu/grouped_convolution_forward_xdl_merged_groups.inc
+112
-0
library/include/ck/library/utility/convolution_host_tensor_descriptor_helper.hpp
...ary/utility/convolution_host_tensor_descriptor_helper.hpp
+41
-1
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt
...ion_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt
+2
-0
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev2_instance.cpp
...t_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev2_instance.cpp
+41
-0
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev5_instance.cpp
...t_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev5_instance.cpp
+41
-0
No files found.
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
View file @
171ed358
...
@@ -32,8 +32,6 @@ struct FmhaFwdSplitKVKernel
...
@@ -32,8 +32,6 @@ struct FmhaFwdSplitKVKernel
using
KDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
KDataType
>
;
using
KDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
KDataType
>
;
using
VDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
VDataType
>
;
using
VDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
VDataType
>
;
using
BiasDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
BiasDataType
>
;
using
BiasDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
BiasDataType
>
;
using
RandValOutputDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
RandValOutputDataType
>
;
using
LSEDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
LSEDataType
>
;
using
LSEDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
LSEDataType
>
;
using
SaccDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
SaccDataType
>
;
using
SaccDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
SaccDataType
>
;
using
OaccDataType
=
remove_cvref_t
<
typename
FmhaPipeline
::
OaccDataType
>
;
using
OaccDataType
=
remove_cvref_t
<
typename
FmhaPipeline
::
OaccDataType
>
;
...
@@ -46,8 +44,10 @@ struct FmhaFwdSplitKVKernel
...
@@ -46,8 +44,10 @@ struct FmhaFwdSplitKVKernel
static
constexpr
bool
kPadHeadDimQ
=
FmhaPipeline
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimQ
=
FmhaPipeline
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimV
=
FmhaPipeline
::
kPadHeadDimV
;
static
constexpr
bool
kPadHeadDimV
=
FmhaPipeline
::
kPadHeadDimV
;
static
constexpr
auto
BiasEnum
=
FmhaPipeline
::
BiasEnum
;
static
constexpr
auto
BiasEnum
=
FmhaPipeline
::
BiasEnum
;
static
constexpr
bool
kHasDropout
=
FmhaPipeline
::
kHasDropout
;
static
constexpr
bool
kDoFp8StaticQuant
=
FmhaPipeline
::
Problem
::
kDoFp8StaticQuant
;
static
constexpr
bool
kDoFp8StaticQuant
=
FmhaPipeline
::
Problem
::
kDoFp8StaticQuant
;
static
constexpr
bool
kIsPagedKV
=
FmhaPipeline
::
Problem
::
kIsPagedKV
;
static_assert
(
!
kIsGroupMode
||
(
kIsGroupMode
&&
!
kIsPagedKV
),
"paged-kvcache only supported by batch mode kernels"
);
using
FmhaMask
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
FmhaMask
>
;
using
FmhaMask
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
FmhaMask
>
;
static
constexpr
bool
kHasMask
=
FmhaMask
::
IsMasking
;
static
constexpr
bool
kHasMask
=
FmhaMask
::
IsMasking
;
...
@@ -85,8 +85,8 @@ struct FmhaFwdSplitKVKernel
...
@@ -85,8 +85,8 @@ struct FmhaFwdSplitKVKernel
"w"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"w"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
(
kBlockPerCuInput
==
-
1
?
""
:
(
"o"
+
_TS_
(
kBlockPerCu
)
+
"_"
))
+
_SS_
(
FmhaPipeline
::
name
)
+
"_"
+
(
kBlockPerCuInput
==
-
1
?
""
:
(
"o"
+
_TS_
(
kBlockPerCu
)
+
"_"
))
+
_SS_
(
FmhaPipeline
::
name
)
+
"_"
+
"v"
+
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
?
"r"
:
"c"
)
+
(
pn
.
empty
()
?
""
:
"_"
+
pn
)
+
"v"
+
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
?
"r"
:
"c"
)
+
(
pn
.
empty
()
?
""
:
"_"
+
pn
)
+
(
BiasEnum
==
BlockAttentionBiasEnum
::
NO_BIAS
?
_SS_
(
""
)
:
(
_SS_
(
"_"
)
+
BlockAttentionBiasEnumToStr
<
BiasEnum
>::
name
))
+
(
BiasEnum
==
BlockAttentionBiasEnum
::
NO_BIAS
?
_SS_
(
""
)
:
(
_SS_
(
"_"
)
+
BlockAttentionBiasEnumToStr
<
BiasEnum
>::
name
))
+
(
kHasMask
?
"_"
+
_SS_
(
FmhaMask
::
name
)
:
""
)
+
(
k
HasDropout
?
"_dropou
t"
:
""
)
+
(
k
DoFp8StaticQuant
?
"_squant
"
:
""
);
(
kHasMask
?
"_"
+
_SS_
(
FmhaMask
::
name
)
:
""
)
+
(
k
DoFp8StaticQuant
?
"_squan
t"
:
""
)
+
(
k
IsPagedKV
?
"_pagedkv
"
:
""
);
#undef _SS_
#undef _SS_
#undef _TS_
#undef _TS_
// clang-format on
// clang-format on
...
@@ -110,7 +110,6 @@ struct FmhaFwdSplitKVKernel
...
@@ -110,7 +110,6 @@ struct FmhaFwdSplitKVKernel
void
*
o_acc_ptr
;
void
*
o_acc_ptr
;
ck_tile
::
index_t
batch
;
ck_tile
::
index_t
batch
;
ck_tile
::
index_t
max_seqlen_q
;
ck_tile
::
index_t
seqlen_q
;
ck_tile
::
index_t
seqlen_q
;
ck_tile
::
index_t
seqlen_k
;
ck_tile
::
index_t
seqlen_k
;
...
@@ -136,6 +135,7 @@ struct FmhaFwdSplitKVKernel
...
@@ -136,6 +135,7 @@ struct FmhaFwdSplitKVKernel
ck_tile
::
index_t
nhead_stride_lse_acc
;
ck_tile
::
index_t
nhead_stride_lse_acc
;
ck_tile
::
index_t
nhead_stride_o_acc
;
ck_tile
::
index_t
nhead_stride_o_acc
;
ck_tile
::
index_t
batch_stride_lse_acc
;
ck_tile
::
index_t
batch_stride_o_acc
;
ck_tile
::
index_t
batch_stride_o_acc
;
ck_tile
::
index_t
split_stride_lse_acc
;
ck_tile
::
index_t
split_stride_lse_acc
;
...
@@ -173,32 +173,16 @@ struct FmhaFwdSplitKVKernel
...
@@ -173,32 +173,16 @@ struct FmhaFwdSplitKVKernel
float
scale_p
;
float
scale_p
;
};
};
struct
CommonDropout
Kargs
struct
PageBlockTable
Kargs
{
{
void
init_dropout
(
const
float
p_drop
,
const
int32_t
*
block_table_ptr
;
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
)
ck_tile
::
index_t
batch_stride_block_table
;
{
ck_tile
::
index_t
page_block_size
;
float
p_undrop
=
1.0
-
p_drop
;
p_undrop_in_uint8_t
=
uint8_t
(
std
::
floor
(
p_undrop
*
std
::
numeric_limits
<
uint8_t
>::
max
()));
rp_undrop
=
1.0
/
p_undrop
;
drop_seed
=
std
::
get
<
0
>
(
drop_seed_offset
);
drop_offset
=
std
::
get
<
1
>
(
drop_seed_offset
);
}
float
rp_undrop
=
1
;
uint8_t
p_undrop_in_uint8_t
=
std
::
numeric_limits
<
uint8_t
>::
max
();
bool
is_store_randval
=
false
;
uint64_t
drop_seed
=
1
;
uint64_t
drop_offset
=
0
;
void
*
rand_val_ptr
=
nullptr
;
ck_tile
::
index_t
stride_randval
=
0
;
ck_tile
::
index_t
nhead_stride_randval
=
0
;
};
};
struct
BatchModeDropoutKargs
:
CommonDropoutKargs
struct
CacheBatchIdxKargs
{
{
c
k_tile
::
index_t
batch_stride_randval
=
0
;
c
onst
int32_t
*
cache_batch_idx
;
};
};
struct
BatchModeKargs
struct
BatchModeKargs
...
@@ -210,12 +194,13 @@ struct FmhaFwdSplitKVKernel
...
@@ -210,12 +194,13 @@ struct FmhaFwdSplitKVKernel
EmptyKargs
<
0
>>>
,
EmptyKargs
<
0
>>>
,
std
::
conditional_t
<
kHasMask
,
MaskKargs
,
EmptyKargs
<
1
>>
,
std
::
conditional_t
<
kHasMask
,
MaskKargs
,
EmptyKargs
<
1
>>
,
std
::
conditional_t
<
kDoFp8StaticQuant
,
Fp8StaticQuantKargs
,
EmptyKargs
<
2
>>
,
std
::
conditional_t
<
kDoFp8StaticQuant
,
Fp8StaticQuantKargs
,
EmptyKargs
<
2
>>
,
std
::
conditional_t
<
k
HasDropout
,
BatchModeDropoutKargs
,
Empty
Kargs
<
3
>
>
std
::
conditional_t
<
k
IsPagedKV
,
PageBlockTableKargs
,
CacheBatchIdx
Kargs
>
{
{
const
int32_t
*
seqlen_k_ptr
;
ck_tile
::
index_t
batch_stride_q
;
ck_tile
::
index_t
batch_stride_q
;
ck_tile
::
index_t
batch_stride_k
;
ck_tile
::
index_t
batch_stride_k
;
ck_tile
::
index_t
batch_stride_v
;
ck_tile
::
index_t
batch_stride_v
;
ck_tile
::
index_t
batch_stride_lse_acc
;
};
};
struct
GroupModeKargs
struct
GroupModeKargs
...
@@ -226,12 +211,14 @@ struct FmhaFwdSplitKVKernel
...
@@ -226,12 +211,14 @@ struct FmhaFwdSplitKVKernel
AlibiKargs
,
AlibiKargs
,
EmptyKargs
<
0
>>>
,
EmptyKargs
<
0
>>>
,
std
::
conditional_t
<
kHasMask
,
MaskKargs
,
EmptyKargs
<
1
>>
,
std
::
conditional_t
<
kHasMask
,
MaskKargs
,
EmptyKargs
<
1
>>
,
std
::
conditional_t
<
kDoFp8StaticQuant
,
Fp8StaticQuantKargs
,
EmptyKargs
<
2
>>
,
std
::
conditional_t
<
kDoFp8StaticQuant
,
Fp8StaticQuantKargs
,
EmptyKargs
<
2
>>
std
::
conditional_t
<
kHasDropout
,
CommonDropoutKargs
,
EmptyKargs
<
3
>>
{
{
const
int32_t
*
seqstart_q_ptr
;
const
int32_t
*
seqstart_q_ptr
;
const
int32_t
*
seqstart_k_ptr
;
const
int32_t
*
seqstart_k_ptr
;
const
int32_t
*
seqlen_k_ptr
;
const
int32_t
*
seqlen_k_ptr
;
ck_tile
::
index_t
batch_stride_k
;
ck_tile
::
index_t
batch_stride_v
;
};
};
using
Kargs
=
std
::
conditional_t
<
kIsGroupMode
,
GroupModeKargs
,
BatchModeKargs
>
;
using
Kargs
=
std
::
conditional_t
<
kIsGroupMode
,
GroupModeKargs
,
BatchModeKargs
>
;
...
@@ -242,48 +229,45 @@ struct FmhaFwdSplitKVKernel
...
@@ -242,48 +229,45 @@ struct FmhaFwdSplitKVKernel
const
void
*
k_ptr
,
const
void
*
k_ptr
,
const
void
*
v_ptr
,
const
void
*
v_ptr
,
const
void
*
bias_ptr
,
const
void
*
bias_ptr
,
void
*
rand_val_ptr
,
void
*
lse_acc_ptr
,
void
*
lse_acc_ptr
,
void
*
o_acc_ptr
,
void
*
o_acc_ptr
,
ck_tile
::
index_t
batch
,
ck_tile
::
index_t
batch
,
ck_tile
::
index_t
max_seqlen_q
,
ck_tile
::
index_t
seqlen_q
,
ck_tile
::
index_t
seqlen_q
,
ck_tile
::
index_t
seqlen_k
,
ck_tile
::
index_t
seqlen_k
,
// only used if 'seqlen_k_ptr' is not specified
const
void
*
seqlen_k_ptr
,
// only used for (paged-) kvcache
ck_tile
::
index_t
hdim_q
,
ck_tile
::
index_t
hdim_q
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
num_head_q
,
ck_tile
::
index_t
num_head_q
,
ck_tile
::
index_t
nhead_ratio_qk
,
ck_tile
::
index_t
nhead_ratio_qk
,
ck_tile
::
index_t
num_splits
,
ck_tile
::
index_t
num_splits
,
const
void
*
block_table_ptr
,
ck_tile
::
index_t
batch_stride_block_table
,
ck_tile
::
index_t
page_block_size
,
const
void
*
cache_batch_idx
,
float
scale_s
,
float
scale_s
,
float
scale_p
,
float
scale_p
,
ck_tile
::
index_t
stride_q
,
ck_tile
::
index_t
stride_q
,
ck_tile
::
index_t
stride_k
,
ck_tile
::
index_t
stride_k
,
ck_tile
::
index_t
stride_v
,
ck_tile
::
index_t
stride_v
,
ck_tile
::
index_t
stride_bias
,
ck_tile
::
index_t
stride_bias
,
ck_tile
::
index_t
stride_randval
,
ck_tile
::
index_t
stride_o_acc
,
ck_tile
::
index_t
stride_o_acc
,
ck_tile
::
index_t
nhead_stride_q
,
ck_tile
::
index_t
nhead_stride_q
,
ck_tile
::
index_t
nhead_stride_k
,
ck_tile
::
index_t
nhead_stride_k
,
ck_tile
::
index_t
nhead_stride_v
,
ck_tile
::
index_t
nhead_stride_v
,
ck_tile
::
index_t
nhead_stride_bias
,
ck_tile
::
index_t
nhead_stride_bias
,
ck_tile
::
index_t
nhead_stride_randval
,
ck_tile
::
index_t
nhead_stride_lse_acc
,
ck_tile
::
index_t
nhead_stride_lse_acc
,
ck_tile
::
index_t
nhead_stride_o_acc
,
ck_tile
::
index_t
nhead_stride_o_acc
,
ck_tile
::
index_t
batch_stride_q
,
ck_tile
::
index_t
batch_stride_q
,
ck_tile
::
index_t
batch_stride_k
,
ck_tile
::
index_t
batch_stride_k
,
ck_tile
::
index_t
batch_stride_v
,
ck_tile
::
index_t
batch_stride_v
,
ck_tile
::
index_t
batch_stride_bias
,
ck_tile
::
index_t
batch_stride_bias
,
ck_tile
::
index_t
batch_stride_randval
,
ck_tile
::
index_t
batch_stride_lse_acc
,
ck_tile
::
index_t
batch_stride_lse_acc
,
ck_tile
::
index_t
batch_stride_o_acc
,
ck_tile
::
index_t
batch_stride_o_acc
,
ck_tile
::
index_t
split_stride_lse_acc
,
ck_tile
::
index_t
split_stride_lse_acc
,
ck_tile
::
index_t
split_stride_o_acc
,
ck_tile
::
index_t
split_stride_o_acc
,
ck_tile
::
index_t
window_size_left
,
ck_tile
::
index_t
window_size_left
,
ck_tile
::
index_t
window_size_right
,
ck_tile
::
index_t
window_size_right
,
ck_tile
::
index_t
mask_type
,
ck_tile
::
index_t
mask_type
)
float
p_drop
,
bool
s_randval
,
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
)
{
{
Kargs
kargs
{{
q_ptr
,
Kargs
kargs
{{
q_ptr
,
k_ptr
,
k_ptr
,
...
@@ -291,7 +275,6 @@ struct FmhaFwdSplitKVKernel
...
@@ -291,7 +275,6 @@ struct FmhaFwdSplitKVKernel
lse_acc_ptr
,
lse_acc_ptr
,
o_acc_ptr
,
o_acc_ptr
,
batch
,
batch
,
max_seqlen_q
,
seqlen_q
,
seqlen_q
,
seqlen_k
,
seqlen_k
,
hdim_q
,
hdim_q
,
...
@@ -313,17 +296,18 @@ struct FmhaFwdSplitKVKernel
...
@@ -313,17 +296,18 @@ struct FmhaFwdSplitKVKernel
nhead_stride_v
,
nhead_stride_v
,
nhead_stride_lse_acc
,
nhead_stride_lse_acc
,
nhead_stride_o_acc
,
nhead_stride_o_acc
,
batch_stride_lse_acc
,
batch_stride_o_acc
,
batch_stride_o_acc
,
split_stride_lse_acc
,
split_stride_lse_acc
,
split_stride_o_acc
},
// args for common karg
split_stride_o_acc
},
// args for common karg
{},
// placeholder for bias
{},
// placeholder for bias
{},
// placeholder for mask
{},
// placeholder for mask
{},
// placeholder for fp8_static_quant args
{},
// placeholder for fp8_static_quant args
{},
// placeholder for dropout
{},
// placeholder for paged-block table or cache_batch_idx
reinterpret_cast
<
const
int32_t
*>
(
seqlen_k_ptr
),
batch_stride_q
,
batch_stride_q
,
batch_stride_k
,
batch_stride_k
,
batch_stride_v
,
batch_stride_v
};
batch_stride_lse_acc
};
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
{
...
@@ -347,14 +331,15 @@ struct FmhaFwdSplitKVKernel
...
@@ -347,14 +331,15 @@ struct FmhaFwdSplitKVKernel
{
{
kargs
.
scale_p
=
scale_p
;
kargs
.
scale_p
=
scale_p
;
}
}
if
constexpr
(
kHasDropout
)
if
constexpr
(
kIsPagedKV
)
{
kargs
.
block_table_ptr
=
reinterpret_cast
<
const
int32_t
*>
(
block_table_ptr
);
kargs
.
batch_stride_block_table
=
batch_stride_block_table
;
kargs
.
page_block_size
=
page_block_size
;
}
else
{
{
kargs
.
init_dropout
(
p_drop
,
drop_seed_offset
);
kargs
.
cache_batch_idx
=
reinterpret_cast
<
const
int32_t
*>
(
cache_batch_idx
);
kargs
.
rand_val_ptr
=
rand_val_ptr
;
kargs
.
stride_randval
=
stride_randval
;
kargs
.
nhead_stride_randval
=
nhead_stride_randval
;
kargs
.
batch_stride_randval
=
batch_stride_randval
;
kargs
.
is_store_randval
=
s_randval
;
}
}
return
kargs
;
return
kargs
;
...
@@ -366,11 +351,9 @@ struct FmhaFwdSplitKVKernel
...
@@ -366,11 +351,9 @@ struct FmhaFwdSplitKVKernel
const
void
*
k_ptr
,
const
void
*
k_ptr
,
const
void
*
v_ptr
,
const
void
*
v_ptr
,
const
void
*
bias_ptr
,
const
void
*
bias_ptr
,
void
*
rand_val_ptr
,
void
*
lse_acc_ptr
,
void
*
lse_acc_ptr
,
void
*
o_acc_ptr
,
void
*
o_acc_ptr
,
ck_tile
::
index_t
batch
,
ck_tile
::
index_t
batch
,
ck_tile
::
index_t
max_seqlen_q
,
const
void
*
seqstart_q_ptr
,
const
void
*
seqstart_q_ptr
,
const
void
*
seqstart_k_ptr
,
const
void
*
seqstart_k_ptr
,
const
void
*
seqlen_k_ptr
,
const
void
*
seqlen_k_ptr
,
...
@@ -385,24 +368,22 @@ struct FmhaFwdSplitKVKernel
...
@@ -385,24 +368,22 @@ struct FmhaFwdSplitKVKernel
ck_tile
::
index_t
stride_k
,
ck_tile
::
index_t
stride_k
,
ck_tile
::
index_t
stride_v
,
ck_tile
::
index_t
stride_v
,
ck_tile
::
index_t
stride_bias
,
ck_tile
::
index_t
stride_bias
,
ck_tile
::
index_t
stride_randval
,
ck_tile
::
index_t
stride_o_acc
,
ck_tile
::
index_t
stride_o_acc
,
ck_tile
::
index_t
nhead_stride_q
,
ck_tile
::
index_t
nhead_stride_q
,
ck_tile
::
index_t
nhead_stride_k
,
ck_tile
::
index_t
nhead_stride_k
,
ck_tile
::
index_t
nhead_stride_v
,
ck_tile
::
index_t
nhead_stride_v
,
ck_tile
::
index_t
nhead_stride_bias
,
ck_tile
::
index_t
nhead_stride_bias
,
ck_tile
::
index_t
nhead_stride_randval
,
ck_tile
::
index_t
nhead_stride_lse_acc
,
ck_tile
::
index_t
nhead_stride_lse_acc
,
ck_tile
::
index_t
nhead_stride_o_acc
,
ck_tile
::
index_t
nhead_stride_o_acc
,
ck_tile
::
index_t
batch_stride_k
,
ck_tile
::
index_t
batch_stride_v
,
ck_tile
::
index_t
batch_stride_lse_acc
,
ck_tile
::
index_t
batch_stride_o_acc
,
ck_tile
::
index_t
batch_stride_o_acc
,
ck_tile
::
index_t
split_stride_lse_acc
,
ck_tile
::
index_t
split_stride_lse_acc
,
ck_tile
::
index_t
split_stride_o_acc
,
ck_tile
::
index_t
split_stride_o_acc
,
ck_tile
::
index_t
window_size_left
,
ck_tile
::
index_t
window_size_left
,
ck_tile
::
index_t
window_size_right
,
ck_tile
::
index_t
window_size_right
,
ck_tile
::
index_t
mask_type
,
ck_tile
::
index_t
mask_type
)
float
p_drop
,
bool
s_randval
,
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
)
{
{
Kargs
kargs
{{
q_ptr
,
Kargs
kargs
{{
q_ptr
,
k_ptr
,
k_ptr
,
...
@@ -410,9 +391,8 @@ struct FmhaFwdSplitKVKernel
...
@@ -410,9 +391,8 @@ struct FmhaFwdSplitKVKernel
lse_acc_ptr
,
lse_acc_ptr
,
o_acc_ptr
,
o_acc_ptr
,
batch
,
batch
,
max_seqlen_q
,
-
1
,
// seqlen_q will be updated by another pointer
-
1
,
// seqlen will be updated by another pointer
-
1
,
// seqlen_k will be updated by another pointer
-
1
,
//
hdim_q
,
hdim_q
,
hdim_v
,
hdim_v
,
num_head_q
,
num_head_q
,
...
@@ -432,16 +412,18 @@ struct FmhaFwdSplitKVKernel
...
@@ -432,16 +412,18 @@ struct FmhaFwdSplitKVKernel
nhead_stride_v
,
nhead_stride_v
,
nhead_stride_lse_acc
,
nhead_stride_lse_acc
,
nhead_stride_o_acc
,
nhead_stride_o_acc
,
batch_stride_lse_acc
,
batch_stride_o_acc
,
batch_stride_o_acc
,
split_stride_lse_acc
,
split_stride_lse_acc
,
split_stride_o_acc
},
// args for common karg
split_stride_o_acc
},
// args for common karg
{},
// placeholder for bias
{},
// placeholder for bias
{},
// placeholder for mask
{},
// placeholder for mask
{},
// placeholder for fp8_static_quant args
{},
// placeholder for fp8_static_quant args
{},
// placeholder for dropout
reinterpret_cast
<
const
int32_t
*>
(
seqstart_q_ptr
),
reinterpret_cast
<
const
int32_t
*>
(
seqstart_q_ptr
),
reinterpret_cast
<
const
int32_t
*>
(
seqstart_k_ptr
),
reinterpret_cast
<
const
int32_t
*>
(
seqstart_k_ptr
),
reinterpret_cast
<
const
int32_t
*>
(
seqlen_k_ptr
)};
reinterpret_cast
<
const
int32_t
*>
(
seqlen_k_ptr
),
batch_stride_k
,
batch_stride_v
};
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
{
...
@@ -464,14 +446,6 @@ struct FmhaFwdSplitKVKernel
...
@@ -464,14 +446,6 @@ struct FmhaFwdSplitKVKernel
{
{
kargs
.
scale_p
=
scale_p
;
kargs
.
scale_p
=
scale_p
;
}
}
if
constexpr
(
kHasDropout
)
{
kargs
.
init_dropout
(
p_drop
,
drop_seed_offset
);
kargs
.
rand_val_ptr
=
rand_val_ptr
;
kargs
.
stride_randval
=
stride_randval
;
kargs
.
nhead_stride_randval
=
nhead_stride_randval
;
kargs
.
is_store_randval
=
s_randval
;
}
return
kargs
;
return
kargs
;
}
}
...
@@ -508,7 +482,6 @@ struct FmhaFwdSplitKVKernel
...
@@ -508,7 +482,6 @@ struct FmhaFwdSplitKVKernel
long_index_t
batch_offset_k
=
0
;
long_index_t
batch_offset_k
=
0
;
long_index_t
batch_offset_v
=
0
;
long_index_t
batch_offset_v
=
0
;
long_index_t
batch_offset_bias
=
0
;
long_index_t
batch_offset_bias
=
0
;
long_index_t
batch_offset_randval
=
0
;
long_index_t
batch_offset_lse_acc
=
0
;
long_index_t
batch_offset_lse_acc
=
0
;
const
long_index_t
batch_offset_o_acc
=
const
long_index_t
batch_offset_o_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_o_acc
;
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_o_acc
;
...
@@ -534,14 +507,9 @@ struct FmhaFwdSplitKVKernel
...
@@ -534,14 +507,9 @@ struct FmhaFwdSplitKVKernel
{
{
batch_offset_bias
=
query_start
*
kargs
.
stride_bias
+
key_start
;
batch_offset_bias
=
query_start
*
kargs
.
stride_bias
+
key_start
;
}
}
if
constexpr
(
kHasDropout
)
{
batch_offset_randval
=
query_start
*
kargs
.
stride_randval
;
}
// get real # queries & # keys under group mode
// get real # queries & # keys under group mode
const
auto
adjusted_seqstart_q_ptr
=
kargs
.
seqstart_q_ptr
+
i_batch
;
kargs
.
seqlen_q
=
kargs
.
seqstart_q_ptr
[
i_batch
+
1
]
-
kargs
.
seqstart_q_ptr
[
i_batch
];
kargs
.
seqlen_q
=
adjusted_seqstart_q_ptr
[
1
]
-
adjusted_seqstart_q_ptr
[
0
];
// # of required blocks is different in each groups, terminate unnecessary blocks
// # of required blocks is different in each groups, terminate unnecessary blocks
// earlier
// earlier
...
@@ -556,24 +524,36 @@ struct FmhaFwdSplitKVKernel
...
@@ -556,24 +524,36 @@ struct FmhaFwdSplitKVKernel
}
}
else
else
{
{
const
auto
adjusted_seqstart_k_ptr
=
kargs
.
seqstart_k_ptr
+
i_batch
;
kargs
.
seqlen_k
=
kargs
.
seqstart_k_ptr
[
i_batch
+
1
]
-
kargs
.
seqstart_k_ptr
[
i_batch
];
kargs
.
seqlen_k
=
adjusted_seqstart_k_ptr
[
1
]
-
adjusted_seqstart_k_ptr
[
0
];
}
}
}
}
else
else
{
{
const
index_t
i_cache_batch
=
[
&
,
i_batch_
=
i_batch
]
{
if
constexpr
(
kIsPagedKV
)
{
return
i_batch_
;
}
else
{
return
(
kargs
.
cache_batch_idx
!=
nullptr
?
kargs
.
cache_batch_idx
[
i_batch_
]
:
i_batch_
);
}
}();
batch_offset_q
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_q
;
batch_offset_q
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_q
;
batch_offset_k
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_k
;
batch_offset_k
=
static_cast
<
long_index_t
>
(
i_
cache_
batch
)
*
kargs
.
batch_stride_k
;
batch_offset_v
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_v
;
batch_offset_v
=
static_cast
<
long_index_t
>
(
i_
cache_
batch
)
*
kargs
.
batch_stride_v
;
batch_offset_lse_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lse_acc
;
batch_offset_lse_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lse_acc
;
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
{
batch_offset_bias
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_bias
;
batch_offset_bias
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_bias
;
}
}
if
constexpr
(
kHasDropout
)
if
(
kargs
.
seqlen_k_ptr
!=
nullptr
)
{
{
batch_offset_randval
=
kargs
.
seqlen_k
=
kargs
.
seqlen_k_ptr
[
i_batch
];
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_randval
;
}
}
}
}
...
@@ -589,6 +569,7 @@ struct FmhaFwdSplitKVKernel
...
@@ -589,6 +569,7 @@ struct FmhaFwdSplitKVKernel
reinterpret_cast
<
const
VDataType
*>
(
kargs
.
v_ptr
)
+
reinterpret_cast
<
const
VDataType
*>
(
kargs
.
v_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
/
kargs
.
nhead_ratio_qk
)
*
kargs
.
nhead_stride_v
+
static_cast
<
long_index_t
>
(
i_nhead
/
kargs
.
nhead_ratio_qk
)
*
kargs
.
nhead_stride_v
+
batch_offset_v
;
batch_offset_v
;
OaccDataType
*
o_acc_ptr
=
reinterpret_cast
<
OaccDataType
*>
(
kargs
.
o_acc_ptr
)
+
OaccDataType
*
o_acc_ptr
=
reinterpret_cast
<
OaccDataType
*>
(
kargs
.
o_acc_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_o_acc
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_o_acc
+
batch_offset_o_acc
+
i_split
*
kargs
.
split_stride_o_acc
;
batch_offset_o_acc
+
i_split
*
kargs
.
split_stride_o_acc
;
...
@@ -616,10 +597,11 @@ struct FmhaFwdSplitKVKernel
...
@@ -616,10 +597,11 @@ struct FmhaFwdSplitKVKernel
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
}
}
}();
}();
const
auto
k_dram
=
[
&
]()
{
const
auto
make_k_dram
=
[
&
](
const
KDataType
*
data
,
index_t
height
)
{
const
auto
k_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
const
auto
k_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
k_ptr
,
data
,
// will update this pointer if using paged-kvcache
make_tuple
(
kargs
.
seqlen_k
,
kargs
.
hdim_q
),
make_tuple
(
height
,
kargs
.
hdim_q
),
make_tuple
(
kargs
.
stride_k
,
1
),
make_tuple
(
kargs
.
stride_k
,
1
),
number
<
FmhaPipeline
::
kAlignmentK
>
{},
number
<
FmhaPipeline
::
kAlignmentK
>
{},
number
<
1
>
{});
number
<
1
>
{});
...
@@ -628,13 +610,24 @@ struct FmhaFwdSplitKVKernel
...
@@ -628,13 +610,24 @@ struct FmhaFwdSplitKVKernel
k_dram_naive
,
k_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK0
>
{}),
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK0
>
{}),
sequence
<
kPadSeqLenK
,
kPadHeadDimQ
>
{});
sequence
<
kPadSeqLenK
,
kPadHeadDimQ
>
{});
};
const
auto
k_dram
=
[
&
]()
{
if
constexpr
(
kIsPagedKV
)
{
return
make_k_dram
(
nullptr
,
kargs
.
page_block_size
);
}
else
{
return
make_k_dram
(
k_ptr
,
kargs
.
seqlen_k
);
}
}();
}();
const
auto
v_dram
=
[
&
]()
{
const
auto
make_v_dram
=
[
&
](
const
VDataType
*
data
,
index_t
length
)
{
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
{
const
auto
v_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
const
auto
v_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
v_ptr
,
data
,
// will update this pointer if using paged-kvcache
make_tuple
(
kargs
.
seqlen_k
,
kargs
.
hdim_v
),
make_tuple
(
length
,
kargs
.
hdim_v
),
make_tuple
(
kargs
.
stride_v
,
1
),
make_tuple
(
kargs
.
stride_v
,
1
),
number
<
FmhaPipeline
::
kAlignmentV
>
{},
number
<
FmhaPipeline
::
kAlignmentV
>
{},
number
<
1
>
{});
number
<
1
>
{});
...
@@ -642,7 +635,7 @@ struct FmhaFwdSplitKVKernel
...
@@ -642,7 +635,7 @@ struct FmhaFwdSplitKVKernel
const
auto
v_dram_transposed
=
const
auto
v_dram_transposed
=
transform_tensor_view
(
v_dram_naive
,
transform_tensor_view
(
v_dram_naive
,
make_tuple
(
make_pass_through_transform
(
kargs
.
hdim_v
),
make_tuple
(
make_pass_through_transform
(
kargs
.
hdim_v
),
make_pass_through_transform
(
kargs
.
seqlen_k
)),
make_pass_through_transform
(
length
)),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
...
@@ -654,8 +647,8 @@ struct FmhaFwdSplitKVKernel
...
@@ -654,8 +647,8 @@ struct FmhaFwdSplitKVKernel
else
else
{
{
const
auto
v_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
const
auto
v_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
v_ptr
,
data
,
// will update this pointer if using paged-kvcache
make_tuple
(
kargs
.
hdim_v
,
kargs
.
seqlen_k
),
make_tuple
(
kargs
.
hdim_v
,
length
),
make_tuple
(
kargs
.
stride_v
,
1
),
make_tuple
(
kargs
.
stride_v
,
1
),
number
<
FmhaPipeline
::
kAlignmentV
>
{},
number
<
FmhaPipeline
::
kAlignmentV
>
{},
number
<
1
>
{});
number
<
1
>
{});
...
@@ -665,6 +658,76 @@ struct FmhaFwdSplitKVKernel
...
@@ -665,6 +658,76 @@ struct FmhaFwdSplitKVKernel
make_tuple
(
number
<
FmhaPipeline
::
kN1
>
{},
number
<
FmhaPipeline
::
kK1
>
{}),
make_tuple
(
number
<
FmhaPipeline
::
kN1
>
{},
number
<
FmhaPipeline
::
kK1
>
{}),
sequence
<
kPadHeadDimV
,
kPadSeqLenK
>
{});
sequence
<
kPadHeadDimV
,
kPadSeqLenK
>
{});
}
}
};
const
auto
v_dram
=
[
&
]()
{
if
constexpr
(
kIsPagedKV
)
{
return
make_v_dram
(
nullptr
,
kargs
.
page_block_size
);
}
else
{
return
make_v_dram
(
v_ptr
,
kargs
.
seqlen_k
);
}
}();
auto
k_page_block_navigator
=
[
&
,
i_batch_
=
i_batch
,
i_nhead_
=
i_nhead
]()
{
if
constexpr
(
kIsPagedKV
)
{
const
auto
*
block_indices
=
reinterpret_cast
<
const
int32_t
*>
(
kargs
.
block_table_ptr
)
+
i_batch_
*
kargs
.
batch_stride_block_table
;
const
index_t
num_blocks
=
integer_divide_ceil
(
kargs
.
seqlen_k
,
kargs
.
page_block_size
);
const
long_index_t
fixed_offset
=
static_cast
<
long_index_t
>
(
i_nhead_
/
kargs
.
nhead_ratio_qk
)
*
kargs
.
nhead_stride_k
;
return
make_page_block_navigator
<
const
KDataType
,
0
>
(
kargs
.
k_ptr
,
kargs
.
batch_stride_k
,
fixed_offset
,
block_indices
,
num_blocks
,
kargs
.
page_block_size
,
k_dram
,
make_k_dram
(
nullptr
,
kargs
.
seqlen_k
-
(
num_blocks
-
1
)
*
kargs
.
page_block_size
));
}
else
{
return
make_page_block_navigator
(
k_dram
);
}
}();
auto
v_page_block_navigator
=
[
&
,
i_batch_
=
i_batch
,
i_nhead_
=
i_nhead
]()
{
if
constexpr
(
kIsPagedKV
)
{
const
auto
*
block_indices
=
reinterpret_cast
<
const
int32_t
*>
(
kargs
.
block_table_ptr
)
+
i_batch_
*
kargs
.
batch_stride_block_table
;
const
index_t
num_blocks
=
integer_divide_ceil
(
kargs
.
seqlen_k
,
kargs
.
page_block_size
);
const
long_index_t
fixed_offset
=
static_cast
<
long_index_t
>
(
i_nhead_
/
kargs
.
nhead_ratio_qk
)
*
kargs
.
nhead_stride_v
;
return
make_page_block_navigator
<
const
VDataType
,
1
>
(
kargs
.
v_ptr
,
kargs
.
batch_stride_v
,
fixed_offset
,
block_indices
,
num_blocks
,
kargs
.
page_block_size
,
v_dram
,
make_v_dram
(
nullptr
,
kargs
.
seqlen_k
-
(
num_blocks
-
1
)
*
kargs
.
page_block_size
));
}
else
{
return
make_page_block_navigator
(
v_dram
);
}
}();
}();
auto
q_dram_window
=
make_tile_window
(
auto
q_dram_window
=
make_tile_window
(
...
@@ -678,13 +741,11 @@ struct FmhaFwdSplitKVKernel
...
@@ -678,13 +741,11 @@ struct FmhaFwdSplitKVKernel
}(),
}(),
{
i_m0
,
0
});
{
i_m0
,
0
});
auto
k_dram_window
=
make_tile_window
(
auto
k_dram_window_lengths
=
k_dram
,
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK0
>
{}),
{
0
,
0
});
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK0
>
{});
auto
v_dram_window_lengths
=
make_tuple
(
number
<
FmhaPipeline
::
kN1
>
{},
number
<
FmhaPipeline
::
kK1
>
{});
auto
v_dram_window
=
make_tile_window
(
v_dram
,
make_tuple
(
number
<
FmhaPipeline
::
kN1
>
{},
number
<
FmhaPipeline
::
kK1
>
{}),
{
i_n1
,
0
});
/// FIXME: Before C++20, capturing structured binding variables are not supported. Remove
/// FIXME: Before C++20, capturing structured binding variables are not supported. Remove
/// following copy capture of the 'i_nhead' if in C++20
/// following copy capture of the 'i_nhead' if in C++20
const
auto
bias_dram_window
=
[
&
,
i_nhead_
=
i_nhead
]()
{
const
auto
bias_dram_window
=
[
&
,
i_nhead_
=
i_nhead
]()
{
...
@@ -741,62 +802,6 @@ struct FmhaFwdSplitKVKernel
...
@@ -741,62 +802,6 @@ struct FmhaFwdSplitKVKernel
return
make_tile_window
(
lse_acc_dram
,
lse_acc_dram_window_lengths
,
{
i_m0
});
return
make_tile_window
(
lse_acc_dram
,
lse_acc_dram_window_lengths
,
{
i_m0
});
}();
}();
// dropout
float
rp_undrop
=
1
;
uint8_t
p_undrop_in_uint8_t
=
std
::
numeric_limits
<
uint8_t
>::
max
();
uint64_t
drop_seed
=
0
;
uint64_t
drop_offset
=
0
;
bool
is_store_randval
=
false
;
if
constexpr
(
kHasDropout
)
{
rp_undrop
=
kargs
.
rp_undrop
;
p_undrop_in_uint8_t
=
kargs
.
p_undrop_in_uint8_t
;
drop_seed
=
kargs
.
drop_seed
;
drop_offset
=
kargs
.
drop_offset
;
is_store_randval
=
kargs
.
is_store_randval
;
}
BlockDropout
dropout
(
i_batch
,
i_nhead
,
kargs
.
num_head_q
,
drop_seed
,
drop_offset
,
rp_undrop
,
p_undrop_in_uint8_t
,
is_store_randval
);
auto
randval_dram_window
=
[
&
,
i_nhead_
=
i_nhead
]()
{
constexpr
auto
randval_dram_window_lengths
=
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kN0
>
{});
if
constexpr
(
kHasDropout
)
{
RandValOutputDataType
*
rand_val_ptr
=
reinterpret_cast
<
RandValOutputDataType
*>
(
kargs
.
rand_val_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead_
)
*
kargs
.
nhead_stride_randval
+
batch_offset_randval
;
const
auto
randval_dram
=
[
&
]()
{
const
auto
randval_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
rand_val_ptr
,
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
seqlen_k
),
make_tuple
(
kargs
.
stride_randval
,
1
),
number
<
1
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
randval_dram_naive
,
randval_dram_window_lengths
,
sequence
<
kPadSeqLenQ
,
kPadSeqLenK
>
{});
}();
return
make_tile_window
(
randval_dram
,
randval_dram_window_lengths
,
{
i_m0
,
0
});
}
else
{
return
make_null_tile_window
(
randval_dram_window_lengths
);
}
}();
FmhaMask
mask
=
[
&
]()
{
FmhaMask
mask
=
[
&
]()
{
if
constexpr
(
kHasMask
)
if
constexpr
(
kHasMask
)
return
ck_tile
::
make_generic_attention_mask_from_lr_window
<
FmhaMask
>
(
return
ck_tile
::
make_generic_attention_mask_from_lr_window
<
FmhaMask
>
(
...
@@ -823,16 +828,16 @@ struct FmhaFwdSplitKVKernel
...
@@ -823,16 +828,16 @@ struct FmhaFwdSplitKVKernel
#endif
#endif
if
constexpr
(
kHasMask
)
if
constexpr
(
kHasMask
)
{
{
return
make_alibi_from_lr_mask
<
SaccDataType
,
true
>
(
slope
,
return
make_alibi_from_lr_mask
<
SaccDataType
,
true
,
32
>
(
slope
,
kargs
.
window_size_left
,
kargs
.
window_size_left
,
kargs
.
window_size_right
,
kargs
.
window_size_right
,
kargs
.
seqlen_q
,
kargs
.
seqlen_q
,
kargs
.
seqlen_k
,
kargs
.
seqlen_k
,
kargs
.
mask_type
);
kargs
.
mask_type
);
}
}
else
else
{
{
return
Alibi
<
SaccDataType
,
true
>
{
return
Alibi
<
SaccDataType
,
true
,
32
>
{
slope
,
kargs
.
seqlen_q
,
kargs
.
seqlen_k
,
AlibiMode
::
FROM_BOTTOM_RIGHT
};
slope
,
kargs
.
seqlen_q
,
kargs
.
seqlen_k
,
AlibiMode
::
FROM_BOTTOM_RIGHT
};
}
}
}
}
...
@@ -847,13 +852,14 @@ struct FmhaFwdSplitKVKernel
...
@@ -847,13 +852,14 @@ struct FmhaFwdSplitKVKernel
{
{
return
FmhaPipeline
{}(
q_dram_window
,
return
FmhaPipeline
{}(
q_dram_window
,
identity
{},
// q_element_func
identity
{},
// q_element_func
k_dram_window
,
k_dram_window_lengths
,
k_page_block_navigator
,
identity
{},
// k_element_func
identity
{},
// k_element_func
v_dram_window
,
v_dram_window_lengths
,
v_page_block_navigator
,
identity
{},
// v_element_func
identity
{},
// v_element_func
bias_dram_window
,
bias_dram_window
,
identity
{},
// bias_element_func
identity
{},
// bias_element_func
randval_dram_window
,
lse_acc_dram_window
,
lse_acc_dram_window
,
identity
{},
// lse_element_func
identity
{},
// lse_element_func
identity
{},
// s_acc_element_func
identity
{},
// s_acc_element_func
...
@@ -864,24 +870,23 @@ struct FmhaFwdSplitKVKernel
...
@@ -864,24 +870,23 @@ struct FmhaFwdSplitKVKernel
mask
,
mask
,
position_encoding
,
position_encoding
,
kargs
.
scale_s
,
kargs
.
scale_s
,
smem_ptr
,
smem_ptr
);
dropout
);
}
}
else
else
{
{
return
FmhaPipeline
{}(
q_dram_window
,
return
FmhaPipeline
{}(
q_dram_window
,
k_dram_window
,
k_dram_window_lengths
,
v_dram_window
,
k_page_block_navigator
,
v_dram_window_lengths
,
v_page_block_navigator
,
bias_dram_window
,
bias_dram_window
,
randval_dram_window
,
lse_acc_dram_window
,
lse_acc_dram_window
,
kargs
.
num_splits
,
kargs
.
num_splits
,
i_split_
,
i_split_
,
mask
,
mask
,
position_encoding
,
position_encoding
,
kargs
.
scale_s
,
kargs
.
scale_s
,
smem_ptr
,
smem_ptr
);
dropout
);
}
}
}();
}();
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp
0 → 100644
View file @
171ed358
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp"
namespace
ck_tile
{
template
<
typename
Problem_
,
typename
Policy_
=
BlockFmhaFwdAppendKVPipelineDefaultPolicy
>
struct
BlockFmhaFwdAppendKVPipeline
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
using
QDataType
=
typename
Problem
::
QDataType
;
using
KDataType
=
typename
Problem
::
KDataType
;
using
VDataType
=
typename
Problem
::
VDataType
;
using
VLayout
=
typename
Problem
::
VLayout
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kM0
=
Problem
::
kM0
;
static
constexpr
index_t
kN0
=
Problem
::
kN0
;
static
constexpr
index_t
kK0
=
Problem
::
kK0
;
static
constexpr
index_t
kN1
=
Problem
::
kN1
;
static
constexpr
auto
RotaryEnum
=
Problem
::
RotaryEnum
;
static
constexpr
bool
kIsPagedKV
=
Problem
::
kIsPagedKV
;
static
constexpr
bool
kPadSeqLenQ
=
Problem
::
kPadSeqLenQ
;
static
constexpr
bool
kPadSeqLenK
=
Problem
::
kPadSeqLenK
;
static
constexpr
bool
kPadHeadDimQ
=
Problem
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimV
=
Problem
::
kPadHeadDimV
;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
static
constexpr
index_t
kAlignmentQ
=
kPadHeadDimQ
?
1
:
Policy
::
template
GetAlignmentQ
<
Problem
>();
static
constexpr
index_t
kAlignmentK
=
kPadHeadDimQ
?
1
:
Policy
::
template
GetAlignmentK
<
Problem
>();
static
constexpr
index_t
kAlignmentV
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
return
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentV
<
Problem
>();
else
return
kPadSeqLenK
?
1
:
Policy
::
template
GetAlignmentV
<
Problem
>();
}();
static
constexpr
index_t
kBlockPerCu
=
[]()
{
if
constexpr
(
Problem
::
kBlockPerCu
!=
-
1
)
return
Problem
::
kBlockPerCu
;
else
{
if
constexpr
(
kK0
<=
32
)
{
return
2
;
}
else
if
constexpr
(
kK0
<=
64
)
{
return
3
;
}
else
if
constexpr
(
kK0
<=
128
)
{
return
2
;
}
else
if
constexpr
(
kK0
<=
256
)
{
return
1
;
}
}
}();
template
<
typename
QDramBlockWindow
,
typename
KDramBlockWindow
,
typename
KPageBlockNavigator
,
typename
KnewDramBlockWindow
,
typename
VDramBlockWindow
,
typename
VPageBlockNavigator
,
typename
VnewDramBlockWindow
,
typename
QElementFunction
,
typename
KnewElementFunction
,
typename
VnewElementFunction
,
typename
QRotaryCosDramBlockWindow
,
typename
QRotarySinDramBlockWindow
,
typename
KnewRotaryCosDramBlockWindow
,
typename
KnewRotarySinDramBlockWindow
>
CK_TILE_HOST_DEVICE
auto
operator
()(
QDramBlockWindow
&
q_dram_block_window
,
// M0*K0 tile
const
QElementFunction
&
q_element_func
,
KDramBlockWindow
&
k_dram_block_window
,
// N0*K0 tile
index_t
i_page_block_k
,
const
KPageBlockNavigator
&
k_page_block_navigator
,
const
KnewDramBlockWindow
&
knew_dram_block_window
,
// N0*K0 tile
const
KnewElementFunction
&
knew_element_func
,
VDramBlockWindow
&
v_dram_block_window
,
// N1*N0 tile
index_t
i_page_block_v
,
const
VPageBlockNavigator
&
v_page_block_navigator
,
const
VnewDramBlockWindow
&
vnew_dram_block_window
,
// N1*N0 tile
const
VnewElementFunction
&
vnew_element_func
,
const
QRotaryCosDramBlockWindow
q_rotary_cos_dram_block_window
,
const
QRotarySinDramBlockWindow
q_rotary_sin_dram_block_window
,
const
KnewRotaryCosDramBlockWindow
knew_rotary_cos_dram_block_window
,
const
KnewRotarySinDramBlockWindow
knew_rotary_sin_dram_block_window
,
index_t
rotary_dim
,
bool
skip_rotate_q
,
bool
skip_rotate_append_kv
)
const
{
if
(
!
skip_rotate_append_kv
)
{
// append Knew to K
auto
knew_window
=
make_tile_window
(
knew_dram_block_window
,
Policy
::
template
MakeKnewDramTileDistribution
<
Problem
>());
auto
knew_tile
=
[
&
]()
{
auto
knew
=
load_tile
(
knew_window
);
return
tile_elementwise_in
(
knew_element_func
,
knew
);
}();
// optionally apply rotary embedding to Knew
if
constexpr
(
RotaryEnum
!=
RotaryEmbeddingEnum
::
NONE
)
{
auto
rotary_cos_window
=
make_tile_window
(
knew_rotary_cos_dram_block_window
,
Policy
::
template
MakeRotaryCosSinTileDistribution
<
Problem
,
/*IsRotaryCosSinForQ=*/
false
>());
auto
rotary_sin_window
=
make_tile_window
(
knew_rotary_sin_dram_block_window
,
Policy
::
template
MakeRotaryCosSinTileDistribution
<
Problem
,
/*IsRotaryCosSinForQ=*/
false
>());
// We assume that each thread owns contiguous elements on head dimention. And we
// will use the distribution to enable/disable threads in order to override partial
// knew_tile content
auto
[
thread_start
,
thread_end
]
=
Policy
::
template
GetKnewThreadRangeAlongK
<
Problem
>();
ignore
=
thread_start
;
BlockRotaryEmbedding
<
RotaryEnum
>::
apply
(
knew_tile
,
knew_window
,
rotary_cos_window
,
rotary_sin_window
,
rotary_dim
,
thread_end
);
}
store_tile
(
k_dram_block_window
,
knew_tile
);
// write tile to another block if nesscary
if
constexpr
(
kIsPagedKV
)
{
if
(
k_page_block_navigator
.
is_cross_block
(
i_page_block_k
,
k_dram_block_window
))
{
k_page_block_navigator
.
move_to_block
(
i_page_block_k
,
k_dram_block_window
,
i_page_block_k
+
1
);
store_tile
(
k_dram_block_window
,
knew_tile
);
}
}
// append Vnew to V
auto
vnew_window
=
make_tile_window
(
vnew_dram_block_window
,
Policy
::
template
MakeVnewDramTileDistribution
<
Problem
>());
auto
vnew_tile
=
[
&
]()
{
auto
vnew
=
load_tile
(
vnew_window
);
return
tile_elementwise_in
(
vnew_element_func
,
vnew
);
}();
store_tile
(
v_dram_block_window
,
vnew_tile
);
// write tile to another block if nesscary
if
constexpr
(
kIsPagedKV
)
{
if
(
v_page_block_navigator
.
is_cross_block
(
i_page_block_v
,
v_dram_block_window
))
{
v_page_block_navigator
.
move_to_block
(
i_page_block_v
,
v_dram_block_window
,
i_page_block_v
+
1
);
store_tile
(
v_dram_block_window
,
vnew_tile
);
}
}
}
if
(
!
skip_rotate_q
)
{
// optionally apply rotary embedding to Q
if
constexpr
(
RotaryEnum
!=
RotaryEmbeddingEnum
::
NONE
)
{
auto
q_window
=
make_tile_window
(
q_dram_block_window
,
Policy
::
template
MakeQDramTileDistribution
<
Problem
>());
auto
q_tile
=
[
&
]()
{
auto
q
=
load_tile
(
q_window
);
return
tile_elementwise_in
(
q_element_func
,
q
);
}();
auto
rotary_cos_window
=
make_tile_window
(
q_rotary_cos_dram_block_window
,
Policy
::
template
MakeRotaryCosSinTileDistribution
<
Problem
,
/*IsRotaryCosSinForQ=*/
true
>());
auto
rotary_sin_window
=
make_tile_window
(
q_rotary_sin_dram_block_window
,
Policy
::
template
MakeRotaryCosSinTileDistribution
<
Problem
,
/*IsRotaryCosSinForQ=*/
true
>());
// We assume that each thread owns contiguous elements on head dimention. And we
// will use the distribution to enable/disable threads in order to override partial
// q_tile content
auto
[
thread_start
,
thread_end
]
=
Policy
::
template
GetQThreadRangeAlongK
<
Problem
>();
ignore
=
thread_start
;
BlockRotaryEmbedding
<
RotaryEnum
>::
apply
(
q_tile
,
q_window
,
rotary_cos_window
,
rotary_sin_window
,
rotary_dim
,
thread_end
);
store_tile
(
q_dram_block_window
,
q_tile
);
}
}
}
template
<
typename
QDramBlockWindow
,
typename
KDramBlockWindow
,
typename
KPageBlockNavigator
,
typename
KnewDramBlockWindow
,
typename
VDramBlockWindow
,
typename
VPageBlockNavigator
,
typename
VnewDramBlockWindow
,
typename
QRotaryCosDramBlockWindow
,
typename
QRotarySinDramBlockWindow
,
typename
KnewRotaryCosDramBlockWindow
,
typename
KnewRotarySinDramBlockWindow
>
CK_TILE_HOST_DEVICE
auto
operator
()(
QDramBlockWindow
&
q_dram_block_window
,
KDramBlockWindow
&
k_dram_block_window
,
index_t
i_page_block_k
,
const
KPageBlockNavigator
&
k_page_block_navigator
,
const
KnewDramBlockWindow
&
knew_dram_block_window
,
VDramBlockWindow
&
v_dram_block_window
,
index_t
i_page_block_v
,
const
VPageBlockNavigator
&
v_page_block_navigator
,
const
VnewDramBlockWindow
&
vnew_dram_block_window
,
const
QRotaryCosDramBlockWindow
&
q_rotary_cos_dram_block_window
,
const
QRotarySinDramBlockWindow
&
q_rotary_sin_dram_block_window
,
const
KnewRotaryCosDramBlockWindow
&
knew_rotary_cos_dram_block_window
,
const
KnewRotarySinDramBlockWindow
&
knew_rotary_sin_dram_block_window
,
index_t
rotary_dim
,
bool
skip_rotate_q
,
bool
skip_rotate_append_kv
)
const
{
return
operator
()(
q_dram_block_window
,
identity
{},
k_dram_block_window
,
i_page_block_k
,
k_page_block_navigator
,
knew_dram_block_window
,
identity
{},
v_dram_block_window
,
i_page_block_v
,
v_page_block_navigator
,
vnew_dram_block_window
,
identity
{},
q_rotary_cos_dram_block_window
,
q_rotary_sin_dram_block_window
,
knew_rotary_cos_dram_block_window
,
knew_rotary_sin_dram_block_window
,
rotary_dim
,
skip_rotate_q
,
skip_rotate_append_kv
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp
0 → 100644
View file @
171ed358
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
// This pipeline is qkv all located in LDS
struct
BlockFmhaFwdAppendKVPipelineDefaultPolicy
{
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentQ
()
{
using
QDataType
=
remove_cvref_t
<
typename
Problem
::
QDataType
>
;
return
16
/
sizeof
(
QDataType
);
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentK
()
{
using
KDataType
=
remove_cvref_t
<
typename
Problem
::
KDataType
>
;
return
16
/
sizeof
(
KDataType
);
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentV
()
{
using
VLayout
=
remove_cvref_t
<
typename
Problem
::
VLayout
>
;
using
VDataType
=
remove_cvref_t
<
typename
Problem
::
VDataType
>
;
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
kN1
;
constexpr
index_t
total_pixels
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
// TODO: not correct!
if
constexpr
(
total_pixels
>
4
)
return
4
;
else
return
2
;
}
else
{
return
16
/
sizeof
(
VDataType
);
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetQNumElemsPerRead
()
{
using
DataType
=
typename
Problem
::
QDataType
;
if
constexpr
(
Problem
::
RotaryEnum
==
RotaryEmbeddingEnum
::
HALF_ROTATED
)
{
/// NOTICE: we might need to lower down this to support smaller rotary_dim
return
16
/
sizeof
(
DataType
);
}
else
{
return
16
/
sizeof
(
DataType
);
}
}
template
<
typename
Problem
>
CK_TILE_DEVICE
static
auto
GetQThreadRangeAlongK
()
{
static_assert
(
Problem
::
RotaryEnum
!=
RotaryEmbeddingEnum
::
NONE
);
if
constexpr
(
Problem
::
RotaryEnum
==
RotaryEmbeddingEnum
::
INTERLEAVED
)
{
constexpr
index_t
KPerThread
=
GetQNumElemsPerRead
<
Problem
>
();
static_assert
(
Problem
::
kK0
%
KPerThread
==
0
);
constexpr
index_t
KThreadPerBlock
=
Problem
::
kK0
/
KPerThread
;
index_t
start_pos
=
(
get_thread_id
()
%
KThreadPerBlock
)
*
KPerThread
;
return
make_tuple
(
start_pos
,
start_pos
+
KPerThread
);
}
else
{
constexpr
index_t
KPerThread
=
GetQNumElemsPerRead
<
Problem
>
();
static_assert
(
Problem
::
kK0
%
KPerThread
==
0
);
constexpr
index_t
KThreadPerBlock
=
Problem
::
kK0
/
KPerThread
;
index_t
start_pos
=
(
get_thread_id
()
%
KThreadPerBlock
)
*
KPerThread
;
return
make_tuple
(
start_pos
,
start_pos
+
KPerThread
);
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeQDramTileDistribution
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
kK0
;
constexpr
index_t
KPerThread
=
GetQNumElemsPerRead
<
Problem
>
();
constexpr
index_t
KThreadPerBlock
=
kKPerBlock
/
KPerThread
;
constexpr
index_t
MThreadPerWarp
=
get_warp_size
()
/
KThreadPerBlock
;
constexpr
index_t
NumWarps
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
MPerThread
=
kMPerBlock
/
(
NumWarps
*
MThreadPerWarp
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
MPerThread
,
NumWarps
,
MThreadPerWarp
>
,
sequence
<
KThreadPerBlock
,
KPerThread
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetKnewNumElemsPerRead
()
{
using
DataType
=
typename
Problem
::
KDataType
;
if
constexpr
(
Problem
::
RotaryEnum
==
RotaryEmbeddingEnum
::
HALF_ROTATED
)
{
/// NOTICE: we might need to lower down this to support smaller rotary_dim
return
16
/
sizeof
(
DataType
);
}
else
{
return
16
/
sizeof
(
DataType
);
}
}
template
<
typename
Problem
>
CK_TILE_DEVICE
static
auto
GetKnewThreadRangeAlongK
()
{
static_assert
(
Problem
::
RotaryEnum
!=
RotaryEmbeddingEnum
::
NONE
);
if
constexpr
(
Problem
::
RotaryEnum
==
RotaryEmbeddingEnum
::
INTERLEAVED
)
{
constexpr
index_t
KPerThread
=
GetKnewNumElemsPerRead
<
Problem
>
();
constexpr
index_t
KThreadPerBlock
=
Problem
::
kK0
/
KPerThread
;
index_t
start_pos
=
(
get_thread_id
()
%
KThreadPerBlock
)
*
KPerThread
;
return
make_tuple
(
start_pos
,
start_pos
+
KPerThread
);
}
else
{
constexpr
index_t
KPerThread
=
GetKnewNumElemsPerRead
<
Problem
>
();
constexpr
index_t
KThreadPerBlock
=
Problem
::
kK0
/
KPerThread
;
index_t
start_pos
=
(
get_thread_id
()
%
KThreadPerBlock
)
*
KPerThread
;
return
make_tuple
(
start_pos
,
start_pos
+
KPerThread
);
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeKnewDramTileDistribution
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
kK0
;
constexpr
index_t
KPerThread
=
GetKnewNumElemsPerRead
<
Problem
>
();
constexpr
index_t
KThreadPerBlock
=
kKPerBlock
/
KPerThread
;
constexpr
index_t
NThreadPerWarp
=
get_warp_size
()
/
KThreadPerBlock
;
constexpr
index_t
NumWarps
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
NPerThread
=
kNPerBlock
/
(
NumWarps
*
NThreadPerWarp
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
NPerThread
,
NumWarps
,
NThreadPerWarp
>
,
sequence
<
KThreadPerBlock
,
KPerThread
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPackV
()
{
// TODO: this is for 3d layout
using
VDataType
=
remove_cvref_t
<
typename
Problem
::
VDataType
>
;
return
16
/
sizeof
(
VDataType
);
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeVnewDramTileDistribution
()
{
using
VLayout
=
remove_cvref_t
<
typename
Problem
::
VLayout
>
;
using
VDataType
=
remove_cvref_t
<
typename
Problem
::
VDataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
kN1
;
constexpr
index_t
kKPerBlock
=
Problem
::
kN0
;
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
constexpr
index_t
NPerThread
=
16
/
sizeof
(
VDataType
);
constexpr
index_t
NThreadPerBlock
=
kNPerBlock
/
NPerThread
;
constexpr
index_t
KThreadPerWarp
=
get_warp_size
()
/
NThreadPerBlock
;
constexpr
index_t
NumWarps
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
KPerThread
=
kKPerBlock
/
(
NumWarps
*
KThreadPerWarp
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
NThreadPerBlock
,
NPerThread
>
,
sequence
<
KPerThread
,
NumWarps
,
KThreadPerWarp
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
0
,
2
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
0
>>
{});
}
else
{
constexpr
index_t
KPerThread
=
16
/
sizeof
(
VDataType
);
constexpr
index_t
KThreadPerBlock
=
kKPerBlock
/
KPerThread
;
constexpr
index_t
NThreadPerWarp
=
get_warp_size
()
/
KThreadPerBlock
;
constexpr
index_t
NumWarps
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
NPerThread
=
kNPerBlock
/
(
NumWarps
*
NThreadPerWarp
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
NPerThread
,
NumWarps
,
NThreadPerWarp
>
,
sequence
<
KThreadPerBlock
,
KPerThread
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
}
template
<
typename
Problem
,
bool
IsRotaryCosSinForQ
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetRotaryCosSinTileSize
()
{
constexpr
index_t
height
=
(
IsRotaryCosSinForQ
?
Problem
::
kM0
:
Problem
::
kN0
);
if
constexpr
(
Problem
::
RotaryEnum
==
RotaryEmbeddingEnum
::
HALF_ROTATED
)
{
return
make_tuple
(
number
<
height
>
{},
number
<
Problem
::
kK0
>
{});
}
else
{
return
make_tuple
(
number
<
height
>
{},
number
<
Problem
::
kK0
/
2
>
{});
}
}
template
<
typename
Problem
,
bool
IsRotaryCosSinForQ
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeRotaryCosSinTileDistribution
()
{
using
DataType
=
std
::
conditional_t
<
IsRotaryCosSinForQ
,
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
>
;
constexpr
auto
TileSize
=
GetRotaryCosSinTileSize
<
Problem
,
IsRotaryCosSinForQ
>
();
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
TileSize
[
number
<
0
>
{}];
constexpr
index_t
kKPerBlock
=
TileSize
[
number
<
1
>
{}];
constexpr
index_t
KPerThread
=
[]()
{
if
constexpr
(
Problem
::
RotaryEnum
==
RotaryEmbeddingEnum
::
HALF_ROTATED
)
{
/// NOTICE: we might need to lower down this to support smaller rotary_dim
return
16
/
sizeof
(
DataType
);
}
else
{
return
8
/
sizeof
(
DataType
);
}
}();
constexpr
index_t
KThreadPerBlock
=
kKPerBlock
/
KPerThread
;
constexpr
index_t
NThreadPerWarp
=
get_warp_size
()
/
KThreadPerBlock
;
constexpr
index_t
NumWarps
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
NPerThread
=
kNPerBlock
/
(
NumWarps
*
NThreadPerWarp
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
NPerThread
,
NumWarps
,
NThreadPerWarp
>
,
sequence
<
KThreadPerBlock
,
KPerThread
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp
View file @
171ed358
...
@@ -6,7 +6,6 @@
...
@@ -6,7 +6,6 @@
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace
ck_tile
{
namespace
ck_tile
{
...
@@ -15,19 +14,18 @@ namespace ck_tile {
...
@@ -15,19 +14,18 @@ namespace ck_tile {
template
<
typename
Problem_
,
typename
Policy_
=
BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy
>
template
<
typename
Problem_
,
typename
Policy_
=
BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy
>
struct
BlockFmhaFwdSplitKVPipelineQRKSVS
struct
BlockFmhaFwdSplitKVPipelineQRKSVS
{
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
using
QDataType
=
remove_cvref_t
<
typename
Problem
::
QDataType
>
;
using
QDataType
=
remove_cvref_t
<
typename
Problem
::
QDataType
>
;
using
KDataType
=
remove_cvref_t
<
typename
Problem
::
KDataType
>
;
using
KDataType
=
remove_cvref_t
<
typename
Problem
::
KDataType
>
;
using
VDataType
=
remove_cvref_t
<
typename
Problem
::
VDataType
>
;
using
VDataType
=
remove_cvref_t
<
typename
Problem
::
VDataType
>
;
using
SaccDataType
=
remove_cvref_t
<
typename
Problem
::
SaccDataType
>
;
using
SaccDataType
=
remove_cvref_t
<
typename
Problem
::
SaccDataType
>
;
using
SMPLComputeDataType
=
remove_cvref_t
<
typename
Problem
::
SMPLComputeDataType
>
;
using
SMPLComputeDataType
=
remove_cvref_t
<
typename
Problem
::
SMPLComputeDataType
>
;
using
BiasDataType
=
remove_cvref_t
<
typename
Problem
::
BiasDataType
>
;
using
BiasDataType
=
remove_cvref_t
<
typename
Problem
::
BiasDataType
>
;
using
RandValOutputDataType
=
remove_cvref_t
<
typename
Problem
::
RandValOutputDataType
>
;
using
LSEDataType
=
remove_cvref_t
<
typename
Problem
::
LSEDataType
>
;
using
LSEDataType
=
remove_cvref_t
<
typename
Problem
::
LSEDataType
>
;
using
PDataType
=
remove_cvref_t
<
typename
Problem
::
PDataType
>
;
using
PDataType
=
remove_cvref_t
<
typename
Problem
::
PDataType
>
;
using
OaccDataType
=
remove_cvref_t
<
typename
Problem
::
OaccDataType
>
;
using
OaccDataType
=
remove_cvref_t
<
typename
Problem
::
OaccDataType
>
;
using
FmhaMask
=
remove_cvref_t
<
typename
Problem
::
FmhaMask
>
;
using
FmhaMask
=
remove_cvref_t
<
typename
Problem
::
FmhaMask
>
;
using
BlockFmhaShape
=
remove_cvref_t
<
typename
Problem
::
BlockFmhaShape
>
;
using
BlockFmhaShape
=
remove_cvref_t
<
typename
Problem
::
BlockFmhaShape
>
;
using
VLayout
=
remove_cvref_t
<
typename
BlockFmhaShape
::
VLayout
>
;
using
VLayout
=
remove_cvref_t
<
typename
BlockFmhaShape
::
VLayout
>
;
...
@@ -49,8 +47,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
...
@@ -49,8 +47,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
static
constexpr
bool
kPadHeadDimQ
=
Problem
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimQ
=
Problem
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimV
=
Problem
::
kPadHeadDimV
;
static
constexpr
bool
kPadHeadDimV
=
Problem
::
kPadHeadDimV
;
static
constexpr
auto
BiasEnum
=
Problem
::
BiasEnum
;
static
constexpr
auto
BiasEnum
=
Problem
::
BiasEnum
;
static
constexpr
bool
kStoreLSE
=
true
;
// always store LSE (acc)
static
constexpr
bool
kStoreLSE
=
true
;
// always store LSE (acc)
static
constexpr
bool
k
HasDropout
=
false
;
// ignore this flag
static
constexpr
bool
k
IsPagedKV
=
Problem
::
kIsPagedKV
;
static
constexpr
bool
kHasUnevenSplits
=
Problem
::
kHasUnevenSplits
;
static
constexpr
bool
kHasUnevenSplits
=
Problem
::
kHasUnevenSplits
;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
...
@@ -106,10 +104,11 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
...
@@ -106,10 +104,11 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
}
}
template
<
typename
QDramBlockWindowTmp
,
template
<
typename
QDramBlockWindowTmp
,
typename
KDramBlockWindowTmp
,
typename
KDramBlockWindowLengths
,
typename
VDramBlockWindowTmp
,
typename
KPageBlockNavigator
,
typename
VDramBlockWindowLengths
,
typename
VPageBlockNavigator
,
typename
BiasDramBlockWindowTmp
,
typename
BiasDramBlockWindowTmp
,
typename
RandValDramBlockWindowTmp
,
typename
LSEaccDramBlockWindowTmp
,
typename
LSEaccDramBlockWindowTmp
,
typename
QElementFunction
,
typename
QElementFunction
,
typename
KElementFunction
,
typename
KElementFunction
,
...
@@ -123,13 +122,14 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
...
@@ -123,13 +122,14 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
CK_TILE_HOST_DEVICE
auto
CK_TILE_HOST_DEVICE
auto
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
// M0*K0 tile
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
// M0*K0 tile
const
QElementFunction
&
q_element_func
,
const
QElementFunction
&
q_element_func
,
const
KDramBlockWindowTmp
&
k_dram_block_window_tmp
,
// N0*K0 tile
const
KDramBlockWindowLengths
&
k_dram_block_window_lengths
,
// N0*K0 tile
const
KPageBlockNavigator
&
k_page_block_navigator
,
const
KElementFunction
&
k_element_func
,
const
KElementFunction
&
k_element_func
,
const
VDramBlockWindowTmp
&
v_dram_block_window_tmp
,
// N1*K1 tile
const
VDramBlockWindowLengths
&
v_dram_block_window_lengths
,
// N1*K1 tile
const
VPageBlockNavigator
&
v_page_block_navigator
,
const
VElementFunction
&
v_element_func
,
const
VElementFunction
&
v_element_func
,
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
// M0*N0 tile
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
// M0*N0 tile
const
BiasElementFunction
&
bias_element_func
,
const
BiasElementFunction
&
bias_element_func
,
RandValDramBlockWindowTmp
&
randval_dram_block_window_tmp
,
LSEaccDramBlockWindowTmp
&
lse_acc_dram_window_tmp
,
// M0*1 tile
LSEaccDramBlockWindowTmp
&
lse_acc_dram_window_tmp
,
// M0*1 tile
const
LSEaccElementFunction
&
lse_acc_element_func
,
const
LSEaccElementFunction
&
lse_acc_element_func
,
const
SAccElementFunction
&
s_acc_element_func
,
const
SAccElementFunction
&
s_acc_element_func
,
...
@@ -140,20 +140,19 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
...
@@ -140,20 +140,19 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
FmhaMask
mask
,
FmhaMask
mask
,
PositionEncoding
position_encoding
,
PositionEncoding
position_encoding
,
float
scale_s
,
float
scale_s
,
void
*
smem_ptr
,
void
*
smem_ptr
)
const
BlockDropout
&
dropout
)
const
{
{
static_assert
(
static_assert
(
std
::
is_same_v
<
QDataType
,
remove_cvref_t
<
typename
QDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
QDataType
,
remove_cvref_t
<
typename
QDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
KDataType
,
remove_cvref_t
<
typename
K
Dram
Block
WindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
KDataType
,
remove_cvref_t
<
typename
K
Page
Block
Navigator
::
DataType
>>
&&
std
::
is_same_v
<
VDataType
,
remove_cvref_t
<
typename
V
Dram
Block
WindowTmp
::
DataType
>>
,
std
::
is_same_v
<
VDataType
,
remove_cvref_t
<
typename
V
Page
Block
Navigator
::
DataType
>>
,
"wrong!"
);
"wrong!"
);
static_assert
(
kM0
==
QDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
static_assert
(
kM0
==
QDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kN0
==
KDramBlockWindow
Tmp
{}.
get_window_l
engths
()
[
number
<
0
>
{}]
&&
kN0
==
KDramBlockWindow
L
engths
{}
[
number
<
0
>
{}]
&&
kK0
==
KDramBlockWindow
Tmp
{}.
get_window_l
engths
()
[
number
<
1
>
{}]
&&
kK0
==
KDramBlockWindow
L
engths
{}
[
number
<
1
>
{}]
&&
kN1
==
VDramBlockWindow
Tmp
{}.
get_window_l
engths
()
[
number
<
0
>
{}]
&&
kN1
==
VDramBlockWindow
L
engths
{}
[
number
<
0
>
{}]
&&
kK1
==
VDramBlockWindow
Tmp
{}.
get_window_l
engths
()
[
number
<
1
>
{}]
&&
kK1
==
VDramBlockWindow
L
engths
{}
[
number
<
1
>
{}]
&&
kM0
==
BiasDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kM0
==
BiasDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kN0
==
BiasDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}],
kN0
==
BiasDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}],
"wrong!"
);
"wrong!"
);
...
@@ -213,12 +212,12 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
...
@@ -213,12 +212,12 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
const
auto
[
seqlen_k_start
,
seqlen_k_end
]
=
mask
.
GetTileRangeAlongX
(
const
auto
[
seqlen_k_start
,
seqlen_k_end
]
=
mask
.
GetTileRangeAlongX
(
q_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{},
num_splits
,
i_split
);
q_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{},
num_splits
,
i_split
);
const
auto
num_total_loop
=
integer_divide_ceil
(
seqlen_k_end
-
seqlen_k_start
,
kN0
);
// check early exit if masked and no work to do.
// check early exit if masked and no work to do.
if
constexpr
(
FmhaMask
::
IsMasking
||
kHasUnevenSplits
)
if
constexpr
(
FmhaMask
::
IsMasking
||
kHasUnevenSplits
)
{
{
if
(
num_total_loop
<=
0
)
const
index_t
original_num_total_loop
=
integer_divide_ceil
(
seqlen_k_end
-
seqlen_k_start
,
kN0
);
if
(
original_num_total_loop
<=
0
)
{
{
if
constexpr
(
kStoreLSE
)
if
constexpr
(
kStoreLSE
)
{
{
...
@@ -237,26 +236,34 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
...
@@ -237,26 +236,34 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
}
}
}
}
auto
k_dram_block_window
=
// make sure the first tile is completely located in page-block
make_tile_window
(
k_dram_block_window_tmp
.
get_bottom_tensor_view
(),
const
index_t
adjusted_seqlen_k_start
=
[
&
,
seqlen_k_start_
=
seqlen_k_start
]
{
k_dram_block_window_tmp
.
get_window_lengths
(),
if
constexpr
(
kIsPagedKV
)
{
seqlen_k_start
,
0
});
{
return
kN0
*
integer_divide_floor
(
seqlen_k_start_
,
kN0
);
}
else
{
return
seqlen_k_start_
;
}
}();
const
index_t
num_total_loop
=
integer_divide_ceil
(
seqlen_k_end
-
adjusted_seqlen_k_start
,
kN0
);
auto
[
i_page_block_k
,
k_dram_block_window
]
=
k_page_block_navigator
.
make_tile_window
(
k_dram_block_window_lengths
,
{
adjusted_seqlen_k_start
,
0
});
const
auto
bias_origin
=
bias_dram_block_window_tmp
.
get_window_origin
();
const
auto
bias_origin
=
bias_dram_block_window_tmp
.
get_window_origin
();
auto
bias_dram_window
=
make_tile_window
(
auto
bias_dram_window
=
make_tile_window
(
bias_dram_block_window_tmp
.
get_bottom_tensor_view
(),
bias_dram_block_window_tmp
.
get_bottom_tensor_view
(),
bias_dram_block_window_tmp
.
get_window_lengths
(),
bias_dram_block_window_tmp
.
get_window_lengths
(),
{
bias_origin
.
at
(
number
<
0
>
{}),
seqlen_k_start
},
// M/N
{
bias_origin
.
at
(
number
<
0
>
{}),
adjusted_
seqlen_k_start
},
// M/N
Policy
::
template
MakeBiasDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)>());
Policy
::
template
MakeBiasDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)>());
auto
randval_dram_window
=
dropout
.
MakeRandvalDramWindow
<
decltype
(
gemm_0
)
>
(
auto
[
i_page_block_v
,
v_dram_window
]
=
v_page_block_navigator
.
make_tile_window
(
randval_dram_block_window_tmp
,
seqlen_k_start
);
v_dram_block_window_lengths
,
{
0
,
adjusted_seqlen_k_start
},
// TODO: hdim split?
auto
v_dram_window
=
Policy
::
template
MakeVDramTileDistribution
<
Problem
>());
make_tile_window
(
v_dram_block_window_tmp
.
get_bottom_tensor_view
(),
v_dram_block_window_tmp
.
get_window_lengths
(),
{
0
,
seqlen_k_start
},
// TODO: hdim split?
Policy
::
template
MakeVDramTileDistribution
<
Problem
>());
auto
q_tile
=
tile_elementwise_in
(
q_element_func
,
q
);
auto
q_tile
=
tile_elementwise_in
(
q_element_func
,
q
);
...
@@ -271,14 +278,14 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
...
@@ -271,14 +278,14 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
{
{
// STAGE 1, QK gemm
// STAGE 1, QK gemm
auto
k_dram_window
=
make_tile_window
(
auto
k_dram_window
=
make_tile_window
(
k_dram_block_window
.
get_bottom_tensor_view
(),
k_dram_block_window
,
k_dram_block_window
.
get_window_lengths
(),
k_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeKDramTileDistribution
<
Problem
>());
// K DRAM tile window for
Policy
::
template
MakeKDramTileDistribution
<
Problem
>());
// K DRAM tile window for
// load
// load
auto
k_block_tile
=
load_tile
(
k_dram_window
);
auto
k_block_tile
=
load_tile
(
k_dram_window
);
{
{
// moving k_dram_window is an in-page-block operation, so there is
// no need to invoke k_page_block_navigator.move_tile_window() here.
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
clear_tile
(
s_acc
);
// initialize C
clear_tile
(
s_acc
);
// initialize C
store_tile
(
k_lds_window
,
tile_elementwise_in
(
k_element_func
,
k_block_tile
));
store_tile
(
k_lds_window
,
tile_elementwise_in
(
k_element_func
,
k_block_tile
));
...
@@ -355,7 +362,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
...
@@ -355,7 +362,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
}
}
else
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
else
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
{
const
auto
k_origin
=
k_dram_block_window
.
get_window_origin
();
const
auto
k_origin
=
k_page_block_navigator
.
to_global_window_origin
(
i_page_block_k
,
k_dram_block_window
.
get_window_origin
());
constexpr
auto
s_spans
=
decltype
(
s_acc
)
::
get_distributed_spans
();
constexpr
auto
s_spans
=
decltype
(
s_acc
)
::
get_distributed_spans
();
s_acc
=
tile_elementwise_in
(
s_acc_element_func
,
s_acc
);
s_acc
=
tile_elementwise_in
(
s_acc_element_func
,
s_acc
);
sweep_tile_span
(
s_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
s_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
...
@@ -381,22 +389,32 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
...
@@ -381,22 +389,32 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
}
}
move_tile_window
(
bias_dram_window
,
{
0
,
kN0
});
move_tile_window
(
bias_dram_window
,
{
0
,
kN0
});
/// TODO: only check in last iteration without increasing code size
/// TODO: only check in
first/
last iteration without increasing code size
if
constexpr
(
kHasUnevenSplits
)
if
constexpr
(
kHasUnevenSplits
)
{
{
const
auto
k_origin
=
k_dram_block_window
.
get_window_origin
();
const
auto
k_origin
=
k_page_block_navigator
.
to_global_window_origin
(
i_page_block_k
,
k_dram_block_window
.
get_window_origin
());
set_tile_if
(
s_acc
,
set_tile_if
(
s_acc
,
-
numeric
<
SMPLComputeDataType
>::
infinity
(),
-
numeric
<
SMPLComputeDataType
>::
infinity
(),
[
&
,
seqlen_k_end_
=
seqlen_k_end
](
auto
tile_idx
)
{
[
&
,
seqlen_k_start_
=
seqlen_k_start
,
seqlen_k_end_
=
seqlen_k_end
](
auto
tile_idx
)
{
const
auto
col
=
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
return
seqlen_k_end_
<=
col
;
if
constexpr
(
kIsPagedKV
)
{
return
col
<
seqlen_k_start_
||
seqlen_k_end_
<=
col
;
}
else
{
return
seqlen_k_end_
<=
col
;
}
});
});
}
}
if
constexpr
(
kPadSeqLenK
||
FmhaMask
::
IsMasking
)
if
constexpr
(
kPadSeqLenK
||
FmhaMask
::
IsMasking
)
{
{
const
auto
k_origin
=
k_dram_block_window
.
get_window_origin
();
const
auto
k_origin
=
k_page_block_navigator
.
to_global_window_origin
(
i_page_block_k
,
k_dram_block_window
.
get_window_origin
());
bool
need_perpixel_check
=
mask
.
IsEdgeTile
(
q_origin
.
at
(
number
<
0
>
{}),
bool
need_perpixel_check
=
mask
.
IsEdgeTile
(
q_origin
.
at
(
number
<
0
>
{}),
k_origin
.
at
(
number
<
0
>
{}),
k_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kM0
>
{},
...
@@ -501,12 +519,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
...
@@ -501,12 +519,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
});
});
});
});
if
constexpr
(
kHasDropout
)
{
dropout
.
Run
<
decltype
(
gemm_0
),
SMPLComputeDataType
,
RandValOutputDataType
>
(
smem_ptr
,
seqlen_k_start
+
i_total_loops
*
kN0
,
p_compute
,
randval_dram_window
);
}
block_sync_lds
();
block_sync_lds
();
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
{
...
@@ -522,7 +534,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
...
@@ -522,7 +534,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
store_tile
(
v_lds_window
,
store_tile
(
v_lds_window
,
tile_elementwise_in
(
v_element_func
,
v_prefetch
));
// store the prefetch
tile_elementwise_in
(
v_element_func
,
v_prefetch
));
// store the prefetch
}
}
move_tile_window
(
v_dram_window
,
{
0
,
kK1
});
i_page_block_v
=
v_page_block_navigator
.
move_tile_window
(
i_page_block_v
,
v_dram_window
,
{
0
,
kK1
});
const
auto
p
=
const
auto
p
=
cast_tile
<
PDataType
>
(
tile_elementwise_in
(
p_compute_element_func
,
p_compute
));
cast_tile
<
PDataType
>
(
tile_elementwise_in
(
p_compute_element_func
,
p_compute
));
...
@@ -530,8 +543,10 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
...
@@ -530,8 +543,10 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
// STAGE 3, KV gemm
// STAGE 3, KV gemm
if
constexpr
(
k1_loops
>
1
)
if
constexpr
(
k1_loops
>
1
)
{
{
static_for
<
0
,
k1_loops
-
1
,
1
>
{}([
&
](
auto
i_k1
)
{
static_for
<
0
,
k1_loops
-
1
,
1
>
{}([
&
,
const
auto
v
=
load_tile
(
v_dram_window
);
// load next v
&
i_page_block_v_
=
i_page_block_v
,
&
v_dram_window_
=
v_dram_window
](
auto
i_k1
)
{
const
auto
v
=
load_tile
(
v_dram_window_
);
// load next v
block_sync_lds
();
block_sync_lds
();
gemm_1
(
o_acc
,
gemm_1
(
o_acc
,
get_slice_tile
(
get_slice_tile
(
...
@@ -552,11 +567,13 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
...
@@ -552,11 +567,13 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
store_tile
(
v_lds_window
,
store_tile
(
v_lds_window
,
tile_elementwise_in
(
v_element_func
,
v
));
// store next v
tile_elementwise_in
(
v_element_func
,
v
));
// store next v
}
}
move_tile_window
(
v_dram_window
,
{
0
,
kK1
});
i_page_block_v_
=
v_page_block_navigator
.
move_tile_window
(
i_page_block_v_
,
v_dram_window_
,
{
0
,
kK1
});
});
});
}
}
// move K tile windows
// move K tile windows
move_tile_window
(
k_dram_block_window
,
{
kN0
,
0
});
i_page_block_k
=
k_page_block_navigator
.
move_tile_window
(
i_page_block_k
,
k_dram_block_window
,
{
kN0
,
0
});
// tail
// tail
{
{
block_sync_lds
();
block_sync_lds
();
...
@@ -618,36 +635,38 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
...
@@ -618,36 +635,38 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
}
}
template
<
typename
QDramBlockWindowTmp
,
template
<
typename
QDramBlockWindowTmp
,
typename
KDramBlockWindowTmp
,
typename
KDramBlockWindowLengths
,
typename
VDramBlockWindowTmp
,
typename
KPageBlockNavigator
,
typename
VDramBlockWindowLengths
,
typename
VPageBlockNavigator
,
typename
BiasDramBlockWindowTmp
,
typename
BiasDramBlockWindowTmp
,
typename
RandValDramBlockWindowTmp
,
typename
LSEaccDramBlockWindowTmp
,
typename
LSEaccDramBlockWindowTmp
,
typename
PositionEncoding
>
typename
PositionEncoding
>
CK_TILE_HOST_DEVICE
auto
CK_TILE_HOST_DEVICE
auto
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
// M0*K0 tile
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
// M0*K0 tile
const
KDramBlockWindowTmp
&
k_dram_block_window_tmp
,
// N0*K0 tile
const
KDramBlockWindowLengths
&
k_dram_block_window_lengths
,
// N0*K0 tile
const
VDramBlockWindowTmp
&
v_dram_block_window_tmp
,
// N1*K1 tile
const
KPageBlockNavigator
&
k_page_block_navigator
,
const
VDramBlockWindowLengths
&
v_dram_block_window_lengths
,
// N1*K1 tile
const
VPageBlockNavigator
&
v_page_block_navigator
,
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
// M0*N0 tile
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
// M0*N0 tile
RandValDramBlockWindowTmp
&
randval_dram_block_window_tmp
,
// M0*N0 tile
LSEaccDramBlockWindowTmp
&
lse_acc_dram_block_window_tmp
,
// M0*1 tile
LSEaccDramBlockWindowTmp
&
lse_acc_dram_block_window_tmp
,
// M0*1 tile
index_t
num_splits
,
index_t
num_splits
,
index_t
i_split
,
index_t
i_split
,
FmhaMask
mask
,
FmhaMask
mask
,
PositionEncoding
position_encoding
,
PositionEncoding
position_encoding
,
float
scale_s
,
float
scale_s
,
void
*
smem_ptr
,
void
*
smem_ptr
)
const
BlockDropout
&
dropout
)
const
{
{
return
operator
()(
q_dram_block_window_tmp
,
return
operator
()(
q_dram_block_window_tmp
,
identity
{},
identity
{},
k_dram_block_window_tmp
,
k_dram_block_window_lengths
,
k_page_block_navigator
,
identity
{},
identity
{},
v_dram_block_window_tmp
,
v_dram_block_window_lengths
,
v_page_block_navigator
,
identity
{},
identity
{},
bias_dram_block_window_tmp
,
bias_dram_block_window_tmp
,
identity
{},
identity
{},
randval_dram_block_window_tmp
,
lse_acc_dram_block_window_tmp
,
lse_acc_dram_block_window_tmp
,
identity
{},
identity
{},
identity
{},
identity
{},
...
@@ -658,8 +677,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
...
@@ -658,8 +677,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
mask
,
mask
,
position_encoding
,
position_encoding
,
scale_s
,
scale_s
,
smem_ptr
,
smem_ptr
);
dropout
);
}
}
};
};
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp
deleted
100644 → 0
View file @
829e0eb3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async_default_policy.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace
ck_tile
{
// a variation of qr/ks/vs, where we use async copy to load k (potentially v in the future)
template
<
typename
Problem_
,
typename
Policy_
=
BlockFmhaFwdSplitKVPipelineQRKSVSAsyncDefaultPolicy
>
struct
BlockFmhaFwdSplitKVPipelineQRKSVSAsync
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
using
QDataType
=
remove_cvref_t
<
typename
Problem
::
QDataType
>
;
using
KDataType
=
remove_cvref_t
<
typename
Problem
::
KDataType
>
;
using
VDataType
=
remove_cvref_t
<
typename
Problem
::
VDataType
>
;
using
SaccDataType
=
remove_cvref_t
<
typename
Problem
::
SaccDataType
>
;
using
SMPLComputeDataType
=
remove_cvref_t
<
typename
Problem
::
SMPLComputeDataType
>
;
using
BiasDataType
=
remove_cvref_t
<
typename
Problem
::
BiasDataType
>
;
using
RandValOutputDataType
=
remove_cvref_t
<
typename
Problem
::
RandValOutputDataType
>
;
using
LSEDataType
=
remove_cvref_t
<
typename
Problem
::
LSEDataType
>
;
using
PDataType
=
remove_cvref_t
<
typename
Problem
::
PDataType
>
;
using
OaccDataType
=
remove_cvref_t
<
typename
Problem
::
OaccDataType
>
;
using
FmhaMask
=
remove_cvref_t
<
typename
Problem
::
FmhaMask
>
;
using
BlockFmhaShape
=
remove_cvref_t
<
typename
Problem
::
BlockFmhaShape
>
;
using
VLayout
=
remove_cvref_t
<
typename
BlockFmhaShape
::
VLayout
>
;
static
constexpr
bool
kQLoadOnce
=
true
;
// if q_tile load whole block length (hdim) at once
static_assert
(
kQLoadOnce
==
Policy
::
QLoadOnce
);
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kM0
=
BlockFmhaShape
::
kM0
;
static
constexpr
index_t
kN0
=
BlockFmhaShape
::
kN0
;
static
constexpr
index_t
kK0
=
BlockFmhaShape
::
kK0
;
static
constexpr
index_t
kN1
=
BlockFmhaShape
::
kN1
;
static
constexpr
index_t
kK1
=
BlockFmhaShape
::
kK1
;
static
constexpr
index_t
kK0BlockLength
=
BlockFmhaShape
::
kK0BlockLength
;
static
constexpr
bool
kIsGroupMode
=
Problem
::
kIsGroupMode
;
// TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x)
// only need special care about seq_k padding (oob need set -INF of p instead of zero)
static_assert
(
Problem
::
kPadSeqLenQ
==
true
&&
Problem
::
kPadHeadDimQ
==
true
&&
Problem
::
kPadHeadDimV
==
true
);
static
constexpr
bool
kPadSeqLenQ
=
true
;
static
constexpr
bool
kPadSeqLenK
=
Problem
::
kPadSeqLenK
;
static
constexpr
bool
kPadHeadDimQ
=
true
;
// support multiple of vector(like 8x)
static
constexpr
bool
kPadHeadDimV
=
true
;
// support multiple of vector(like 8x)
static
constexpr
auto
BiasEnum
=
Problem
::
BiasEnum
;
static
constexpr
bool
kStoreLSE
=
true
;
// always store LSE (acc)
static
constexpr
bool
kHasDropout
=
false
;
// ignore this flag
static
constexpr
bool
kHasUnevenSplits
=
Problem
::
kHasUnevenSplits
;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
static
constexpr
index_t
kAlignmentQ
=
Policy
::
template
GetAlignmentQ
<
Problem
>();
static
constexpr
index_t
kAlignmentK
=
Policy
::
template
GetAlignmentK
<
Problem
>();
static
constexpr
index_t
kAlignmentV
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
return
Policy
::
template
GetAlignmentV
<
Problem
>();
else
return
kPadSeqLenK
?
1
:
Policy
::
template
GetAlignmentV
<
Problem
>();
}();
static
constexpr
index_t
kAlignmentO
=
Policy
::
template
GetAlignmentO
<
Problem
>();
static
constexpr
index_t
kAlignmentBias
=
kPadSeqLenK
?
1
:
Policy
::
template
GetAlignmentBias
<
Problem
>();
#if CK_TILE_FMHA_FWD_FAST_EXP2
static
constexpr
auto
R_LOG2E
=
1.0
/
log2e_v
<
SaccDataType
>
;
#endif
static
constexpr
index_t
kBlockPerCu
=
[]()
{
if
constexpr
(
Problem
::
kBlockPerCu
!=
-
1
)
return
Problem
::
kBlockPerCu
;
else
{
if
constexpr
(
kK0BlockLength
<=
32
)
{
if
constexpr
(
kPadSeqLenK
&&
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
&&
FmhaMask
::
IsMasking
)
return
1
;
else
return
2
;
}
else
if
constexpr
(
kK0BlockLength
<=
64
)
{
if
constexpr
(
kPadSeqLenK
&&
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
return
2
;
else
return
3
;
}
else
if
constexpr
(
kK0BlockLength
<=
128
)
{
if
constexpr
(
kPadSeqLenK
&&
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
return
1
;
else
return
2
;
}
else
if
constexpr
(
kK0BlockLength
<=
256
)
{
return
1
;
}
}
}();
static
constexpr
const
char
*
name
=
"qr_async"
;
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
template
<
typename
QDramBlockWindowTmp
,
typename
KDramBlockWindowTmp
,
typename
VDramBlockWindowTmp
,
typename
BiasDramBlockWindowTmp
,
typename
RandValDramBlockWindowTmp
,
typename
LSEaccDramBlockWindowTmp
,
typename
QElementFunction
,
typename
KElementFunction
,
typename
VElementFunction
,
typename
BiasElementFunction
,
typename
LSEaccElementFunction
,
typename
SAccElementFunction
,
typename
PComputeElementFunction
,
typename
OAccElementFunction
,
typename
PositionEncoding
>
CK_TILE_HOST_DEVICE
auto
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
// M0*K0 tile
const
QElementFunction
&
q_element_func
,
const
KDramBlockWindowTmp
&
k_dram_block_window_tmp
,
// N0*K0 tile
const
KElementFunction
&
/*k_element_func*/
,
const
VDramBlockWindowTmp
&
v_dram_block_window_tmp
,
// N1*K1 tile
const
VElementFunction
&
v_element_func
,
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
// M0*N0 tile
const
BiasElementFunction
&
bias_element_func
,
RandValDramBlockWindowTmp
&
randval_dram_block_window_tmp
,
LSEaccDramBlockWindowTmp
&
lse_acc_dram_window_tmp
,
// M0*1 tile
const
LSEaccElementFunction
&
lse_acc_element_func
,
const
SAccElementFunction
&
s_acc_element_func
,
const
PComputeElementFunction
&
p_compute_element_func
,
const
OAccElementFunction
&
o_acc_element_func
,
index_t
num_splits
,
index_t
i_split
,
FmhaMask
mask
,
PositionEncoding
position_encoding
,
float
scale_s
,
void
*
smem_ptr
,
BlockDropout
&
dropout
)
const
{
static_assert
(
std
::
is_same_v
<
QDataType
,
remove_cvref_t
<
typename
QDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
KDataType
,
remove_cvref_t
<
typename
KDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
VDataType
,
remove_cvref_t
<
typename
VDramBlockWindowTmp
::
DataType
>>
,
"wrong!"
);
static_assert
(
kM0
==
QDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kN0
==
KDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kK0
==
KDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}]
&&
kN1
==
VDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kK1
==
VDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}]
&&
kM0
==
BiasDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kN0
==
BiasDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}],
"wrong!"
);
constexpr
auto
LdsSeq
=
Policy
::
template
GetLdsBufferSequence
<
Problem
>();
// K tile in LDS
auto
k_lds_ptr
=
reinterpret_cast
<
KDataType
*>
(
smem_ptr
);
auto
k_lds_store
=
generate_tuple
(
[
&
](
auto
i_buf
)
{
return
make_tile_window
(
make_tensor_view
<
address_space_enum
::
lds
>
(
k_lds_ptr
,
Policy
::
template
MakeKLdsStoreBlockDescriptor
<
Problem
>(
i_buf
)),
Policy
::
template
MakeKLdsStoreBlockDescriptor
<
Problem
>(
i_buf
).
get_lengths
(),
{
0
,
0
,
0
});
},
number
<
Policy
::
NumPrefetchK
>
{});
#if K_LDS_LOAD_USE_OFFSET_TRANSFORM
auto
k_lds_load
=
generate_tuple
(
[
&
](
auto
i_buf
)
{
return
make_tile_window
(
make_tensor_view
<
address_space_enum
::
lds
>
(
k_lds_ptr
,
Policy
::
template
MakeKLdsLoadBlockDescriptor
<
Problem
>(
i_buf
)),
Policy
::
template
MakeKLdsLoadBlockDescriptor
<
Problem
>(
i_buf
).
get_lengths
(),
{
0
,
0
});
},
number
<
Policy
::
NumPrefetchK
>
{});
#else
auto
k_lds_Load_view
=
make_tensor_view
<
address_space_enum
::
lds
>
(
k_lds_ptr
,
Policy
::
template
MakeKLdsLoadBlockDescriptor
<
Problem
>());
auto
k_lds_load
=
make_tile_window
(
k_lds_Load_view
,
Policy
::
template
MakeKLdsLoadBlockDescriptor
<
Problem
>().
get_lengths
(),
{
0
,
0
});
#endif
// V tile in LDS
auto
v_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
reinterpret_cast
<
VDataType
*>
(
smem_ptr
),
Policy
::
template
MakeVLdsBlockDescriptor
<
Problem
>());
auto
v_lds_window
=
make_tile_window
(
v_lds
,
Policy
::
template
MakeVLdsBlockDescriptor
<
Problem
>().
get_lengths
(),
{
0
,
0
});
// Block GEMM
constexpr
auto
gemm_0
=
Policy
::
template
GetQKBlockGemm
<
Problem
>();
constexpr
auto
gemm_1
=
Policy
::
template
GetKVBlockGemm
<
Problem
>();
auto
q_dram_window
=
make_tile_window
(
q_dram_block_window_tmp
.
get_bottom_tensor_view
(),
q_dram_block_window_tmp
.
get_window_lengths
(),
q_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeQDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)>());
// TODO: we use async Copy for K, which is inline asm
// a side effect is we have to use inline asm for q as well
auto
q
=
decltype
(
load_tile
(
q_dram_window
)){};
set_tile
(
q
,
number
<
0
>
{});
// use per-dword clear to avoid scratch
load_tile_raw
(
q
,
q_dram_window
);
__builtin_amdgcn_sched_barrier
(
0
);
using
SaccBlockTileType
=
decltype
(
gemm_0
.
MakeCBlockTile
());
auto
s_acc
=
SaccBlockTileType
{};
// reduction function for softmax
const
auto
f_max
=
[](
auto
e0
,
auto
e1
)
{
return
max
(
e0
,
e1
);
};
const
auto
f_sum
=
[](
auto
e0
,
auto
e1
)
{
return
e0
+
e1
;
};
// infer Sacc, S, P, M, L, Oacc type
using
SBlockTileType
=
decltype
(
cast_tile
<
SMPLComputeDataType
>
(
s_acc
));
using
MLBlockTileType
=
decltype
(
block_tile_reduce
<
SMPLComputeDataType
>
(
SBlockTileType
{},
sequence
<
1
>
{},
f_max
,
SMPLComputeDataType
{
0
}));
using
OaccBlockTileType
=
decltype
(
gemm_1
.
MakeCBlockTile
());
// init Oacc, M, L
auto
o_acc
=
OaccBlockTileType
{};
auto
m
=
MLBlockTileType
{};
auto
l
=
MLBlockTileType
{};
clear_tile
(
o_acc
);
set_tile
(
m
,
-
numeric
<
SMPLComputeDataType
>::
infinity
());
clear_tile
(
l
);
__builtin_amdgcn_sched_barrier
(
0
);
const
auto
q_origin
=
q_dram_window
.
get_window_origin
();
const
auto
[
seqlen_k_start
,
seqlen_k_end
]
=
mask
.
GetTileRangeAlongX
(
q_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{},
num_splits
,
i_split
);
const
auto
num_total_loop
=
integer_divide_ceil
(
seqlen_k_end
-
seqlen_k_start
,
kN0
);
// check early exit if masked and no work to do.
if
constexpr
(
FmhaMask
::
IsMasking
||
kPadSeqLenK
||
kHasUnevenSplits
)
{
if
(
num_total_loop
<=
0
)
{
if
constexpr
(
kStoreLSE
)
{
auto
lse_acc
=
make_static_distributed_tensor
<
LSEDataType
>
(
m
.
get_tile_distribution
());
set_tile
(
lse_acc
,
-
numeric
<
SMPLComputeDataType
>::
infinity
());
store_tile
(
lse_acc_dram_window_tmp
,
tile_elementwise_in
(
lse_acc_element_func
,
lse_acc
));
}
buffer_load_fence
(
0
);
// rocm-6.1, if whole tile is masked out, need to fence(0)
// otherwise will have compute error(maybe compiler bug?)
// Note: here occ are all cleard, return it
return
o_acc
;
}
__builtin_amdgcn_sched_barrier
(
0
);
// make sure sched_barrier(0) for this check
}
auto
k_dram_block_window
=
make_tile_window
(
k_dram_block_window_tmp
.
get_bottom_tensor_view
(),
k_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_k_start
,
0
});
auto
k_dram_window
=
make_tile_window
(
k_dram_block_window
.
get_bottom_tensor_view
(),
k_dram_block_window
.
get_window_lengths
(),
k_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeKDramTileDistribution
<
Problem
>());
// K DRAM tile window for
// load
const
auto
bias_origin
=
bias_dram_block_window_tmp
.
get_window_origin
();
auto
bias_dram_window
=
make_tile_window
(
bias_dram_block_window_tmp
.
get_bottom_tensor_view
(),
bias_dram_block_window_tmp
.
get_window_lengths
(),
{
bias_origin
.
at
(
number
<
0
>
{}),
seqlen_k_start
},
// M/N
Policy
::
template
MakeBiasDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)>());
auto
randval_dram_window
=
dropout
.
MakeRandvalDramWindow
<
decltype
(
gemm_0
)
>
(
randval_dram_block_window_tmp
,
seqlen_k_start
);
auto
v_dram_window
=
make_tile_window
(
v_dram_block_window_tmp
.
get_bottom_tensor_view
(),
v_dram_block_window_tmp
.
get_window_lengths
(),
{
0
,
seqlen_k_start
},
// TODO: hdim split?
Policy
::
template
MakeVDramTileDistribution
<
Problem
>());
// prefetch K tile
async_load_tile_raw
(
k_lds_store
(
LdsSeq
.
at
(
number
<
0
>
{})),
k_dram_window
);
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
__builtin_amdgcn_sched_barrier
(
0
);
buffer_load_fence
(
k_dram_window
.
get_num_access
(),
q
.
get_thread_buffer
());
(
void
)
q_element_func
;
// ??? rocm-6.x if use q element func will have scratch on hdim=64/32
// auto q_tile = q; // tile_elementwise_in(q_element_func, q);
index_t
i_total_loops
=
0
;
constexpr
index_t
k0_loops
=
kK0BlockLength
/
kK0
;
constexpr
index_t
k1_loops
=
kN0
/
kK1
;
static_assert
(
1
<=
k0_loops
);
static_assert
(
1
<=
k1_loops
);
// main loop
do
{
// STAGE 1, QK gemm
clear_tile
(
s_acc
);
// initialize C
if
constexpr
(
k0_loops
>
1
)
{
static_for
<
0
,
k0_loops
-
1
,
1
>
{}([
&
](
auto
i_k0
)
{
async_load_tile_raw
(
k_lds_store
(
number
<
LdsSeq
.
at
(
number
<
i_k0
+
1
>
{})
>
{}),
k_dram_window
);
if
constexpr
(
i_k0
<
k0_loops
-
1
)
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
async_load_fence
(
k_dram_window
.
get_num_access
());
__builtin_amdgcn_s_barrier
();
__builtin_amdgcn_sched_barrier
(
0
);
gemm_0
(
s_acc
,
get_slice_tile
(
q
,
sequence
<
0
,
i_k0
*
kK0
>
{},
sequence
<
kM0
,
(
i_k0
+
1
)
*
kK0
>
{}),
#if K_LDS_LOAD_USE_OFFSET_TRANSFORM
k_lds_load
[
number
<
LdsSeq
.
at
(
number
<
i_k0
>
{})
>
{}]);
#else
get_slice_tile
(
k_lds_load
,
sequence
<
(
LdsSeq
.
at
(
number
<
i_k0
>
{}))
*
kN0
,
0
>
{},
sequence
<
(
LdsSeq
.
at
(
number
<
i_k0
>
{})
+
1
)
*
kN0
,
kK0
>
{}));
#endif
});
}
// TODO: this to fix a bug when loop smaller than 2,
// the following fence/barrier will be scheduled inside 1st loop
if
constexpr
(
k0_loops
<=
2
)
__builtin_amdgcn_sched_barrier
(
0
);
async_load_fence
();
__builtin_amdgcn_s_barrier
();
const
auto
bias_tile
=
load_tile
(
bias_dram_window
);
// load bias tile
auto
v_buf
=
load_tile
(
v_dram_window
,
bool_constant
<
false
>
{});
__builtin_amdgcn_sched_barrier
(
0
);
{
// tail
gemm_0
(
s_acc
,
get_slice_tile
(
q
,
sequence
<
0
,
(
k0_loops
-
1
)
*
kK0
>
{},
sequence
<
kM0
,
k0_loops
*
kK0
>
{}),
#if K_LDS_LOAD_USE_OFFSET_TRANSFORM
k_lds_load
[
number
<
LdsSeq
.
at
(
number
<
k0_loops
-
1
>
{})
>
{}]);
#else
get_slice_tile
(
k_lds_load
,
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
-
1
>
{}))
*
kN0
,
0
>
{},
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
-
1
>
{})
+
1
)
*
kN0
,
kK0
>
{}));
#endif
}
__builtin_amdgcn_sched_barrier
(
1
);
// STAGE 2, scale_s, add bias, mask, softmax
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
s_acc
=
tile_elementwise_in
(
s_acc_element_func
,
s_acc
);
tile_elementwise_inout
([
&
scale_s
](
auto
&
x
)
{
x
=
x
*
scale_s
;
},
s_acc
);
tile_elementwise_inout
(
[
&
](
auto
&
x
,
const
auto
&
y
)
{
#if !CK_TILE_FMHA_FWD_FAST_EXP2
x
+=
type_convert
<
SaccDataType
>
(
bias_element_func
(
y
));
#else
x
+=
log2e_v
<
SaccDataType
>
*
type_convert
<
SaccDataType
>
(
bias_element_func
(
y
));
#endif
},
s_acc
,
bias_tile
);
}
else
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
const
auto
k_origin
=
k_dram_block_window
.
get_window_origin
();
constexpr
auto
s_spans
=
decltype
(
s_acc
)
::
get_distributed_spans
();
s_acc
=
tile_elementwise_in
(
s_acc_element_func
,
s_acc
);
sweep_tile_span
(
s_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
s_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
const
auto
tile_idx
=
get_x_indices_from_distributed_indices
(
s_acc
.
get_tile_distribution
(),
make_tuple
(
idx0
,
idx1
));
const
auto
row
=
q_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
s_acc
(
i_j_idx
)
*=
scale_s
;
position_encoding
.
update
(
s_acc
(
i_j_idx
),
row
,
col
);
});
});
}
else
{
s_acc
=
tile_elementwise_in
(
s_acc_element_func
,
s_acc
);
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout
([
&
scale_s
](
auto
&
x
)
{
x
=
x
*
scale_s
;
},
s_acc
);
#endif
}
move_tile_window
(
bias_dram_window
,
{
0
,
kN0
});
/// TODO: only check in last iteration without increasing code size
if
constexpr
(
kHasUnevenSplits
)
{
const
auto
k_origin
=
k_dram_block_window
.
get_window_origin
();
set_tile_if
(
s_acc
,
-
numeric
<
SMPLComputeDataType
>::
infinity
(),
[
&
,
seqlen_k_end_
=
seqlen_k_end
](
auto
tile_idx
)
{
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
return
seqlen_k_end_
<=
col
;
});
}
if
constexpr
(
kPadSeqLenK
||
FmhaMask
::
IsMasking
)
{
const
auto
k_origin
=
k_dram_block_window
.
get_window_origin
();
bool
need_perpixel_check
=
mask
.
IsEdgeTile
(
q_origin
.
at
(
number
<
0
>
{}),
k_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{});
if
(
need_perpixel_check
)
{
set_tile_if
(
s_acc
,
-
numeric
<
SMPLComputeDataType
>::
infinity
(),
[
&
](
auto
tile_idx
)
{
const
auto
row
=
q_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
return
mask
.
IsOutOfBound
(
row
,
col
);
});
}
}
const
auto
s
=
cast_tile
<
SMPLComputeDataType
>
(
s_acc
);
// S{j}
auto
m_local
=
block_tile_reduce
<
SMPLComputeDataType
>
(
s
,
sequence
<
1
>
{},
f_max
,
-
numeric
<
SMPLComputeDataType
>::
infinity
());
// m_local = rowmax(S{j})
block_tile_reduce_sync
(
m_local
,
f_max
,
bool_constant
<
false
>
{});
const
auto
m_old
=
m
;
// m{j-1}
tile_elementwise_inout
(
[](
auto
&
e0
,
auto
e1
,
auto
e2
)
{
e0
=
max
(
e1
,
e2
);
},
m
,
m_old
,
m_local
);
// m{j}
auto
p_compute
=
make_static_distributed_tensor
<
SMPLComputeDataType
>
(
s
.
get_tile_distribution
());
// Pcompute{j}
__builtin_amdgcn_sched_barrier
(
0x7F
);
// store & prefetch next v, after the max reduction
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
auto
v_shuffle_tmp
=
make_static_distributed_tensor
<
VDataType
>
(
Policy
::
template
MakeShuffledVRegBlockDescriptor
<
Problem
>());
shuffle_tile
(
v_shuffle_tmp
,
v_buf
);
auto
v_lds_window_tmp
=
get_slice_tile
(
v_lds_window
,
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
>
{}))
*
kN1
,
0
>
{},
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
>
{})
+
1
)
*
kN1
,
kK1
>
{});
store_tile
(
v_lds_window_tmp
,
tile_elementwise_in
(
v_element_func
,
v_shuffle_tmp
));
// store the prefetch
}
else
{
auto
v_lds_window_tmp
=
get_slice_tile
(
v_lds_window
,
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
>
{}))
*
kN1
,
0
>
{},
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
>
{})
+
1
)
*
kN1
,
kK1
>
{});
store_tile
(
v_lds_window_tmp
,
tile_elementwise_in
(
v_element_func
,
v_buf
));
// store the prefetch
}
if
constexpr
(
k1_loops
>
1
)
{
move_tile_window
(
v_dram_window
,
{
0
,
kK1
});
// will have scratch if move this right after load_tile(v_dram)...
v_buf
=
load_tile
(
v_dram_window
,
bool_constant
<
false
>
{});
// load next v_buf
}
__builtin_amdgcn_sched_barrier
(
0
);
static
const
auto
get_validated_m
=
[](
SMPLComputeDataType
raw_m
)
{
/// NOTICE: bias might be materialized mask including -inf values, need
/// consideration. alibi does not have this problem
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
FmhaMask
::
IsMasking
)
{
return
raw_m
==
-
numeric
<
SMPLComputeDataType
>::
infinity
()
?
type_convert
<
SMPLComputeDataType
>
(
0.
f
)
:
raw_m
;
}
else
{
return
raw_m
;
}
};
constexpr
auto
p_spans
=
decltype
(
p_compute
)
::
get_distributed_spans
();
sweep_tile_span
(
p_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
auto
row_max
=
scale_s
*
get_validated_m
(
m
[
i_idx
]);
#endif
sweep_tile_span
(
p_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
p_compute
(
i_j_idx
)
=
exp2
(
s
[
i_j_idx
]
-
get_validated_m
(
m
[
i_idx
]));
}
else
{
p_compute
(
i_j_idx
)
=
exp2
(
scale_s
*
s
[
i_j_idx
]
-
row_max
);
}
#else
p_compute
(
i_j_idx
)
=
exp
(
s
[
i_j_idx
]
-
get_validated_m
(
m
[
i_idx
]));
#endif
});
});
auto
rowsum_p
=
block_tile_reduce
<
SMPLComputeDataType
>
(
p_compute
,
sequence
<
1
>
{},
f_sum
,
SMPLComputeDataType
{
0
});
// rowsum(Pcompute{j})
block_tile_reduce_sync
(
rowsum_p
,
f_sum
,
bool_constant
<
false
>
{});
// l{j}, Oacc{j}
constexpr
auto
o_spans
=
decltype
(
o_acc
)
::
get_distributed_spans
();
sweep_tile_span
(
o_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
const
auto
tmp
=
[
&
]()
{
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
return
exp2
(
m_old
[
i_idx
]
-
get_validated_m
(
m
[
i_idx
]));
}
else
{
auto
row_max
=
scale_s
*
get_validated_m
(
m
[
i_idx
]);
return
exp2
(
scale_s
*
m_old
[
i_idx
]
-
row_max
);
}
}();
#else
const
auto
tmp
=
exp
(
m_old
[
i_idx
]
-
get_validated_m
(
m
[
i_idx
]));
#endif
l
(
i_idx
)
=
tmp
*
l
[
i_idx
]
+
rowsum_p
[
i_idx
];
sweep_tile_span
(
o_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
// FIXME: this use different equation from FA v2 paper,
// but produce correc result.
// Is the equation wrong?
o_acc
(
i_j_idx
)
*=
tmp
;
});
});
if
constexpr
(
kHasDropout
)
{
auto
randval_ptr
=
reinterpret_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeKV
<
Problem
>();
dropout
.
Run
<
decltype
(
gemm_0
),
SMPLComputeDataType
,
RandValOutputDataType
>
(
randval_ptr
,
seqlen_k_start
+
i_total_loops
*
kN0
,
p_compute
,
randval_dram_window
);
}
const
auto
p
=
cast_tile
<
PDataType
>
(
tile_elementwise_in
(
p_compute_element_func
,
p_compute
));
// STAGE 3, KV gemm
if
constexpr
(
k1_loops
>
1
)
{
static_for
<
0
,
k1_loops
-
1
,
1
>
{}([
&
](
auto
i_k1
)
{
if
constexpr
(
i_k1
!=
0
&&
i_k1
<
k1_loops
-
1
)
{
v_buf
=
load_tile
(
v_dram_window
,
bool_constant
<
false
>
{});
// load next v_buf
}
block_sync_lds
();
gemm_1
(
o_acc
,
get_slice_tile
(
p
,
sequence
<
0
,
i_k1
*
kK1
>
{},
sequence
<
kM0
,
(
i_k1
+
1
)
*
kK1
>
{}),
get_slice_tile
(
v_lds_window
,
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
+
i_k1
>
{}))
*
kN1
,
0
>
{},
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
+
i_k1
>
{})
+
1
)
*
kN1
,
kK1
>
{}));
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
auto
v_shuffle_tmp
=
make_static_distributed_tensor
<
VDataType
>
(
Policy
::
template
MakeShuffledVRegBlockDescriptor
<
Problem
>());
shuffle_tile
(
v_shuffle_tmp
,
v_buf
);
auto
v_lds_window_tmp
=
get_slice_tile
(
v_lds_window
,
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
+
i_k1
+
1
>
{}))
*
kN1
,
0
>
{},
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
+
i_k1
+
1
>
{})
+
1
)
*
kN1
,
kK1
>
{});
store_tile
(
v_lds_window_tmp
,
tile_elementwise_in
(
v_element_func
,
v_shuffle_tmp
));
// store the prefetch
}
else
{
auto
v_lds_window_tmp
=
get_slice_tile
(
v_lds_window
,
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
+
i_k1
+
1
>
{}))
*
kN1
,
0
>
{},
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
+
i_k1
+
1
>
{})
+
1
)
*
kN1
,
kK1
>
{});
store_tile
(
v_lds_window_tmp
,
tile_elementwise_in
(
v_element_func
,
v_buf
));
// store next v_buf
}
if
constexpr
(
i_k1
<
k1_loops
-
1
)
move_tile_window
(
v_dram_window
,
{
0
,
kK1
});
});
}
i_total_loops
++
;
if
(
i_total_loops
<
num_total_loop
)
{
// move K tile windows
move_tile_window
(
k_dram_block_window
,
{
kN0
,
0
});
k_dram_window
=
make_tile_window
(
k_dram_block_window
.
get_bottom_tensor_view
(),
k_dram_block_window
.
get_window_lengths
(),
k_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeKDramTileDistribution
<
Problem
>());
if
constexpr
(
k1_loops
>=
2
&&
LdsSeq
.
at
(
number
<
0
>
{})
==
LdsSeq
.
at
(
number
<
k0_loops
+
k1_loops
-
2
>
{}))
__builtin_amdgcn_s_barrier
();
async_load_tile_raw
(
k_lds_store
(
LdsSeq
.
at
(
number
<
0
>
{})),
k_dram_window
);
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
}
// tail
{
block_sync_lds
();
gemm_1
(
o_acc
,
get_slice_tile
(
p
,
sequence
<
0
,
(
k1_loops
-
1
)
*
kK1
>
{},
sequence
<
kM0
,
kN0
>
{}),
get_slice_tile
(
v_lds_window
,
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
+
k1_loops
-
1
>
{}))
*
kN1
,
0
>
{},
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
+
k1_loops
-
1
>
{})
+
1
)
*
kN1
,
kK1
>
{}));
}
}
while
(
i_total_loops
<
num_total_loop
);
// store lse acc
if
constexpr
(
kStoreLSE
)
{
auto
lse_acc
=
make_static_distributed_tensor
<
LSEDataType
>
(
m
.
get_tile_distribution
());
constexpr
auto
lse_acc_spans
=
decltype
(
lse_acc
)
::
get_distributed_spans
();
sweep_tile_span
(
lse_acc_spans
[
number
<
0
>
{}],
[
&
,
m_
=
m
,
l_
=
l
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
lse_acc
(
i_idx
)
=
m_
[
i_idx
]
*
R_LOG2E
+
log
(
l_
[
i_idx
]);
}
else
{
lse_acc
(
i_idx
)
=
m_
[
i_idx
]
*
scale_s
*
R_LOG2E
+
log
(
l_
[
i_idx
]);
}
#else
lse_acc
(
i_idx
)
=
m_
[
i_idx
]
+
log
(
l_
[
i_idx
]);
#endif
});
store_tile
(
lse_acc_dram_window_tmp
,
tile_elementwise_in
(
lse_acc_element_func
,
lse_acc
));
}
// finally, O
constexpr
auto
o_spans
=
decltype
(
o_acc
)
::
get_distributed_spans
();
sweep_tile_span
(
o_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
const
auto
tmp
=
[
&
]()
{
if
constexpr
(
FmhaMask
::
IsMasking
)
{
return
l
[
i_idx
]
==
0.
f
?
0.
f
:
1
/
l
[
i_idx
];
}
else
return
1
/
l
[
i_idx
];
}();
sweep_tile_span
(
o_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
o_acc
(
i_j_idx
)
*=
tmp
;
});
});
o_acc
=
tile_elementwise_in
(
o_acc_element_func
,
o_acc
);
return
o_acc
;
}
template
<
typename
QDramBlockWindowTmp
,
typename
KDramBlockWindowTmp
,
typename
VDramBlockWindowTmp
,
typename
BiasDramBlockWindowTmp
,
typename
RandValDramBlockWindowTmp
,
typename
LSEaccDramBlockWindowTmp
,
typename
PositionEncoding
>
CK_TILE_HOST_DEVICE
auto
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
// M0*K0 tile
const
KDramBlockWindowTmp
&
k_dram_block_window_tmp
,
// N0*K0 tile
const
VDramBlockWindowTmp
&
v_dram_block_window_tmp
,
// N1*K1 tile
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
// M0*N0 tile
RandValDramBlockWindowTmp
&
randval_dram_block_window_tmp
,
// M0*N0 tile
LSEaccDramBlockWindowTmp
&
lse_acc_dram_block_window_tmp
,
// M0*1 tile
index_t
num_splits
,
index_t
i_split
,
FmhaMask
mask
,
PositionEncoding
position_encoding
,
float
scale_s
,
void
*
smem_ptr
,
BlockDropout
&
dropout
)
const
{
return
operator
()(
q_dram_block_window_tmp
,
identity
{},
k_dram_block_window_tmp
,
identity
{},
v_dram_block_window_tmp
,
identity
{},
bias_dram_block_window_tmp
,
identity
{},
randval_dram_block_window_tmp
,
lse_acc_dram_block_window_tmp
,
identity
{},
identity
{},
identity
{},
identity
{},
num_splits
,
i_split
,
mask
,
position_encoding
,
scale_s
,
smem_ptr
,
dropout
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async_default_policy.hpp
deleted
100644 → 0
View file @
829e0eb3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
namespace
ck_tile
{
// This pipeline is qkv all located in LDS
using
BlockFmhaFwdSplitKVPipelineQRKSVSAsyncDefaultPolicy
=
BlockFmhaPipelineQXKSVSCustomPolicy
<
/* QLoadOnce = */
true
,
/* AsyncCopyK = */
true
,
/* AsyncCopyV = */
false
,
/* NumPrefetchK = */
3
,
/* NumPrefetchV = */
3
>
;
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp
View file @
171ed358
...
@@ -54,38 +54,50 @@ struct BlockFmhaPipelineProblem
...
@@ -54,38 +54,50 @@ struct BlockFmhaPipelineProblem
static
constexpr
index_t
kBlockPerCu
=
Traits
::
kBlockPerCu
;
static
constexpr
index_t
kBlockPerCu
=
Traits
::
kBlockPerCu
;
};
};
template
<
typename
QDataType
,
template
<
typename
QDataType_
,
typename
KDataType
,
typename
KDataType_
,
typename
VDataType
,
typename
VDataType_
,
typename
SaccDataType
,
typename
SaccDataType_
,
typename
SMPLComputeDataType
,
typename
SMPLComputeDataType_
,
typename
BiasDataType
,
typename
BiasDataType_
,
typename
RandValOutputDataType
,
typename
LSEDataType_
,
typename
LSEDataType
,
typename
PDataType_
,
typename
PDataType
,
typename
OaccDataType_
,
typename
OaccDataType
,
typename
ODataType_
,
typename
ODataType
,
typename
BlockFmhaShape_
,
typename
BlockFmhaShape
,
bool
kIsGroupMode_
,
bool
kIsGroupMode
,
typename
FmhaMask_
,
typename
FmhaMask
,
typename
Traits_
>
typename
Traits
>
struct
BlockFmhaFwdSplitKVPipelineProblem
struct
BlockFmhaFwdSplitKVPipelineProblem
:
BlockFmhaPipelineProblem
<
QDataType
,
KDataType
,
VDataType
,
SaccDataType
,
SMPLComputeDataType
,
BiasDataType
,
RandValOutputDataType
,
LSEDataType
,
PDataType
,
OaccDataType
,
ODataType
,
BlockFmhaShape
,
kIsGroupMode
,
FmhaMask
,
Traits
>
{
{
static
constexpr
bool
kHasUnevenSplits
=
kIsGroupMode
||
Traits
::
kHasUnevenSplits
;
using
QDataType
=
remove_cvref_t
<
QDataType_
>
;
using
KDataType
=
remove_cvref_t
<
KDataType_
>
;
using
VDataType
=
remove_cvref_t
<
VDataType_
>
;
using
SaccDataType
=
remove_cvref_t
<
SaccDataType_
>
;
using
SMPLComputeDataType
=
remove_cvref_t
<
SMPLComputeDataType_
>
;
using
BiasDataType
=
remove_cvref_t
<
BiasDataType_
>
;
using
LSEDataType
=
remove_cvref_t
<
LSEDataType_
>
;
using
PDataType
=
remove_cvref_t
<
PDataType_
>
;
using
OaccDataType
=
remove_cvref_t
<
OaccDataType_
>
;
using
ODataType
=
remove_cvref_t
<
ODataType_
>
;
using
BlockFmhaShape
=
remove_cvref_t
<
BlockFmhaShape_
>
;
using
FmhaMask
=
remove_cvref_t
<
FmhaMask_
>
;
using
Traits
=
remove_cvref_t
<
Traits_
>
;
static
constexpr
index_t
kBlockSize
=
BlockFmhaShape
::
NumWarps
*
get_warp_size
();
static
constexpr
bool
kIsGroupMode
=
kIsGroupMode_
;
// attributes from traits
static
constexpr
bool
kPadSeqLenQ
=
Traits
::
kPadSeqLenQ
;
static
constexpr
bool
kPadSeqLenK
=
Traits
::
kPadSeqLenK
;
static
constexpr
bool
kPadHeadDimQ
=
Traits
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimV
=
Traits
::
kPadHeadDimV
;
static
constexpr
auto
BiasEnum
=
Traits
::
BiasEnum
;
static
constexpr
bool
kStoreLSE
=
Traits
::
kStoreLSE
;
static
constexpr
bool
kDoFp8StaticQuant
=
Traits
::
kDoFp8StaticQuant
;
static
constexpr
bool
kIsPagedKV
=
Traits
::
kIsPagedKV
;
static
constexpr
bool
kHasUnevenSplits
=
kIsGroupMode
||
Traits
::
kHasUnevenSplits
;
static
constexpr
index_t
kBlockPerCu
=
Traits
::
kBlockPerCu
;
};
};
template
<
typename
LSEDataType_
,
template
<
typename
LSEDataType_
,
...
@@ -119,4 +131,44 @@ struct BlockFmhaSplitKVCombinePipelineProblem
...
@@ -119,4 +131,44 @@ struct BlockFmhaSplitKVCombinePipelineProblem
static
constexpr
index_t
kMaxSplits
=
Traits
::
kMaxSplits
;
static
constexpr
index_t
kMaxSplits
=
Traits
::
kMaxSplits
;
};
};
template
<
typename
QDataType_
,
typename
KDataType_
,
typename
VDataType_
,
index_t
kM0_
,
index_t
kN0_
,
index_t
kK0_
,
index_t
kN1_
,
bool
kIsVLayoutRowMajor_
,
RotaryEmbeddingEnum
RotaryEnum_
,
bool
kIsPagedKV_
,
typename
Traits_
>
struct
BlockFmhaFwdAppendKVPipelineProblem
{
using
QDataType
=
remove_cvref_t
<
QDataType_
>
;
using
KDataType
=
remove_cvref_t
<
KDataType_
>
;
using
VDataType
=
remove_cvref_t
<
VDataType_
>
;
using
Traits
=
remove_cvref_t
<
Traits_
>
;
static
constexpr
index_t
kBlockSize
=
256
;
static
constexpr
index_t
kM0
=
kM0_
;
static
constexpr
index_t
kN0
=
kN0_
;
static
constexpr
index_t
kK0
=
kK0_
;
static
constexpr
index_t
kN1
=
kN1_
;
using
VLayout
=
std
::
conditional_t
<
kIsVLayoutRowMajor_
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
;
static
constexpr
auto
RotaryEnum
=
RotaryEnum_
;
static
constexpr
bool
kIsPagedKV
=
kIsPagedKV_
;
// attributes from traits
static
constexpr
bool
kPadSeqLenQ
=
Traits
::
kPadSeqLenQ
;
static
constexpr
bool
kPadSeqLenK
=
Traits
::
kPadSeqLenK
;
static
constexpr
bool
kPadHeadDimQ
=
Traits
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimV
=
Traits
::
kPadHeadDimV
;
static
constexpr
index_t
kBlockPerCu
=
Traits
::
kBlockPerCu
;
};
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
View file @
171ed358
...
@@ -707,16 +707,19 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
...
@@ -707,16 +707,19 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
{
{
if
constexpr
(
AsyncCopyK
)
if
constexpr
(
AsyncCopyK
)
{
{
return
GetSmemSizeKV
<
Problem
>
()
+
GetSmemSizeDropout
<
Problem
>
();
return
GetSmemSizeKV
<
Problem
>
()
+
GetSmemSizeDropout
<
Problem
>
(
0
);
}
}
else
else
{
{
return
ck_tile
::
max
(
GetSmemSizeKV
<
Problem
>
(),
GetSmemSizeDropout
<
Problem
>
());
return
ck_tile
::
max
(
GetSmemSizeKV
<
Problem
>
(),
GetSmemSizeDropout
<
Problem
>
(
0
));
}
}
}
}
// this method is only available when Problem::kHasDropout is present
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeDropout
()
CK_TILE_HOST_DEVICE
static
constexpr
std
::
enable_if_t
<
std
::
is_convertible_v
<
decltype
(
Problem
::
kHasDropout
),
bool
>
,
ck_tile
::
index_t
>
GetSmemSizeDropout
(
int
)
{
{
if
constexpr
(
Problem
::
kHasDropout
)
if
constexpr
(
Problem
::
kHasDropout
)
{
{
...
@@ -736,6 +739,13 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
...
@@ -736,6 +739,13 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
}
}
}
}
// fallback version if Problem::kHasDropout is not exist
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeDropout
(...)
{
return
0
;
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeKDramTileDistribution
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeKDramTileDistribution
()
{
{
...
...
include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp
View file @
171ed358
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp"
namespace
ck_tile
{
namespace
ck_tile
{
...
@@ -32,30 +33,31 @@ struct TileFmhaTraits
...
@@ -32,30 +33,31 @@ struct TileFmhaTraits
static
constexpr
index_t
kBlockPerCu
=
kBlockPerCu_
;
static
constexpr
index_t
kBlockPerCu
=
kBlockPerCu_
;
};
};
template
<
bool
kPadSeqLenQ
/* padding for seqlen_q */
,
template
<
bool
kPadSeqLenQ_
/* padding for seqlen_q */
,
bool
kPadSeqLenK
/* padding for seqlen_k */
,
bool
kPadSeqLenK_
/* padding for seqlen_k */
,
bool
kPadHeadDimQ
/* paddding for hdim_q */
,
bool
kPadHeadDimQ_
/* paddding for hdim_q */
,
bool
kPadHeadDimV
/* paddding for hdim_v */
,
bool
kPadHeadDimV_
/* paddding for hdim_v */
,
BlockAttentionBiasEnum
BiasEnum
,
BlockAttentionBiasEnum
BiasEnum_
,
bool
kHasBiasGrad
,
bool
kHasBiasGrad_
,
bool
kStoreLSE
,
bool
kStoreLSE_
,
bool
kHasDropout
,
bool
kDoFp8StaticQuant_
,
bool
kDoFp8StaticQuant
,
bool
kIsPagedKV_
,
bool
kHasUnevenSplits_
=
true
,
bool
kHasUnevenSplits_
,
index_t
kBlockPerCu
=
-
1
/* overwrite occupancy if not -1 */
>
index_t
kBlockPerCu_
=
-
1
/* overwrite occupancy if not -1 */
>
struct
TileFmhaFwdSplitKVTraits
:
TileFmhaTraits
<
kPadSeqLenQ
,
struct
TileFmhaFwdSplitKVTraits
kPadSeqLenK
,
kPadHeadDimQ
,
kPadHeadDimV
,
BiasEnum
,
kHasBiasGrad
,
kStoreLSE
,
kHasDropout
,
kDoFp8StaticQuant
,
kBlockPerCu
>
{
{
static
constexpr
bool
kPadSeqLenQ
=
kPadSeqLenQ_
;
static
constexpr
bool
kPadSeqLenK
=
kPadSeqLenK_
;
static
constexpr
bool
kPadHeadDimQ
=
kPadHeadDimQ_
;
static
constexpr
bool
kPadHeadDimV
=
kPadHeadDimV_
;
static
constexpr
auto
BiasEnum
=
BiasEnum_
;
static
constexpr
bool
kHasBiasGrad
=
kHasBiasGrad_
;
static
constexpr
bool
kStoreLSE
=
kStoreLSE_
;
static
constexpr
bool
kDoFp8StaticQuant
=
kDoFp8StaticQuant_
;
static
constexpr
bool
kIsPagedKV
=
kIsPagedKV_
;
// determine if some split (length) is not divisible by tile size
// determine if some split (length) is not divisible by tile size
static
constexpr
bool
kHasUnevenSplits
=
kHasUnevenSplits_
;
static
constexpr
bool
kHasUnevenSplits
=
kHasUnevenSplits_
;
static
constexpr
index_t
kBlockPerCu
=
kBlockPerCu_
;
};
};
template
<
bool
kPadSeqLenQ_
/* padding for seqlen_q */
,
template
<
bool
kPadSeqLenQ_
/* padding for seqlen_q */
,
...
@@ -76,6 +78,20 @@ struct TileFmhaFwdSplitKVCombineTraits
...
@@ -76,6 +78,20 @@ struct TileFmhaFwdSplitKVCombineTraits
static
constexpr
index_t
kBlockPerCu
=
kBlockPerCu_
;
static
constexpr
index_t
kBlockPerCu
=
kBlockPerCu_
;
};
};
template
<
bool
kPadSeqLenQ_
/* padding for seqlen_q */
,
bool
kPadSeqLenK_
/* padding for seqlen_k */
,
bool
kPadHeadDimQ_
/* paddding for hdim_q */
,
bool
kPadHeadDimV_
/* paddding for hdim_v */
,
index_t
kBlockPerCu_
=
-
1
/* overwrite occupancy if not -1 */
>
struct
TileFmhaFwdAppendKVTraits
{
static
constexpr
bool
kPadSeqLenQ
=
kPadSeqLenQ_
;
static
constexpr
bool
kPadSeqLenK
=
kPadSeqLenK_
;
static
constexpr
bool
kPadHeadDimQ
=
kPadHeadDimQ_
;
static
constexpr
bool
kPadHeadDimV
=
kPadHeadDimV_
;
static
constexpr
index_t
kBlockPerCu
=
kBlockPerCu_
;
};
template
<
bool
kPadSeqLenQ_
/* padding for seqlen_q */
,
template
<
bool
kPadSeqLenQ_
/* padding for seqlen_q */
,
bool
kPadHeadDimV_
/* paddding for hdim_v */
,
bool
kPadHeadDimV_
/* paddding for hdim_v */
,
index_t
kBlockPerCu_
=
2
/* hint to occupancy */
>
index_t
kBlockPerCu_
=
2
/* hint to occupancy */
>
...
...
library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
View file @
171ed358
...
@@ -74,6 +74,10 @@ using GNWK = ck::tensor_layout::convolution::GNWK;
...
@@ -74,6 +74,10 @@ using GNWK = ck::tensor_layout::convolution::GNWK;
using
GNHWK
=
ck
::
tensor_layout
::
convolution
::
GNHWK
;
using
GNHWK
=
ck
::
tensor_layout
::
convolution
::
GNHWK
;
using
GNDHWK
=
ck
::
tensor_layout
::
convolution
::
GNDHWK
;
using
GNDHWK
=
ck
::
tensor_layout
::
convolution
::
GNDHWK
;
using
NGKW
=
ck
::
tensor_layout
::
convolution
::
NGKW
;
using
NGKHW
=
ck
::
tensor_layout
::
convolution
::
NGKHW
;
using
NGKDHW
=
ck
::
tensor_layout
::
convolution
::
NGKDHW
;
//
//
using
NWGC
=
ck
::
tensor_layout
::
convolution
::
NWGC
;
using
NWGC
=
ck
::
tensor_layout
::
convolution
::
NWGC
;
using
NHWGC
=
ck
::
tensor_layout
::
convolution
::
NHWGC
;
using
NHWGC
=
ck
::
tensor_layout
::
convolution
::
NHWGC
;
...
@@ -87,6 +91,10 @@ using NWGK = ck::tensor_layout::convolution::NWGK;
...
@@ -87,6 +91,10 @@ using NWGK = ck::tensor_layout::convolution::NWGK;
using
NHWGK
=
ck
::
tensor_layout
::
convolution
::
NHWGK
;
using
NHWGK
=
ck
::
tensor_layout
::
convolution
::
NHWGK
;
using
NDHWGK
=
ck
::
tensor_layout
::
convolution
::
NDHWGK
;
using
NDHWGK
=
ck
::
tensor_layout
::
convolution
::
NDHWGK
;
using
NGCW
=
ck
::
tensor_layout
::
convolution
::
NGCW
;
using
NGCHW
=
ck
::
tensor_layout
::
convolution
::
NGCHW
;
using
NGCDHW
=
ck
::
tensor_layout
::
convolution
::
NGCDHW
;
//
//
using
G_K
=
ck
::
tensor_layout
::
convolution
::
G_K
;
using
G_K
=
ck
::
tensor_layout
::
convolution
::
G_K
;
using
GK_Tuple
=
ck
::
Tuple
<
G_K
>
;
using
GK_Tuple
=
ck
::
Tuple
<
G_K
>
;
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp
View file @
171ed358
...
@@ -56,6 +56,46 @@ using device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_f16_instances = std
...
@@ -56,6 +56,46 @@ using device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_f16_instances = std
// clang-format on
// clang-format on
>
;
>
;
// NGCHW requires transpose, we use vector loads and stores params for them
template
<
ck
::
index_t
NDimSpatial
,
typename
ALayout
,
typename
BLayout
,
typename
ELayout
,
ConvolutionBackwardWeightSpecialization
ConvSpec
,
BlockGemmPipelineScheduler
Scheduler
,
BlockGemmPipelineVersion
PipelineVersion
>
using
device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_f16_instances
=
std
::
tuple
<
// clang-format off
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumGroups|
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge|
//#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| |
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | |
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
16
,
16
,
32
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
1
,
F16
,
F16
,
1
,
1
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
32
,
32
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
2
,
F16
,
F16
,
2
,
2
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
64
,
32
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
4
,
F16
,
F16
,
4
,
4
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
128
,
32
,
8
,
32
,
32
,
1
,
4
,
S
<
4
,
4
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
8
,
8
,
false
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
8
,
8
,
false
,
1
,
1
,
S
<
1
,
4
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
8
,
F16
,
F16
,
8
,
8
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
32
,
32
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
2
,
F16
,
F16
,
2
,
2
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
64
,
32
,
32
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
4
,
F16
,
F16
,
4
,
4
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
128
,
32
,
32
,
8
,
32
,
32
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
8
,
8
,
false
,
S
<
4
,
4
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
8
,
8
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
4
>
,
1
,
Scheduler
,
PipelineVersion
,
8
,
F16
,
F16
,
8
,
8
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
32
,
32
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
2
,
F16
,
F16
,
1
,
2
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
64
,
32
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
4
,
F16
,
F16
,
1
,
4
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
128
,
32
,
8
,
32
,
32
,
1
,
4
,
S
<
4
,
4
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
8
,
8
,
false
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
8
,
8
,
false
,
1
,
1
,
S
<
1
,
4
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
8
,
F16
,
F16
,
1
,
8
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
64
,
32
,
32
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
4
,
F16
,
F16
,
1
,
4
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
128
,
32
,
32
,
8
,
32
,
32
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
8
,
8
,
false
,
S
<
4
,
4
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
8
,
8
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
4
>
,
1
,
Scheduler
,
PipelineVersion
,
8
,
F16
,
F16
,
1
,
8
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
32
,
32
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
2
,
F16
,
F16
,
2
,
1
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
64
,
32
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
4
,
F16
,
F16
,
4
,
1
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
128
,
32
,
8
,
32
,
32
,
1
,
4
,
S
<
4
,
4
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
8
,
8
,
false
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
8
,
8
,
false
,
1
,
1
,
S
<
1
,
4
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
8
,
F16
,
F16
,
8
,
1
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
64
,
32
,
32
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
4
,
F16
,
F16
,
4
,
1
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
128
,
32
,
32
,
8
,
32
,
32
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
8
,
8
,
false
,
S
<
4
,
4
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
8
,
8
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
4
>
,
1
,
Scheduler
,
PipelineVersion
,
8
,
F16
,
F16
,
8
,
1
>
// clang-format on
>
;
}
// namespace instance
}
// namespace instance
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp
0 → 100644
View file @
171ed358
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
BF16
=
ck
::
bhalf_t
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
Empty_Tuple
=
ck
::
Tuple
<>
;
using
namespace
ck
::
tensor_layout
::
convolution
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
ConvFwdDefault
=
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
Default
;
static
constexpr
auto
ConvFwd3x3
=
ConvolutionForwardSpecialization
::
Filter3x3
;
static
constexpr
auto
GemmMNKPadding
=
GemmSpecialization
::
MNKPadding
;
template
<
index_t
NDimSpatial
,
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
ConvolutionForwardSpecialization
ConvSpec
>
using
device_grouped_conv_fwd_xdl_merged_groups_bf16_instances
=
std
::
tuple
<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| ACompute| BCompute| BlockGemm| NumGroups|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Type| Type| Pipeline| ToMerge|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | Scheduler| |
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// Instances with NumGroupsPerBatch > 1
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
DsLayout
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
16
,
16
,
4
,
4
,
16
,
16
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
BF16
,
BF16
,
LoopScheduler
::
Default
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
DsLayout
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
16
,
16
,
4
,
4
,
16
,
16
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
BF16
,
BF16
,
LoopScheduler
::
Default
,
16
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
DsLayout
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
16
,
16
,
4
,
4
,
16
,
16
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
BF16
,
BF16
,
LoopScheduler
::
Default
,
32
>
// clang-format on
>
;
template
<
index_t
NDimSpatial
,
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
ConvolutionForwardSpecialization
ConvSpec
>
using
device_grouped_conv_fwd_xdl_merged_groups_f16_instances
=
std
::
tuple
<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// Instances with NumGroupsPerBatch > 1
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
DsLayout
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
16
,
16
,
4
,
4
,
16
,
16
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
F16
,
F16
,
LoopScheduler
::
Default
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
DsLayout
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
16
,
16
,
4
,
4
,
16
,
16
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
F16
,
F16
,
LoopScheduler
::
Default
,
16
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
DsLayout
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
16
,
16
,
4
,
4
,
16
,
16
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
F16
,
F16
,
LoopScheduler
::
Default
,
32
>
// clang-format on
>
;
template
<
index_t
NDimSpatial
,
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
ConvolutionForwardSpecialization
ConvSpec
>
using
device_grouped_conv_fwd_xdl_merged_groups_f32_instances
=
std
::
tuple
<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// Instances with NumGroupsPerBatch > 1
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
DsLayout
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
16
,
16
,
4
,
4
,
16
,
16
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
F32
,
F32
,
LoopScheduler
::
Default
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
DsLayout
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
16
,
16
,
4
,
4
,
16
,
16
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
F32
,
F32
,
LoopScheduler
::
Default
,
16
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
DsLayout
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
16
,
16
,
4
,
4
,
16
,
16
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
F32
,
F32
,
LoopScheduler
::
Default
,
32
>
// clang-format on
>
;
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp
View file @
171ed358
...
@@ -367,6 +367,21 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -367,6 +367,21 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances
(
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances
(
op_ptrs
);
op_ptrs
);
}
}
#endif
}
if
constexpr
(
is_same_v
<
InLayout
,
NGCHW
>
&&
is_same_v
<
WeiLayout
,
GKYXC
>
&&
is_same_v
<
OutLayout
,
NGKHW
>
)
{
#ifdef CK_ENABLE_FP16
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
&&
is_same_v
<
ComputeTypeA
,
half_t
>
&&
is_same_v
<
ComputeTypeB
,
half_t
>
)
{
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev2_instances
(
op_ptrs
);
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev5_instances
(
op_ptrs
);
}
#endif
#endif
}
}
}
}
...
@@ -447,6 +462,21 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -447,6 +462,21 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_f8_instances
(
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_f8_instances
(
op_ptrs
);
op_ptrs
);
}
}
#endif
}
if
constexpr
(
is_same_v
<
InLayout
,
NGCDHW
>
&&
is_same_v
<
WeiLayout
,
GKZYXC
>
&&
is_same_v
<
OutLayout
,
NGKDHW
>
)
{
#ifdef CK_ENABLE_FP16
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
&&
is_same_v
<
ComputeTypeA
,
half_t
>
&&
is_same_v
<
ComputeTypeB
,
half_t
>
)
{
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev2_instances
(
op_ptrs
);
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev5_instances
(
op_ptrs
);
}
#endif
#endif
}
}
}
}
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_xdl.inc
View file @
171ed358
...
@@ -137,6 +137,29 @@ void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pi
...
@@ -137,6 +137,29 @@ void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pi
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
void
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev2_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
2
,
NGCHW
,
GKYXC
,
NGKHW
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev5_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
2
,
NGCHW
,
GKYXC
,
NGKHW
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#endif
#ifdef CK_ENABLE_FP32
#ifdef CK_ENABLE_FP32
void
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instances
(
void
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instances
(
...
@@ -240,6 +263,29 @@ void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16
...
@@ -240,6 +263,29 @@ void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
void
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev2_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
NGCDHW
,
GKZYXC
,
NGKDHW
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev5_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
NGCDHW
,
GKZYXC
,
NGKDHW
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#endif
#ifdef CK_ENABLE_FP32
#ifdef CK_ENABLE_FP32
void
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances
(
void
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances
(
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp
View file @
171ed358
...
@@ -18,6 +18,7 @@
...
@@ -18,6 +18,7 @@
#ifdef CK_USE_XDL
#ifdef CK_USE_XDL
#include "grouped_convolution_forward_xdl.inc"
#include "grouped_convolution_forward_xdl.inc"
#include "grouped_convolution_forward_xdl_large_tensor.inc"
#include "grouped_convolution_forward_xdl_large_tensor.inc"
#include "grouped_convolution_forward_xdl_merged_groups.inc"
#include "grouped_convolution_forward_comp_xdl.inc"
#include "grouped_convolution_forward_comp_xdl.inc"
#include "grouped_convolution_forward_mem_inter_xdl.inc"
#include "grouped_convolution_forward_mem_inter_xdl.inc"
#include "grouped_convolution_forward_mem_intra_xdl.inc"
#include "grouped_convolution_forward_mem_intra_xdl.inc"
...
@@ -202,6 +203,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -202,6 +203,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instances
(
add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instances
(
op_ptrs
);
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances
(
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances
(
op_ptrs
);
op_ptrs
);
...
@@ -217,6 +220,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -217,6 +220,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances
(
add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances
(
op_ptrs
);
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instances
(
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instances
(
op_ptrs
);
op_ptrs
);
...
@@ -234,6 +239,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -234,6 +239,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances
(
add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances
(
op_ptrs
);
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instances
(
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instances
(
op_ptrs
);
op_ptrs
);
...
@@ -293,6 +300,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -293,6 +300,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances
(
op_ptrs
);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances
(
op_ptrs
);
add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instances
(
add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instances
(
op_ptrs
);
op_ptrs
);
add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances
(
op_ptrs
);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances
(
op_ptrs
);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances
(
op_ptrs
);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances
(
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances
(
op_ptrs
);
op_ptrs
);
...
@@ -349,6 +358,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -349,6 +358,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances
(
op_ptrs
);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances
(
op_ptrs
);
add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances
(
add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances
(
op_ptrs
);
op_ptrs
);
add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instances
(
op_ptrs
);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instances
(
op_ptrs
);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instances
(
op_ptrs
);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instances
(
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instances
(
op_ptrs
);
op_ptrs
);
...
@@ -366,6 +377,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -366,6 +377,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances
(
op_ptrs
);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances
(
op_ptrs
);
add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances
(
add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances
(
op_ptrs
);
op_ptrs
);
add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances
(
op_ptrs
);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances
(
op_ptrs
);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances
(
op_ptrs
);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances
(
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances
(
op_ptrs
);
op_ptrs
);
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl_merged_groups.inc
0 → 100644
View file @
171ed358
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
// grouped conv2d forward, NHWGC/GKYXC/NHWGK
#ifdef CK_ENABLE_BF16
void
add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NHWGC
,
GKYXC
,
Empty_Tuple
,
NHWGK
,
BF16
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_FP16
void
add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NHWGC
,
GKYXC
,
Empty_Tuple
,
NHWGK
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_FP32
void
add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NHWGC
,
GKYXC
,
Empty_Tuple
,
NHWGK
,
F32
,
F32
,
Empty_Tuple
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_BF16
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
void
add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
Empty_Tuple
,
NDHWGK
,
BF16
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_FP16
void
add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
Empty_Tuple
,
NDHWGK
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_FP32
void
add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
Empty_Tuple
,
NDHWGK
,
F32
,
F32
,
Empty_Tuple
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/utility/convolution_host_tensor_descriptor_helper.hpp
View file @
171ed358
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -46,6 +46,21 @@ std::vector<std::size_t> get_layout_transpose_gnchw_to_old()
...
@@ -46,6 +46,21 @@ std::vector<std::size_t> get_layout_transpose_gnchw_to_old()
{
{
return
{
0
,
1
,
2
,
3
};
return
{
0
,
1
,
2
,
3
};
}
}
else
if
constexpr
(
ck
::
is_same_v
<
OldLayout
,
ck
::
tensor_layout
::
convolution
::
NGCW
>
||
ck
::
is_same_v
<
OldLayout
,
ck
::
tensor_layout
::
convolution
::
NGKW
>
)
{
return
{
1
,
0
,
2
,
3
};
}
else
if
constexpr
(
ck
::
is_same_v
<
OldLayout
,
ck
::
tensor_layout
::
convolution
::
NGCHW
>
||
ck
::
is_same_v
<
OldLayout
,
ck
::
tensor_layout
::
convolution
::
NGKHW
>
)
{
return
{
1
,
0
,
2
,
3
,
4
};
}
else
if
constexpr
(
ck
::
is_same_v
<
OldLayout
,
ck
::
tensor_layout
::
convolution
::
NGCDHW
>
||
ck
::
is_same_v
<
OldLayout
,
ck
::
tensor_layout
::
convolution
::
NGKDHW
>
)
{
return
{
1
,
0
,
2
,
3
,
4
,
5
};
}
else
if
constexpr
(
ck
::
is_same_v
<
OldLayout
,
ck
::
tensor_layout
::
convolution
::
GNCHW
>
||
else
if
constexpr
(
ck
::
is_same_v
<
OldLayout
,
ck
::
tensor_layout
::
convolution
::
GNCHW
>
||
ck
::
is_same_v
<
OldLayout
,
ck
::
tensor_layout
::
convolution
::
GKCYX
>
||
ck
::
is_same_v
<
OldLayout
,
ck
::
tensor_layout
::
convolution
::
GKCYX
>
||
ck
::
is_same_v
<
OldLayout
,
ck
::
tensor_layout
::
convolution
::
GNKHW
>
)
ck
::
is_same_v
<
OldLayout
,
ck
::
tensor_layout
::
convolution
::
GNKHW
>
)
...
@@ -132,6 +147,18 @@ make_input_host_tensor_descriptor_g_n_c_wis_packed(const ck::utils::conv::ConvPa
...
@@ -132,6 +147,18 @@ make_input_host_tensor_descriptor_g_n_c_wis_packed(const ck::utils::conv::ConvPa
param
.
input_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
param
.
input_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
}
}
// separate from legacy code above
// separate from legacy code above
else
if
constexpr
(
ck
::
is_same_v
<
InLayout
,
ck
::
tensor_layout
::
convolution
::
NGCW
>
||
ck
::
is_same_v
<
InLayout
,
ck
::
tensor_layout
::
convolution
::
NGCHW
>
||
ck
::
is_same_v
<
InLayout
,
ck
::
tensor_layout
::
convolution
::
NGCDHW
>
)
{
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
N_
),
static_cast
<
std
::
size_t
>
(
param
.
G_
),
static_cast
<
std
::
size_t
>
(
param
.
C_
)};
physical_lengths
.
insert
(
physical_lengths
.
end
(),
param
.
input_spatial_lengths_
.
begin
(),
param
.
input_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
}
else
if
constexpr
(
ck
::
is_same_v
<
InLayout
,
ck
::
tensor_layout
::
convolution
::
GNCW
>
||
else
if
constexpr
(
ck
::
is_same_v
<
InLayout
,
ck
::
tensor_layout
::
convolution
::
GNCW
>
||
ck
::
is_same_v
<
InLayout
,
ck
::
tensor_layout
::
convolution
::
GNCHW
>
||
ck
::
is_same_v
<
InLayout
,
ck
::
tensor_layout
::
convolution
::
GNCHW
>
||
ck
::
is_same_v
<
InLayout
,
ck
::
tensor_layout
::
convolution
::
GNCDHW
>
)
ck
::
is_same_v
<
InLayout
,
ck
::
tensor_layout
::
convolution
::
GNCDHW
>
)
...
@@ -314,6 +341,19 @@ make_output_host_tensor_descriptor_g_n_k_wos_packed(const ck::utils::conv::ConvP
...
@@ -314,6 +341,19 @@ make_output_host_tensor_descriptor_g_n_k_wos_packed(const ck::utils::conv::ConvP
param
.
output_spatial_lengths_
.
begin
(),
param
.
output_spatial_lengths_
.
begin
(),
param
.
output_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
param
.
output_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
}
}
// separate from legacy code above
else
if
constexpr
(
ck
::
is_same_v
<
OutLayout
,
ck
::
tensor_layout
::
convolution
::
NGKW
>
||
ck
::
is_same_v
<
OutLayout
,
ck
::
tensor_layout
::
convolution
::
NGKHW
>
||
ck
::
is_same_v
<
OutLayout
,
ck
::
tensor_layout
::
convolution
::
NGKDHW
>
)
{
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
N_
),
static_cast
<
std
::
size_t
>
(
param
.
G_
),
static_cast
<
std
::
size_t
>
(
param
.
K_
)};
physical_lengths
.
insert
(
physical_lengths
.
end
(),
param
.
output_spatial_lengths_
.
begin
(),
param
.
output_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
}
else
if
constexpr
(
ck
::
is_same_v
<
OutLayout
,
ck
::
tensor_layout
::
convolution
::
GNWK
>
||
else
if
constexpr
(
ck
::
is_same_v
<
OutLayout
,
ck
::
tensor_layout
::
convolution
::
GNWK
>
||
ck
::
is_same_v
<
OutLayout
,
ck
::
tensor_layout
::
convolution
::
GNHWK
>
||
ck
::
is_same_v
<
OutLayout
,
ck
::
tensor_layout
::
convolution
::
GNHWK
>
||
ck
::
is_same_v
<
OutLayout
,
ck
::
tensor_layout
::
convolution
::
GNDHWK
>
)
ck
::
is_same_v
<
OutLayout
,
ck
::
tensor_layout
::
convolution
::
GNDHWK
>
)
...
...
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt
View file @
171ed358
...
@@ -8,6 +8,8 @@ set(GROUPED_CONV2D_BWD_WEIGHT
...
@@ -8,6 +8,8 @@ set(GROUPED_CONV2D_BWD_WEIGHT
xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev2_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev5_instance.cpp
)
)
if
(
DL_KERNELS
)
if
(
DL_KERNELS
)
...
...
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev2_instance.cpp
0 → 100644
View file @
171ed358
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev2_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
2
,
NGCHW
,
GKYXC
,
NGKHW
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
// 1. Default
add_device_operation_instances
(
instances
,
device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_f16_instances
<
2
,
NGCHW
,
GKYXC
,
NGKHW
,
ConvBwdWeightDefault
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v2
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev5_instance.cpp
0 → 100644
View file @
171ed358
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev5_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
2
,
NGCHW
,
GKYXC
,
NGKHW
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
// 1. Default
add_device_operation_instances
(
instances
,
device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_f16_instances
<
2
,
NGCHW
,
GKYXC
,
NGKHW
,
ConvBwdWeightDefault
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v5
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment