Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
fb6dafee
Commit
fb6dafee
authored
May 28, 2022
by
ltqin
Browse files
start clear gridwise gemm
parent
09b97676
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
87 additions
and
223 deletions
+87
-223
example/01_gemm/gemm_xdl_skip_lds_fp16.cpp
example/01_gemm/gemm_xdl_skip_lds_fp16.cpp
+67
-1
include/ck/tensor_operation/gpu/device/device_gemm_xdl_skip_lds.hpp
.../tensor_operation/gpu/device/device_gemm_xdl_skip_lds.hpp
+2
-2
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_lds_v2r3.hpp
...operation/gpu/grid/gridwise_gemm_xdlops_skip_lds_v2r3.hpp
+18
-220
No files found.
example/01_gemm/gemm_xdl_skip_lds_fp16.cpp
View file @
fb6dafee
...
...
@@ -12,6 +12,7 @@
#include "host_tensor_generator.hpp"
#include "device_tensor.hpp"
#include "device_gemm_xdl_skip_lds.hpp"
#include "device_gemm_xdl.hpp"
#include "device_gemm_xdl_c_shuffle.hpp"
#include "device_gemm_xdl_cshuffle.hpp"
#include "element_wise_operation.hpp"
...
...
@@ -38,6 +39,9 @@ using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using
CElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
#define USING_SKIP_LDS 1
#if USING_SKIP_LDS
// clang-format off
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmXdlSkipLds
//###########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
...
...
@@ -57,7 +61,69 @@ using BDataType = float;
using
CDataType
=
float
;
using
AccDataType
=
float
;
#endif
// clang-format on
// clang-format on
#else
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmXdl
//###########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM|
//Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer|
//ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer|
//BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds|
//CThreadTransfer| CThreadTransfer|
//###########| Type| Type| Type| Type| | | | Elementwise|
//Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per|
//Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar|
//DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim|
//SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//###########| | | | | | | | Operation| Operation|
//Operation| | | | | | | | | Wave| Wave|
//Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1|
//| Lengths_K0_N_K1| ArrangeOrder| | | PerVector|
//PerVector_K1| | | PerVector|
//###########| | | | | | | | | | | | | |
//| | | | | | | | | | | | | |
//| | | | | | | |
//|
<
F32
,
F32
,
F32
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
64
,
16
,
16
,
4
,
4
,
16
,
16
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
7
,
1
>
;
using
ADataType
=
float
;
using
BDataType
=
float
;
using
CDataType
=
float
;
using
AccDataType
=
float
;
#endif
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
...
...
include/ck/tensor_operation/gpu/device/device_gemm_xdl_skip_lds.hpp
View file @
fb6dafee
...
...
@@ -330,7 +330,7 @@ struct DeviceGemmXdlSkipLds
if
(
has_main_k0_block_loop
)
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r3
<
const
auto
kernel
=
kernel_gemm_xdlops_
skip_lds_
v2r3
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
...
...
@@ -363,7 +363,7 @@ struct DeviceGemmXdlSkipLds
}
else
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r3
<
const
auto
kernel
=
kernel_gemm_xdlops_
skip_lds_
v2r3
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_lds_v2r3.hpp
View file @
fb6dafee
...
...
@@ -10,7 +10,7 @@
#include "blockwise_tensor_slice_transfer_v4r1.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "gridwise_gemm_pipeline_v1.hpp"
#define USING_SKIP_LDS 1
namespace
ck
{
template
<
typename
GridwiseGemm
,
...
...
@@ -29,7 +29,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_gemm_xdlops_v2r3
(
kernel_gemm_xdlops_
skip_lds_
v2r3
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
...
...
@@ -58,99 +58,17 @@ __global__ void
c_element_op
,
block_2_ctile_map
);
#else
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_c_grid
;
ignore
=
a_grid_desc_k0_m_k1
;
ignore
=
b_grid_desc_k0_n_k1
;
ignore
=
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
;
ignore
=
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
c_element_op
;
ignore
=
block_2_ctile_map
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatC
,
typename
GemmDesc
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
bool
HasMainK0BlockLoop
,
index_t
MaxGroupCount
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_grouped_gemm_xdlops_v2r3
(
const
StaticallyIndexedArray
<
GemmDesc
,
MaxGroupCount
>
gemm_desc_
,
const
index_t
group_count
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
block_id
=
get_block_1d_id
();
#if 1
static_for
<
0
,
MaxGroupCount
,
1
>
{}([
&
](
auto
i
)
{
if
(
block_id
>=
gemm_desc_
[
i
].
BlockStart_
&&
block_id
<
gemm_desc_
[
i
].
BlockEnd_
&&
i
<
group_count
)
{
auto
group_id
=
i
;
GridwiseGemm
::
template
Run
<
HasMainK0BlockLoop
>(
gemm_desc_
[
group_id
].
a_ptr
,
gemm_desc_
[
group_id
].
b_ptr
,
gemm_desc_
[
group_id
].
c_ptr
,
p_shared
,
gemm_desc_
[
group_id
].
a_grid_desc_k0_m_k1_
,
gemm_desc_
[
group_id
].
b_grid_desc_k0_n_k1_
,
gemm_desc_
[
group_id
].
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
a_element_op
,
b_element_op
,
c_element_op
,
gemm_desc_
[
group_id
].
grouped_gemm_block_2_ctile_map_
);
}
});
#else
const
auto
gemm_desc_ptr
=
reinterpret_cast
<
const
GemmDesc
*>
(
&
gemm_desc_
);
index_t
group_id
=
0
;
static_for
<
0
,
MaxGroupCount
,
1
>
{}([
&
](
auto
i
)
{
group_id
=
(
block_id
>=
gemm_desc_
[
i
].
BlockStart
&&
block_id
<
gemm_desc_
[
i
].
BlockEnd
&&
i
<
group_count
)
?
i
:
group_id
;
});
const
index_t
block_id_grp
=
block_id
-
gemm_desc_ptr
[
group_id
].
BlockStart
;
GridwiseGemm
::
template
Run
<
HasMainK0BlockLoop
>(
gemm_desc_ptr
[
group_id
].
a_ptr
,
gemm_desc_ptr
[
group_id
].
b_ptr
,
gemm_desc_ptr
[
group_id
].
c_ptr
,
p_shared
,
gemm_desc_ptr
[
group_id
].
a_grid_desc_k0_m_k1_
,
gemm_desc_ptr
[
group_id
].
b_grid_desc_k0_n_k1_
,
gemm_desc_ptr
[
group_id
].
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
a_element_op
,
b_element_op
,
c_element_op
,
gemm_desc_ptr
[
group_id
].
block_2_ctile_map_
,
block_id_grp
);
#endif
#else
ignore
=
gemm_desc_
;
ignore
=
group_count
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
c_element_op
;
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_c_grid
;
ignore
=
a_grid_desc_k0_m_k1
;
ignore
=
b_grid_desc_k0_n_k1
;
ignore
=
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
;
ignore
=
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
c_element_op
;
ignore
=
block_2_ctile_map
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
...
...
@@ -203,11 +121,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
#if USING_SKIP_LDS
static
constexpr
auto
MultiK0
=
2
;
#else
static
constexpr
auto
MultiK0
=
1
;
#endif
// K1 should be Number<...>
static
constexpr
auto
K1
=
Number
<
K1Value
>
{};
...
...
@@ -272,15 +188,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_k0_m_k1
.
GetElementSpaceSize
(),
max_lds_align
);
#if USING_SKIP_LDS
constexpr
auto
b_block_space_size_aligned
=
0
;
#else
constexpr
auto
b_block_desc_k0_n_k1
=
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1
();
constexpr
auto
b_block_space_size_aligned
=
math
::
integer_least_multiple
(
b_block_desc_k0_n_k1
.
GetElementSpaceSize
(),
max_lds_align
);
#endif
return
(
a_block_space_size_aligned
+
b_block_space_size_aligned
)
*
sizeof
(
FloatAB
);
return
(
a_block_space_size_aligned
)
*
sizeof
(
FloatAB
);
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
...
...
@@ -357,7 +266,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
// TODO move this function into GEMM-pipeline class
__host__
__device__
static
constexpr
bool
CalculateHasMainK0BlockLoop
(
index_t
K0
)
{
const
bool
has_main_k0_block_loop
=
(
K0
/
(
2
*
K0PerBlock
))
>
1
;
const
bool
has_main_k0_block_loop
=
(
K0
/
(
MultiK0
*
NumPrefetch
*
K0PerBlock
))
>
1
;
return
has_main_k0_block_loop
;
}
...
...
@@ -529,9 +438,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
const
index_t
n_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]
*
NPerBlock
);
// lds max alignment
constexpr
auto
max_lds_align
=
K1
;
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_k0_m_k1
=
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
();
...
...
@@ -565,9 +471,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
a_block_desc_k0_m_k1
,
make_multi_index
(
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
#if USING_SKIP_LDS
ignore
=
b_element_op
;
ignore
=
max_lds_align
;
// B matrix blockwise copy
constexpr
auto
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
...
...
@@ -739,114 +644,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_odd_buf
,
c_thread_buf
);
}
}
#else
ignore
=
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
;
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_block_desc_k0_n_k1
=
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1
();
// B matrix blockwise copy
auto
b_blockwise_copy
=
BlockwiseTensorSliceTransfer_v4r1
<
BlockSize
,
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
K0PerBlock
,
NPerBlock
,
K1
>
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
decltype
(
b_grid_desc_k0_n_k1
),
decltype
(
b_block_desc_k0_n_k1
),
BBlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
BBlockTransferSrcVectorDim
,
2
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
1
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
,
NumPrefetch
>
(
b_grid_desc_k0_n_k1
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
b_element_op
,
b_block_desc_k0_n_k1
,
make_multi_index
(
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[K0PerBlock, MPerBlock] is in LDS
// b_mtx[K0PerBlock, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
FloatAB
,
FloatAcc
,
decltype
(
a_block_desc_k0_m_k1
),
decltype
(
b_block_desc_k0_n_k1
),
MPerXDL
,
NPerXDL
,
MXdlPerWave
,
NXdlPerWave
,
K1
>
{};
auto
c_thread_buf
=
blockwise_gemm
.
GetCThreadBuffer
();
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_k0_m_k1
.
GetElementSpaceSize
(),
max_lds_align
);
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatAB
*>
(
p_shared
),
a_block_desc_k0_m_k1
.
GetElementSpaceSize
());
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatAB
*>
(
p_shared
)
+
a_block_space_size_aligned
,
b_block_desc_k0_n_k1
.
GetElementSpaceSize
());
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
);
// gridwise GEMM pipeline
const
auto
gridwise_gemm_pipeline
=
GridwiseGemmPipeline_v1
<
remove_cvref_t
<
decltype
(
a_grid_desc_k0_m_k1
)
>
,
remove_cvref_t
<
decltype
(
a_block_desc_k0_m_k1
)
>
,
remove_cvref_t
<
decltype
(
a_blockwise_copy
)
>
,
remove_cvref_t
<
decltype
(
a_grid_buf
)
>
,
remove_cvref_t
<
decltype
(
a_block_buf
)
>
,
remove_cvref_t
<
decltype
(
a_block_slice_copy_step
)
>
,
remove_cvref_t
<
decltype
(
b_grid_desc_k0_n_k1
)
>
,
remove_cvref_t
<
decltype
(
b_block_desc_k0_n_k1
)
>
,
remove_cvref_t
<
decltype
(
b_blockwise_copy
)
>
,
remove_cvref_t
<
decltype
(
b_grid_buf
)
>
,
remove_cvref_t
<
decltype
(
b_block_buf
)
>
,
remove_cvref_t
<
decltype
(
b_block_slice_copy_step
)
>
,
remove_cvref_t
<
decltype
(
blockwise_gemm
)
>
,
remove_cvref_t
<
decltype
(
c_thread_buf
)
>
,
NumPrefetch
,
HasMainK0BlockLoop
>
{};
const
index_t
K0BlockMainLoop
=
__builtin_amdgcn_readfirstlane
(
K0
/
K0PerBlock
);
gridwise_gemm_pipeline
.
Run
(
a_grid_desc_k0_m_k1
,
a_block_desc_k0_m_k1
,
a_blockwise_copy
,
a_grid_buf
,
a_block_buf
,
a_block_slice_copy_step
,
b_grid_desc_k0_n_k1
,
b_block_desc_k0_n_k1
,
b_blockwise_copy
,
b_grid_buf
,
b_block_buf
,
b_block_slice_copy_step
,
blockwise_gemm
,
c_thread_buf
,
K0BlockMainLoop
);
#endif
// output: register to global memory
{
constexpr
auto
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment