Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
909f519c
Unverified
Commit
909f519c
authored
Jun 27, 2024
by
Harisankar Sadasivan
Committed by
GitHub
Jun 27, 2024
Browse files
Merge branch 'develop' into universal_streamk
parents
406fa265
3bb0fe6c
Changes
82
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
822 additions
and
1232 deletions
+822
-1232
example/ck_tile/01_fmha/fmha_fwd.hpp
example/ck_tile/01_fmha/fmha_fwd.hpp
+215
-0
example/ck_tile/01_fmha/generate.py
example/ck_tile/01_fmha/generate.py
+33
-1184
include/ck/ck.hpp
include/ck/ck.hpp
+5
-9
include/ck/host_utility/device_prop.hpp
include/ck/host_utility/device_prop.hpp
+5
-0
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
...ude/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
+499
-0
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
...e/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
+7
-0
include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp
...l/device_batched_contraction_multiple_d_wmma_cshuffle.hpp
+16
-7
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp
...ion/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp
+4
-3
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
...evice_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
+5
-4
include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp
...evice/impl/device_contraction_multiple_d_xdl_cshuffle.hpp
+1
-3
include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp
...gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp
...or_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp
...de/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp
...r_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp
+4
-3
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp
.../gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
.../ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
+15
-8
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp
...device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp
...ion/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp
+3
-2
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp
...ice/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp
.../device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp
+4
-3
No files found.
example/ck_tile/01_fmha/fmha_fwd.hpp
View file @
909f519c
...
...
@@ -93,6 +93,8 @@ struct fmha_fwd_args
const
void
*
v_ptr
;
const
void
*
bias_ptr
;
// bias or alibi_slope pointer
void
*
rand_val_ptr
;
void
*
lse_acc_ptr
;
void
*
o_acc_ptr
;
void
*
lse_ptr
;
void
*
o_ptr
;
const
void
*
seqstart_q_ptr
;
...
...
@@ -106,6 +108,7 @@ struct fmha_fwd_args
ck_tile
::
index_t
hdim_v
;
ck_tile
::
index_t
nhead_q
;
ck_tile
::
index_t
nhead_k
;
ck_tile
::
index_t
num_splits
;
float
scale_s
;
float
scale_p
;
float
scale_o
;
...
...
@@ -114,6 +117,7 @@ struct fmha_fwd_args
ck_tile
::
index_t
stride_v
;
ck_tile
::
index_t
stride_bias
;
// if alibi, b*h need set this to h, 1*h need set this to 0
ck_tile
::
index_t
stride_randval
;
ck_tile
::
index_t
stride_o_acc
;
ck_tile
::
index_t
stride_o
;
ck_tile
::
index_t
nhead_stride_q
;
ck_tile
::
index_t
nhead_stride_k
;
...
...
@@ -121,6 +125,8 @@ struct fmha_fwd_args
ck_tile
::
index_t
nhead_stride_bias
;
ck_tile
::
index_t
nhead_stride_randval
;
ck_tile
::
index_t
nhead_stride_lse
;
ck_tile
::
index_t
nhead_stride_lse_acc
;
ck_tile
::
index_t
nhead_stride_o_acc
;
ck_tile
::
index_t
nhead_stride_o
;
ck_tile
::
index_t
batch_stride_q
;
ck_tile
::
index_t
batch_stride_k
;
...
...
@@ -128,7 +134,11 @@ struct fmha_fwd_args
ck_tile
::
index_t
batch_stride_bias
;
ck_tile
::
index_t
batch_stride_randval
;
ck_tile
::
index_t
batch_stride_lse
;
ck_tile
::
index_t
batch_stride_lse_acc
;
ck_tile
::
index_t
batch_stride_o_acc
;
ck_tile
::
index_t
batch_stride_o
;
ck_tile
::
index_t
split_stride_lse_acc
;
ck_tile
::
index_t
split_stride_o_acc
;
ck_tile
::
index_t
window_size_left
;
ck_tile
::
index_t
window_size_right
;
ck_tile
::
index_t
mask_type
;
...
...
@@ -234,6 +244,176 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
return
ck_tile
::
make_tuple
(
kargs
,
grids
);
}
template
<
typename
Kernel
>
auto
fmha_fwd_splitkv_create_kargs_and_grids
(
fmha_fwd_args
args
)
{
assert
(
args
.
nhead_q
%
args
.
nhead_k
==
0
);
auto
kargs
=
[
&
]
{
// create group mode kernel arguments
if
constexpr
(
Kernel
::
kIsGroupMode
)
{
return
Kernel
::
MakeKargs
(
args
.
q_ptr
,
args
.
k_ptr
,
args
.
v_ptr
,
args
.
bias_ptr
,
args
.
rand_val_ptr
,
args
.
lse_acc_ptr
,
args
.
o_acc_ptr
,
args
.
batch
,
args
.
max_seqlen_q
,
args
.
seqstart_q_ptr
,
args
.
seqstart_k_ptr
,
args
.
seqlen_k_ptr
,
args
.
hdim_q
,
args
.
hdim_v
,
args
.
nhead_q
,
args
.
nhead_q
/
args
.
nhead_k
,
args
.
num_splits
,
args
.
scale_s
,
args
.
scale_p
,
args
.
stride_q
,
args
.
stride_k
,
args
.
stride_v
,
args
.
stride_bias
,
args
.
stride_randval
,
args
.
stride_o_acc
,
args
.
nhead_stride_q
,
args
.
nhead_stride_k
,
args
.
nhead_stride_v
,
args
.
nhead_stride_bias
,
args
.
nhead_stride_randval
,
args
.
nhead_stride_lse_acc
,
args
.
nhead_stride_o_acc
,
args
.
batch_stride_lse_acc
,
args
.
batch_stride_o_acc
,
args
.
split_stride_lse_acc
,
args
.
split_stride_o_acc
,
args
.
window_size_left
,
args
.
window_size_right
,
args
.
mask_type
,
args
.
p_drop
,
args
.
s_randval
,
args
.
drop_seed_offset
);
}
else
{
// create batch mode kernel arguments
return
Kernel
::
MakeKargs
(
args
.
q_ptr
,
args
.
k_ptr
,
args
.
v_ptr
,
args
.
bias_ptr
,
args
.
rand_val_ptr
,
args
.
lse_acc_ptr
,
args
.
o_acc_ptr
,
args
.
batch
,
args
.
max_seqlen_q
,
args
.
seqlen_q
,
args
.
seqlen_k
,
args
.
hdim_q
,
args
.
hdim_v
,
args
.
nhead_q
,
args
.
nhead_q
/
args
.
nhead_k
,
args
.
num_splits
,
args
.
scale_s
,
args
.
scale_p
,
args
.
stride_q
,
args
.
stride_k
,
args
.
stride_v
,
args
.
stride_bias
,
args
.
stride_randval
,
args
.
stride_o_acc
,
args
.
nhead_stride_q
,
args
.
nhead_stride_k
,
args
.
nhead_stride_v
,
args
.
nhead_stride_bias
,
args
.
nhead_stride_randval
,
args
.
nhead_stride_lse_acc
,
args
.
nhead_stride_o_acc
,
args
.
batch_stride_q
,
args
.
batch_stride_k
,
args
.
batch_stride_v
,
args
.
batch_stride_bias
,
args
.
batch_stride_randval
,
args
.
batch_stride_lse_acc
,
args
.
batch_stride_o_acc
,
args
.
split_stride_lse_acc
,
args
.
split_stride_o_acc
,
args
.
window_size_left
,
args
.
window_size_right
,
args
.
mask_type
,
args
.
p_drop
,
args
.
s_randval
,
args
.
drop_seed_offset
);
}
}();
dim3
grids
=
Kernel
::
GridSize
(
args
.
batch
,
args
.
nhead_q
,
args
.
max_seqlen_q
,
args
.
hdim_v
,
args
.
num_splits
);
return
ck_tile
::
make_tuple
(
kargs
,
grids
);
}
template
<
typename
Kernel
>
auto
fmha_fwd_splitkv_combine_create_kargs_and_grids
(
fmha_fwd_args
args
)
{
assert
(
args
.
nhead_q
%
args
.
nhead_k
==
0
);
auto
kargs
=
[
&
]
{
// create group mode kernel argumentszs
if
constexpr
(
Kernel
::
kIsGroupMode
)
{
return
Kernel
::
MakeKargs
(
args
.
lse_acc_ptr
,
args
.
o_acc_ptr
,
args
.
lse_ptr
,
args
.
o_ptr
,
args
.
batch
,
args
.
max_seqlen_q
,
args
.
seqstart_q_ptr
,
args
.
hdim_v
,
args
.
num_splits
,
args
.
scale_o
,
args
.
stride_o_acc
,
args
.
stride_o
,
args
.
nhead_stride_lse_acc
,
args
.
nhead_stride_o_acc
,
args
.
nhead_stride_lse
,
args
.
nhead_stride_o
,
args
.
batch_stride_lse_acc
,
args
.
batch_stride_o_acc
,
args
.
batch_stride_lse
,
args
.
split_stride_lse_acc
,
args
.
split_stride_o_acc
);
}
else
{
// create batch mode kernel arguments
return
Kernel
::
MakeKargs
(
args
.
lse_acc_ptr
,
args
.
o_acc_ptr
,
args
.
lse_ptr
,
args
.
o_ptr
,
args
.
batch
,
args
.
max_seqlen_q
,
args
.
seqlen_q
,
args
.
hdim_v
,
args
.
num_splits
,
args
.
scale_o
,
args
.
stride_o_acc
,
args
.
stride_o
,
args
.
nhead_stride_lse_acc
,
args
.
nhead_stride_o_acc
,
args
.
nhead_stride_lse
,
args
.
nhead_stride_o
,
args
.
batch_stride_lse_acc
,
args
.
batch_stride_o_acc
,
args
.
batch_stride_lse
,
args
.
batch_stride_o
,
args
.
split_stride_lse_acc
,
args
.
split_stride_o_acc
);
}
}();
dim3
grids
=
Kernel
::
GridSize
(
args
.
batch
,
args
.
nhead_q
,
args
.
max_seqlen_q
,
args
.
hdim_v
);
return
ck_tile
::
make_tuple
(
kargs
,
grids
);
}
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template
<
ck_tile
::
index_t
HDim_
,
typename
DataType_
,
...
...
@@ -282,6 +462,40 @@ struct fmha_fwd_traits_
template
<
typename
Traits_
>
float
fmha_fwd_
(
const
ck_tile
::
stream_config
&
,
fmha_fwd_args
);
template
<
typename
Traits_
>
void
fmha_fwd_splitkv_oneshot_
(
const
ck_tile
::
stream_config
&
,
fmha_fwd_args
);
template
<
typename
Traits_
>
std
::
string
fmha_fwd_splitkv_get_name_
();
template
<
ck_tile
::
index_t
HDim_
,
typename
DataType_
,
bool
kIsGroupMode_
,
ck_tile
::
index_t
kM0_
,
ck_tile
::
index_t
kN1_
,
bool
kStoreLse_
,
bool
kDoFp8StaticQuant_
,
bool
kPadS_
,
bool
kPadDv_
>
struct
fmha_fwd_splitkv_combine_traits_
{
static
constexpr
ck_tile
::
index_t
HDim
=
HDim_
;
using
DataType
=
ck_tile
::
remove_cvref_t
<
DataType_
>
;
static
constexpr
bool
kIsGroupMode
=
kIsGroupMode_
;
static
constexpr
ck_tile
::
index_t
kM0
=
kM0_
;
static
constexpr
ck_tile
::
index_t
kN1
=
kN1_
;
static
constexpr
bool
kStoreLse
=
kStoreLse_
;
static
constexpr
bool
kDoFp8StaticQuant
=
kDoFp8StaticQuant_
;
static
constexpr
bool
kPadS
=
kPadS_
;
static
constexpr
bool
kPadDv
=
kPadDv_
;
};
template
<
typename
Traits_
>
void
fmha_fwd_splitkv_combine_oneshot_
(
const
ck_tile
::
stream_config
&
,
fmha_fwd_args
);
template
<
typename
Traits_
>
std
::
string
fmha_fwd_splitkv_combine_get_name_
();
// This is the public API, will be generated by script
struct
fmha_fwd_traits
{
...
...
@@ -298,3 +512,4 @@ struct fmha_fwd_traits
// TODO: padding check is inside this api
};
float
fmha_fwd
(
fmha_fwd_traits
,
fmha_fwd_args
,
const
ck_tile
::
stream_config
&
);
float
fmha_fwd_splitkv
(
fmha_fwd_traits
,
fmha_fwd_args
,
const
ck_tile
::
stream_config
&
);
example/ck_tile/01_fmha/generate.py
View file @
909f519c
...
...
@@ -3,1214 +3,62 @@
# generate kernel instances to speed up compilation
import
argparse
import
itertools
from
enum
import
IntEnum
from
pathlib
import
Path
from
typing
import
List
,
Optional
,
Tuple
from
dataclasses
import
dataclass
import
copy
import
fnmatch
from
typing
import
List
,
Optional
DTYPE_MAP
=
{
"fp16"
:
"ck_tile::fp16_t"
,
"bf16"
:
"ck_tile::bf16_t"
,
"fp8"
:
"ck_tile::fp8_t"
}
DTYPE_BITS
=
{
"fp32"
:
32
,
"fp16"
:
16
,
"bf16"
:
16
,
"fp8"
:
8
,
"bf8"
:
8
}
MASK_IMPL
=
{
"generic"
:
"ck_tile::GenericAttentionMask"
,
"simplified"
:
"ck_tile::SimplifiedGenericAttentionMask"
}
MASK_SIMPLIFIED_MAP
=
{
"s_no"
:
"ck_tile::SimplifiedGenericAttentionMask<false>"
,
"s_mask"
:
"ck_tile::SimplifiedGenericAttentionMask<true>"
,
}
MASK_MAP
=
{
"no"
:
"FmhaMasks::NoMask"
,
"causal"
:
"FmhaMasks::CausalMask"
,
"generic"
:
"FmhaMasks::GenericMask"
}
BIAS_MAP
=
{
"no"
:
"ck_tile::BlockAttentionBiasEnum::NO_BIAS"
,
"bias"
:
"ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS"
,
"alibi"
:
"ck_tile::BlockAttentionBiasEnum::ALIBI"
}
# TODO: this is ugly
BIAS_CHECK_MAP
=
{
"no"
:
"bias_enum::no_bias"
,
"bias"
:
"bias_enum::elementwise_bias"
,
"alibi"
:
"bias_enum::alibi"
}
MODE_MAP
=
{
"batch"
:
"false"
,
"group"
:
"true"
}
LAYOUT_MAP
=
{
"row"
:
"true"
,
"col"
:
"false"
}
PIPELINE_MAP
=
{
"qr"
:
"ck_tile::BlockFmhaPipelineQRKSVS"
,
"qr_async"
:
"ck_tile::BlockFmhaPipelineQRKSVSAsync"
,
}
PIPELINE_ENUM_MAP
=
{
"qr"
:
"ck_tile::BlockFmhaPipelineEnum::QRKSVS"
,
"qr_async"
:
"ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC"
,
}
BOOL_MAP
=
{
"t"
:
"true"
,
"f"
:
"false"
}
TILE_PARTITIONER_MAP
=
{
"shb"
:
"ck_tile::FmhaFwdTilePartitioner_SHB"
,
"hbs"
:
"ck_tile::FmhaFwdTilePartitioner_HBS"
,
}
GEN_DIR
=
""
# in Cmake, have to generate files in same folder
FMHA_FWD_KERNEL_HEADER
=
"""// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
\n
// auto generated by generate.py
#include "fmha_fwd.hpp"
"""
FMHA_FWD_KERNEL_BODY
=
"""
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}>;
using fmha_block_warps_{F_idx} = ck_tile::sequence<{F_rm}, {F_rn}, {F_rk}>;
using fmha_warp_tile_{F_idx} = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>;
using fmha_shape_{F_idx} = ck_tile::TileFmhaShape<fmha_block_tile_{F_idx},
fmha_block_warps_{F_idx},
fmha_warp_tile_{F_idx},
fmha_block_warps_{F_idx},
fmha_warp_tile_{F_idx},
{F_vlayout}>;
using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
{F_skpad},
{F_dpad},
{F_dvpad},
{F_bias},
false,
{F_lse},
{F_dropout},
{F_squant},
{F_occupancy}>;
using fmha_mask_{F_idx} = {F_mask};
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem<
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::KDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::VDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::SaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::SMPLComputeDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::BiasDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::RandValOutputDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::LSEDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::PDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
fmha_shape_{F_idx},
{F_mode},
fmha_mask_{F_idx},
fmha_trait_{F_idx}>;
using fmha_pipeline_{F_idx} = {F_pipeline}<
fmha_pipeline_problem_{F_idx}>;
using fmha_epilogue_{F_idx} =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType,
typename FmhaFwdTypeConfig<{F_dtype}>::ODataType,
{F_spad}, {F_dvpad}>>;
using fmha_kernel_{F_idx} =
ck_tile::FmhaFwdKernel<{F_tile_partitioner}<fmha_shape_{F_idx}>,
fmha_pipeline_{F_idx},
fmha_epilogue_{F_idx}>;
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout},
{F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
#include <iostream>
template<>
float fmha_fwd_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_args a)
{{
using k_ = fmha_kernel_{F_idx};
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_fwd_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
}}
"""
FMHA_FWD_API_FILENAME
=
"fmha_fwd_api.cpp"
FMHA_FWD_API
=
"""
float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& s){{
float r = -1;
{F_dispatch}
return r;
}}
"""
FMHA_FWD_API_PER_DTYPE
=
""" {F_if}(t.data_type.compare(
\"
{F_dtype}
\"
) == 0){{
{F_hdim_case}
}}
"""
FMHA_FWD_API_PER_HDIM_CASE
=
""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim}) {{
{F_inner_dispatch}
}}
"""
MASK_CHECK_MAP
=
{
"no"
:
"t.mask_type == mask_enum::no_mask"
,
"causal"
:
"t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right"
,
"generic"
:
"t.mask_type == mask_enum::window_generic"
,
}
MASK_SIMPLIFIED_CHECK_MAP
=
{
"s_no"
:
"t.mask_type == mask_enum::no_mask"
,
"s_mask"
:
"t.mask_type != mask_enum::no_mask"
,
}
FMHA_FWD_API_INNER_DISPATCH
=
""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
return fmha_fwd_<trait_>(s, a);
}}
"""
def
get_mask_map
(
mask
:
str
):
if
mask
==
"generic"
:
return
MASK_MAP
elif
mask
==
"simplified"
:
return
MASK_SIMPLIFIED_MAP
else
:
assert
False
return
None
def
get_mask_check_map
(
mask
:
str
):
if
mask
==
"generic"
:
return
MASK_CHECK_MAP
elif
mask
==
"simplified"
:
return
MASK_SIMPLIFIED_CHECK_MAP
else
:
assert
False
return
None
@
dataclass
class
FmhaFwdApiTrait
:
pipeline_tag
:
str
# sync with fmha_fwd_traits<>, to generate fallback calls
hdim
:
str
dtype
:
str
# data type
mode
:
str
# value from MODE_MAP
bm0
:
int
# tile size along q seqlen (block size)
bn0
:
int
# tile size along qk seqlen
bk0
:
int
# tile size along qk gemm unroll
bn1
:
int
# tile size along v head_dim
bk1
:
int
# tile size along kv gemm unroll
bk0blen
:
int
vlayout
:
str
mask
:
str
bias
:
str
#
lse
:
str
#
dropout
:
str
squant
:
str
#
spad
:
str
skpad
:
str
dpad
:
str
dvpad
:
str
@
property
def
name
(
self
)
->
str
:
return
f
'
{
self
.
hdim
}
-
{
self
.
dtype
}
-
{
self
.
mode
}
-
{
self
.
bm0
}
-
{
self
.
bn0
}
-
{
self
.
bk0
}
-
{
self
.
bn0
}
-
{
self
.
bk1
}
-
{
self
.
bk0blen
}
-'
+
\
f
'
{
self
.
vlayout
}
-
{
self
.
mask
}
-
{
self
.
bias
}
-
{
self
.
lse
}
-
{
self
.
dropout
}
-
{
self
.
squant
}
-
{
self
.
spad
}
-
{
self
.
skpad
}
-
{
self
.
dpad
}
-
{
self
.
dvpad
}
'
@
property
def
scheck
(
self
)
->
str
:
if
self
.
mode
==
'group'
:
return
'true/*group mode spad always true*/'
# group mode only generate spad/skpad == true
if
self
.
pipeline_tag
==
'qr_async'
:
if
self
.
spad
==
't'
:
return
'true'
# always support
else
:
return
'true'
elif
self
.
pipeline_tag
in
[
'qr'
]:
if
self
.
spad
==
't'
:
return
f
'true /*a.seqlen_q %
{
self
.
bm0
}
!= 0*/'
# TODO: order of get_pipelines() matters! (ugly)
else
:
return
f
'a.seqlen_q %
{
self
.
bm0
}
== 0'
else
:
assert
False
@
property
def
skcheck
(
self
)
->
str
:
if
self
.
mode
==
'group'
:
return
'true/*group mode skpad always true*/'
# group mode only generate spad/skpad == true
if
self
.
pipeline_tag
==
'qr_async'
:
if
self
.
skpad
==
't'
:
return
f
'a.seqlen_k == 0 || a.seqlen_k %
{
self
.
bn0
}
!= 0'
else
:
return
f
'a.seqlen_k != 0 && a.seqlen_k %
{
self
.
bn0
}
== 0'
elif
self
.
pipeline_tag
in
[
'qr'
,
'qr_fp8'
]:
if
self
.
skpad
==
't'
:
return
f
'true /*a.seqlen_k %
{
self
.
bn0
}
!= 0*/'
# TODO: order of get_pipelines() matters! (ugly)
else
:
return
f
'a.seqlen_k %
{
self
.
bn0
}
== 0'
else
:
assert
False
@
property
def
dcheck
(
self
)
->
str
:
if
self
.
pipeline_tag
==
'qr_async'
:
vec
=
int
((
32
*
4
)
/
DTYPE_BITS
[
self
.
dtype
])
if
self
.
dpad
==
't'
:
return
f
'a.hdim_q %
{
vec
}
== 0'
else
:
assert
False
elif
self
.
pipeline_tag
in
[
'qr'
]:
if
self
.
dpad
==
't'
:
return
f
'true /*a.hdim_q %
{
self
.
bk0blen
}
!= 0*/'
# TODO: order of get_pipelines() matters! (ugly)
else
:
return
f
'a.hdim_q %
{
self
.
bk0blen
}
== 0'
else
:
assert
False
@
property
def
dvcheck
(
self
)
->
str
:
if
self
.
pipeline_tag
==
'qr_async'
:
vec
=
int
((
32
*
4
)
/
DTYPE_BITS
[
self
.
dtype
])
if
self
.
dvpad
==
't'
:
return
f
'a.hdim_v %
{
vec
}
== 0'
else
:
assert
False
elif
self
.
pipeline_tag
in
[
'qr'
]:
if
self
.
dvpad
==
't'
:
return
f
'true /*a.hdim_v %
{
self
.
bk0blen
}
!= 0*/'
# TODO: order of get_pipelines() matters! (ugly)
else
:
return
f
'a.hdim_v %
{
self
.
bk0blen
}
== 0'
else
:
assert
False
@
dataclass
class
FmhaFwdPipeline
:
tag
:
str
F_vlayout
:
str
# row/col
F_spad
:
str
# true/false
F_skpad
:
str
#
F_dpad
:
str
#
F_dvpad
:
str
#
F_bias
:
str
# true/false
F_lse
:
str
#
F_dropout
:
str
#
F_squant
:
str
#
F_mask
:
str
# value from MASK_MAP
@
property
def
name
(
self
)
->
str
:
def
pad_name
()
->
str
:
n
=
''
if
self
.
F_spad
==
't'
:
n
+=
's'
if
self
.
F_skpad
==
't'
:
n
+=
'sk'
if
self
.
F_dpad
==
't'
:
n
+=
'd'
if
self
.
F_dvpad
==
't'
:
n
+=
'dv'
if
n
!=
''
:
n
=
'p'
+
n
return
n
pn
=
pad_name
()
n
=
f
'
{
self
.
tag
}
_v
{
self
.
F_vlayout
[
0
]
}
'
if
pn
!=
''
:
n
+=
f
'_
{
pn
}
'
if
self
.
F_bias
!=
'no'
:
n
+=
f
'_
{
self
.
F_bias
}
'
if
self
.
F_mask
[
0
:
2
]
==
's_'
:
if
self
.
F_mask
==
's_mask'
:
n
+=
f
'_mask'
else
:
if
self
.
F_mask
!=
'no'
:
n
+=
f
'_m
{
self
.
F_mask
[
0
]
}
'
if
self
.
F_lse
==
't'
:
n
+=
'_lse'
if
self
.
F_dropout
==
't'
:
n
+=
'_dropout'
if
self
.
F_squant
==
't'
:
n
+=
'_squant'
return
n
class
FmhaFwdApiPool
:
def
__init__
(
self
,
mask_impl
):
self
.
pool
=
dict
()
self
.
mask_impl
=
mask_impl
def
register_traits
(
self
,
trait
:
FmhaFwdApiTrait
)
->
None
:
# TODO: do we need to check duplication?
if
trait
.
dtype
not
in
self
.
pool
.
keys
():
self
.
pool
[
trait
.
dtype
]
=
dict
()
if
trait
.
hdim
not
in
self
.
pool
[
trait
.
dtype
].
keys
():
self
.
pool
[
trait
.
dtype
][
trait
.
hdim
]
=
list
()
self
.
pool
[
trait
.
dtype
][
trait
.
hdim
].
append
(
copy
.
copy
(
trait
))
@
property
def
api
(
self
)
->
str
:
per_dtypes
=
str
()
for
i
,
dtype
in
enumerate
(
self
.
pool
.
keys
()):
per_hdim_case
=
str
()
for
j
,
hdim
in
enumerate
(
self
.
pool
[
dtype
].
keys
()):
traits
=
self
.
pool
[
dtype
][
hdim
]
inners
=
str
()
for
k
,
trait
in
enumerate
(
traits
):
if_k
=
'if'
if
k
==
0
else
'else if'
inners
=
inners
+
FMHA_FWD_API_INNER_DISPATCH
.
format
(
F_if
=
if_k
,
F_mode
=
MODE_MAP
[
trait
.
mode
],
F_vlayout
=
LAYOUT_MAP
[
trait
.
vlayout
],
F_pipeline_enum
=
PIPELINE_ENUM_MAP
[
trait
.
pipeline_tag
],
F_mask
=
get_mask_map
(
self
.
mask_impl
)[
trait
.
mask
],
F_mask_check
=
get_mask_check_map
(
self
.
mask_impl
)[
trait
.
mask
],
F_bias_check
=
BIAS_CHECK_MAP
[
trait
.
bias
],
F_bias
=
BIAS_MAP
[
trait
.
bias
],
F_lse
=
BOOL_MAP
[
trait
.
lse
],
F_dropout
=
BOOL_MAP
[
trait
.
dropout
]
,
F_squant
=
BOOL_MAP
[
trait
.
squant
],
F_scheck
=
trait
.
scheck
,
F_skcheck
=
trait
.
skcheck
,
F_dcheck
=
trait
.
dcheck
,
F_dvcheck
=
trait
.
dvcheck
,
F_spad
=
BOOL_MAP
[
trait
.
spad
],
F_skpad
=
BOOL_MAP
[
trait
.
skpad
],
F_dpad
=
BOOL_MAP
[
trait
.
dpad
],
F_dvpad
=
BOOL_MAP
[
trait
.
dvpad
],
F_bm0
=
trait
.
bm0
,
F_bn0
=
trait
.
bn0
,
F_bk0
=
trait
.
bk0
,
F_bn1
=
trait
.
bn1
,
F_bk1
=
trait
.
bk1
,
F_bk0blen
=
trait
.
bk0blen
,
F_hdim
=
hdim
,
F_dtype
=
DTYPE_MAP
[
dtype
])
if_j
=
'if'
if
j
==
0
else
'else if'
per_hdim_case
=
per_hdim_case
+
FMHA_FWD_API_PER_HDIM_CASE
.
format
(
F_if
=
if_j
,
F_hdim
=
hdim
,
F_inner_dispatch
=
inners
)
if_i
=
'if'
if
i
==
0
else
'else if'
per_dtypes
=
per_dtypes
+
FMHA_FWD_API_PER_DTYPE
.
format
(
F_if
=
if_i
,
F_dtype
=
dtype
,
F_hdim_case
=
per_hdim_case
)
return
FMHA_FWD_KERNEL_HEADER
+
FMHA_FWD_API
.
format
(
F_dispatch
=
per_dtypes
)
@
dataclass
class
FmhaFwdTileSize
:
F_bm0
:
int
# tile size along q seqlen (block size)
F_bn0
:
int
# tile size along k seqlen
F_bk0
:
int
# tile size along qk gemm unroll
F_bn1
:
int
# tile size along v head_dim
F_bk1
:
int
# tile size along kv gemm unroll
F_bk0blen
:
int
# total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
F_rm
:
int
# number of warps along q seqlen (block warps)
F_rn
:
int
# number of warps along k seqlen(not used)
F_rk
:
int
# number of warps along gemm-k(not used)
F_wm
:
int
# warp size along m (warp size)
F_wn
:
int
# warp size along n
F_wk
:
int
# warp size along k
F_occupancy
:
int
# occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
@
property
def
name
(
self
)
->
str
:
return
f
"b
{
self
.
F_bm0
}
x
{
self
.
F_bn0
}
x
{
self
.
F_bk0
}
x
{
self
.
F_bn1
}
x
{
self
.
F_bk1
}
x
{
self
.
F_bk0blen
}
"
+
\
f
"_r
{
self
.
F_rm
}
x
{
self
.
F_rn
}
x
{
self
.
F_rk
}
_w
{
self
.
F_wm
}
x
{
self
.
F_wn
}
x
{
self
.
F_wk
}
"
+
\
(
""
if
self
.
F_occupancy
==
-
1
else
f
"_o
{
self
.
F_occupancy
}
"
)
@
dataclass
class
FmhaFwdKernel
:
direction
:
str
F_idx
:
int
# this is not a tunable, but a counter to differentiate symbol
F_hdim
:
int
# hdim
F_dtype
:
str
# data type
F_mode
:
str
# value from MODE_MAP
F_tile
:
FmhaFwdTileSize
F_pipeline
:
FmhaFwdPipeline
mask_impl
:
str
def
get_tp
(
self
)
->
str
:
if
self
.
F_mode
==
'group'
:
return
'hbs'
else
:
return
'shb'
@
property
def
template
(
self
)
->
str
:
kernel_body
=
str
()
return
FMHA_FWD_KERNEL_HEADER
+
\
FMHA_FWD_KERNEL_BODY
.
format
(
F_idx
=
self
.
F_idx
,
F_hdim
=
self
.
F_hdim
,
F_dtype
=
DTYPE_MAP
[
self
.
F_dtype
],
F_bm0
=
self
.
F_tile
.
F_bm0
,
F_bn0
=
self
.
F_tile
.
F_bn0
,
F_bk0
=
self
.
F_tile
.
F_bk0
,
F_bn1
=
self
.
F_tile
.
F_bn1
,
F_bk1
=
self
.
F_tile
.
F_bk1
,
F_bk0blen
=
self
.
F_tile
.
F_bk0blen
,
F_rm
=
self
.
F_tile
.
F_rm
,
F_rn
=
self
.
F_tile
.
F_rn
,
F_rk
=
self
.
F_tile
.
F_rk
,
F_wm
=
self
.
F_tile
.
F_wm
,
F_wn
=
self
.
F_tile
.
F_wn
,
F_wk
=
self
.
F_tile
.
F_wk
,
F_vlayout
=
LAYOUT_MAP
[
self
.
F_pipeline
.
F_vlayout
],
F_spad
=
BOOL_MAP
[
self
.
F_pipeline
.
F_spad
],
F_skpad
=
BOOL_MAP
[
self
.
F_pipeline
.
F_skpad
],
F_dpad
=
BOOL_MAP
[
self
.
F_pipeline
.
F_dpad
],
F_dvpad
=
BOOL_MAP
[
self
.
F_pipeline
.
F_dvpad
],
F_bias
=
BIAS_MAP
[
self
.
F_pipeline
.
F_bias
],
F_lse
=
BOOL_MAP
[
self
.
F_pipeline
.
F_lse
],
F_dropout
=
BOOL_MAP
[
self
.
F_pipeline
.
F_dropout
],
F_squant
=
BOOL_MAP
[
self
.
F_pipeline
.
F_squant
],
F_occupancy
=
self
.
F_tile
.
F_occupancy
,
F_pipeline_enum
=
PIPELINE_ENUM_MAP
[
self
.
F_pipeline
.
tag
],
F_mask
=
get_mask_map
(
self
.
mask_impl
)[
self
.
F_pipeline
.
F_mask
],
F_mode
=
MODE_MAP
[
self
.
F_mode
],
F_pipeline
=
PIPELINE_MAP
[
self
.
F_pipeline
.
tag
],
F_tile_partitioner
=
TILE_PARTITIONER_MAP
[
self
.
get_tp
()])
@
property
def
name
(
self
)
->
str
:
# TODO: we don't encode idx here
return
f
"fmha_
{
self
.
direction
}
_d
{
self
.
F_hdim
}
_
{
self
.
F_dtype
}
_
{
self
.
F_mode
}
_
{
self
.
get_tp
()
}
_"
+
\
self
.
F_tile
.
name
+
'_'
+
self
.
F_pipeline
.
name
@
property
def
filename
(
self
)
->
str
:
return
self
.
name
+
".cpp"
def
api_trait
(
self
)
->
FmhaFwdApiTrait
:
return
FmhaFwdApiTrait
(
pipeline_tag
=
self
.
F_pipeline
.
tag
,
hdim
=
str
(
self
.
F_hdim
),
dtype
=
self
.
F_dtype
,
mode
=
self
.
F_mode
,
bm0
=
self
.
F_tile
.
F_bm0
,
bn0
=
self
.
F_tile
.
F_bn0
,
bk0
=
self
.
F_tile
.
F_bk0
,
bn1
=
self
.
F_tile
.
F_bn1
,
bk1
=
self
.
F_tile
.
F_bk1
,
bk0blen
=
self
.
F_tile
.
F_bk0blen
,
vlayout
=
self
.
F_pipeline
.
F_vlayout
,
mask
=
self
.
F_pipeline
.
F_mask
,
bias
=
self
.
F_pipeline
.
F_bias
,
lse
=
self
.
F_pipeline
.
F_lse
,
dropout
=
self
.
F_pipeline
.
F_dropout
,
squant
=
self
.
F_pipeline
.
F_squant
,
spad
=
self
.
F_pipeline
.
F_spad
,
skpad
=
self
.
F_pipeline
.
F_skpad
,
dpad
=
self
.
F_pipeline
.
F_dpad
,
dvpad
=
self
.
F_pipeline
.
F_dvpad
)
# TODO: design a more practical way to do it
# this is current supported tile size per hdim
def
get_fmha_fwd_tile_dict_from_dtype
(
direction
:
str
,
dtype
:
str
)
->
Optional
[
dict
]:
if
direction
==
'fwd'
:
if
dtype
==
'fp16'
or
dtype
==
'bf16'
:
return
{
'32'
:
FmhaFwdTileSize
(
128
,
64
,
16
,
32
,
32
,
32
,
2
,
1
,
1
,
32
,
32
,
16
,
-
1
),
'64'
:
FmhaFwdTileSize
(
128
,
64
,
32
,
64
,
32
,
64
,
4
,
1
,
1
,
32
,
32
,
16
,
-
1
),
'128'
:
FmhaFwdTileSize
(
128
,
128
,
32
,
128
,
32
,
128
,
4
,
1
,
1
,
32
,
32
,
16
,
-
1
),
'256'
:
FmhaFwdTileSize
(
128
,
128
,
32
,
256
,
32
,
256
,
4
,
1
,
1
,
32
,
32
,
16
,
-
1
),
}
elif
dtype
==
'fp8'
or
dtype
==
'bf8'
:
return
{
'64'
:
FmhaFwdTileSize
(
128
,
64
,
32
,
64
,
32
,
64
,
2
,
1
,
1
,
32
,
32
,
32
,
-
1
),
'128'
:
FmhaFwdTileSize
(
128
,
128
,
32
,
128
,
32
,
128
,
4
,
1
,
1
,
32
,
32
,
32
,
-
1
),
'256'
:
FmhaFwdTileSize
(
128
,
128
,
32
,
256
,
32
,
256
,
4
,
1
,
1
,
32
,
32
,
32
,
-
1
)
}
else
:
return
None
else
:
return
None
def
get_fwd_blobs
(
kernel_filter
:
Optional
[
str
],
receipt
,
mask_impl
)
->
Tuple
[
FmhaFwdApiPool
,
List
[
FmhaFwdKernel
]]:
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
# support this in future
def
get_pipelines
(
dtype
,
hdim
)
->
List
[
FmhaFwdPipeline
]:
# this function will populate a list possible pipelines
# TODO: the order of List matters! the later in this list will be also be checked later
# TODO: currently for qr pipeline, let 't' padding to appear later!!
# TODO: how to design this more generic?
squant
=
't'
if
dtype
==
'fp8'
else
'f'
pipelines
=
[]
if
dtype
in
[
'fp16'
,
'bf16'
]:
for
mask
,
bias
,
lse
,
dropout
in
itertools
.
product
(
get_mask_map
(
mask_impl
).
keys
(),
BIAS_MAP
.
keys
(),
[
"t"
,
"f"
],
[
"t"
,
"f"
]):
if
hdim
==
256
:
# if True:
pipelines
.
append
(
FmhaFwdPipeline
(
'qr'
,
'row'
,
'f'
,
'f'
,
'f'
,
'f'
,
bias
,
lse
,
dropout
,
squant
,
mask
))
pipelines
.
append
(
FmhaFwdPipeline
(
'qr'
,
'col'
,
'f'
,
'f'
,
'f'
,
'f'
,
bias
,
lse
,
dropout
,
squant
,
mask
))
pipelines
.
append
(
FmhaFwdPipeline
(
'qr'
,
'row'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
dropout
,
squant
,
mask
))
pipelines
.
append
(
FmhaFwdPipeline
(
'qr'
,
'col'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
dropout
,
squant
,
mask
))
else
:
pipelines
.
append
(
FmhaFwdPipeline
(
'qr_async'
,
'row'
,
't'
,
'f'
,
't'
,
't'
,
bias
,
lse
,
dropout
,
squant
,
mask
))
pipelines
.
append
(
FmhaFwdPipeline
(
'qr_async'
,
'row'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
dropout
,
squant
,
mask
))
pipelines
.
append
(
FmhaFwdPipeline
(
'qr_async'
,
'col'
,
't'
,
'f'
,
't'
,
't'
,
bias
,
lse
,
dropout
,
squant
,
mask
))
pipelines
.
append
(
FmhaFwdPipeline
(
'qr_async'
,
'col'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
dropout
,
squant
,
mask
))
if
receipt
==
1
:
pipelines
.
append
(
FmhaFwdPipeline
(
'qr'
,
'row'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
dropout
,
squant
,
mask
))
# TODO: cover arbitraty hdim
pipelines
.
append
(
FmhaFwdPipeline
(
'qr'
,
'col'
,
't'
,
'f'
,
't'
,
't'
,
bias
,
lse
,
dropout
,
squant
,
mask
))
# TODO: cover arbitraty hdim
elif
dtype
in
[
'fp8'
,
'bf8'
]:
# no need lse/dropout kernels
for
mask
,
bias
in
itertools
.
product
(
get_mask_map
(
mask_impl
).
keys
(),
BIAS_MAP
.
keys
()):
pipelines
.
append
(
FmhaFwdPipeline
(
'qr'
,
'col'
,
'f'
,
'f'
,
'f'
,
'f'
,
bias
,
'f'
,
'f'
,
squant
,
mask
))
else
:
assert
False
return
pipelines
from
codegen.cmake_config
import
*
from
codegen.ops
import
(
fmha_fwd
,
fmha_fwd_splitkv
,
fmha_bwd
)
gen
=
list
()
api_pool
=
FmhaFwdApiPool
(
mask_impl
)
for
direction
,
dtype
in
itertools
.
product
([
"fwd"
],
DTYPE_MAP
.
keys
()):
d
=
get_fmha_fwd_tile_dict_from_dtype
(
direction
,
dtype
)
if
d
==
None
:
continue
#for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
for
hdim_str
,
mode
in
itertools
.
product
(
d
.
keys
(),
MODE_MAP
.
keys
()):
tile
=
d
[
hdim_str
]
hdim
=
int
(
hdim_str
)
for
pipeline
in
get_pipelines
(
dtype
,
hdim
):
if
mode
==
"group"
:
if
pipeline
.
F_spad
!=
't'
or
pipeline
.
F_skpad
!=
't'
:
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
continue
k
=
FmhaFwdKernel
(
direction
=
direction
,
F_idx
=
0
,
F_hdim
=
hdim
,
F_dtype
=
dtype
,
F_mode
=
mode
,
F_tile
=
tile
,
F_pipeline
=
pipeline
,
mask_impl
=
mask_impl
)
if
kernel_filter
!=
None
:
if
not
fnmatch
.
fnmatch
(
k
.
name
,
kernel_filter
):
continue
if
receipt
==
2
:
cond
=
dtype
in
[
'fp16'
,
'bf16'
]
cond
&=
pipeline
.
F_vlayout
==
'row'
cond
&=
pipeline
.
F_bias
in
[
'no'
,
'alibi'
]
cond
&=
pipeline
.
F_squant
==
'f'
if
not
cond
:
continue
api_pool
.
register_traits
(
k
.
api_trait
())
gen
.
append
(
k
)
class
HandlerId
(
IntEnum
):
LIST_BLOBS
=
0
WRITE_BLOBS
=
1
return
(
api_pool
,
gen
)
BWD_DQDKDV_PIPELINE_MAP
=
{
"ks_kts_vr"
:
"ck_tile::BlockFmhaBwdDQDKDVPipelineKSKTSVR"
,
"qs_ks_vr_dos"
:
"ck_tile::BlockFmhaBwdDQDKDVPipelineQSKSVROGradS"
,
"ks_vr"
:
"ck_tile::BlockFmhaBwdDQDKDVPipelineKSVR"
,
}
BWD_DQDKDV_PIPELINE_ENUM_MAP
=
{
"ks_kts_vr"
:
"ck_tile::BlockFmhaBwdPipelineEnum::KSKTSVR"
,
"qs_ks_vr_dos"
:
"ck_tile::BlockFmhaBwdPipelineEnum::QSKSVROGradS"
,
"ks_vr"
:
"ck_tile::BlockFmhaBwdPipelineEnum::KSVR"
,
handlers
=
{
'fwd'
:
(
fmha_fwd
.
list_blobs
,
fmha_fwd
.
write_blobs
),
'fwd_splitkv'
:
(
fmha_fwd_splitkv
.
list_blobs
,
fmha_fwd_splitkv
.
write_blobs
),
'bwd'
:
(
fmha_bwd
.
list_blobs
,
fmha_bwd
.
write_blobs
),
}
FMHA_BWD_KERNEL_HEADER
=
"""// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
\n
// auto generated by generate.py
#include "fmha_bwd.hpp"
"""
FMHA_BWD_DQ_DK_DV_KERNEL_BODY
=
"""
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bk1}, {F_bk2}, {F_bk3}, {F_bk4}, {F_bhdq}, {F_bhdv}>;
using fmha_block_warps0_{F_idx} = ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>;
using fmha_block_warps1_{F_idx} = ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>;
using fmha_block_warps2_{F_idx} = ck_tile::sequence<{F_rm2}, {F_rn2}, {F_rk2}>;
using fmha_warp_tile_{F_idx} = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_{F_idx} = ck_tile::TileFmhaBwdShape<fmha_block_tile_{F_idx},
fmha_block_warps0_{F_idx},
fmha_warp_tile_{F_idx},
fmha_block_warps1_{F_idx},
fmha_warp_tile_{F_idx},
fmha_block_warps0_{F_idx},
fmha_warp_tile_{F_idx},
fmha_block_warps1_{F_idx},
fmha_warp_tile_{F_idx},
fmha_block_warps2_{F_idx},
fmha_warp_tile_{F_idx}>;
using fmha_bwd_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
{F_skpad},
{F_dpad},
{F_dvpad},
{F_bias},
{F_dbias},
false,
{F_dropout},
false,
{F_occupancy}>;
using fmha_mask_{F_idx} = {F_mask};
using fmha_bwd_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::KDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::VDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::GemmDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::LSEDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::BiasDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::RandValOutputDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::QGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::KGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::VGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::BiasGradDataType,
fmha_bwd_shape_{F_idx},
{F_mode},
fmha_mask_{F_idx},
fmha_bwd_trait_{F_idx}>;
using fmha_bwd_pipeline_{F_idx} = {F_pipeline}<
fmha_bwd_pipeline_problem_{F_idx}>;
using fmha_bwd_dk_epilogue_{F_idx} =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<{F_dtype}>::AccDataType,
typename FmhaBwdTypeConfig<{F_dtype}>::KGradDataType,
false, false>>;
using fmha_bwd_dv_epilogue_{F_idx} =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<{F_dtype}>::AccDataType,
typename FmhaBwdTypeConfig<{F_dtype}>::VGradDataType,
false, false>>;
using fmha_bwd_dq_dk_dv_kernel_{F_idx} =
ck_tile::FmhaBwdDQDKDVKernel<ck_tile::FmhaBwdTilePartitioner<fmha_bwd_shape_{F_idx}>,
fmha_bwd_pipeline_{F_idx},
fmha_bwd_dk_epilogue_{F_idx},
fmha_bwd_dv_epilogue_{F_idx}>;
using dq_dk_dv_trait_{F_idx} = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_dbias}, {F_dropout}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
#include <iostream>
template<>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
{{
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
}}
template<>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
{{
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}});
}}
template<>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_{F_idx}>()
{{
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
return k_::GetName();
}}
"""
FMHA_BWD_API_FILENAME
=
"fmha_bwd_api.cpp"
FMHA_BWD_API
=
"""
#include <iostream>
template<typename dot_do_o_trait_, typename dq_dk_dv_trait_>
float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a)
{{
if(s.log_level_ > 0)
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_>() << std::flush;
return ck_tile::launch_kernel(s,
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_>(s_, a); }}
);
}}
float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s){{
float r = -1;
{F_dispatch}
return r;
}}
"""
FMHA_BWD_API_PER_DTYPE
=
""" {F_if}(t.data_type.compare(
\"
{F_dtype}
\"
) == 0){{
{F_hdim_case}
}}
"""
FMHA_BWD_API_PER_HDIM_CASE
=
""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim}) {{
{F_inner_dispatch}
}}
"""
FMHA_BWD_API_INNER_DISPATCH
=
""" {F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && (t.has_dropout == {F_dropout}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_dbias}, {F_dropout}, {F_spad0}, {F_skpad}, {F_dpad}, {F_dvpad}>;
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1}, {F_dvpad}>;
r = fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_>(s, a);
return r;
}}
"""
@
dataclass
class
FmhaBwdDQDKDVApiTrait
:
pipeline
:
str
# sync with fmha_bwd_traits<>, to generate fallback calls
hdim
:
str
dtype
:
str
# data type
mode
:
str
# value from MODE_MAP
bm0
:
int
# tile size along q seqlen (block size)
bn0
:
int
# tile size along k seqlen
bhdq
:
int
# q head_dim
bhdv
:
int
# v head_dim
mask
:
str
bias
:
str
dbias
:
str
dropout
:
str
spad
:
str
skpad
:
str
dpad
:
str
dvpad
:
str
@
property
def
name
(
self
)
->
str
:
return
f
'
{
self
.
pipeline
}
-
{
self
.
hdim
}
-
{
self
.
dtype
}
-
{
self
.
mode
}
-
{
self
.
mask
}
-
{
self
.
bias
}
-
{
self
.
dbias
}
-
{
self
.
dropout
}
-
{
self
.
spad
}
-
{
self
.
skpad
}
-
{
self
.
dpad
}
-
{
self
.
dvpad
}
'
def
scheck
(
self
,
spad1
:
str
)
->
str
:
if
self
.
mode
==
'group'
:
return
'true'
# always support
elif
self
.
spad
==
't'
and
spad1
==
't'
:
return
f
'a.seqlen_q %
{
self
.
bm0
}
!= 0'
elif
self
.
spad
==
'f'
and
spad1
==
't'
:
return
f
'a.seqlen_q %
{
self
.
bm0
}
== 0 and a.seqlen_q % 256 != 0'
# BlockSize
else
:
# self.skpad == 'f' and skpad1 == 'f'
return
f
'a.seqlen_q % 256 == 0'
# BlockSize
@
property
def
skcheck
(
self
)
->
str
:
if
self
.
mode
==
'group'
:
return
'true'
# always support
elif
self
.
skpad
==
't'
:
return
f
'a.seqlen_k %
{
self
.
bn0
}
!= 0'
else
:
return
f
'a.seqlen_k %
{
self
.
bn0
}
== 0'
@
property
def
dcheck
(
self
)
->
str
:
if
self
.
dpad
==
't'
:
return
f
'a.hdim_q %
{
self
.
bhdq
}
!= 0'
else
:
return
f
'a.hdim_q %
{
self
.
bhdq
}
== 0'
@
property
def
dvcheck
(
self
)
->
str
:
if
self
.
dvpad
==
't'
:
return
f
'a.hdim_v %
{
self
.
bhdv
}
!= 0'
else
:
return
f
'a.hdim_v %
{
self
.
bhdv
}
== 0'
class
FmhaBwdApiPool
:
def
__init__
(
self
,
mask_impl
):
self
.
dq_dk_dv_pool
=
dict
()
self
.
mask_impl
=
mask_impl
def
register_dq_dk_dv_traits
(
self
,
trait
:
FmhaBwdDQDKDVApiTrait
)
->
None
:
# TODO: do we need to check duplication?
if
trait
.
dtype
not
in
self
.
dq_dk_dv_pool
.
keys
():
self
.
dq_dk_dv_pool
[
trait
.
dtype
]
=
dict
()
if
trait
.
hdim
not
in
self
.
dq_dk_dv_pool
[
trait
.
dtype
].
keys
():
self
.
dq_dk_dv_pool
[
trait
.
dtype
][
trait
.
hdim
]
=
list
()
self
.
dq_dk_dv_pool
[
trait
.
dtype
][
trait
.
hdim
].
append
(
copy
.
copy
(
trait
))
@
property
def
api
(
self
)
->
str
:
per_dtypes
=
str
()
for
i
,
dtype
in
enumerate
(
self
.
dq_dk_dv_pool
.
keys
()):
per_hdim_case
=
str
()
for
j
,
hdim
in
enumerate
(
self
.
dq_dk_dv_pool
[
dtype
].
keys
()):
traits
=
self
.
dq_dk_dv_pool
[
dtype
][
hdim
]
inners
=
str
()
for
k
,
trait
in
enumerate
(
traits
):
if_k
=
'if'
if
k
==
0
else
'else if'
for
spad1
in
[
"t"
,
"f"
]:
if
((
spad1
==
"f"
and
trait
.
spad
==
"t"
)
or
(
trait
.
mode
==
"group"
and
spad1
==
"f"
)):
continue
inners
=
inners
+
FMHA_BWD_API_INNER_DISPATCH
.
format
(
F_if
=
if_k
,
F_mode
=
MODE_MAP
[
trait
.
mode
],
F_mask
=
get_mask_map
(
self
.
mask_impl
)[
trait
.
mask
],
F_pipeline_enum
=
BWD_DQDKDV_PIPELINE_ENUM_MAP
[
trait
.
pipeline
],
F_mask_check
=
get_mask_check_map
(
self
.
mask_impl
)[
trait
.
mask
],
F_bias_check
=
BIAS_CHECK_MAP
[
trait
.
bias
],
F_bias
=
BIAS_MAP
[
trait
.
bias
],
F_dbias
=
BOOL_MAP
[
trait
.
dbias
],
F_dropout
=
BOOL_MAP
[
trait
.
dropout
],
F_scheck
=
trait
.
scheck
(
spad1
=
spad1
),
F_skcheck
=
trait
.
skcheck
,
F_dcheck
=
trait
.
dcheck
,
F_dvcheck
=
trait
.
dvcheck
,
F_hdim
=
hdim
,
F_dtype
=
DTYPE_MAP
[
dtype
],
F_spad0
=
BOOL_MAP
[
trait
.
spad
],
F_spad1
=
BOOL_MAP
[
spad1
],
F_skpad
=
BOOL_MAP
[
trait
.
skpad
],
F_dpad
=
BOOL_MAP
[
trait
.
dpad
],
F_dvpad
=
BOOL_MAP
[
trait
.
dvpad
])
if_j
=
'if'
if
j
==
0
else
'else if'
per_hdim_case
=
per_hdim_case
+
FMHA_BWD_API_PER_HDIM_CASE
.
format
(
F_if
=
if_j
,
F_hdim
=
hdim
,
F_inner_dispatch
=
inners
)
if_i
=
'if'
if
i
==
0
else
'else if'
per_dtypes
=
per_dtypes
+
FMHA_BWD_API_PER_DTYPE
.
format
(
F_if
=
if_i
,
F_dtype
=
dtype
,
F_hdim_case
=
per_hdim_case
)
return
FMHA_BWD_KERNEL_HEADER
+
FMHA_BWD_API
.
format
(
F_dispatch
=
per_dtypes
)
# GEMM0: Q@K=S^T
# GEMM1: P^T@dO^T=dV(This was chosen as G1 to match fwd, but N1 must be equal to headdim_v)
# GEMM2: dO@V=dP^T(This was chosen as G2 because of the calculation order)
# GEMM3: dS^T@Q^T=dK(Similar to G1, but N3 must be equal to headdim_qk)
# GEMM4: dS@K^T=dQ(N4 must be equal to headdim_qk)
# Is it necessary to distinguish between K0~K4?
@
dataclass
class
FmhaBwdDQDKDVTileSize
:
F_bm0
:
int
# tile size along q seqlen (block size)
F_bn0
:
int
# tile size along k seqlen
F_bk0
:
int
# tile size along gemm0 unroll(F_bhdq)
F_bk1
:
int
# tile size along gemm1 unroll(F_bm0)
F_bk2
:
int
# tile size along gemm2 unroll(F_bhdv)
F_bk3
:
int
# tile size along gemm3 unroll(F_bm0)
F_bk4
:
int
# tile size along gemm4 unroll(F_bn0)
F_bhdq
:
int
# q head_dim
F_bhdv
:
int
# v head_dim
F_rm0
:
int
# number of warps along q seqlen (block warps) in gemm0/gemm2
F_rn0
:
int
# number of warps along k seqlen (block warps) in gemm0/gemm2
F_rk0
:
int
# number of warps along gemm-k (not used) in gemm0/gemm2
F_rm1
:
int
# number of warps along k seqlen (block warps) in gemm1/gemm3
F_rn1
:
int
# number of warps along q seqlen (block warps) in gemm1/gemm3
F_rk1
:
int
# number of warps along gemm-k (not used) in gemm1/gemm3
F_rm2
:
int
# number of warps along k seqlen (block warps) in gemm4
F_rn2
:
int
# number of warps along q seqlen (block warps) in gemm4
F_rk2
:
int
# number of warps along gemm-k (not used) in gemm4
F_wm
:
int
# warp size along m (warp size)
F_wn
:
int
# warp size along n
F_wk
:
int
# warp size along k
F_occupancy
:
int
# occupancy
@
property
def
name
(
self
)
->
str
:
return
f
"b
{
self
.
F_bm0
}
x
{
self
.
F_bn0
}
x
{
self
.
F_bk0
}
x
{
self
.
F_bk1
}
x
{
self
.
F_bk2
}
x
{
self
.
F_bk3
}
x
{
self
.
F_bk4
}
x
{
self
.
F_bhdq
}
x
{
self
.
F_bhdv
}
"
+
\
f
"_r
{
self
.
F_rm0
}
x
{
self
.
F_rn0
}
x
{
self
.
F_rk0
}
_r
{
self
.
F_rm1
}
x
{
self
.
F_rn1
}
x
{
self
.
F_rk1
}
_r
{
self
.
F_rm2
}
x
{
self
.
F_rn2
}
x
{
self
.
F_rk2
}
"
+
\
f
"_w
{
self
.
F_wm
}
x
{
self
.
F_wn
}
x
{
self
.
F_wk
}
_o
{
self
.
F_occupancy
}
"
@
dataclass
class
FmhaBwdDQDKDVKernel
:
direction
:
str
F_idx
:
int
# this is not a tunable, but a counter to differentiate symbol
F_hdim
:
int
# hdim
F_dtype
:
str
# data type
F_tile
:
FmhaBwdDQDKDVTileSize
F_spad
:
str
# true/false
F_skpad
:
str
#
F_dpad
:
str
#
F_dvpad
:
str
#
F_bias
:
str
#
F_dbias
:
str
#
F_dropout
:
str
#
F_mask
:
str
# value from MASK_MAP
F_mode
:
str
# value from MODE_MAP
F_pipeline
:
str
mask_impl
:
str
@
property
def
template
(
self
)
->
str
:
return
FMHA_BWD_KERNEL_HEADER
+
\
FMHA_BWD_DQ_DK_DV_KERNEL_BODY
.
format
(
F_idx
=
self
.
F_idx
,
F_hdim
=
self
.
F_hdim
,
F_dtype
=
DTYPE_MAP
[
self
.
F_dtype
],
F_bm0
=
self
.
F_tile
.
F_bm0
,
F_bn0
=
self
.
F_tile
.
F_bn0
,
F_bk0
=
self
.
F_tile
.
F_bk0
,
F_bk1
=
self
.
F_tile
.
F_bk1
,
F_bk2
=
self
.
F_tile
.
F_bk2
,
F_bk3
=
self
.
F_tile
.
F_bk3
,
F_bk4
=
self
.
F_tile
.
F_bk4
,
F_bhdq
=
self
.
F_tile
.
F_bhdq
,
F_bhdv
=
self
.
F_tile
.
F_bhdv
,
F_rm0
=
self
.
F_tile
.
F_rm0
,
F_rn0
=
self
.
F_tile
.
F_rn0
,
F_rk0
=
self
.
F_tile
.
F_rk0
,
F_rm1
=
self
.
F_tile
.
F_rm1
,
F_rn1
=
self
.
F_tile
.
F_rn1
,
F_rk1
=
self
.
F_tile
.
F_rk1
,
F_rm2
=
self
.
F_tile
.
F_rm2
,
F_rn2
=
self
.
F_tile
.
F_rn2
,
F_rk2
=
self
.
F_tile
.
F_rk2
,
F_wm
=
self
.
F_tile
.
F_wm
,
F_wn
=
self
.
F_tile
.
F_wn
,
F_wk
=
self
.
F_tile
.
F_wk
,
F_spad
=
BOOL_MAP
[
self
.
F_spad
],
F_skpad
=
BOOL_MAP
[
self
.
F_skpad
],
F_dpad
=
BOOL_MAP
[
self
.
F_dpad
],
F_dvpad
=
BOOL_MAP
[
self
.
F_dvpad
],
F_bias
=
BIAS_MAP
[
self
.
F_bias
],
F_dbias
=
BOOL_MAP
[
self
.
F_dbias
],
F_dropout
=
BOOL_MAP
[
self
.
F_dropout
],
F_occupancy
=
self
.
F_tile
.
F_occupancy
,
F_mask
=
get_mask_map
(
self
.
mask_impl
)[
self
.
F_mask
],
F_mode
=
MODE_MAP
[
self
.
F_mode
],
F_pipeline_enum
=
BWD_DQDKDV_PIPELINE_ENUM_MAP
[
self
.
F_pipeline
],
F_pipeline
=
BWD_DQDKDV_PIPELINE_MAP
[
self
.
F_pipeline
])
@
property
def
name
(
self
)
->
str
:
def
pad_name
()
->
str
:
n
=
''
if
self
.
F_spad
==
't'
:
n
+=
's'
if
self
.
F_skpad
==
't'
:
n
+=
'sk'
if
self
.
F_dpad
==
't'
:
n
+=
'd'
if
self
.
F_dvpad
==
't'
:
n
+=
'dv'
if
n
!=
''
:
n
=
'p'
+
n
return
n
pn
=
pad_name
()
n
=
f
"fmha_
{
self
.
direction
}
_d
{
self
.
F_hdim
}
_
{
self
.
F_dtype
}
_
{
self
.
F_mode
}
_"
+
self
.
F_tile
.
name
if
pn
!=
''
:
n
+=
f
'_
{
pn
}
'
if
self
.
F_bias
!=
'no'
:
n
+=
f
'_
{
self
.
F_bias
}
'
if
self
.
F_dbias
==
't'
:
n
+=
'_dbias'
if
self
.
F_mask
[
0
:
2
]
==
's_'
:
if
self
.
F_mask
==
's_mask'
:
n
+=
f
'_mask'
else
:
if
self
.
F_mask
!=
'no'
:
n
+=
f
'_m
{
self
.
F_mask
[
0
]
}
'
if
self
.
F_dropout
==
't'
:
n
+=
'_dropout'
return
n
@
property
def
filename
(
self
)
->
str
:
return
self
.
name
+
".cpp"
def
api_trait
(
self
)
->
FmhaBwdDQDKDVApiTrait
:
return
FmhaBwdDQDKDVApiTrait
(
pipeline
=
self
.
F_pipeline
,
hdim
=
str
(
self
.
F_hdim
),
dtype
=
self
.
F_dtype
,
mode
=
self
.
F_mode
,
bm0
=
self
.
F_tile
.
F_bm0
,
bn0
=
self
.
F_tile
.
F_bn0
,
bhdq
=
self
.
F_tile
.
F_bhdq
,
bhdv
=
self
.
F_tile
.
F_bhdv
,
mask
=
self
.
F_mask
,
bias
=
self
.
F_bias
,
dbias
=
self
.
F_dbias
,
dropout
=
self
.
F_dropout
,
spad
=
self
.
F_spad
,
skpad
=
self
.
F_skpad
,
dpad
=
self
.
F_dpad
,
dvpad
=
self
.
F_dvpad
)
# TODO: design a more practical way to do it
# this is current supported tile size & pipeline.
def
get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype
(
direction
:
str
,
dtype
:
str
)
->
Optional
[
dict
]:
if
direction
==
'bwd'
:
if
dtype
==
'fp16'
or
dtype
==
'bf16'
:
return
{
'32'
:
[
FmhaBwdDQDKDVTileSize
(
128
,
128
,
32
,
32
,
32
,
32
,
32
,
32
,
32
,
1
,
4
,
1
,
4
,
1
,
1
,
4
,
1
,
1
,
32
,
32
,
16
,
1
),
"qs_ks_vr_dos"
],
'64'
:
[
FmhaBwdDQDKDVTileSize
(
64
,
128
,
32
,
32
,
32
,
32
,
32
,
64
,
64
,
1
,
4
,
1
,
4
,
1
,
1
,
2
,
2
,
1
,
32
,
32
,
16
,
1
),
"qs_ks_vr_dos"
],
'128'
:
[
FmhaBwdDQDKDVTileSize
(
64
,
128
,
32
,
32
,
32
,
32
,
32
,
128
,
128
,
1
,
4
,
1
,
4
,
1
,
1
,
2
,
2
,
1
,
32
,
32
,
16
,
1
),
"ks_vr"
]
}
else
:
return
None
else
:
return
None
def
get_bwd_dq_dk_dv_blobs
(
kernel_filter
:
Optional
[
str
],
receipt
,
mask_impl
)
->
Tuple
[
FmhaBwdApiPool
,
List
[
FmhaBwdDQDKDVKernel
]]:
# TODO: we don't support tuning yet, so pick up one value for pad
# support this in future
gen
=
list
()
api_pool
=
FmhaBwdApiPool
(
mask_impl
)
for
direction
,
dtype
in
itertools
.
product
([
"bwd"
],
DTYPE_MAP
.
keys
()):
d
=
get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype
(
direction
,
dtype
)
if
d
==
None
:
continue
for
hdim_str
,
mode
,
mask
,
bias
,
dbias
,
dropout
,
spad
,
skpad
,
dpad
,
dvpad
in
itertools
.
product
(
d
.
keys
(),
MODE_MAP
.
keys
(),
get_mask_map
(
mask_impl
).
keys
(),
BIAS_MAP
.
keys
(),
[
"t"
,
"f"
],
[
"t"
,
"f"
],
[
"t"
,
"f"
],
[
"t"
,
"f"
],
[
"t"
,
"f"
],
[
"t"
,
"f"
]):
tile
=
d
[
hdim_str
][
0
]
ppl
=
d
[
hdim_str
][
1
]
hdim
=
int
(
hdim_str
)
if
(
mode
==
"group"
)
and
(
spad
==
"f"
or
skpad
==
"f"
):
continue
if
((
bias
==
"no"
or
bias
==
"alibi"
)
and
dbias
==
"t"
):
continue
k
=
FmhaBwdDQDKDVKernel
(
direction
=
direction
,
F_idx
=
0
,
F_hdim
=
hdim
,
F_dtype
=
dtype
,
F_tile
=
tile
,
F_spad
=
spad
,
F_skpad
=
skpad
,
F_dpad
=
dpad
,
F_dvpad
=
dvpad
,
F_bias
=
bias
,
F_dbias
=
dbias
,
F_dropout
=
dropout
,
F_mask
=
mask
,
F_mode
=
mode
,
F_pipeline
=
ppl
,
mask_impl
=
mask_impl
)
if
kernel_filter
!=
None
:
if
not
fnmatch
.
fnmatch
(
k
.
name
,
kernel_filter
):
continue
if
receipt
==
2
:
cond
=
dtype
in
[
'fp16'
,
'bf16'
]
cond
&=
bias
in
[
'no'
,
'alibi'
]
if
not
cond
:
continue
api_pool
.
register_dq_dk_dv_traits
(
k
.
api_trait
())
gen
.
append
(
k
)
return
(
api_pool
,
gen
)
FMHA_BWD_DOT_DO_O_KERNEL_BODY
=
"""
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_bwd_dot_do_o_trait_{F_idx} = ck_tile::TileFmhaBwdOGradDotOTraits<{F_spad},
{F_dvpad},
{F_occupancy}>;
using fmha_bwd_dot_do_o_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::DDataType,
/* BlockSize = */ 256,
{F_hdim},
{F_mode},
fmha_bwd_dot_do_o_trait_{F_idx}>;
using fmha_bwd_dot_do_o_{F_idx} = typename ck_tile::BlockFmhaBwdOGradDotO<
fmha_bwd_dot_do_o_pipeline_problem_{F_idx}>;
using fmha_bwd_dot_do_o_kernel_{F_idx} =
ck_tile::FmhaBwdOGradDotOKernel<ck_tile::FmhaBwdOGradDotOTilePartitioner</* BlockSize = */ 256>,
fmha_bwd_dot_do_o_{F_idx}>;
using dot_do_o_trait_{F_idx} = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad}, {F_dvpad}>;
#include <iostream>
template<>
float fmha_bwd_dot_do_o_<dot_do_o_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
{{
using k_ = fmha_bwd_dot_do_o_kernel_{F_idx};
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
}}
template<>
void fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
{{
using k_ = fmha_bwd_dot_do_o_kernel_{F_idx};
auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}});
}}
template<>
std::string fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_{F_idx}>()
{{
using k_ = fmha_bwd_dot_do_o_kernel_{F_idx};
return k_::GetName();
}}
"""
@
dataclass
class
FmhaBwdOGradDotOKernel
:
direction
:
str
F_idx
:
int
# this is not a tunable, but a counter to differentiate symbol
F_hdim
:
int
# hdim
F_dtype
:
str
# data type
F_spad
:
str
# true/false
F_dvpad
:
str
#
F_mode
:
str
# value from MODE_MAP
F_occupancy
:
int
@
property
def
template
(
self
)
->
str
:
return
FMHA_BWD_KERNEL_HEADER
+
\
FMHA_BWD_DOT_DO_O_KERNEL_BODY
.
format
(
F_idx
=
self
.
F_idx
,
F_hdim
=
self
.
F_hdim
,
F_dtype
=
DTYPE_MAP
[
self
.
F_dtype
],
F_spad
=
BOOL_MAP
[
self
.
F_spad
],
F_dvpad
=
BOOL_MAP
[
self
.
F_dvpad
],
F_mode
=
MODE_MAP
[
self
.
F_mode
],
F_occupancy
=
self
.
F_occupancy
)
@
property
def
name
(
self
)
->
str
:
def
pad_name
()
->
str
:
n
=
''
if
self
.
F_spad
==
't'
:
n
+=
's'
if
self
.
F_dvpad
==
't'
:
n
+=
'dv'
if
n
!=
''
:
n
=
'p'
+
n
return
n
pn
=
pad_name
()
n
=
f
"fmha_
{
self
.
direction
}
_d
{
self
.
F_hdim
}
_
{
self
.
F_dtype
}
_
{
self
.
F_mode
}
_o
{
self
.
F_occupancy
}
"
if
pn
!=
''
:
n
+=
f
'_
{
pn
}
'
return
n
@
property
def
filename
(
self
)
->
str
:
return
self
.
name
+
".cpp"
def
get_bwd_dot_do_o_blobs
()
->
List
[
FmhaBwdOGradDotOKernel
]:
# TODO: we don't support tuning yet, so pick up one value for pad/occupancy
# support this in future
def
get_occupancy
(
dtype
,
hdim
):
return
2
gen
=
list
()
for
direction
,
dtype
in
itertools
.
product
([
"bwd"
],
DTYPE_MAP
.
keys
()):
d
=
get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype
(
direction
,
dtype
)
if
d
==
None
:
continue
for
hdim_str
,
mode
,
spad
,
dvpad
in
itertools
.
product
(
d
.
keys
(),
MODE_MAP
.
keys
(),
[
"t"
,
"f"
],
[
"t"
,
"f"
]):
hdim
=
int
(
hdim_str
)
if
(
mode
==
"group"
and
spad
==
"f"
):
continue
k
=
FmhaBwdOGradDotOKernel
(
direction
=
direction
+
"_dot_do_o"
,
F_idx
=
0
,
F_hdim
=
hdim
,
F_dtype
=
dtype
,
F_spad
=
spad
,
F_dvpad
=
dvpad
,
F_mode
=
mode
,
F_occupancy
=
get_occupancy
(
dtype
,
hdim
))
gen
.
append
(
k
)
return
gen
def
write_single_fwd_kernel
(
kernel
:
FmhaFwdKernel
,
autogen_dir
:
Path
)
->
None
:
(
autogen_dir
/
kernel
.
filename
).
write_text
(
kernel
.
template
)
def
write_fwd_api
(
api_pool
:
FmhaFwdApiPool
,
autogen_dir
:
Path
)
->
None
:
(
autogen_dir
/
FMHA_FWD_API_FILENAME
).
write_text
(
api_pool
.
api
)
def
write_single_bwd_dq_dk_dv_kernel
(
kernel
:
FmhaBwdDQDKDVKernel
,
autogen_dir
:
Path
)
->
None
:
(
autogen_dir
/
kernel
.
filename
).
write_text
(
kernel
.
template
)
def
write_single_bwd_dot_do_o_kernel
(
kernel
:
FmhaBwdOGradDotOKernel
,
autogen_dir
:
Path
)
->
None
:
(
autogen_dir
/
kernel
.
filename
).
write_text
(
kernel
.
template
)
def
write_bwd_api
(
api_pool
:
FmhaBwdApiPool
,
autogen_dir
:
Path
)
->
None
:
(
autogen_dir
/
FMHA_BWD_API_FILENAME
).
write_text
(
api_pool
.
api
)
def
write_blobs
(
output_dir
:
Optional
[
str
],
direction
:
str
,
kernel_filter
:
Optional
[
str
],
receipt
,
mask_impl
)
->
None
:
def
write_blobs
(
output_dir
:
Optional
[
str
],
api_list
:
List
[
str
],
kernel_filter
:
Optional
[
str
],
receipt
,
mask_impl
)
->
None
:
if
output_dir
is
None
:
output_dir
=
Path
(
__file__
).
parent
else
:
output_dir
=
Path
(
output_dir
)
/
GEN_DIR
output_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
if
direction
==
'fwd'
:
api_pool
,
kernels
=
get_fwd_blobs
(
kernel_filter
,
receipt
,
mask_impl
)
for
kernel
in
kernels
:
write_single_fwd_kernel
(
kernel
,
output_dir
)
write_fwd_api
(
api_pool
,
output_dir
)
else
:
kernels
=
get_bwd_dot_do_o_blobs
()
for
kernel
in
kernels
:
write_single_bwd_dot_do_o_kernel
(
kernel
,
output_dir
)
api_pool
,
kernels
=
get_bwd_dq_dk_dv_blobs
(
kernel_filter
,
receipt
,
mask_impl
)
for
kernel
in
kernels
:
write_single_bwd_dq_dk_dv_kernel
(
kernel
,
output_dir
)
write_bwd_api
(
api_pool
,
output_dir
)
for
api
in
api_list
:
handler
=
handlers
[
api
][
HandlerId
.
WRITE_BLOBS
]
handler
(
output_dir
,
kernel_filter
,
receipt
,
mask_impl
)
# list all the files that will be generated
def
list_blobs
(
output_file
:
Optional
[
str
],
direction
:
str
,
kernel_filter
:
Optional
[
str
],
receipt
,
mask_impl
)
->
None
:
def
list_blobs
(
output_file
:
Optional
[
str
],
api_list
:
List
[
str
]
,
kernel_filter
:
Optional
[
str
],
receipt
,
mask_impl
)
->
None
:
assert
output_file
is
not
None
file_path
=
Path
(
output_file
)
with
file_path
.
open
(
'a'
)
as
f
:
if
direction
==
'fwd'
:
_
,
kernels
=
get_fwd_blobs
(
kernel_filter
,
receipt
,
mask_impl
)
for
kernel
in
kernels
:
f
.
write
(
str
(
file_path
.
parent
/
GEN_DIR
/
kernel
.
filename
)
+
"
\n
"
)
f
.
write
(
str
(
file_path
.
parent
/
GEN_DIR
/
FMHA_FWD_API_FILENAME
)
+
"
\n
"
)
else
:
kernels
=
get_bwd_dot_do_o_blobs
()
for
kernel
in
kernels
:
f
.
write
(
str
(
file_path
.
parent
/
GEN_DIR
/
kernel
.
filename
)
+
"
\n
"
)
_
,
kernels
=
get_bwd_dq_dk_dv_blobs
(
kernel_filter
,
receipt
,
mask_impl
)
for
kernel
in
kernels
:
f
.
write
(
str
(
file_path
.
parent
/
GEN_DIR
/
kernel
.
filename
)
+
"
\n
"
)
f
.
write
(
str
(
file_path
.
parent
/
GEN_DIR
/
FMHA_BWD_API_FILENAME
)
+
"
\n
"
)
for
api
in
api_list
:
handler
=
handlers
[
api
][
HandlerId
.
LIST_BLOBS
]
handler
(
file_path
,
kernel_filter
,
receipt
,
mask_impl
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
prog
=
"generate"
,
description
=
"gen
api
for CK fmha kernel"
,
description
=
"gen
API
for CK fmha kernel"
,
)
parser
.
add_argument
(
"-d"
,
"--direction"
,
"--direction"
,
# we keep 'direction' option for backward compatibility
"-a"
,
"--api"
,
default
=
'fwd'
,
choices
=
[
'fwd'
,
'bwd'
],
required
=
False
,
help
=
"
choose the direction of kernels(default: fwd)
"
help
=
"
supply API(s) to generate (default: fwd). separated by comma.
"
)
parser
.
add_argument
(
"-o"
,
...
...
@@ -1251,7 +99,8 @@ if __name__ == "__main__":
)
args
=
parser
.
parse_args
()
api_list
=
args
.
direction
.
split
(
','
)
if
args
.
list_blobs
is
not
None
:
list_blobs
(
args
.
list_blobs
,
a
rgs
.
direction
,
args
.
filter
,
int
(
args
.
receipt
),
mask_impl
=
args
.
mask
)
list_blobs
(
args
.
list_blobs
,
a
pi_list
,
args
.
filter
,
int
(
args
.
receipt
),
mask_impl
=
args
.
mask
)
else
:
write_blobs
(
args
.
output_dir
,
a
rgs
.
direction
,
args
.
filter
,
int
(
args
.
receipt
),
mask_impl
=
args
.
mask
)
write_blobs
(
args
.
output_dir
,
a
pi_list
,
args
.
filter
,
int
(
args
.
receipt
),
mask_impl
=
args
.
mask
)
\ No newline at end of file
include/ck/ck.hpp
View file @
909f519c
...
...
@@ -69,6 +69,9 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__)
#define __gfx11__
#endif
#if defined(__gfx1200__) || defined(__gfx1201__)
#define __gfx12__
#endif
// buffer resource
#ifndef __HIP_DEVICE_COMPILE__ // for host code
...
...
@@ -77,7 +80,7 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000
#elif defined(__gfx103__)
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000
#elif defined(__gfx11__)
#elif defined(__gfx11__)
|| defined(__gfx12__)
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31004000
#endif
...
...
@@ -89,7 +92,7 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
#define CK_USE_AMD_V_FMAC_F32
#define CK_USE_AMD_V_DOT2_F32_F16
#define CK_USE_AMD_V_DOT4_I32_I8
#elif defined(__gfx11__)
#elif defined(__gfx11__)
|| defined(__gfx12__)
#define CK_USE_AMD_V_FMAC_F32
#define CK_USE_AMD_V_DOT2_F32_F16
#define CK_USE_AMD_V_DOT4_I32_I8_GFX11
...
...
@@ -110,13 +113,6 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
#define CK_USE_AMD_MFMA_GFX940
#endif
// WMMA instruction
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_USE_AMD_WMMA
#elif defined(__gfx11__) // for GPU code
#define CK_USE_AMD_WMMA
#endif
// buffer load
#define CK_USE_AMD_BUFFER_LOAD 1
...
...
include/ck/host_utility/device_prop.hpp
View file @
909f519c
...
...
@@ -84,4 +84,9 @@ inline bool is_gfx11_supported()
ck
::
get_device_name
()
==
"gfx1102"
||
ck
::
get_device_name
()
==
"gfx1103"
;
}
inline
bool
is_gfx12_supported
()
{
return
ck
::
get_device_name
()
==
"gfx1200"
||
ck
::
get_device_name
()
==
"gfx1201"
;
}
}
// namespace ck
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
View file @
909f519c
...
...
@@ -13,6 +13,504 @@
namespace
ck
{
#ifdef __gfx12__
template
<
index_t
BlockSize
,
typename
FloatA
,
typename
FloatB
,
typename
FloatAcc
,
typename
ABlockDesc
,
typename
BBlockDesc
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
MPerWMMA
,
index_t
NPerWMMA
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPack
,
bool
AEnableLds
=
true
,
bool
BEnableLds
=
true
,
bool
TransposeC
=
false
>
/* Option: Read from LDS, big buffer hold all threads required data
* Source
* A: K0PerBlock x MPerBlock x K1
* B: K0PerBlock x NPerBlock x K1
* Destination
* C, non-transpose
* thread level: MRepeat x NRepeat x MAccVgprs
* block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs
* KPACK == WMMA_K = 16
*
* Option: Read from VMEM, small buffer hold each thread own required data (Skip LDS)
* Source:
* A(if skip LDS): MRepeat x KPack
* B(if skip LDS): NRepeat x KPack
* Destination
* C, non-transpose
* block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs
*/
struct
BlockwiseGemmWMMA
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
WmmaK
=
Number
<
16
>
{};
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
// Hardcode of WaveSize, since current HIP Runtime(5.4.0-10984) could not return correct one.
static
constexpr
index_t
WaveSize
=
32
;
// When use LDS, each Row(16 consecutive lanes) read whole data from source buffer
// When not use LDS, each Row read half of whole data from source buffer, exchange the data via
// permutation
static
constexpr
index_t
A_KRow
=
2
;
static
constexpr
index_t
B_KRow
=
2
;
static
constexpr
index_t
A_K1
=
ABlockDesc
{}.
GetLength
(
I5
);
static
constexpr
index_t
B_K1
=
BBlockDesc
{}.
GetLength
(
I5
);
static
constexpr
auto
wmma_gemm
=
WmmaGemm
<
FloatA
,
FloatB
,
FloatAcc
,
MPerWMMA
,
NPerWMMA
,
KPack
,
TransposeC
>
{};
static
constexpr
index_t
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerWMMA
);
static
constexpr
index_t
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerWMMA
);
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
FloatAcc
,
MRepeat
*
NRepeat
,
wmma_gemm
.
GetRegSizePerWmma
(),
true
>
c_thread_buf_
;
__host__
__device__
constexpr
auto
&
GetCThreadBuffer
()
{
return
c_thread_buf_
;
}
__device__
static
auto
GetWaveIdx
()
{
const
index_t
thread_id
=
ThisThreadBlock
::
GetThreadId
();
constexpr
auto
threadid_to_wave_idx_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
MWaves
,
NWaves
,
WaveSize
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
threadid_to_wave_idx_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
thread_id
));
}
// Default, Block buffer in LDS, thread level offset enabled
__device__
static
auto
CalculateAThreadOriginDataIndex
()
{
if
constexpr
(
AEnableLds
)
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
WMMA_a_idx
=
wmma_gemm
.
CalculateAThreadOriginDataIndex
();
// |KRepeat |MRepeat|MWave |KRow |MLane |KPack
return
make_tuple
(
0
,
0
,
waveId_m
,
wmma_gemm
.
GetSubGroupId
(),
WMMA_a_idx
,
0
);
}
else
{
return
make_tuple
(
0
,
0
,
0
,
0
,
0
,
0
);
}
}
__device__
static
auto
CalculateBThreadOriginDataIndex
()
{
if
constexpr
(
BEnableLds
)
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
WMMA_b_idx
=
wmma_gemm
.
CalculateBThreadOriginDataIndex
();
// |KRepeat |NRepeat|Nwave |KRow |NLane |KPack
return
make_tuple
(
0
,
0
,
waveId_n
,
wmma_gemm
.
GetSubGroupId
(),
WMMA_b_idx
,
0
);
}
else
{
return
make_tuple
(
0
,
0
,
0
,
0
,
0
,
0
);
}
}
template
<
index_t
m0
,
index_t
n0
>
__device__
static
auto
CalculateCThreadOriginDataIndex
(
Number
<
m0
>
,
Number
<
n0
>
)
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
blk_idx
=
wmma_gemm
.
GetBeginOfThreadBlk
();
constexpr
auto
mrepeat_mwave_mperWMMA_to_m_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MRepeat
,
MWaves
,
MPerWMMA
))),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}));
constexpr
auto
nrepeat_nwave_nperWMMA_to_n_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
NRepeat
,
NWaves
,
NPerWMMA
))),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}));
const
index_t
c_thread_m
=
mrepeat_mwave_mperWMMA_to_m_adaptor
.
CalculateBottomIndex
(
make_tuple
(
m0
,
waveId_m
,
blk_idx
[
I0
]))[
I0
];
const
index_t
c_thread_n
=
nrepeat_nwave_nperWMMA_to_n_adaptor
.
CalculateBottomIndex
(
make_tuple
(
n0
,
waveId_n
,
blk_idx
[
I1
]))[
I0
];
return
make_tuple
(
c_thread_m
,
c_thread_n
);
}
template
<
index_t
m0
,
index_t
n0
>
__device__
static
auto
CalculateCThreadOriginDataIndex7D
(
Number
<
m0
>
,
Number
<
n0
>
)
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
blk_idx
=
wmma_gemm
.
GetBeginOfThreadBlk3D
();
return
make_tuple
(
Number
<
m0
>
{},
waveId_m
,
blk_idx
[
I0
],
Number
<
n0
>
{},
waveId_n
,
blk_idx
[
I1
],
blk_idx
[
I2
]);
}
using
Tuple6
=
decltype
(
CalculateAThreadOriginDataIndex
());
__host__
__device__
BlockwiseGemmWMMA
(
Tuple6
a_origin
=
CalculateAThreadOriginDataIndex
(),
Tuple6
b_origin
=
CalculateBThreadOriginDataIndex
())
:
a_thread_copy_
(
a_origin
),
b_thread_copy_
(
b_origin
)
{
static_assert
(
ABlockDesc
::
IsKnownAtCompileTime
()
&&
BBlockDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
static_assert
(
ThisThreadBlock
::
GetNumOfThread
()
==
MWaves
*
NWaves
*
WaveSize
,
"ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize
\n
"
);
static_assert
(
MPerBlock
%
(
MPerWMMA
*
MRepeat
)
==
0
&&
NPerBlock
%
(
NPerWMMA
*
NRepeat
)
==
0
,
"wrong!"
);
}
// transposed WMMA output C' = B' * A'
__host__
__device__
static
constexpr
auto
GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs
()
{
constexpr
auto
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
=
wmma_gemm
.
GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths
();
constexpr
auto
NAccVgprs
=
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
[
I2
];
return
make_naive_tensor_descriptor_packed
(
// |MRepeat |MWave |MSubGroup |NRepeat |NWave
// |NThreadPerSubGroup |MAccVgprs
make_tuple
(
Number
<
MRepeat
>
{},
I1
,
I1
,
Number
<
NRepeat
>
{},
I1
,
I1
,
NAccVgprs
));
}
// Thread level, register decriptor. Vector-write
__host__
__device__
static
constexpr
auto
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs
()
{
constexpr
auto
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
=
wmma_gemm
.
GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths
();
constexpr
auto
MAccVgprs
=
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
[
I2
];
constexpr
auto
AccStride
=
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
[
I3
];
return
make_naive_tensor_descriptor
(
// |MRepeat |MWave |MSubGroup |NRepeat |NWave
// |NThreadPerSubGroup |MAccVgprs
make_tuple
(
Number
<
MRepeat
>
{},
I1
,
I1
,
Number
<
NRepeat
>
{},
I1
,
I1
,
MAccVgprs
),
make_tuple
(
Number
<
NRepeat
>
{}
*
MAccVgprs
*
AccStride
,
Number
<
NRepeat
>
{}
*
MAccVgprs
*
AccStride
,
Number
<
NRepeat
>
{}
*
MAccVgprs
*
AccStride
,
MAccVgprs
*
AccStride
,
MAccVgprs
*
AccStride
,
MAccVgprs
*
AccStride
,
AccStride
));
}
template
<
typename
CGridDesc_M_N
>
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
const
auto
c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma
=
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M
/
(
MWaves
*
MPerWMMA
),
MWaves
,
MPerWMMA
)),
make_unmerge_transform
(
make_tuple
(
N
/
(
NWaves
*
NPerWMMA
),
NWaves
,
NPerWMMA
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
,
5
>
{}));
return
wmma_gemm
.
MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs
(
c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma
);
}
// transposed WMMA output C' = B' * A'
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs
()
{
constexpr
auto
c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerWMMA
>
{},
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerWMMA
>
{}));
return
wmma_gemm
.
MakeCDesc_MBlockxRepeat_MWave_MThreadPerSubGroup_NBlockxRepeat_NWave_NSubGroup_NAccVgprs
(
c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma
);
}
// Provide dimension size
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs
()
{
constexpr
auto
c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerWMMA
>
{},
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerWMMA
>
{}));
return
wmma_gemm
.
MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs
(
c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma
);
}
// Describe how data allocated in thread copy src buffer
// M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma
static
constexpr
ABlockDesc
a_block_desc_k0_m0_m1_m2_k1
;
static
constexpr
BBlockDesc
b_block_desc_k0_n0_n1_n2_k1
;
template
<
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CThreadBuffer
>
__device__
void
Run
(
const
ABlockBuffer
&
a_block_buf
,
const
BBlockBuffer
&
b_block_buf
,
CThreadBuffer
&
c_thread_buf
)
const
{
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatA
>
(
a_thread_desc_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatB
>
(
b_thread_desc_
.
GetElementSpaceSize
());
static_assert
(
KPack
%
(
A_K1
*
A_KRow
)
==
0
,
""
);
static_assert
(
KPack
%
(
B_K1
*
B_KRow
)
==
0
,
""
);
// basic intrinsic to determine loopover direction
if
constexpr
(
MRepeat
<
NRepeat
)
{
static_for
<
0
,
KPerBlock
/
KPack
,
1
>
{}(
[
&
](
auto
k
)
{
// k=0,1,2 instead of k=0,kpack*1, ...
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
// read A
a_thread_copy_
.
Run
(
a_block_desc_k0_m0_m1_m2_k1
,
make_tuple
(
Number
<
k
*
KPack
/
A_K1
/
A_KRow
>
{},
m0
,
I0
,
I0
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
m0
,
I0
,
I0
,
I0
,
I0
),
a_thread_buf
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
// read B
b_thread_copy_
.
Run
(
b_block_desc_k0_n0_n1_n2_k1
,
make_tuple
(
Number
<
k
*
KPack
/
B_K1
/
B_KRow
>
{},
n0
,
I0
,
I0
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
n0
,
I0
,
I0
,
I0
,
I0
),
b_thread_buf
);
vector_type
<
FloatA
,
KPack
/
A_KRow
>
a_thread_vec
;
vector_type
<
FloatB
,
KPack
/
B_KRow
>
b_thread_vec
;
static_for
<
0
,
KPack
/
A_KRow
,
1
>
{}([
&
](
auto
i
)
{
a_thread_vec
.
template
AsType
<
FloatA
>()(
i
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
i
/
A_K1
,
m0
,
0
,
0
,
0
,
i
%
A_K1
))
>
{}];
});
static_for
<
0
,
KPack
/
B_KRow
,
1
>
{}([
&
](
auto
i
)
{
b_thread_vec
.
template
AsType
<
FloatB
>()(
i
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
i
/
B_K1
,
n0
,
0
,
0
,
0
,
i
%
B_K1
))
>
{}];
});
using
wmma_input_type_a
=
typename
vector_type
<
FloatA
,
WmmaK
/
A_KRow
>::
type
;
using
wmma_input_type_b
=
typename
vector_type
<
FloatB
,
WmmaK
/
B_KRow
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
wmma_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
wmma_input_type_a
>(),
b_thread_vec
.
template
AsType
<
wmma_input_type_b
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
});
});
});
}
else
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
KPerBlock
/
KPack
,
1
>
{}([
&
](
auto
k
)
{
// k=0,1,2 instead of
// k=0,kpack*1, ..
// read B
b_thread_copy_
.
Run
(
b_block_desc_k0_n0_n1_n2_k1
,
make_tuple
(
Number
<
k
*
KPack
/
B_K1
/
B_KRow
>
{},
n0
,
I0
,
I0
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
n0
,
I0
,
I0
,
I0
,
I0
),
b_thread_buf
);
// read A
a_thread_copy_
.
Run
(
a_block_desc_k0_m0_m1_m2_k1
,
make_tuple
(
Number
<
k
*
KPack
/
A_K1
/
A_KRow
>
{},
m0
,
I0
,
I0
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
m0
,
I0
,
I0
,
I0
,
I0
),
a_thread_buf
);
vector_type
<
FloatA
,
KPack
/
A_KRow
>
a_thread_vec
;
vector_type
<
FloatB
,
KPack
/
B_KRow
>
b_thread_vec
;
static_for
<
0
,
KPack
/
A_KRow
,
1
>
{}([
&
](
auto
i
)
{
a_thread_vec
.
template
AsType
<
FloatA
>()(
i
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
i
/
A_K1
,
m0
,
0
,
0
,
0
,
i
%
A_K1
))
>
{}];
});
static_for
<
0
,
KPack
/
B_KRow
,
1
>
{}([
&
](
auto
i
)
{
b_thread_vec
.
template
AsType
<
FloatB
>()(
i
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
i
/
B_K1
,
n0
,
0
,
0
,
0
,
i
%
B_K1
))
>
{}];
});
using
wmma_input_type_a
=
typename
vector_type
<
FloatA
,
WmmaK
/
A_KRow
>::
type
;
using
wmma_input_type_b
=
typename
vector_type
<
FloatB
,
WmmaK
/
B_KRow
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
wmma_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
wmma_input_type_a
>(),
b_thread_vec
.
template
AsType
<
wmma_input_type_b
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
});
});
});
}
}
protected:
static
constexpr
auto
a_thread_desc_
=
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KPack
/
A_K1
/
A_KRow
>
{},
Number
<
MRepeat
>
{},
I1
,
I1
,
I1
,
Number
<
A_K1
>
{}),
make_tuple
(
Number
<
A_K1
>
{},
Number
<
KPack
/
A_KRow
>
{},
Number
<
A_K1
>
{},
Number
<
A_K1
>
{},
Number
<
A_K1
>
{},
Number
<
1
>
{}));
static
constexpr
auto
b_thread_desc_
=
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KPack
/
B_K1
/
B_KRow
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
I1
,
Number
<
B_K1
>
{}),
make_tuple
(
Number
<
B_K1
>
{},
Number
<
KPack
/
B_KRow
>
{},
Number
<
B_K1
>
{},
Number
<
B_K1
>
{},
Number
<
B_K1
>
{},
Number
<
1
>
{}));
// C[M, N, NumRegWMMA]
static
constexpr
auto
c_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
wmma_gemm
.
GetRegSizePerWmma
()));
template
<
bool
EnableLds
>
struct
AThreadCopySelector
;
template
<
>
struct
AThreadCopySelector
<
true
>
{
using
type
=
ThreadwiseTensorSliceTransfer_v4
<
FloatA
,
FloatA
,
decltype
(
a_block_desc_k0_m0_m1_m2_k1
),
decltype
(
a_thread_desc_
),
Sequence
<
KPack
/
A_K1
/
A_KRow
,
1
,
1
,
1
,
1
,
A_K1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
A_K1
,
A_K1
>
;
};
template
<
>
struct
AThreadCopySelector
<
false
>
{
using
type
=
ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow
<
FloatA
,
FloatA
,
decltype
(
a_block_desc_k0_m0_m1_m2_k1
),
decltype
(
a_thread_desc_
),
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
KPack
/
A_K1
/
A_KRow
,
1
,
1
,
1
,
1
,
A_K1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
A_K1
,
false
>
;
};
template
<
bool
EnableLds
>
struct
BThreadCopySelector
;
template
<
>
struct
BThreadCopySelector
<
true
>
{
using
type
=
ThreadwiseTensorSliceTransfer_v4
<
FloatB
,
FloatB
,
decltype
(
b_block_desc_k0_n0_n1_n2_k1
),
decltype
(
b_thread_desc_
),
Sequence
<
KPack
/
B_K1
/
B_KRow
,
1
,
1
,
1
,
1
,
B_K1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
B_K1
,
B_K1
>
;
};
template
<
>
struct
BThreadCopySelector
<
false
>
{
using
type
=
ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow
<
FloatB
,
FloatB
,
decltype
(
b_block_desc_k0_n0_n1_n2_k1
),
decltype
(
b_thread_desc_
),
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
KPack
/
B_K1
/
B_KRow
,
1
,
1
,
1
,
1
,
B_K1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
B_K1
,
false
>
;
};
typename
AThreadCopySelector
<
AEnableLds
>::
type
a_thread_copy_
;
typename
BThreadCopySelector
<
BEnableLds
>::
type
b_thread_copy_
;
};
#else
template
<
index_t
BlockSize
,
typename
FloatA
,
typename
FloatB
,
...
...
@@ -527,5 +1025,6 @@ struct BlockwiseGemmWMMA
typename
AThreadCopySelector
<
AEnableLds
>::
type
a_thread_copy_
;
typename
BThreadCopySelector
<
BEnableLds
>::
type
b_thread_copy_
;
};
#endif
}
// namespace ck
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
View file @
909f519c
...
...
@@ -487,7 +487,14 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
// sync point.
if
constexpr
(
k
.
value
!=
0
||
KPerInnerLoop
==
KPerThread
)
{
#ifdef __gfx12__
asm
volatile
(
"\
s_barrier_signal -1
\n
\
s_barrier_wait -1 \
"
::
);
#else
asm
volatile
(
"s_barrier"
::
);
#endif
__builtin_amdgcn_sched_barrier
(
0
);
}
static_for
<
0
,
KPerInnerLoop
,
KPack
>
{}([
&
](
auto
k_
)
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp
View file @
909f519c
...
...
@@ -133,8 +133,13 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
static
constexpr
auto
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerWmma
);
static
constexpr
auto
WmmaK
=
K1
==
16
?
32
:
16
;
static
constexpr
auto
AEnableLds_auto
=
NWaves
==
1
?
false
:
true
;
static
constexpr
auto
BEnableLds_auto
=
MWaves
==
1
?
false
:
true
;
static
constexpr
auto
MaxVectorLoadA
=
K1
*
sizeof
(
ADataType
)
==
16
?
true
:
false
;
static
constexpr
auto
MaxVectorLoadB
=
K1
*
sizeof
(
BDataType
)
==
16
?
true
:
false
;
static
constexpr
auto
AEnableLds_auto
=
(
NWaves
==
1
&&
(
MaxVectorLoadA
||
MRepeat
==
1
))
?
false
:
true
;
static
constexpr
auto
BEnableLds_auto
=
(
MWaves
==
1
&&
(
MaxVectorLoadB
||
NRepeat
==
1
))
?
false
:
true
;
// If true, LDS is used unconditionally
static
constexpr
auto
AEnableLds_manu
=
false
;
...
...
@@ -829,7 +834,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
ck
::
is_gfx11_supported
())
if
(
ck
::
is_gfx11_supported
()
||
ck
::
is_gfx12_supported
()
)
{
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
int32_t
>
))
{
...
...
@@ -869,11 +874,15 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
}
else
{
if
(
!
(
arg
.
a_kz_stride_
==
1
&&
arg
.
a_grid_desc_
.
GetLength
(
I2
)
%
ABlockTransferSrcScalarPerVector
==
0
))
if
(
!
(
arg
.
a_kz_stride_
==
1
))
{
printf
(
"DeviceOp: Vector Access A-k check failure
\n
"
);
return
false
;
index_t
LastK
=
AEnableLds
?
arg
.
a_grid_desc_
.
GetLength
(
I2
)
:
arg
.
a_grid_desc_
.
GetLength
(
I6
);
if
(
LastK
%
ABlockTransferSrcScalarPerVector
==
0
)
{
printf
(
"DeviceOp: Vector Access A-k check failure
\n
"
);
return
false
;
}
}
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp
View file @
909f519c
...
...
@@ -70,8 +70,9 @@ __global__ void
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
,
const
Block2CTileMap
block_2_ctile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__) || \
defined(__gfx12__))
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
...
...
@@ -648,7 +649,7 @@ struct DeviceBatchedGemmMultipleD_Dl : public DeviceBatchedGemmMultiD<ALayout,
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
ck
::
get_device_name
()
==
"gfx906"
||
ck
::
is_xdl_supported
()
||
ck
::
is_gfx103_supported
()
||
ck
::
is_gfx11_supported
())
ck
::
is_gfx103_supported
()
||
ck
::
is_gfx11_supported
()
||
ck
::
is_gfx12_supported
()
)
{
bool
pass
=
true
;
pass
=
pass
&&
arg
.
K_
%
K1
==
0
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
View file @
909f519c
...
...
@@ -56,7 +56,7 @@ __global__ void
bool
input_permute
,
bool
output_permute
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)
|| defined(__gfx12__)
)
// clang-format off
// ***************************************************
...
...
@@ -159,6 +159,7 @@ __global__ void
ignore
=
O
;
ignore
=
G0
;
ignore
=
G1
;
ignore
=
alpha
;
ignore
=
input_permute
;
ignore
=
output_permute
;
#endif // end of if (defined(__gfx11__))
...
...
@@ -187,7 +188,7 @@ __global__ void
index_t
head_size
,
float
alpha
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)
|| defined(__gfx12__)
)
// clang-format off
// ***************************************************
...
...
@@ -321,7 +322,7 @@ __global__ void
index_t
head_size
,
float
alpha
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)
|| defined(__gfx12__)
)
// clang-format off
// ***************************************************
...
...
@@ -858,7 +859,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
static
bool
IsSupportedArgument
(
const
RawArg
&
arg
)
{
if
(
ck
::
is_gfx11_supported
())
if
(
ck
::
is_gfx11_supported
()
||
ck
::
is_gfx12_supported
()
)
{
if
constexpr
(
!
(
is_same_v
<
Acc0DataType
,
float
>
||
is_same_v
<
Acc0DataType
,
int32_t
>
))
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp
View file @
909f519c
...
...
@@ -592,9 +592,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
return
false
;
}
if
(
ck
::
get_device_name
()
!=
"gfx90a"
&&
ck
::
get_device_name
()
!=
"gfx940"
&&
ck
::
get_device_name
()
!=
"gfx941"
&&
ck
::
get_device_name
()
!=
"gfx942"
&&
std
::
is_same
<
ADataType
,
double
>::
value
)
if
(
!
ck
::
is_lds_direct_load_supported
()
&&
std
::
is_same
<
ADataType
,
double
>::
value
)
{
return
false
;
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp
View file @
909f519c
...
...
@@ -1393,7 +1393,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl
{
// check device
if
(
!
(
ck
::
get_device_name
()
==
"gfx906"
||
ck
::
is_gfx103_supported
()
||
ck
::
is_gfx11_supported
()))
ck
::
is_gfx11_supported
()
||
ck
::
is_gfx12_supported
()
))
{
return
false
;
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp
View file @
909f519c
...
...
@@ -509,7 +509,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
ck
::
is_gfx11_supported
())
if
(
ck
::
is_gfx11_supported
()
||
ck
::
is_gfx12_supported
()
)
{
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
ck
::
half_t
>
||
is_same_v
<
AccDataType
,
int32_t
>
))
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp
View file @
909f519c
...
...
@@ -536,7 +536,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
}
if
(
ck
::
get_device_name
()
==
"gfx906"
||
ck
::
is_gfx103_supported
()
||
ck
::
is_gfx11_supported
())
ck
::
is_gfx11_supported
()
||
ck
::
is_gfx12_supported
()
)
{
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_m_n_
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp
View file @
909f519c
...
...
@@ -50,8 +50,9 @@ __global__ void
const
CGridDesc_M0_M10_M11_N0_N10_N11
e_grid_desc_m0_m10_m11_n0_n10_n11
,
const
Block2CTileMap
block_2_ctile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__) || \
defined(__gfx12__))
constexpr
index_t
shared_block_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
ABDataType
);
...
...
@@ -552,7 +553,7 @@ struct DeviceGemmMultipleD_Dl : public DeviceGemmMultipleD<ALayout,
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
ck
::
get_device_name
()
==
"gfx906"
||
ck
::
is_xdl_supported
()
||
ck
::
is_gfx103_supported
()
||
ck
::
is_gfx11_supported
())
ck
::
is_gfx103_supported
()
||
ck
::
is_gfx11_supported
()
||
ck
::
is_gfx12_supported
()
)
{
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
e_grid_desc_m_n_
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp
View file @
909f519c
...
...
@@ -515,7 +515,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
ck
::
is_gfx11_supported
())
if
(
ck
::
is_gfx11_supported
()
||
ck
::
is_gfx12_supported
()
)
{
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
int32_t
>
))
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
View file @
909f519c
...
...
@@ -84,14 +84,21 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
// K1 = Max Vector Access Pixels
static
constexpr
auto
K1Number
=
Number
<
K1
>
{};
static
constexpr
auto
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerWmma
);
static
constexpr
auto
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerWmma
);
static
constexpr
auto
WmmaK
=
K1
==
16
?
32
:
16
;
static
constexpr
auto
AEnableLds_auto
=
(
NWaves
==
1
&&
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
?
false
:
true
;
static
constexpr
auto
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerWmma
);
static
constexpr
auto
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerWmma
);
static
constexpr
auto
WmmaK
=
K1
==
16
?
32
:
16
;
static
constexpr
auto
MaxVectorLoadA
=
K1
*
sizeof
(
ADataType
)
==
16
?
true
:
false
;
static
constexpr
auto
MaxVectorLoadB
=
K1
*
sizeof
(
BDataType
)
==
16
?
true
:
false
;
static
constexpr
auto
AEnableLds_auto
=
(
NWaves
==
1
&&
(
MaxVectorLoadA
||
MRepeat
==
1
)
&&
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
?
false
:
true
;
static
constexpr
auto
BEnableLds_auto
=
(
MWaves
==
1
&&
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
?
false
:
true
;
(
MWaves
==
1
&&
(
MaxVectorLoadB
||
NRepeat
==
1
)
&&
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
?
false
:
true
;
// If true, LDS is used unconditionally
static
constexpr
auto
AEnableLds_manu
=
false
;
...
...
@@ -443,7 +450,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
ck
::
is_gfx11_supported
())
if
(
ck
::
is_gfx11_supported
()
||
ck
::
is_gfx12_supported
()
)
{
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
ck
::
half_t
>
||
is_same_v
<
AccDataType
,
int32_t
>
))
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp
View file @
909f519c
...
...
@@ -629,7 +629,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
// check device
if
(
ck
::
is_gfx11_supported
())
if
(
ck
::
is_gfx11_supported
()
||
ck
::
is_gfx12_supported
()
)
{
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
int32_t
>
))
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp
View file @
909f519c
...
...
@@ -48,8 +48,9 @@ __global__ void
const
Block2CTileMap
block_2_ctile_map
,
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \
defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \
defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__) || \
defined(__gfx12__))
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp
View file @
909f519c
...
...
@@ -692,7 +692,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
// check device
if
(
ck
::
is_gfx11_supported
())
if
(
ck
::
is_gfx11_supported
()
||
ck
::
is_gfx12_supported
()
)
{
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
int32_t
>
))
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp
View file @
909f519c
...
...
@@ -90,8 +90,9 @@ __global__ void
const
Block2CTileMap
block_2_ctile_map
,
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \
defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \
defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__) || \
defined(__gfx12__))
// offset base pointer for each work-group
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
...
...
@@ -667,7 +668,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
// check device
if
(
!
(
ck
::
get_device_name
()
==
"gfx906"
||
ck
::
is_xdl_supported
()
||
ck
::
is_gfx103_supported
()
||
ck
::
is_gfx11_supported
()))
ck
::
is_gfx103_supported
()
||
ck
::
is_gfx11_supported
()
||
ck
::
is_gfx12_supported
()
))
{
return
false
;
}
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment