Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
7d42a6d4
"vscode:/vscode.git/clone" did not exist on "d38c804320192c3844ff0bc7deed83e8b8cb7856"
Commit
7d42a6d4
authored
May 11, 2022
by
ltqin
Browse files
add thread copy desc and register buffer
parent
c08dcaad
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
81 additions
and
24 deletions
+81
-24
example/01_gemm/gemm_xdl_skip_lds_fp16.cpp
example/01_gemm/gemm_xdl_skip_lds_fp16.cpp
+8
-8
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_lds_v2r3.hpp
...operation/gpu/grid/gridwise_gemm_xdlops_skip_lds_v2r3.hpp
+73
-16
No files found.
example/01_gemm/gemm_xdl_skip_lds_fp16.cpp
View file @
7d42a6d4
...
@@ -51,8 +51,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSkipLds
...
@@ -51,8 +51,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSkipLds
//###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
;
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
32
,
64
,
4
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
;
// clang-format on
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
...
@@ -64,13 +64,13 @@ int main(int argc, char* argv[])
...
@@ -64,13 +64,13 @@ int main(int argc, char* argv[])
int
nrepeat
=
5
;
int
nrepeat
=
5
;
// GEMM shape
// GEMM shape
ck
::
index_t
M
=
3840
;
ck
::
index_t
M
=
64
;
ck
::
index_t
N
=
4096
;
ck
::
index_t
N
=
128
;
ck
::
index_t
K
=
409
6
;
ck
::
index_t
K
=
6
4
;
ck
::
index_t
StrideA
=
409
6
;
ck
::
index_t
StrideA
=
6
4
;
ck
::
index_t
StrideB
=
409
6
;
ck
::
index_t
StrideB
=
6
4
;
ck
::
index_t
StrideC
=
4096
;
ck
::
index_t
StrideC
=
128
;
if
(
argc
==
4
)
if
(
argc
==
4
)
{
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_lds_v2r3.hpp
View file @
7d42a6d4
...
@@ -202,6 +202,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
...
@@ -202,6 +202,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
// K1 should be Number<...>
// K1 should be Number<...>
static
constexpr
auto
K1
=
Number
<
K1Value
>
{};
static
constexpr
auto
K1
=
Number
<
K1Value
>
{};
static
constexpr
index_t
WaveSize
=
64
;
static
constexpr
index_t
MWaves
=
MPerBlock
/
(
MXdlPerWave
*
MPerXDL
);
static
constexpr
index_t
NWaves
=
NPerBlock
/
(
NXdlPerWave
*
NPerXDL
);
static
constexpr
auto
xdlops_gemm
=
XdlopsGemm
<
FloatAB
,
MPerXDL
,
NPerXDL
,
K1
>
{};
static
constexpr
index_t
KPerThread
=
K0PerBlock
/
xdlops_gemm
.
K0PerXdlops
;
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
()
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
()
{
{
constexpr
auto
max_lds_align
=
K1
;
constexpr
auto
max_lds_align
=
K1
;
...
@@ -349,7 +356,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
...
@@ -349,7 +356,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
const
auto
K0
=
b_grid_desc_k0_n_k1
.
GetLength
(
I0
);
const
auto
K0
=
b_grid_desc_k0_n_k1
.
GetLength
(
I0
);
const
auto
N
=
b_grid_desc_k0_n_k1
.
GetLength
(
I1
);
const
auto
N
=
b_grid_desc_k0_n_k1
.
GetLength
(
I1
);
constexpr
index_t
NWaves
=
NPerBlock
/
(
NXdlPerWave
*
NPerXDL
);
const
auto
b_griddesc_k0_nblockid_nrepeat_waves_nperxdlops_k1
=
transform_tensor_descriptor
(
const
auto
b_griddesc_k0_nblockid_nrepeat_waves_nperxdlops_k1
=
transform_tensor_descriptor
(
b_grid_desc_k0_n_k1
,
b_grid_desc_k0_n_k1
,
make_tuple
(
make_pass_through_transform
(
K0
),
make_tuple
(
make_pass_through_transform
(
K0
),
...
@@ -360,6 +366,29 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
...
@@ -360,6 +366,29 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
,
4
>
{},
Sequence
<
5
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
,
4
>
{},
Sequence
<
5
>
{}));
return
b_griddesc_k0_nblockid_nrepeat_waves_nperxdlops_k1
;
return
b_griddesc_k0_nblockid_nrepeat_waves_nperxdlops_k1
;
}
}
__device__
static
auto
GetWaveIdx
()
{
const
index_t
thread_id
=
get_thread_local_1d_id
();
constexpr
auto
threadid_to_wave_idx_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
MWaves
,
NWaves
,
WaveSize
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
threadid_to_wave_idx_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
thread_id
));
}
__device__
static
auto
GetWaveKNIdx
(
const
index_t
thread_id
)
{
constexpr
auto
wave_threadid_to_nk_idx_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
KPerThread
,
NPerXDL
))),
make_tuple
(
Sequence
<
0
,
1
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
wave_threadid_to_nk_idx_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
thread_id
));
}
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
{
...
@@ -525,26 +554,54 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
...
@@ -525,26 +554,54 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
// B matrix blockwise copy
// B matrix blockwise copy
constexpr
auto
b_thread_copy_desc_k0_n0_n1_n2_n3_k1
=
/*static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerXDL, NPerXDL, K1>{};
static constexpr index_t KPerThread = K0PerBlock / xdlops_gemm.K0PerXdlops;
constexpr auto b_k0_n0_n1_n2_n3_k1_thread_copy_desc =
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
KPerThread
>
{},
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
KPerThread
>
{},
I1
,
// NBlockId
Number
<
MXdlPerWave
>
{},
// repeat
Number
<
MXdlPerWave
>
{},
// repeat
I1
,
// waves
I1
,
// waves
I1
,
// NPerXdlops
I1
,
// NPerXdlops
Number
<
K1
>
{}));
Number
<
K1
>
{}));
StaticBuffer<AddressSpaceEnum::Vgpr,
ignore
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAB,
FloatAB
,
b_k0_n0_n1_n2_n3_k1_thread_copy_desc.GetElementSpaceSize(),
b_thread_copy_desc_k0_n0_n1_n2_n3_k1
.
GetElementSpaceSize
(),
true>
true
>
{};
b_thread_buf;
*/
auto
b_grid_desc_k0_n0_n1_n2_n3_k1
=
MakeBGridDescriptor_K0_N0_N1_N2_N3_K1
(
b_grid_desc_k0_n_k1
);
MakeBGridDescriptor_K0_N0_N1_N2_N3_K1
(
b_grid_desc_k0_n_k1
);
const
auto
wave_id
=
GetWaveIdx
();
const
auto
wave_k_n_id
=
GetWaveKNIdx
(
wave_id
[
I2
]);
#if 0
const index_t block_id = get_block_1d_id();
const index_t thread_id = get_thread_local_1d_id();
printf("block id: %d m blockid: %d n block id: %d ,thread id: %d, wave id :{%d %d %d} "
"kn id: {%d %d}\n",
block_id,
block_work_idx[I0],
block_work_idx[I1],
thread_id,
wave_id[I0],
wave_id[I1],
wave_id[I2],
wave_k_n_id[I0],
wave_k_n_id[I1]);
#endif
ignore
=
ThreadwiseTensorSliceTransfer_v2
<
FloatAB
,
FloatAB
,
decltype
(
b_grid_desc_k0_n0_n1_n2_n3_k1
),
decltype
(
b_thread_copy_desc_k0_n0_n1_n2_n3_k1
),
Sequence
<
Number
<
KPerThread
>
{},
I1
,
Number
<
MXdlPerWave
>
{},
I1
,
I1
,
Number
<
K1
>
{}
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
b_grid_desc_k0_n0_n1_n2_n3_k1
,
make_multi_index
(
wave_k_n_id
[
I0
],
n_block_data_idx_on_grid
,
0
,
wave_id
[
I1
],
wave_k_n_id
[
I1
],
0
));
auto
b_blockwise_copy
=
auto
b_blockwise_copy
=
BlockwiseTensorSliceTransfer_v4r1
<
BlockSize
,
BlockwiseTensorSliceTransfer_v4r1
<
BlockSize
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment