Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
1d478b9e
"vscode:/vscode.git/clone" did not exist on "d846082a010cd6d9897a9541bf1d47e3f207b86e"
Commit
1d478b9e
authored
May 27, 2022
by
ltqin
Browse files
add tail for pipeline
parent
c5c32b4d
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
92 additions
and
73 deletions
+92
-73
example/01_gemm/gemm_xdl_skip_lds_fp16.cpp
example/01_gemm/gemm_xdl_skip_lds_fp16.cpp
+11
-14
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_b_register.hpp
..._operation/gpu/block/blockwise_gemm_xdlops_b_register.hpp
+17
-13
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_lds_v2r3.hpp
...operation/gpu/grid/gridwise_gemm_xdlops_skip_lds_v2r3.hpp
+60
-46
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
...operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
+4
-0
No files found.
example/01_gemm/gemm_xdl_skip_lds_fp16.cpp
View file @
1d478b9e
...
...
@@ -29,8 +29,6 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ALayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
BLayout
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
CLayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
...
...
@@ -40,29 +38,28 @@ using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using
CElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
#define NORMAL_CONFIG 1
// clang-format off
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmXdlSkipLds
//###########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#if
1
#if
0
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>;
using ADataType = ck::half_t;
using BDataType = ck::half_t;
using CDataType = ck::half_t;
using AccDataType = float;
#else
<
F32
,
F32
,
F32
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
64
,
128
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
4
,
6
4
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
S
<
4
,
6
4
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
7
,
1
>
;
<
F32
,
F32
,
F32
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
64
,
16
,
16
,
4
,
1
,
16
,
16
,
1
,
1
,
S
<
4
,
1
6
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
1
,
true
,
S
<
4
,
1
6
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
1
,
true
,
7
,
1
>
;
using
ADataType
=
float
;
using
BDataType
=
float
;
using
CDataType
=
float
;
using
AccDataType
=
float
;
#endif
// clang-format on
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
template
<
typename
DataType
>
...
...
@@ -88,10 +85,10 @@ int main(int argc, char* argv[])
int
nrepeat
=
5
;
// GEMM shape
#if
NORMAL_CONFIG
#if
0
ck::index_t M = 256;
ck::index_t N = 4096;
ck
::
index_t
K
=
64
;
ck::index_t K =
32
;
ck::index_t StrideA = 64;
ck::index_t StrideB = 64;
...
...
@@ -99,10 +96,10 @@ int main(int argc, char* argv[])
#else
ck
::
index_t
M
=
16
;
ck
::
index_t
N
=
16
;
ck
::
index_t
K
=
16
;
ck
::
index_t
K
=
8
;
ck
::
index_t
StrideA
=
16
;
ck
::
index_t
StrideB
=
16
;
ck
::
index_t
StrideA
=
8
;
ck
::
index_t
StrideB
=
8
;
ck
::
index_t
StrideC
=
16
;
#endif
...
...
@@ -234,7 +231,7 @@ int main(int argc, char* argv[])
ref_invoker
.
Run
(
ref_argument
);
#if
0
#if
1
{
show_2d_matrix
(
std
::
cout
<<
"a : "
,
a_m_k
)
<<
std
::
endl
;
show_2d_matrix
(
std
::
cout
<<
"b: "
,
b_k_n
)
<<
std
::
endl
;
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_b_register.hpp
View file @
1d478b9e
...
...
@@ -12,8 +12,10 @@ template <index_t BlockSize,
typename
FloatAB
,
typename
FloatAcc
,
typename
AK0MK1BlockDesc
,
typename
BK0NK1BlockDesc
,
typename
BK0K0BN0N1N2N3K1BlockDesc
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
K0PerBlock
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
...
...
@@ -28,21 +30,15 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
static
constexpr
index_t
WaveSize
=
64
;
static
constexpr
index_t
MPerBlock
=
AK0MK1BlockDesc
{}.
GetLength
(
I1
);
static
constexpr
index_t
NPerBlock
=
BK0NK1BlockDesc
{}.
GetLength
(
I1
);
static
constexpr
index_t
KPerBlock
=
BK0NK1BlockDesc
{}.
GetLength
(
I0
)
*
BK0NK1BlockDesc
{}.
GetLength
(
I2
);
static
constexpr
index_t
KPerBlock
=
K0PerBlock
*
KPack
;
static
constexpr
index_t
A_K0
=
AK0MK1BlockDesc
{}.
GetLength
(
I0
);
static
constexpr
index_t
B_K0
=
BK0NK1BlockDesc
{}.
GetLength
(
I0
);
static
constexpr
index_t
A_K1
=
AK0MK1BlockDesc
{}.
GetLength
(
I2
);
static
constexpr
index_t
B_K1
=
BK0NK1BlockDesc
{}.
GetLength
(
I2
);
static
constexpr
auto
xdlops_gemm
=
XdlopsGemm
<
FloatAB
,
MPerXDL
,
NPerXDL
,
KPack
>
{};
static
constexpr
index_t
KPerThread
=
KPerBlock
/
xdlops_gemm
.
K0PerXdlops
;
static
constexpr
index_t
K0PerThread
=
BK0NK1BlockDesc
{}.
GetLength
(
I0
)
/
xdlops_gemm
.
K0PerXdlops
;
static
constexpr
index_t
K0PerThread
=
K0PerBlock
/
xdlops_gemm
.
K0PerXdlops
;
static
constexpr
index_t
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerXDL
);
static
constexpr
index_t
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerXDL
);
...
...
@@ -122,7 +118,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
__host__
__device__
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
()
{
static_assert
(
AK0MK1BlockDesc
::
IsKnownAtCompileTime
()
&&
BK0
N
K1BlockDesc
::
IsKnownAtCompileTime
(),
BK0
K0BN0N1N2N3
K1BlockDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
static_assert
(
BlockSize
==
MWaves
*
NWaves
*
WaveSize
,
...
...
@@ -235,6 +231,16 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
make_tuple
(
Sequence
<
3
>
{},
Sequence
<
0
,
1
,
2
>
{}));
}
__device__
void
MoveABlockSliceWindow
()
{
a_thread_copy_
.
MoveSrcSliceWindow
(
a_block_desc_m0_m1_m2_k
,
make_multi_index
(
0
,
0
,
0
,
K0PerBlock
*
KPack
));
}
__device__
void
ResetABlockStartWindow
()
{
a_thread_copy_
.
SetSrcCoord
(
CalculateAThreadOriginDataIndex
());
}
static
constexpr
auto
a_block_desc_m0_m1_m2_k
=
MakeABlockDescriptor_M0_M1_M2_K
();
template
<
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CThreadBuffer
>
...
...
@@ -244,8 +250,6 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
{
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAB
>
(
a_thread_desc_
.
GetElementSpaceSize
());
// auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
// b_thread_desc_.GetElementSpaceSize());
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
// read A
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_lds_v2r3.hpp
View file @
1d478b9e
...
...
@@ -203,7 +203,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
#if USING_SKIP_LDS
static
constexpr
auto
MultiK0
=
2
;
#else
static
constexpr
auto
MultiK0
=
1
;
#endif
// K1 should be Number<...>
static
constexpr
auto
K1
=
Number
<
K1Value
>
{};
...
...
@@ -223,13 +227,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
if
constexpr
(
ABlockLdsExtraM
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
K0PerBlock
*
MultiK0
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
MPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
make_tuple
(
Number
<
K0PerBlock
*
MultiK0
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
...
...
@@ -352,7 +357,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
// TODO move this function into GEMM-pipeline class
__host__
__device__
static
constexpr
bool
CalculateHasMainK0BlockLoop
(
index_t
K0
)
{
const
bool
has_main_k0_block_loop
=
(
K0
/
(
NumPrefetch
*
K0PerBlock
))
>
1
;
const
bool
has_main_k0_block_loop
=
(
K0
/
(
2
*
K0PerBlock
))
>
1
;
return
has_main_k0_block_loop
;
}
...
...
@@ -530,16 +535,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_k0_m_k1
=
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
();
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_block_desc_k0_n_k1
=
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1
();
// A matrix blockwise copy
auto
a_blockwise_copy
=
BlockwiseTensorSliceTransfer_v4r1
<
BlockSize
,
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
K0PerBlock
,
MPerBlock
,
K1
>
,
Sequence
<
K0PerBlock
*
MultiK0
,
MPerBlock
,
K1
>
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
FloatAB
,
...
...
@@ -638,8 +640,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
FloatAB
,
FloatAcc
,
decltype
(
a_block_desc_k0_m_k1
),
decltype
(
b_block_desc_k0_n_k1
),
decltype
(
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
),
MPerBlock
,
NPerBlock
,
K0PerBlock
,
MPerXDL
,
NPerXDL
,
MXdlPerWave
,
...
...
@@ -653,7 +657,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
static_cast
<
FloatAB
*>
(
p_shared
),
a_block_desc_k0_m_k1
.
GetElementSpaceSize
());
// gridwise GEMM pipeline
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
);
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
*
MultiK0
,
0
,
0
);
constexpr
auto
b_thread_slice_copy_step
=
make_multi_index
(
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
);
// preload data to regiester and LDS
{
...
...
@@ -682,9 +686,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
index_t
i
=
0
;
do
{
a_blockwise_copy
.
RunRead
(
a_grid_desc_k0_m_k1
,
a_grid_buf
);
block_sync_lds
();
// block_sync_lds();
b_threadwise_copy
.
Run
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_grid_buf
,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
...
...
@@ -693,45 +695,57 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_even_buf
,
c_thread_buf
);
block_sync_lds
();
a_blockwise_copy
.
RunWrite
(
a_block_desc_k0_m_k1
,
a_block_buf
);
// move windows
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_k0_m_k1
,
a_block_slice_copy_step
);
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_thread_slice_copy_step
);
a_blockwise_copy
.
RunRead
(
a_grid_desc_k0_m_k1
,
a_grid_buf
);
block_sync_lds
();
b_threadwise_copy
.
Run
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_grid_buf
,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_even_buf
);
// block_sync_lds();
blockwise_gemm
.
MoveABlockSliceWindow
();
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_odd_buf
,
c_thread_buf
);
block_sync_lds
();
a_blockwise_copy
.
RunWrite
(
a_block_desc_k0_m_k1
,
a_block_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_k0_m_k1
,
a_block_slice_copy_step
);
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_thread_slice_copy_step
);
i
+=
2
;
}
while
(
i
<
(
K0BlockMainLoop
-
1
));
}
while
(
i
<
(
K0BlockMainLoop
-
2
));
}
// tail
{
block_sync_lds
();
// block_sync_lds();
b_threadwise_copy
.
Run
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_grid_buf
,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_odd_buf
);
blockwise_gemm
.
ResetABlockStartWindow
();
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_even_buf
,
c_thread_buf
);
// blockwise_gemm.Run(a_block_buf, b_thread_buf, c_thread_buf);
// move windows
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_thread_slice_copy_step
);
b_threadwise_copy
.
Run
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_grid_buf
,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_even_buf
);
// block_sync_lds();
blockwise_gemm
.
MoveABlockSliceWindow
();
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_odd_buf
,
c_thread_buf
);
}
}
#else
ignore
=
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
;
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_block_desc_k0_n_k1
=
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1
();
// B matrix blockwise copy
auto
b_blockwise_copy
=
BlockwiseTensorSliceTransfer_v4r1
<
BlockSize
,
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
View file @
1d478b9e
...
...
@@ -1163,6 +1163,10 @@ struct ThreadwiseTensorSliceTransfer_v4
move_tensor_coordinate
(
SrcDesc
{},
src_ref_coord_
,
src_slice_move_step_iter
);
}
__device__
void
SetSrcCoord
(
const
Index
&
src_ref_idx
)
{
src_ref_coord_
=
make_tensor_coordinate
(
SrcDesc
{},
src_ref_idx
);
}
private:
SrcCoord
src_ref_coord_
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment