Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
bc45fd98
Commit
bc45fd98
authored
Jul 27, 2023
by
Jing Zhang
Browse files
clean code
parent
ee1c88b1
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
13 additions
and
188 deletions
+13
-188
example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_bias_fp16.cpp
...e/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_bias_fp16.cpp
+9
-3
example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp
example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp
+2
-2
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp
...tion/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp
+2
-4
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp
...gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp
+0
-179
No files found.
example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_bias_fp16.cpp
View file @
bc45fd98
...
...
@@ -59,7 +59,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl_F
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
;
<
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
;
// clang-format on
struct
ProblemSize
final
...
...
@@ -202,8 +202,13 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
p_Cs
.
push_back
(
c_tensors_device
[
i
]
->
GetDeviceBuffer
());
gemm_descs
.
push_back
(
{
1
,
problem_size
.
Ns
[
i
],
problem_size
.
Ks
[
i
],
1
,
problem_size
.
stride_Bs
[
i
],
1
,
{
0
}});
gemm_descs
.
push_back
({
sum_of_m
,
problem_size
.
Ns
[
i
],
problem_size
.
Ks
[
i
],
1
,
problem_size
.
stride_Bs
[
i
],
1
,
{
0
}});
grouped_gemm_kernel_args_
.
push_back
(
{
a_tensors_device
[
i
]
->
GetDeviceBuffer
(),
...
...
@@ -249,6 +254,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
}
gemm
.
SetDeviceKernelArgs
(
argument
,
gemm_desc_workspace
.
GetDeviceBuffer
());
gemm
.
SetKBatch
(
argument
,
1
);
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
});
...
...
example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp
View file @
bc45fd98
...
...
@@ -55,10 +55,10 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl_F
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<2, 0, 1, 3>, S<2, 0, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<2, 0, 1, 3>, S<2, 0, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
<
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
2
,
0
,
1
,
3
>
,
S
<
2
,
0
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
2
,
0
,
1
,
3
>
,
S
<
2
,
0
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
;
//< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<2, 0, 1, 3>, S<2, 0, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<2, 0, 1, 3>, S<2, 0, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
//< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 64, 128, 64, 8, 8, 32, 32, 1, 2, S<1, 8, 32, 1>, S<2, 0, 1, 3>, S<2, 0, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<2, 0, 1, 3>, S<2, 0, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
<
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmDefault
,
1
,
256
,
64
,
128
,
64
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
8
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
8
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
;
//
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 64, 128, 64, 8, 8, 32, 32, 1, 2, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// clang-format on
struct
ProblemSize
final
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp
View file @
bc45fd98
...
...
@@ -13,8 +13,6 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
//#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
//#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
...
...
@@ -419,7 +417,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
throw
std
::
runtime_error
(
"wrong! k_batch must be > 0"
);
}
const
index_t
AverM
=
sum_of_m
/
group_count_
;
const
index_t
AverM
=
math
::
integer_divide_ceil
(
sum_of_m
,
group_count_
)
;
const
index_t
StrideE
=
gemm_desc_kernel_arg_
[
0
].
StrideE_
;
const
index_t
N
=
gemm_desc_kernel_arg_
[
0
].
N_
;
...
...
@@ -481,7 +479,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
index_t
group_id
=
0
;
sum_of_m
=
gemm_descs
[
0
].
M_
;
const
index_t
AverM
=
sum_of_m
/
group_count_
;
const
index_t
AverM
=
math
::
integer_divide_ceil
(
sum_of_m
,
group_count_
)
;
const
index_t
N
=
gemm_descs
[
0
].
N_
;
const
index_t
K
=
gemm_descs
[
0
].
K_
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp
View file @
bc45fd98
...
...
@@ -200,42 +200,6 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
c_block_size
*
sizeof
(
CShuffleDataType
));
}
#if 0
// A desc for source in blockwise copy
template <typename AGridDesc_M_K>
__host__ __device__ static constexpr auto
MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k)
{
const auto M = a_grid_desc_m_k.GetLength(I0);
const auto K = a_grid_desc_m_k.GetLength(I1);
const auto AK0 = K / AK1;
return transform_tensor_descriptor(a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
// B desc for source in blockwise copy
template <typename BGridDesc_N_K>
__host__ __device__ static constexpr auto
MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k)
{
const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = b_grid_desc_n_k.GetLength(I1);
const auto BK0 = K / BK1;
return transform_tensor_descriptor(b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
#endif
__host__
__device__
static
auto
CalculateMPadded
(
index_t
M
)
{
return
math
::
integer_least_multiple
(
M
,
MPerBlock
);
...
...
@@ -414,16 +378,6 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
MakeBGridDescriptor_KBatch_BK0_N_BK1
<
BLayout
,
GemmSpec
>
(
K
,
N
,
StrideB
,
KBatch
);
ignore
=
StrideDs
;
// using DsGridDesc_M_N =
// remove_cvref_t<decltype(MakeDsGridDescriptor_M_N<DsLayout, GemmSpec>({}, {}, {}))>;
// DsGridDesc_M_N ds_grid_desc_m_n;
// static_for<0, NumDTensor, 1>{}([&](auto j) {
// using DLayout = remove_cvref_t<tuple_element_t<j.value, DsLayout>>;
// ds_grid_desc_m_n(j) = MakeEGridDescriptor_M_N<DLayout, GemmSpec>(M, N, StrideDs[j]);
//});
const
auto
e_grid_desc_m_n
=
MakeEGridDescriptor_M_N
<
ELayout
,
GemmSpec
>
(
M
,
N
,
StrideE
);
...
...
@@ -457,80 +411,6 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
return
true
;
}
#if 0
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename AGridDesc_M_K,
typename BGridDesc_N_K,
typename DsGridDesc_M_N,
typename EGridDesc_M_N,
typename Block2ETileMap>
__host__ __device__ static constexpr bool CheckValidity(const AGridDesc_M_K& a_grid_desc_m_k,
const BGridDesc_N_K& b_grid_desc_n_k,
const DsGridDesc_M_N& ds_grid_desc_m_n,
const EGridDesc_M_N& e_grid_desc_m_n,
const Block2ETileMap& block_2_etile_map)
{
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
"Invalid tuning param!");
const auto M = a_grid_desc_m_k.GetLength(I0);
const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = a_grid_desc_m_k.GetLength(I1);
// check consistency of desc
if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1)))
{
return false;
}
bool valid = true;
static_for<0, NumDTensor, 1>{}([&](auto i) {
valid = valid && (M == ds_grid_desc_m_n[i].GetLength(I0) &&
N == ds_grid_desc_m_n[i].GetLength(I1));
});
if(!valid)
{
return false;
}
// check tile size
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
{
return false;
}
// check gridwise gemm pipeline
const auto num_k_loop = K / KPerBlock;
if(!GridwiseGemmPipe::IsSupported(num_k_loop))
{
return false;
}
// check block-to-E-tile
if(!block_2_etile_map.CheckValidity(e_grid_desc_m_n))
{
return false;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
// check tensor size: cannot be larger than 2GB each
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
if(!(a_grid_desc_m_k.GetElementSpaceSize() * sizeof(ABDataType) <= TwoGB &&
b_grid_desc_n_k.GetElementSpaceSize() * sizeof(ABDataType) <= TwoGB &&
e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB))
{
return false;
}
return true;
}
#endif
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
{
const
index_t
num_loop
=
K
/
KPerBlock
;
...
...
@@ -540,56 +420,6 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
using
DsGridPointer
=
decltype
(
MakeDsGridPointer
());
#if 0
template <typename ALayout, GemmSpecialization GemmSpec>
__host__ __device__ static auto
MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA)
{
constexpr auto matrix_padder =
ck::tensor_operation::device::MatrixPadder<GemmSpec, index_t, index_t, index_t>{
MPerBlock, NPerBlock, KPerBlock};
const auto a_grid_desc_mraw_kraw = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(StrideA, I1));
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(I1, StrideA));
}
}();
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
}
template <typename BLayout, GemmSpecialization GemmSpec>
__host__ __device__ static auto
MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
{
constexpr auto matrix_padder =
ck::tensor_operation::device::MatrixPadder<GemmSpec, index_t, index_t, index_t>{
MPerBlock, NPerBlock, KPerBlock};
const auto b_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(I1, StrideB));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(StrideB, I1));
}
}();
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
}
#endif
template
<
typename
ELayout
,
GemmSpecialization
GemmSpec
>
__host__
__device__
static
auto
MakeEGridDescriptor_M_N
(
index_t
MRaw
,
index_t
NRaw
,
index_t
StrideE
)
...
...
@@ -683,15 +513,6 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
const
index_t
n_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I2
]
*
NPerBlock
);
// if(get_thread_local_1d_id() == 0)
//{
// printf("%d %d %d %d\n",
// get_block_1d_id(),
// kbatch_id,
// block_work_idx[I1],
// block_work_idx[I2]);
//}
// lds max alignment
constexpr
auto
max_lds_align
=
math
::
lcm
(
AK1
,
BK1
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment