Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
9e1dd262
Commit
9e1dd262
authored
Jul 27, 2023
by
Jing Zhang
Browse files
clean code
parent
91075f0f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
109 additions
and
29 deletions
+109
-29
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp
...tion/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp
+31
-27
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp
...gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp
+78
-2
No files found.
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp
View file @
9e1dd262
...
...
@@ -193,6 +193,8 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
static
const
index_t
k_batch
=
1
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
...
...
@@ -574,15 +576,15 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
});
// tensor descriptors for problem definiton
const
auto
a_grid_desc_m_k
=
DeviceOp
::
MakeAGridDescriptor_M_K
(
M
,
K
,
StrideA
);
const
auto
b_grid_desc_n_k
=
DeviceOp
::
MakeBGridDescriptor_N_K
(
K
,
N
,
StrideB
);
//
const auto a_grid_desc_m_k = DeviceOp::MakeAGridDescriptor_M_K(M, K, StrideA);
//
const auto b_grid_desc_n_k = DeviceOp::MakeBGridDescriptor_N_K(K, N, StrideB);
DsGridDesc_M_N
ds_grid_desc_m_n
;
//
DsGridDesc_M_N ds_grid_desc_m_n;
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
j
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
j
.
value
,
DsLayout
>>
;
//
using DLayout = remove_cvref_t<tuple_element_t<j.value, DsLayout>>;
if
(
gemm_descs
[
i
].
stride_Ds_
.
size
()
!=
NumDTensor
)
{
...
...
@@ -590,9 +592,9 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
"wrong! gemm_descs[i].stride_Ds_.size() does not match NumDTensor"
);
}
StrideDs
[
j
]
=
gemm_descs
[
i
].
stride_Ds_
[
j
];
ds_grid_desc_m_n
(
j
)
=
DeviceOp
::
MakeEGridDescriptor_M_N
<
DLayout
>
(
M
,
N
,
gemm_descs
[
i
].
stride_Ds_
[
j
]);
StrideDs
[
j
]
=
gemm_descs
[
i
].
stride_Ds_
[
j
];
//
ds_grid_desc_m_n(j) = DeviceOp::MakeEGridDescriptor_M_N<DLayout>(
//
M, N, gemm_descs[i].stride_Ds_[j]);
});
#if 0
...
...
@@ -619,32 +621,34 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
grid_size_
+=
grid_size_grp
;
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_m_k
,
b_grid_desc_n_k
,
ds_grid_desc_m_n
,
e_grid_desc_m_n
,
local_b2c_tile_map
))
// check block-to-E-tile
if
(
!
local_b2c_tile_map
.
CheckValidity
(
e_grid_desc_m_n
))
{
gemm_desc_kernel_arg_
.
push_back
(
GemmBiasTransKernelArg
{
p_As
.
size
()
==
0
?
nullptr
:
p_As
[
i
],
p_Bs
.
size
()
==
0
?
nullptr
:
p_Bs
[
i
],
p_ds_grid
,
p_Es
[
i
],
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideDs
,
StrideC
,
});
throw
std
::
runtime_error
(
"wrong! block_2_etile_map validation failed"
);
}
else
if
(
!
GridwiseGemm
::
template
CheckValidity
<
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
GemmSpec
>(
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideDs
,
StrideC
,
1
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"
);
}
gemm_desc_kernel_arg_
.
push_back
(
GemmBiasTransKernelArg
{
p_As
.
size
()
==
0
?
nullptr
:
p_As
[
i
],
p_Bs
.
size
()
==
0
?
nullptr
:
p_Bs
[
i
],
p_ds_grid
,
p_Es
[
i
],
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideDs
,
StrideC
,
});
group_id
++
;
}
}
...
...
@@ -682,7 +686,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
gemm_desc_kernel_arg_
.
size
();
i
++
)
{
const
auto
KPad
=
GridwiseGemm
::
CalculateKPadded
(
arg
.
gemm_desc_kernel_arg_
[
i
].
K_
,
1
);
GridwiseGemm
::
CalculateKPadded
(
arg
.
gemm_desc_kernel_arg_
[
i
].
K_
,
k_batch
);
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
KPad
)
!=
has_main_k_block_loop
)
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp
View file @
9e1dd262
...
...
@@ -16,6 +16,7 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
namespace
ck
{
...
...
@@ -393,6 +394,71 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
e_grid_desc_m_n
);
}
template
<
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
GemmSpecialization
GemmSpec
>
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
index_t
M
,
const
index_t
N
,
const
index_t
K
,
const
index_t
StrideA
,
const
index_t
StrideB
,
const
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
,
const
index_t
StrideE
,
const
index_t
KBatch
=
1
)
{
const
auto
a_grid_desc_kbatch_ak0_m_ak1
=
MakeAGridDescriptor_KBatch_AK0_M_AK1
<
ALayout
,
GemmSpec
>
(
M
,
K
,
StrideA
,
KBatch
);
const
auto
b_grid_desc_kbatch_bk0_n_bk1
=
MakeBGridDescriptor_KBatch_BK0_N_BK1
<
BLayout
,
GemmSpec
>
(
K
,
N
,
StrideB
,
KBatch
);
ignore
=
StrideDs
;
// using DsGridDesc_M_N =
// remove_cvref_t<decltype(MakeDsGridDescriptor_M_N<DsLayout, GemmSpec>({}, {}, {}))>;
// DsGridDesc_M_N ds_grid_desc_m_n;
// static_for<0, NumDTensor, 1>{}([&](auto j) {
// using DLayout = remove_cvref_t<tuple_element_t<j.value, DsLayout>>;
// ds_grid_desc_m_n(j) = MakeEGridDescriptor_M_N<DLayout, GemmSpec>(M, N, StrideDs[j]);
//});
const
auto
e_grid_desc_m_n
=
MakeEGridDescriptor_M_N
<
ELayout
,
GemmSpec
>
(
M
,
N
,
StrideE
);
#if 0
// check tile size
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
{
return false;
}
#endif
// check gridwise gemm pipeline
const
auto
num_k_loop
=
K
/
KPerBlock
;
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_k_loop
))
{
return
false
;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
// check tensor size: cannot be larger than 2GB each
constexpr
long_index_t
TwoGB
=
(
long_index_t
{
1
}
<<
31
);
if
(
!
(
a_grid_desc_kbatch_ak0_m_ak1
.
GetElementSpaceSize
()
*
sizeof
(
ABDataType
)
<=
TwoGB
&&
b_grid_desc_kbatch_bk0_n_bk1
.
GetElementSpaceSize
()
*
sizeof
(
ABDataType
)
<=
TwoGB
&&
e_grid_desc_m_n
.
GetElementSpaceSize
()
*
sizeof
(
EDataType
)
<=
TwoGB
))
{
return
false
;
}
return
true
;
}
#if 0
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename AGridDesc_M_K,
typename BGridDesc_N_K,
...
...
@@ -464,6 +530,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
return true;
}
#endif
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
{
...
...
@@ -616,12 +683,23 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
}
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const
index_t
kbatch_id
=
0
;
//__builtin_amdgcn_readfirstlane(block_work_idx[I0]);
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]
*
MPerBlock
);
const
index_t
n_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]
*
NPerBlock
);
// if(get_thread_local_1d_id() == 0)
//{
// printf("%d %d %d %d\n",
// get_block_1d_id(),
// kbatch_id,
// block_work_idx[I1],
// block_work_idx[I2]);
//}
// lds max alignment
constexpr
auto
max_lds_align
=
math
::
lcm
(
AK1
,
BK1
);
...
...
@@ -633,8 +711,6 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
constexpr
auto
b_block_desc_kbatch_bk0_n_bk1
=
GetBBlockDescriptor_KBatch_BK0PerBlock_NPerBlock_BK1
();
const
index_t
kbatch_id
=
0
;
// A matrix blockwise copy
auto
a_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment