Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
a3b4c5cb
Commit
a3b4c5cb
authored
Jun 03, 2022
by
wangshaojie6
Browse files
merge develop branch and add gridwise pipeline v3
parents
48918ab9
1677cf70
Changes
361
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2698 additions
and
1760 deletions
+2698
-1760
include/ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp
...sor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp
+152
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp
...de/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp
+180
-201
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
+168
-129
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v3.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v3.hpp
+128
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp
...eration/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp
+248
-287
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
+119
-171
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp
...or_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp
+976
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
+106
-251
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp
+62
-96
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
...tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
+65
-109
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp
+98
-159
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp
+100
-157
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp
+110
-170
include/ck/tensor_operation/gpu/thread/threadwise_contraction_dl.hpp
...tensor_operation/gpu/thread/threadwise_contraction_dl.hpp
+5
-8
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
...operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
+9
-10
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp
...tion/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp
+1
-3
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp
...tion/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp
+7
-2
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
+33
-3
include/ck/utility/amd_buffer_addressing.hpp
include/ck/utility/amd_buffer_addressing.hpp
+108
-0
include/ck/utility/amd_xdlops.hpp
include/ck/utility/amd_xdlops.hpp
+23
-4
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp
0 → 100644
View file @
a3b4c5cb
#pragma once
#include "cluster_descriptor.hpp"
#include "data_type.hpp"
#include "element_wise_operation.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
namespace
ck
{
template
<
typename
GridwiseBinEltwise
,
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
AGridDesc_M
,
typename
BGridDesc_M
,
typename
CGridDesc_M
,
typename
ElementwiseFunctor
>
__global__
void
kernel_binary_elementwise_1d
(
const
ADataType
*
__restrict__
p_a_global
,
const
BDataType
*
__restrict__
p_b_global
,
CDataType
*
__restrict__
p_c_global
,
const
AGridDesc_M
a_grid_desc_m
,
const
BGridDesc_M
b_grid_desc_m
,
const
CGridDesc_M
c_grid_desc_m
,
const
ElementwiseFunctor
functor
)
{
GridwiseBinEltwise
::
Run
(
p_a_global
,
p_b_global
,
p_c_global
,
a_grid_desc_m
,
b_grid_desc_m
,
c_grid_desc_m
,
functor
);
}
template
<
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
ComputeDataType
,
typename
AGridDesc_M
,
typename
BGridDesc_M
,
typename
CGridDesc_M
,
typename
ElementwiseFunctor
,
index_t
MPerThread
,
index_t
AScalarPerVector
,
index_t
BScalarPerVector
,
index_t
CScalarPerVector
>
struct
GridwiseBinaryElementwise_1D
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
thread_desc_m
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MPerThread
>
{}));
using
PassThrough
=
tensor_operation
::
element_wise
::
PassThrough
;
static
__device__
auto
CalculateElementwiseIndex
()
{
const
index_t
global_thread_id
=
get_thread_global_1d_id
();
return
make_multi_index
(
global_thread_id
*
MPerThread
);
}
__device__
static
void
Run
(
const
ADataType
*
__restrict__
p_a_global
,
const
BDataType
*
__restrict__
p_b_global
,
CDataType
*
__restrict__
p_c_global
,
const
AGridDesc_M
a_grid_desc_m
,
const
BGridDesc_M
b_grid_desc_m
,
const
CGridDesc_M
c_grid_desc_m
,
const
ElementwiseFunctor
functor
)
{
const
auto
a_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_global
,
a_grid_desc_m
.
GetElementSpaceSize
());
const
auto
b_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_global
,
b_grid_desc_m
.
GetElementSpaceSize
());
auto
c_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_global
,
c_grid_desc_m
.
GetElementSpaceSize
());
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MPerThread
,
true
>
a_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MPerThread
,
true
>
b_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MPerThread
,
true
>
c_thread_buf
;
const
auto
thread_store_global_offset
=
CalculateElementwiseIndex
();
auto
a_global_load
=
ThreadwiseTensorSliceTransfer_v2
<
ADataType
,
ComputeDataType
,
AGridDesc_M
,
decltype
(
thread_desc_m
),
Sequence
<
MPerThread
>
,
// SliceLengths
Sequence
<
0
>
,
// DimAccessOrder
0
,
// SrcVectorDim
AScalarPerVector
,
// ScalarPerVector
1
,
// SrcScalarStrideInVector
false
>
{
a_grid_desc_m
,
thread_store_global_offset
};
auto
b_global_load
=
ThreadwiseTensorSliceTransfer_v2
<
BDataType
,
ComputeDataType
,
BGridDesc_M
,
decltype
(
thread_desc_m
),
Sequence
<
MPerThread
>
,
// SliceLengths
Sequence
<
0
>
,
// DimAccessOrder
0
,
// SrcVectorDim
BScalarPerVector
,
// ScalarPerVector
1
,
// SrcScalarStrideInVector
false
>
{
b_grid_desc_m
,
thread_store_global_offset
};
auto
c_global_write
=
ThreadwiseTensorSliceTransfer_v1r3
<
ComputeDataType
,
CDataType
,
decltype
(
thread_desc_m
),
CGridDesc_M
,
PassThrough
,
Sequence
<
MPerThread
>
,
// SliceLengths
Sequence
<
0
>
,
// DimAccessOrder
0
,
// DstVectorDim
CScalarPerVector
,
// ScalarPerVector
InMemoryDataOperationEnum
::
Set
,
1
,
// DstScalarStrideInVector
false
>
{
c_grid_desc_m
,
thread_store_global_offset
,
PassThrough
{}};
const
index_t
blockSize
=
get_block_size
();
const
index_t
blockPerGrid
=
get_grid_size
();
const
auto
M
=
c_grid_desc_m
.
GetLength
(
I0
);
const
index_t
loop_step
=
blockPerGrid
*
blockSize
*
MPerThread
;
const
auto
loop_step_index
=
make_multi_index
(
loop_step
);
index_t
num_iter
=
M
/
(
loop_step
);
do
{
// read and process MPerThread elements
a_global_load
.
Run
(
a_grid_desc_m
,
a_global_buf
,
thread_desc_m
,
make_tuple
(
I0
),
a_thread_buf
);
b_global_load
.
Run
(
b_grid_desc_m
,
b_global_buf
,
thread_desc_m
,
make_tuple
(
I0
),
b_thread_buf
);
static_for
<
0
,
MPerThread
,
1
>
{}([
&
](
auto
m
)
{
constexpr
auto
offset
=
thread_desc_m
.
CalculateOffset
(
make_tuple
(
m
));
functor
(
c_thread_buf
(
Number
<
offset
>
{}),
a_thread_buf
(
Number
<
offset
>
{}),
b_thread_buf
(
Number
<
offset
>
{}));
});
c_global_write
.
Run
(
thread_desc_m
,
make_tuple
(
I0
),
// SrcSliceOriginIdx
c_thread_buf
,
c_grid_desc_m
,
c_global_buf
);
a_global_load
.
MoveSrcSliceWindow
(
a_grid_desc_m
,
loop_step_index
);
b_global_load
.
MoveSrcSliceWindow
(
b_grid_desc_m
,
loop_step_index
);
c_global_write
.
MoveDstSliceWindow
(
c_grid_desc_m
,
loop_step_index
);
}
while
(
--
num_iter
);
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl
ops
_v1r3.hpp
→
include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp
View file @
a3b4c5cb
#ifndef CK_GRIDWISE_GEMM_V1R3_HPP
#define CK_GRIDWISE_GEMM_V1R3_HPP
#pragma once
#include "common_header.hpp"
#include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_dlops_v2r3.hpp"
#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "blockwise_gemm_dl_v2r3.hpp"
#include "blockwise_tensor_slice_transfer_v5r1.hpp"
#include "threadwise_tensor_slice_transfer
_v2
.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_set.hpp"
#include "element_wise_operation.hpp"
namespace
ck
{
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatC
,
typename
A
K0M0M1K1
GridDesc
,
typename
B
K0N0N1K1
GridDesc
,
typename
CM0M10M11N0N10N11
GridDesc
,
typename
C
Block
IdToM0N0BlockClusterAdaptor
,
typename
AGridDesc
_K0_M0_M1_K1
,
typename
BGridDesc
_K0_N0_N1_K1
,
typename
C
GridDesc_
M0
_
M10
_
M11
_
N0
_
N10
_
N11
,
typename
Block
2CTileMap
,
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_gemm_dlops_v1r3
(
const
FloatAB
*
__restrict__
p_a_grid
,
kernel_gemm_dl_v1r3
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
A
K0M0M1K1
GridDesc
a_k
0_
m
0_
m
1_
k1
_grid_desc
,
const
B
K0N0N1K1
GridDesc
b_k
0_
n
0_
n
1_
k1
_grid_desc
,
const
CM0M10M11N0N10N11
G
rid
D
esc
c
_m0_m10_m11_n0_n10_n11
_grid_desc
,
const
C
Block
IdToM0N0BlockClusterAdaptor
cblockid_to_m0_n0_block_cluster_adaptor
)
const
AGridDesc
_K
0_
M
0_
M
1_
K1
a
_grid_desc
_k0_m0_m1_k1
,
const
BGridDesc
_K
0_
N
0_
N
1_
K1
b
_grid_desc
_k0_n0_n1_k1
,
const
C
GridDesc_
M0
_
M10
_
M11
_
N0
_
N10
_
N11
c_g
rid
_d
esc_m0_m10_m11_n0_n10_n11
,
const
Block
2CTileMap
block_2_ctile_map
)
{
constexpr
index_t
shared_block_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
...
...
@@ -43,10 +43,10 @@ __global__ void
p_b_grid
,
p_c_grid
,
p_shared_block
,
a_k0_m0_m1_k1
_grid_desc
,
b_k0_n0_n1_k1
_grid_desc
,
c_m0_m10_m11_n0_n10_n11
_grid_desc
,
c
block
id_to_m0_n0_block_cluster_adaptor
,
a_
grid_desc_
k0_m0_m1_k1
,
b_
grid_desc_
k0_n0_n1_k1
,
c_
grid_desc_
m0_m10_m11_n0_n10_n11
,
block
_2_ctile_map
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
}
...
...
@@ -56,12 +56,12 @@ template <index_t BlockSize,
typename
FloatAcc
,
typename
FloatC
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
A
K0MK1
GridDesc
,
typename
B
K0NK1
GridDesc
,
typename
C
MN
GridDesc
,
index_t
MPerBlock
M1
,
index_t
NPerBlock
N1
,
index_t
KPerBlock
,
typename
AGridDesc
_K0_M_K1
,
typename
BGridDesc
_K0_N_K1
,
typename
CGridDesc
_M_N
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
K
0
PerBlock
,
index_t
M1PerThreadM111
,
index_t
N1PerThreadN111
,
index_t
KPerThread
,
...
...
@@ -83,13 +83,8 @@ template <index_t BlockSize,
typename
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1
,
typename
CThreadTransferSrcDstAccessOrder
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferDstScalarPerVector
,
typename
AGridStepHacks
,
typename
BGridStepHacks
,
typename
CGridStepHacks
,
typename
AGridMoveSliceWindowStepHacks
,
typename
BGridMoveSliceWindowStepHacks
>
struct
GridwiseGemmDlops_km_kn_mn_v1r3
index_t
CThreadTransferDstScalarPerVector
>
struct
GridwiseGemmDl_km_kn_mn_v1r3
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
@@ -97,7 +92,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
static
constexpr
auto
I3
=
Number
<
3
>
{};
// K1 should be Number<...>
static
constexpr
auto
K1
=
A
K0MK1
GridDesc
{}.
GetLength
(
I2
);
static
constexpr
auto
K1
=
AGridDesc
_K0_M_K1
{}.
GetLength
(
I2
);
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
...
...
@@ -106,112 +101,112 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
// TODO: check alignment
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_
k_m_
block_desc
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
M1
>
{},
K1
),
max_lds_align
);
constexpr
auto
a_block_desc
_k_m
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K
0
PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
// TODO: check alignment
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_
k_n_
block_desc
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
N1
>
{},
K1
),
max_lds_align
);
constexpr
auto
b_block_desc
_k_n
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K
0
PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
// TODO: check alignment
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_aligned_space_size
=
math
::
integer_least_multiple
(
a_
k_m_
block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
math
::
integer_least_multiple
(
a_block_desc
_k_m
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
auto
b_block_aligned_space_size
=
math
::
integer_least_multiple
(
b_
k_n_
block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
math
::
integer_least_multiple
(
b_block_desc
_k_n
.
GetElementSpaceSize
(),
max_lds_align
);
return
2
*
(
a_block_aligned_space_size
+
b_block_aligned_space_size
)
*
sizeof
(
FloatAB
);
}
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
A
K0MK1
GridDesc
&
a_k0_m_k1
_grid_desc
,
const
B
K0NK1
GridDesc
&
b_k0_n_k1
_grid_desc
,
const
C
MN
GridDesc
&
c_m_n
_grid_desc
)
CheckValidity
(
const
AGridDesc
_K0_M_K1
&
a
_grid_desc
_k0_m_k1
,
const
BGridDesc
_K0_N_K1
&
b
_grid_desc
_k0_n_k1
,
const
CGridDesc
_M_N
&
c
_grid_desc
_m_n
)
{
const
auto
M
=
a_
k0_m_k1_
grid_desc
.
GetLength
(
I1
);
const
auto
N
=
b_
k0_n_k1_
grid_desc
.
GetLength
(
I1
);
const
auto
K0
=
a_
k0_m_k1_
grid_desc
.
GetLength
(
I0
);
const
auto
M
=
a_grid_desc
_k0_m_k1
.
GetLength
(
I1
);
const
auto
N
=
b_grid_desc
_k0_n_k1
.
GetLength
(
I1
);
const
auto
K0
=
a_grid_desc
_k0_m_k1
.
GetLength
(
I0
);
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return
(
M
==
c_
m_n_
grid_desc
.
GetLength
(
I0
)
&&
N
==
c_
m_n_
grid_desc
.
GetLength
(
I1
)
&&
K0
==
b_
k0_n_k1_
grid_desc
.
GetLength
(
I0
)
&&
K1
==
a_
k0_m_k1_
grid_desc
.
GetLength
(
I2
)
&&
K1
==
b_
k0_n_k1_
grid_desc
.
GetLength
(
I2
))
&&
(
M
%
MPerBlock
M1
==
0
&&
N
%
NPerBlock
N1
==
0
&&
K0
%
KPerBlock
==
0
);
return
(
M
==
c_grid_desc
_m_n
.
GetLength
(
I0
)
&&
N
==
c_grid_desc
_m_n
.
GetLength
(
I1
)
&&
K0
==
b_grid_desc
_k0_n_k1
.
GetLength
(
I0
)
&&
K1
==
a_grid_desc
_k0_m_k1
.
GetLength
(
I2
)
&&
K1
==
b_grid_desc
_k0_n_k1
.
GetLength
(
I2
))
&&
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K0
%
K
0
PerBlock
==
0
);
}
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
index_t
M
,
index_t
N
)
{
const
index_t
grid_size
=
(
M
/
MPerBlock
M1
)
*
(
N
/
NPerBlock
N1
);
const
index_t
grid_size
=
(
M
/
MPerBlock
)
*
(
N
/
NPerBlock
);
return
grid_size
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K0
)
{
const
bool
has_main_k_block_loop
=
(
K0
+
KPerBlock
)
/
(
2
*
KPerBlock
)
>
1
;
const
bool
has_main_k_block_loop
=
(
K0
+
K
0
PerBlock
)
/
(
2
*
K
0
PerBlock
)
>
1
;
return
has_main_k_block_loop
;
}
__host__
__device__
static
constexpr
bool
CalculateHasDoubleTailKBlockLoop
(
index_t
K0
)
{
const
bool
has_double_tail_k_block_loop
=
(
K0
/
KPerBlock
)
%
2
==
0
;
const
bool
has_double_tail_k_block_loop
=
(
K0
/
K
0
PerBlock
)
%
2
==
0
;
return
has_double_tail_k_block_loop
;
}
__host__
__device__
static
constexpr
auto
MakeA
K0M0M1K1
GridDescriptor
(
const
A
K0MK1
GridDesc
&
a_k0_m_k1
_grid_desc
)
MakeAGridDescriptor
_K0_M0_M1_K1
(
const
AGridDesc
_K0_M_K1
&
a
_grid_desc
_k0_m_k1
)
{
const
auto
K0
=
a_
k0_m_k1_
grid_desc
.
GetLength
(
I0
);
const
auto
M
=
a_
k0_m_k1_
grid_desc
.
GetLength
(
I1
);
const
auto
K0
=
a_grid_desc
_k0_m_k1
.
GetLength
(
I0
);
const
auto
M
=
a_grid_desc
_k0_m_k1
.
GetLength
(
I1
);
const
auto
M1
=
Number
<
MPerBlock
M1
>
{};
const
auto
M1
=
Number
<
MPerBlock
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
a_k0_m0_m1_k1
_grid_desc
=
transform_tensor_descriptor
(
a_
k0_m_k1_
grid_desc
,
const
auto
a_
grid_desc_
k0_m0_m1_k1
=
transform_tensor_descriptor
(
a_grid_desc
_k0_m_k1
,
make_tuple
(
make_pass_through_transform
(
K0
),
make_unmerge_transform
(
make_tuple
(
M0
,
M1
)),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}));
return
a_k0_m0_m1_k1
_grid_desc
;
return
a_
grid_desc_
k0_m0_m1_k1
;
}
__host__
__device__
static
constexpr
auto
MakeB
K0N0N1K1
GridDescriptor
(
const
B
K0NK1
GridDesc
&
b_k0_n_k1
_grid_desc
)
MakeBGridDescriptor
_K0_N0_N1_K1
(
const
BGridDesc
_K0_N_K1
&
b
_grid_desc
_k0_n_k1
)
{
const
auto
K0
=
b_
k0_n_k1_
grid_desc
.
GetLength
(
I0
);
const
auto
N
=
b_
k0_n_k1_
grid_desc
.
GetLength
(
I1
);
const
auto
K0
=
b_grid_desc
_k0_n_k1
.
GetLength
(
I0
);
const
auto
N
=
b_grid_desc
_k0_n_k1
.
GetLength
(
I1
);
const
auto
N1
=
Number
<
NPerBlock
N1
>
{};
const
auto
N1
=
Number
<
NPerBlock
>
{};
const
auto
N0
=
N
/
N1
;
const
auto
b_k0_n0_n1_k1
_grid_desc
=
transform_tensor_descriptor
(
b_
k0_n_k1_
grid_desc
,
const
auto
b_
grid_desc_
k0_n0_n1_k1
=
transform_tensor_descriptor
(
b_grid_desc
_k0_n_k1
,
make_tuple
(
make_pass_through_transform
(
K0
),
make_unmerge_transform
(
make_tuple
(
N0
,
N1
)),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}));
return
b_k0_n0_n1_k1
_grid_desc
;
return
b_
grid_desc_
k0_n0_n1_k1
;
}
__host__
__device__
static
constexpr
auto
MakeCM0M10M11N0N10N11
GridDescriptor
(
const
C
MN
GridDesc
&
c_m_n
_grid_desc
)
MakeC
GridDescriptor_
M0
_
M10
_
M11
_
N0
_
N10
_
N11
(
const
CGridDesc
_M_N
&
c
_grid_desc
_m_n
)
{
const
auto
M
=
c_
m_n_
grid_desc
.
GetLength
(
I0
);
const
auto
N
=
c_
m_n_
grid_desc
.
GetLength
(
I1
);
const
auto
M
=
c_grid_desc
_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc
_m_n
.
GetLength
(
I1
);
constexpr
auto
M1
=
Number
<
MPerBlock
M1
>
{};
constexpr
auto
N1
=
Number
<
NPerBlock
N1
>
{};
constexpr
auto
M1
=
Number
<
MPerBlock
>
{};
constexpr
auto
N1
=
Number
<
NPerBlock
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
N0
=
N
/
N1
;
...
...
@@ -226,41 +221,29 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
constexpr
auto
M10
=
M1
/
M11
;
constexpr
auto
N10
=
N1
/
N11
;
const
auto
c_m0_m10_m11_n0_n10_n11
_grid_desc
=
transform_tensor_descriptor
(
c_
m_n_
grid_desc
,
const
auto
c_
grid_desc_
m0_m10_m11_n0_n10_n11
=
transform_tensor_descriptor
(
c_grid_desc
_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M0
,
M10
,
M11
)),
make_unmerge_transform
(
make_tuple
(
N0
,
N10
,
N11
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
,
5
>
{}));
return
c_m0_m10_m11_n0_n10_n11
_grid_desc
;
return
c_
grid_desc_
m0_m10_m11_n0_n10_n11
;
}
// return block_id to C matrix tile idx (m0, n0) mapping
__host__
__device__
static
constexpr
auto
Make
CBlockIdToM0N0BlockClusterAdaptor
(
const
C
MN
GridDesc
&
c_m_n
_grid_desc
)
Make
DefaultBlock2CTileMap
(
const
CGridDesc
_M_N
&
c
_grid_desc
_m_n
)
{
const
auto
M
=
c_m_n_grid_desc
.
GetLength
(
I0
);
const
auto
N
=
c_m_n_grid_desc
.
GetLength
(
I1
);
constexpr
auto
M1
=
Number
<
MPerBlockM1
>
{};
constexpr
auto
N1
=
Number
<
NPerBlockN1
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
N0
=
N
/
N1
;
const
auto
cblockid_to_m0_n0_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M0
,
N0
))),
make_tuple
(
Sequence
<
0
,
1
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
cblockid_to_m0_n0_block_cluster_adaptor
;
return
BlockToCTileMap_M00_N00_M01_N01
<
MPerBlock
,
NPerBlock
,
CGridDesc_M_N
>
(
c_grid_desc_m_n
);
}
using
A
K0M0M1K1GridDesc
=
decltype
(
MakeA
K0M0M1K1
GridDescriptor
(
AK0MK1GridDesc
{}));
using
B
K0N0N1K1GridDesc
=
decltype
(
MakeB
K0N0N1K1
GridDescriptor
(
BK0NK1GridDesc
{}));
using
C
M0M10M11N0N10N11GridDesc
=
decltype
(
MakeC
M0M10M11N0N10N11
GridDescriptor
(
CMNGridDesc
{}));
using
CBlockIdToM0N0BlockClusterAdaptor
=
decltype
(
MakeCBlockIdToM0N0BlockClusterAdaptor
(
CMN
GridDesc
{}));
using
A
GridDesc_K0_M0_M1_K1
=
decltype
(
MakeAGridDescriptor
_K0_M0_M1_K1
(
AGridDesc_K0_M_K1
{}));
using
B
GridDesc_K0_N0_N1_K1
=
decltype
(
MakeBGridDescriptor
_K0_N0_N1_K1
(
BGridDesc_K0_N_K1
{}));
using
C
GridDesc_
M0
_
M10
_
M11
_
N0
_
N10
_
N11
=
decltype
(
MakeCGridDescriptor_M0_M10_M11_N0_N10_N11
(
CGridDesc_M_N
{}));
using
Block2CTileMap
=
decltype
(
MakeDefaultBlock2CTileMap
(
C
GridDesc
_M_N
{}));
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__device__
static
void
...
...
@@ -268,57 +251,64 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatAB
*
__restrict__
p_shared_block
,
const
A
K0M0M1K1
GridDesc
&
a_k
0_
m
0_
m
1_
k1
_grid_desc
,
const
B
K0N0N1K1
GridDesc
&
b_k
0_
n
0_
n
1_
k1
_grid_desc
,
const
CM0M10M11N0N10N11
GridDesc
&
c_m0_m10_m11_n0_n10_n11
_grid_desc
,
const
C
Block
IdToM0N0BlockClusterAdaptor
&
c
block
id_to_m0_n0_block_cluster_adaptor
,
const
AGridDesc
_K
0_
M
0_
M
1_
K1
&
a
_grid_desc
_k0_m0_m1_k1
,
const
BGridDesc
_K
0_
N
0_
N
1_
K1
&
b
_grid_desc
_k0_n0_n1_k1
,
const
C
GridDesc_
M0
_
M10
_
M11
_
N0
_
N10
_
N11
&
c_grid_des
c_m0_m10_m11_n0_n10_n11
,
const
Block
2CTileMap
&
block
_2_ctile_map
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
{
const
auto
a_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_k0_m0_m1_k1
_grid_desc
.
GetElementSpaceSize
());
p_a_grid
,
a_
grid_desc_
k0_m0_m1_k1
.
GetElementSpaceSize
());
const
auto
b_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_grid
,
b_k0_n0_n1_k1
_grid_desc
.
GetElementSpaceSize
());
p_b_grid
,
b_
grid_desc_
k0_n0_n1_k1
.
GetElementSpaceSize
());
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
,
c_m0_m10_m11_n0_n10_n11
_grid_desc
.
GetElementSpaceSize
());
p_c_grid
,
c_
grid_desc_
m0_m10_m11_n0_n10_n11
.
GetElementSpaceSize
());
// divide block work by [M, N]
const
auto
c_m0_n0_block_cluster_idx
=
cblockid_to_m0_n0_block_cluster_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
block_2_ctile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
// HACK: this force index data into SGPR
const
index_t
im0
=
__builtin_amdgcn_readfirstlane
(
c_m0_n0_block_cluster_idx
[
I0
]);
const
index_t
in0
=
__builtin_amdgcn_readfirstlane
(
c_m0_n0_block_cluster_idx
[
I1
]);
if
(
!
block_2_ctile_map
.
ValidCTileIndex
(
make_tuple
(
im0
,
in0
),
make_tuple
(
c_grid_desc_m0_m10_m11_n0_n10_n11
.
GetLength
(
I0
),
c_grid_desc_m0_m10_m11_n0_n10_n11
.
GetLength
(
I3
))))
{
return
;
}
// TODO: change this. I think it needs multi-dimensional alignment
constexpr
auto
max_lds_align
=
K1
;
// TODO: check alignment
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
a_k0_m0_m1_k1
_block_desc
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
KPerBlock
>
{},
I1
,
Number
<
MPerBlock
M1
>
{},
K1
),
max_lds_align
);
constexpr
auto
a_
block_desc_
k0_m0_m1_k1
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K
0
PerBlock
>
{},
I1
,
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
// TODO: check alignment
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
b_k0_n0_n1_k1
_block_desc
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
KPerBlock
>
{},
I1
,
Number
<
NPerBlock
N1
>
{},
K1
),
max_lds_align
);
constexpr
auto
b_
block_desc_
k0_n0_n1_k1
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K
0
PerBlock
>
{},
I1
,
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
// TODO: check alignment
// A matrix in LDS memory, for blockwise GEMM
constexpr
auto
a_k0_m_k1_block_desc
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
M1
>
{},
K1
),
max_lds_align
);
make_tuple
(
Number
<
K
0
PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
// TODO: check alignment
// B matrix in LDS memory, for blockwise GEMM
constexpr
auto
b_k0_n_k1_block_desc
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
N1
>
{},
K1
),
max_lds_align
);
make_tuple
(
Number
<
K
0
PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
static_assert
(
a_k0_m0_m1_k1
_block_desc
.
GetElementSpaceSize
()
==
static_assert
(
a_
block_desc_
k0_m0_m1_k1
.
GetElementSpaceSize
()
==
a_k0_m_k1_block_desc
.
GetElementSpaceSize
()
&&
b_k0_n0_n1_k1
_block_desc
.
GetElementSpaceSize
()
==
b_
block_desc_
k0_n0_n1_k1
.
GetElementSpaceSize
()
==
b_k0_n_k1_block_desc
.
GetElementSpaceSize
()
&&
"wrong!"
);
...
...
@@ -326,14 +316,14 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
auto
a_blockwise_copy
=
BlockwiseTensorSliceTransfer_v5r1
<
BlockSize
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
KPerBlock
,
1
,
MPerBlock
M1
,
K1
.
value
>
,
Sequence
<
K
0
PerBlock
,
1
,
MPerBlock
,
K1
.
value
>
,
ABlockTransferThreadSliceLengths_K0_M0_M1_K1
,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1
,
ABlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
decltype
(
a_k0_m0_m1_k1
_grid_desc
),
decltype
(
a_k0_m0_m1_k1
_block_desc
),
remove_reference_t
<
decltype
(
a_
grid_desc_
k0_m0_m1_k1
)
>
,
decltype
(
a_
block_desc_
k0_m0_m1_k1
),
ABlockTransferSrcAccessOrder
,
Sequence
<
0
,
1
,
2
,
3
>
,
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1
,
// SrcVectorTensorLengths
...
...
@@ -341,23 +331,23 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
ABlockTransferSrcVectorTensorContiguousDimOrder
,
// SrcVectorTensorContiguousDimOrder
Sequence
<
0
,
1
,
2
,
3
>
,
// DstVectorTensorContiguousDimOrder
false
,
true
>
(
a_k0_m0_m1_k1
_grid_desc
,
true
>
(
a_
grid_desc_
k0_m0_m1_k1
,
make_multi_index
(
0
,
im0
,
0
,
0
),
a_k0_m0_m1_k1
_block_desc
,
a_
block_desc_
k0_m0_m1_k1
,
make_multi_index
(
0
,
0
,
0
,
0
));
// B matrix blockwise copy
auto
b_blockwise_copy
=
BlockwiseTensorSliceTransfer_v5r1
<
BlockSize
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
KPerBlock
,
1
,
NPerBlock
N1
,
K1
.
value
>
,
Sequence
<
K
0
PerBlock
,
1
,
NPerBlock
,
K1
.
value
>
,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1
,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1
,
BBlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
decltype
(
b_k0_n0_n1_k1
_grid_desc
),
decltype
(
b_k0_n0_n1_k1
_block_desc
),
remove_reference_t
<
decltype
(
b_
grid_desc_
k0_n0_n1_k1
)
>
,
decltype
(
b_
block_desc_
k0_n0_n1_k1
),
BBlockTransferSrcAccessOrder
,
Sequence
<
0
,
1
,
2
,
3
>
,
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1
,
// SrcVectorTensorLengths
...
...
@@ -365,19 +355,19 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
BBlockTransferSrcVectorTensorContiguousDimOrder
,
// SrcVectorTensorContiguousDimOrder
Sequence
<
0
,
1
,
2
,
3
>
,
// DstVectorTensorContiguousDimOrder
false
,
true
>
(
b_k0_n0_n1_k1
_grid_desc
,
true
>
(
b_
grid_desc_
k0_n0_n1_k1
,
make_multi_index
(
0
,
in0
,
0
,
0
),
b_k0_n0_n1_k1
_block_desc
,
b_
block_desc_
k0_n0_n1_k1
,
make_multi_index
(
0
,
0
,
0
,
0
));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlock
M1
] is in LDS
// b_mtx[KPerBlocl, NPerBlock
N1
] is in LDS
// c_mtx[MPerBlock
M1
, NPerBlock
N1
] is distributed among threads, and saved in
// a_mtx[K
0
PerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
const
auto
blockwise_gemm
=
BlockwiseGemmDl
ops
_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2
<
BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2
<
BlockSize
,
FloatAB
,
FloatAB
,
...
...
@@ -395,58 +385,53 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
constexpr
auto
c_m10_m11_n10_n11_thread_tensor_lengths
=
decltype
(
blockwise_gemm
)
::
GetCThreadTensorLengths_BM0_BM1_BN0_BN1
();
constexpr
auto
c_m10_m11_n10_n11
_thread_desc
=
make_naive_tensor_descriptor_packed
(
constexpr
auto
c_
thread_desc_
m10_m11_n10_n11
=
make_naive_tensor_descriptor_packed
(
sequence_to_tuple_of_number
(
c_m10_m11_n10_n11_thread_tensor_lengths
));
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_aligned_space_size
=
math
::
integer_least_multiple
(
a_k0_m0_m1_k1
_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
a_
block_desc_
k0_m0_m1_k1
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
auto
b_block_aligned_space_size
=
math
::
integer_least_multiple
(
b_k0_n0_n1_k1
_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
b_
block_desc_
k0_n0_n1_k1
.
GetElementSpaceSize
(),
max_lds_align
);
FloatAB
*
p_a_block_double
=
p_shared_block
;
FloatAB
*
p_b_block_double
=
p_shared_block
+
2
*
a_block_aligned_space_size
;
// register allocation for output
auto
c_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAcc
>
(
c_m10_m11_n10_n11
_thread_desc
.
GetElementSpaceSize
());
c_
thread_desc_
m10_m11_n10_n11
.
GetElementSpaceSize
());
ThreadwiseTensorSliceSet_v1
<
FloatAcc
,
decltype
(
c_m10_m11_n10_n11_thread_desc
),
decltype
(
c_m10_m11_n10_n11_thread_tensor_lengths
)
>
{}
.
Run
(
c_m10_m11_n10_n11_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
FloatAcc
{
0
});
// Initialize C
c_thread_buf
.
Clear
();
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
,
0
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
,
0
,
0
,
0
);
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
K
0
PerBlock
,
0
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
K
0
PerBlock
,
0
,
0
,
0
);
auto
a_block_even_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_a_block_double
,
a_k0_m0_m1_k1
_block_desc
.
GetElementSpaceSize
());
p_a_block_double
,
a_
block_desc_
k0_m0_m1_k1
.
GetElementSpaceSize
());
auto
b_block_even_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_b_block_double
,
b_k0_n0_n1_k1
_block_desc
.
GetElementSpaceSize
());
p_b_block_double
,
b_
block_desc_
k0_n0_n1_k1
.
GetElementSpaceSize
());
auto
a_block_odd_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_a_block_double
+
a_block_aligned_space_size
,
a_k0_m0_m1_k1
_block_desc
.
GetElementSpaceSize
());
a_
block_desc_
k0_m0_m1_k1
.
GetElementSpaceSize
());
auto
b_block_odd_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_b_block_double
+
b_block_aligned_space_size
,
b_k0_n0_n1_k1
_block_desc
.
GetElementSpaceSize
());
b_
block_desc_
k0_n0_n1_k1
.
GetElementSpaceSize
());
// LDS double buffer: preload data into LDS
{
a_blockwise_copy
.
RunRead
(
a_k0_m0_m1_k1
_grid_desc
,
a_global_buf
,
AGridStepHacks
{}
);
b_blockwise_copy
.
RunRead
(
b_k0_n0_n1_k1
_grid_desc
,
b_global_buf
,
BGridStepHacks
{}
);
a_blockwise_copy
.
RunRead
(
a_
grid_desc_
k0_m0_m1_k1
,
a_global_buf
);
b_blockwise_copy
.
RunRead
(
b_
grid_desc_
k0_n0_n1_k1
,
b_global_buf
);
a_blockwise_copy
.
RunWrite
(
a_k0_m0_m1_k1
_block_desc
,
a_block_even_buf
);
b_blockwise_copy
.
RunWrite
(
b_k0_n0_n1_k1
_block_desc
,
b_block_even_buf
);
a_blockwise_copy
.
RunWrite
(
a_
block_desc_
k0_m0_m1_k1
,
a_block_even_buf
);
b_blockwise_copy
.
RunWrite
(
b_
block_desc_
k0_n0_n1_k1
,
b_block_even_buf
);
}
if
constexpr
(
HasMainKBlockLoop
)
{
const
auto
K0
=
a_k0_m0_m1_k1
_grid_desc
.
GetLength
(
I0
);
const
auto
K0
=
a_
grid_desc_
k0_m0_m1_k1
.
GetLength
(
I0
);
index_t
k_block_data_begin
=
0
;
...
...
@@ -455,82 +440,76 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
do
{
// even iteration
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k0_m0_m1_k1_grid_desc
,
a_block_slice_copy_step
,
AGridMoveSliceWindowStepHacks
{});
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_k0_n0_n1_k1_grid_desc
,
b_block_slice_copy_step
,
BGridMoveSliceWindowStepHacks
{});
__syncthreads
();
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_k0_m0_m1_k1
,
a_block_slice_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_n0_n1_k1
,
b_block_slice_copy_step
);
// LDS doubel buffer: load next data from device mem
a_blockwise_copy
.
RunRead
(
a_k0_m0_m1_k1_grid_desc
,
a_global_buf
,
AGridStepHacks
{});
b_blockwise_copy
.
RunRead
(
b_k0_n0_n1_k1_grid_desc
,
b_global_buf
,
BGridStepHacks
{});
a_blockwise_copy
.
RunRead
(
a_grid_desc_k0_m0_m1_k1
,
a_global_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc_k0_n0_n1_k1
,
b_global_buf
);
block_sync_lds
();
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
c_m10_m11_n10_n11
_thread_desc
,
blockwise_gemm
.
Run
(
c_
thread_desc_
m10_m11_n10_n11
,
a_block_even_buf
,
b_block_even_buf
,
c_thread_buf
);
// LDS double buffer: store next data to LDS
a_blockwise_copy
.
RunWrite
(
a_k0_m0_m1_k1
_block_desc
,
a_block_odd_buf
);
b_blockwise_copy
.
RunWrite
(
b_k0_n0_n1_k1
_block_desc
,
b_block_odd_buf
);
a_blockwise_copy
.
RunWrite
(
a_
block_desc_
k0_m0_m1_k1
,
a_block_odd_buf
);
b_blockwise_copy
.
RunWrite
(
b_
block_desc_
k0_n0_n1_k1
,
b_block_odd_buf
);
// odd iteration
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k0_m0_m1_k1_grid_desc
,
a_block_slice_copy_step
,
AGridMoveSliceWindowStepHacks
{});
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_k0_n0_n1_k1_grid_desc
,
b_block_slice_copy_step
,
BGridMoveSliceWindowStepHacks
{});
__syncthreads
();
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_k0_m0_m1_k1
,
a_block_slice_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_n0_n1_k1
,
b_block_slice_copy_step
);
// LDS doubel buffer: load next data from device mem
a_blockwise_copy
.
RunRead
(
a_k0_m0_m1_k1_grid_desc
,
a_global_buf
,
AGridStepHacks
{});
b_blockwise_copy
.
RunRead
(
b_k0_n0_n1_k1_grid_desc
,
b_global_buf
,
BGridStepHacks
{});
a_blockwise_copy
.
RunRead
(
a_grid_desc_k0_m0_m1_k1
,
a_global_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc_k0_n0_n1_k1
,
b_global_buf
);
block_sync_lds
();
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
c_m10_m11_n10_n11
_thread_desc
,
a_block_odd_buf
,
b_block_odd_buf
,
c_thread_buf
);
c_
thread_desc_
m10_m11_n10_n11
,
a_block_odd_buf
,
b_block_odd_buf
,
c_thread_buf
);
// LDS double buffer: store next data to LDS
a_blockwise_copy
.
RunWrite
(
a_k0_m0_m1_k1
_block_desc
,
a_block_even_buf
);
b_blockwise_copy
.
RunWrite
(
b_k0_n0_n1_k1
_block_desc
,
b_block_even_buf
);
a_blockwise_copy
.
RunWrite
(
a_
block_desc_
k0_m0_m1_k1
,
a_block_even_buf
);
b_blockwise_copy
.
RunWrite
(
b_
block_desc_
k0_n0_n1_k1
,
b_block_even_buf
);
k_block_data_begin
+=
2
*
KPerBlock
;
}
while
(
k_block_data_begin
<
K0
-
2
*
KPerBlock
);
k_block_data_begin
+=
2
*
K
0
PerBlock
;
}
while
(
k_block_data_begin
<
K0
-
2
*
K
0
PerBlock
);
}
// LDS double buffer: tail
if
constexpr
(
HasDoubleTailKBlockLoop
)
// if has 2 iteration left
{
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k0_m0_m1_k1_grid_desc
,
a_block_slice_copy_step
,
AGridMoveSliceWindowStepHacks
{});
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_k0_n0_n1_k1_grid_desc
,
b_block_slice_copy_step
,
BGridMoveSliceWindowStepHacks
{});
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_k0_m0_m1_k1
,
a_block_slice_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_n0_n1_k1
,
b_block_slice_copy_step
);
__syncthrea
ds
();
block_sync_l
ds
();
// LDS double buffer: load last data from device mem
a_blockwise_copy
.
RunRead
(
a_k0_m0_m1_k1
_grid_desc
,
a_global_buf
,
AGridStepHacks
{}
);
b_blockwise_copy
.
RunRead
(
b_k0_n0_n1_k1
_grid_desc
,
b_global_buf
,
BGridStepHacks
{}
);
a_blockwise_copy
.
RunRead
(
a_
grid_desc_
k0_m0_m1_k1
,
a_global_buf
);
b_blockwise_copy
.
RunRead
(
b_
grid_desc_
k0_n0_n1_k1
,
b_global_buf
);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm
.
Run
(
c_m10_m11_n10_n11
_thread_desc
,
a_block_even_buf
,
b_block_even_buf
,
c_thread_buf
);
c_
thread_desc_
m10_m11_n10_n11
,
a_block_even_buf
,
b_block_even_buf
,
c_thread_buf
);
// LDS double buffer: store last data to LDS
a_blockwise_copy
.
RunWrite
(
a_k0_m0_m1_k1
_block_desc
,
a_block_odd_buf
);
b_blockwise_copy
.
RunWrite
(
b_k0_n0_n1_k1
_block_desc
,
b_block_odd_buf
);
a_blockwise_copy
.
RunWrite
(
a_
block_desc_
k0_m0_m1_k1
,
a_block_odd_buf
);
b_blockwise_copy
.
RunWrite
(
b_
block_desc_
k0_n0_n1_k1
,
b_block_odd_buf
);
__syncthrea
ds
();
block_sync_l
ds
();
// LDS double buffer: GEMM on last data
blockwise_gemm
.
Run
(
c_m10_m11_n10_n11
_thread_desc
,
a_block_odd_buf
,
b_block_odd_buf
,
c_thread_buf
);
c_
thread_desc_
m10_m11_n10_n11
,
a_block_odd_buf
,
b_block_odd_buf
,
c_thread_buf
);
}
else
// if has 1 iteration left
{
...
...
@@ -538,12 +517,12 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
// LDS double buffer: GEMM on last data
blockwise_gemm
.
Run
(
c_m10_m11_n10_n11
_thread_desc
,
a_block_even_buf
,
b_block_even_buf
,
c_thread_buf
);
c_
thread_desc_
m10_m11_n10_n11
,
a_block_even_buf
,
b_block_even_buf
,
c_thread_buf
);
}
// output: register to global memory
{
constexpr
auto
c_m0_m10_m11_n0_n10_n11
_thread_desc
=
constexpr
auto
c_
thread_desc_
m0_m10_m11_n0_n10_n11
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
c_m10_m11_n10_n11_thread_tensor_lengths
[
I0
]
>
{},
...
...
@@ -559,8 +538,9 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
ThreadwiseTensorSliceTransfer_v1r3
<
FloatAcc
,
FloatC
,
decltype
(
c_m0_m10_m11_n0_n10_n11_thread_desc
),
decltype
(
c_m0_m10_m11_n0_n10_n11_grid_desc
),
decltype
(
c_thread_desc_m0_m10_m11_n0_n10_n11
),
decltype
(
c_grid_desc_m0_m10_m11_n0_n10_n11
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
1
,
c_m10_m11_n10_n11_thread_tensor_lengths
[
I0
],
c_m10_m11_n10_n11_thread_tensor_lengths
[
I1
],
...
...
@@ -572,22 +552,21 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
CThreadTransferDstScalarPerVector
,
CGlobalMemoryDataOperation
,
1
,
true
>
{
c_m0_m10_m11_n0_n10_n11
_grid_desc
,
true
>
{
c_
grid_desc_
m0_m10_m11_n0_n10_n11
,
make_multi_index
(
im0
,
c_m10_m11_n10_n11_thread_origin_idx_on_block
[
I0
],
c_m10_m11_n10_n11_thread_origin_idx_on_block
[
I1
],
in0
,
c_m10_m11_n10_n11_thread_origin_idx_on_block
[
I2
],
c_m10_m11_n10_n11_thread_origin_idx_on_block
[
I3
])}
.
Run
(
c_m0_m10_m11_n0_n10_n11_thread_desc
,
c_m10_m11_n10_n11_thread_origin_idx_on_block
[
I3
]),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}}
.
Run
(
c_thread_desc_m0_m10_m11_n0_n10_n11
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
c_m0_m10_m11_n0_n10_n11_grid_desc
,
c_grid_buf
,
CGridStepHacks
{});
c_grid_desc_m0_m10_m11_n0_n10_n11
,
c_grid_buf
);
}
}
};
}
// namespace ck
#endif
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
View file @
a3b4c5cb
#ifndef CK_GRIDWISE_GEMM_PIPELINE_V1_HPP
#define CK_GRIDWISE_GEMM_PIPELINE_V1_HPP
#pragma once
#include "common_header.hpp"
#include "tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
namespace
ck
{
template
<
typename
AGridDesc
,
typename
ABlockDesc
,
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
ABlockBuffer
,
typename
ABlockTransferStep
,
typename
BGridDesc
,
typename
BBlockDesc
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BBlockBuffer
,
typename
BBlockTransferStep
,
typename
BlockwiseGemm
,
typename
CThreadBuffer
,
index_t
NumPrefetch
,
bool
HasMainLoop
>
template
<
index_t
NumPrefetch
>
struct
GridwiseGemmPipeline_v1
;
// 1-stage prefetch
template
<
typename
AGridDesc
,
template
<
>
struct
GridwiseGemmPipeline_v1
<
1
>
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
__host__
__device__
static
constexpr
bool
IsSupported
(
index_t
/* num_loop */
)
{
return
true
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainLoop
(
index_t
num_loop
)
{
return
num_loop
>
1
;
}
template
<
bool
HasMainLoop
,
typename
AGridDesc
,
typename
ABlockDesc
,
typename
ABlockTransfer
,
typename
AGridBuffer
,
...
...
@@ -37,29 +35,8 @@ template <typename AGridDesc,
typename
BBlockBuffer
,
typename
BBlockTransferStep
,
typename
BlockwiseGemm
,
typename
CThreadBuffer
,
bool
HasMainLoop
>
struct
GridwiseGemmPipeline_v1
<
AGridDesc
,
ABlockDesc
,
ABlockTransfer
,
AGridBuffer
,
ABlockBuffer
,
ABlockTransferStep
,
BGridDesc
,
BBlockDesc
,
BBlockTransfer
,
BGridBuffer
,
BBlockBuffer
,
BBlockTransferStep
,
BlockwiseGemm
,
CThreadBuffer
,
1
,
HasMainLoop
>
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
__device__
void
Run
(
const
AGridDesc
&
a_grid_desc
,
typename
CThreadBuffer
>
__device__
static
void
Run
(
const
AGridDesc
&
a_grid_desc
,
const
ABlockDesc
&
a_block_desc
,
ABlockTransfer
&
a_blockwise_copy
,
const
AGridBuffer
&
a_grid_buf
,
...
...
@@ -75,51 +52,6 @@ struct GridwiseGemmPipeline_v1<AGridDesc,
CThreadBuffer
&
c_thread_buf
,
index_t
num_loop
)
{
#if 0
// preload data into LDS
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
// Initialize C
c_thread_buf.Clear();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
// main body
if constexpr(HasMainLoop)
{
index_t i = 0;
do
{
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
block_sync_lds();
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
++i;
} while(i < (num_loop - 1));
}
// tail
{
block_sync_lds();
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
}
#else
// preload data into LDS
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
...
...
@@ -166,12 +98,29 @@ struct GridwiseGemmPipeline_v1<AGridDesc,
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
}
#endif
}
};
// 2-stage prefetch
template
<
typename
AGridDesc
,
template
<
>
struct
GridwiseGemmPipeline_v1
<
2
>
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
__host__
__device__
static
constexpr
bool
IsSupported
(
index_t
num_loop
)
{
// TODO: improve applicability
return
num_loop
%
2
==
0
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainLoop
(
index_t
num_loop
)
{
return
(
num_loop
/
2
)
>
1
;
}
template
<
bool
HasMainLoop
,
typename
AGridDesc
,
typename
ABlockDesc
,
typename
ABlockTransfer
,
typename
AGridBuffer
,
...
...
@@ -184,28 +133,7 @@ template <typename AGridDesc,
typename
BBlockBuffer
,
typename
BBlockTransferStep
,
typename
BlockwiseGemm
,
typename
CThreadBuffer
,
bool
HasMainLoop
>
struct
GridwiseGemmPipeline_v1
<
AGridDesc
,
ABlockDesc
,
ABlockTransfer
,
AGridBuffer
,
ABlockBuffer
,
ABlockTransferStep
,
BGridDesc
,
BBlockDesc
,
BBlockTransfer
,
BGridBuffer
,
BBlockBuffer
,
BBlockTransferStep
,
BlockwiseGemm
,
CThreadBuffer
,
2
,
HasMainLoop
>
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
typename
CThreadBuffer
>
static
__device__
void
Run
(
const
AGridDesc
&
a_grid_desc
,
const
ABlockDesc
&
a_block_desc
,
ABlockTransfer
&
a_blockwise_copy
,
...
...
@@ -321,5 +249,116 @@ struct GridwiseGemmPipeline_v1<AGridDesc,
}
};
template
<
index_t
NumPrefetch
>
struct
GridwiseGemmPipelineInterwave_v1
;
template
<
>
struct
GridwiseGemmPipelineInterwave_v1
<
1
>
{
__host__
__device__
static
constexpr
bool
IsSupported
(
index_t
/* num_loop */
)
{
return
true
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainLoop
(
index_t
num_loop
)
{
return
num_loop
>
1
;
}
template
<
bool
HasMainLoop
,
typename
AGridDesc
,
typename
ABlockDesc
,
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
ABlockBuffer
,
typename
ABlockTransferStep
,
typename
BGridDesc
,
typename
BBlockDesc
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BBlockBuffer
,
typename
BBlockTransferStep
,
typename
BlockwiseGemm
,
typename
CThreadBuffer
>
static
__device__
void
Run
(
const
AGridDesc
&
a_grid_desc
,
const
ABlockDesc
&
a_block_desc
,
ABlockTransfer
&
a_blockwise_copy
,
const
AGridBuffer
&
a_grid_buf
,
ABlockBuffer
&
a_block_buf
,
const
ABlockTransferStep
&
a_block_copy_step
,
const
BGridDesc
&
b_grid_desc
,
const
BBlockDesc
&
b_block_desc
,
BBlockTransfer
&
b_blockwise_copy
,
const
BGridBuffer
&
b_grid_buf
,
BBlockBuffer
&
b_block_buf
,
const
BBlockTransferStep
&
b_block_copy_step
,
const
BlockwiseGemm
&
blockwise_gemm
,
CThreadBuffer
&
c_thread_buf
,
index_t
num_loop
)
{
// preload data into LDS
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// Initialize C
c_thread_buf
.
Clear
();
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
// main body
if
constexpr
(
HasMainLoop
)
{
index_t
i
=
0
;
do
{
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
block_sync_lds
();
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
// block_sync_lds(); // moved into blockwise_gemm
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
++
i
;
}
while
(
i
<
(
num_loop
-
1
));
}
// tail
{
block_sync_lds
();
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
}
}
};
// Note: 2 stage prefetch not optimized for inter-wave loop scheduler
template
<
>
struct
GridwiseGemmPipelineInterwave_v1
<
2
>
:
public
GridwiseGemmPipeline_v1
<
2
>
{
};
template
<
index_t
NumPrefetch
,
LoopScheduler
LoopSched
>
constexpr
auto
GridwiseGemmPipeline_v1_Selector
()
{
if
constexpr
(
LoopSched
==
LoopScheduler
::
Default
)
{
return
GridwiseGemmPipeline_v1
<
NumPrefetch
>
{};
}
else
if
constexpr
(
LoopSched
==
LoopScheduler
::
Interwave
)
{
return
GridwiseGemmPipelineInterwave_v1
<
NumPrefetch
>
{};
}
}
}
// namespace ck
#endif
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v3.hpp
0 → 100644
View file @
a3b4c5cb
#pragma once
#include "common_header.hpp"
namespace
ck
{
struct
GridwiseGemmPipeline_v3
{
__host__
__device__
static
constexpr
bool
IsSupported
(
index_t
num_loop
)
{
// TODO: improve applicability
return
num_loop
%
2
==
0
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainLoop
(
index_t
num_loop
)
{
return
(
num_loop
/
2
)
>
1
;
}
template
<
bool
HasMainLoop
,
typename
AGridDesc
,
typename
ABlockDesc
,
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
ABlockBuffer
,
typename
ABlockTransferStep
,
typename
BGridDesc
,
typename
BBlockDesc
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BBlockBuffer
,
typename
BBlockTransferStep
,
typename
BlockwiseGemm
,
typename
CThreadBuffer
>
__device__
static
void
Run
(
const
AGridDesc
&
a_grid_desc
,
const
ABlockDesc
&
a_block_desc
,
ABlockTransfer
&
a_blockwise_copy
,
const
AGridBuffer
&
a_grid_buf
,
ABlockBuffer
&
a_block_buf
,
const
ABlockTransferStep
&
a_block_copy_step
,
const
BGridDesc
&
b_grid_desc
,
const
BBlockDesc
&
b_block_desc
,
BBlockTransfer
&
b_blockwise_copy
,
const
BGridBuffer
&
b_grid_buf
,
BBlockBuffer
&
b_block_buf
,
const
BBlockTransferStep
&
b_block_copy_step
,
const
BlockwiseGemm
&
blockwise_gemm
,
CThreadBuffer
&
c_thread_buf
,
index_t
num_loop
)
{
// global read 0
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
// move to 1
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// Initialize C
c_thread_buf
.
Clear
();
// LDS write 0
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
// global Read 1
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
// LDS write 0
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
block_sync_lds
();
// global Read 1
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
// main body
if
constexpr
(
HasMainLoop
)
{
index_t
i
=
0
;
do
{
// block_sync_lds();
// GEMM i
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
block_sync_lds
();
// move to i + 2
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// LDS write i + 1
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
// global read i + 2
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
// LDS write i + 1
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
block_sync_lds
();
// global read i + 2
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
++
i
;
}
while
(
i
<
(
num_loop
-
2
));
}
// tail
{
block_sync_lds
();
// GEMM num_loop - 2
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
block_sync_lds
();
// LDS write num_loop - 1
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
block_sync_lds
();
// GEMM num_loop - 1
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
}
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp
View file @
a3b4c5cb
...
...
@@ -3,29 +3,31 @@
#include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "
blockwise
_tensor_slice_transfer_v4r1.hpp"
#include "
blockwise
_tensor_slice_transfer_v6r1.hpp"
#include "
thread_group
_tensor_slice_transfer_v4r1.hpp"
#include "
thread_group
_tensor_slice_transfer_v6r1.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "gridwise_gemm_pipeline_v1.hpp"
#include "reduction_functions_threadwise.hpp"
namespace
ck
{
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatC
,
typename
FloatD
,
typename
DPtrsGlobal
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
D
0Reduc
eOperation
,
typename
D
1Reduc
eOperation
,
typename
D
xsInElementwis
eOperation
,
typename
D
xsAccElementwis
eOperation
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
DGridDescriptor_MBlock_MPerBlock
,
typename
Block2CTileMap
,
bool
HasMainK
0
BlockLoop
>
bool
HasMainKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
...
...
@@ -34,13 +36,12 @@ __global__ void
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatD
*
__restrict__
p_d0_grid
,
FloatD
*
__restrict__
p_d1_grid
,
DPtrsGlobal
p_ds_grid
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
,
const
D
0ReduceOperation
d0_reduce
_op
,
const
D
1ReduceOperation
d1_reduce
_op
,
const
D
xsInElementwiseOperation
dxs_in_element
_op
,
const
D
xsAccElementwiseOperation
dxs_out_element
_op
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
...
...
@@ -51,17 +52,16 @@ __global__ void
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainK
0
BlockLoop
>(
p_a_grid
,
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
p_b_grid
,
p_c_grid
,
p_d0_grid
,
p_d1_grid
,
p_ds_grid
,
p_shared
,
a_element_op
,
b_element_op
,
c_element_op
,
d0_reduce
_op
,
d1_reduce
_op
,
dxs_in_element
_op
,
dxs_out_element
_op
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
...
...
@@ -71,13 +71,12 @@ __global__ void
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_c_grid
;
ignore
=
p_d0_grid
;
ignore
=
p_d1_grid
;
ignore
=
p_ds_grid
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
c_element_op
;
ignore
=
d
0_reduce
_op
;
ignore
=
d
1_reduce
_op
;
ignore
=
d
xs_in_element
_op
;
ignore
=
d
xs_out_element
_op
;
ignore
=
a_grid_desc_ak0_m_ak1
;
ignore
=
b_grid_desc_bk0_n_bk1
;
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
...
...
@@ -91,14 +90,15 @@ template <typename FloatAB,
typename
FloatCShuffle
,
typename
FloatC
,
typename
FloatReduceAcc
,
typename
FloatD
,
typename
DPtrsGlobal
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
D0ReduceOperation
,
typename
D1ReduceOperation
,
typename
DxsReduceOperation
,
typename
DxsInElementwiseOperation
,
typename
DxsAccElementwiseOperation
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
InMemoryDataOperationEnum
DGlobalMemoryDataOperation
,
typename
DGlobalMemoryDataOperation
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
CGridDesc_M_N
,
...
...
@@ -136,7 +136,8 @@ template <typename FloatAB,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
typename
CReduceThreadClusterLengths_MPerBlock_NPerBlock
,
index_t
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock
,
index_t
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock
>
index_t
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock
,
LoopScheduler
LoopSched
>
struct
GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
@@ -154,6 +155,10 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
static
constexpr
auto
AK1
=
Number
<
AK1Value
>
{};
static
constexpr
auto
BK1
=
Number
<
BK1Value
>
{};
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
GridwiseGemmPipeline_v1
<
NumGemmKPrefetchStage
>
;
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
{
// A matrix in LDS memory, dst of blockwise copy
...
...
@@ -214,10 +219,12 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template
<
typename
Block2CTileMap
>
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
const
Block2CTileMap
&
block_2_ctile_map
)
{
// static_assert(is_known_at_compile_time<remove_cv_t<decltype(AK1)>>::value &&
// is_known_at_compile_time<remove_cv_t<decltype(BK1)>>::value,
...
...
@@ -237,21 +244,15 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
))
return
false
;
// check NumGemmKPrefetchStage
if
constexpr
(
NumGemmKPrefetchStage
==
1
)
{
// 1-stage prefetch always supported
}
else
if
constexpr
(
NumGemmKPrefetchStage
==
2
)
{
// 2-stage prefetch currently only support even number of K0 loop
// TODO: add support for odd number of K0 loop
if
(
!
((
K
/
KPerBlock
)
%
2
==
0
))
// check gridwise gemm pipeline
const
auto
num_k_loop
=
K
/
KPerBlock
;
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_k_loop
))
{
return
false
;
}
}
else
if
(
!
block_2_ctile_map
.
CheckValidity
(
c_grid_desc_m_n
))
{
return
false
;
}
...
...
@@ -260,23 +261,11 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
return
true
;
}
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
const
index_t
grid_size
=
(
M
/
MPerBlock
)
*
(
N
/
NPerBlock
);
return
grid_size
;
}
// TODO move this function into GEMM-pipeline class
__host__
__device__
static
constexpr
bool
CalculateHasMainK0BlockLoop
(
index_t
K0
)
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
{
const
bool
has_main_k0_block_loop
=
((
K0
*
AK1
)
/
(
NumGemmKPrefetchStage
*
KPerBlock
))
>
1
;
const
index_t
num_loop
=
K
/
KPerBlock
;
return
has_main_k0_block
_loop
;
return
GridwiseGemmPipe
::
CalculateHasMainLoop
(
num
_loop
)
;
}
__host__
__device__
static
constexpr
auto
...
...
@@ -317,40 +306,8 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2CTileMap
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
constexpr
auto
M1
=
Number
<
MPerBlock
>
{};
constexpr
auto
N1
=
Number
<
NPerBlock
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
N0
=
N
/
N1
;
// FIXME: remove
constexpr
auto
M01
=
I1
;
constexpr
auto
N01
=
I1
;
const
auto
M00
=
M0
/
M01
;
const
auto
N00
=
N0
/
N01
;
const
auto
m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M00
,
M01
)),
make_unmerge_transform
(
make_tuple
(
N00
,
N01
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
,
3
>
{}));
const
auto
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M00
,
N00
,
M01
,
N01
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
cblockid_to_m0_n0_block_cluster_adaptor
=
chain_tensor_adaptors
(
m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
,
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor
);
return
cblockid_to_m0_n0_block_cluster_adaptor
;
return
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
,
CGridDesc_M_N
>
(
c_grid_desc_m_n
);
}
using
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
...
...
@@ -362,18 +319,17 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using
DefaultBlock2CTileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{}))
>
;
template
<
bool
HasMainK
0
BlockLoop
,
typename
Block2CTileMap
>
template
<
bool
HasMainKBlockLoop
,
typename
Block2CTileMap
>
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatD
*
__restrict__
p_d0_grid
,
FloatD
*
__restrict__
p_d1_grid
,
DPtrsGlobal
p_ds_grid
,
void
*
__restrict__
p_shared
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CElementwiseOperation
&
c_element_op
,
const
D
0ReduceOperation
&
d0_reduce
_op
,
const
D
1ReduceOperation
&
d1_reduce
_op
,
const
D
xsInElementwiseOperation
&
dxs_in_element
_op
,
const
D
xsAccElementwiseOperation
&
dxs_out_element
_op
,
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
...
...
@@ -387,15 +343,19 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
p_b_grid
,
b_grid_desc_bk0_n_bk1
.
GetElementSpaceSize
());
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
auto
d0_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_d0_grid
,
d_grid_desc_mblock_mperblock
.
GetElementSpaceSize
());
auto
d1_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_d1_grid
,
d_grid_desc_mblock_mperblock
.
GetElementSpaceSize
());
// divide block work by [M, N]
const
auto
block_work_idx
=
block_2_ctile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
if
(
!
block_2_ctile_map
.
ValidCTileIndex
(
block_work_idx
,
make_tuple
(
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I0
),
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I2
))))
{
return
;
}
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]
*
MPerBlock
);
...
...
@@ -414,7 +374,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// A matrix blockwise copy
auto
a_blockwise_copy
=
Blockwise
TensorSliceTransfer_v4r1
<
Block
Size
,
ThreadGroup
TensorSliceTransfer_v4r1
<
ThisThread
Block
,
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
...
...
@@ -445,7 +405,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// B matrix blockwise copy
auto
b_blockwise_copy
=
Blockwise
TensorSliceTransfer_v4r1
<
Block
Size
,
ThreadGroup
TensorSliceTransfer_v4r1
<
ThisThread
Block
,
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
...
...
@@ -484,8 +444,8 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
BlockSize
,
FloatAB
,
FloatGemmAcc
,
decltype
(
a_block_desc_ak0_m_ak1
),
...
...
@@ -494,7 +454,8 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
NPerXdl
,
MXdlPerWave
,
NXdlPerWave
,
KPack
>
{};
KPack
,
LoopSched
>
();
auto
c_thread_buf
=
blockwise_gemm
.
GetCThreadBuffer
();
...
...
@@ -514,28 +475,13 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// gridwise GEMM pipeline
const
auto
gridwise_gemm_pipeline
=
GridwiseGemmPipeline_v1
<
remove_cvref_t
<
decltype
(
a_grid_desc_ak0_m_ak1
)
>
,
remove_cvref_t
<
decltype
(
a_block_desc_ak0_m_ak1
)
>
,
remove_cvref_t
<
decltype
(
a_blockwise_copy
)
>
,
remove_cvref_t
<
decltype
(
a_grid_buf
)
>
,
remove_cvref_t
<
decltype
(
a_block_buf
)
>
,
remove_cvref_t
<
decltype
(
a_block_slice_copy_step
)
>
,
remove_cvref_t
<
decltype
(
b_grid_desc_bk0_n_bk1
)
>
,
remove_cvref_t
<
decltype
(
b_block_desc_bk0_n_bk1
)
>
,
remove_cvref_t
<
decltype
(
b_blockwise_copy
)
>
,
remove_cvref_t
<
decltype
(
b_grid_buf
)
>
,
remove_cvref_t
<
decltype
(
b_block_buf
)
>
,
remove_cvref_t
<
decltype
(
b_block_slice_copy_step
)
>
,
remove_cvref_t
<
decltype
(
blockwise_gemm
)
>
,
remove_cvref_t
<
decltype
(
c_thread_buf
)
>
,
NumGemmKPrefetchStage
,
HasMainK0BlockLoop
>
{};
GridwiseGemmPipeline_v1_Selector
<
NumGemmKPrefetchStage
,
LoopSched
>
();
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))
/
KPerBlock
);
gridwise_gemm_pipeline
.
Run
(
a_grid_desc_ak0_m_ak1
,
gridwise_gemm_pipeline
.
template
Run
<
HasMainKBlockLoop
>
(
a_grid_desc_ak0_m_ak1
,
a_block_desc_ak0_m_ak1
,
a_blockwise_copy
,
a_grid_buf
,
...
...
@@ -551,7 +497,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
c_thread_buf
,
num_k_block_main_loop
);
// shuffle C
and
write out
// shuffle C
+ reduction +
write out
{
static_assert
(
MXdlPerWave
%
CShuffleMXdlPerWavePerShuffle
==
0
&&
NXdlPerWave
%
CShuffleNXdlPerWavePerShuffle
==
0
,
...
...
@@ -665,8 +611,8 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
// shuffle: blockwise copy C from LDS to global
auto
c_shuffle_block_copy_lds_to_global
=
Blockwise
TensorSliceTransfer_v6r1
<
BlockSize
,
// index_t BlockSize,
auto
c_shuffle_block_copy_lds_to_global
=
ThreadGroup
TensorSliceTransfer_v6r1
<
ThisThreadBlock
,
// ThreadGroup
CElementwiseOperation
,
// ElementwiseOperation,
CGlobalMemoryDataOperation
,
// DstInMemOp,
Sequence
<
1
,
...
...
@@ -690,6 +636,29 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
make_multi_index
(
block_work_idx
[
I0
],
0
,
block_work_idx
[
I1
],
0
),
c_element_op
};
// space filling curve for threadwise C in VGPR
constexpr
auto
sfc_c_vgpr
=
SpaceFillingCurve
<
Sequence
<
MXdlPerWave
,
NXdlPerWave
,
1
,
1
,
M2
,
1
,
M4
,
1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
Sequence
<
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
1
,
1
,
M2
,
1
,
M4
,
1
>>
{};
// space filling curve for shuffled blockwise C in global mem
constexpr
auto
sfc_c_global
=
SpaceFillingCurve
<
Sequence
<
1
,
MPerBlock
,
1
,
NPerBlock
>
,
Sequence
<
0
,
2
,
1
,
3
>
,
Sequence
<
1
,
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
1
,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>>
{};
// TODO: this should be implemented as a blockwise reduction
// LDS c_reduce_block_desc_mperblock_nperblock
constexpr
auto
c_reduce_block_desc_mperblock_nperblock
=
transform_tensor_descriptor
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
...
...
@@ -740,16 +709,9 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
constexpr
auto
d_reduce_thread_desc_mblock_mperblock
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
mreduce_per_thread
>
{}));
// TODO: this should be implemented as a blockwise reduction
auto
c_reduce_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatCShuffle
>
(
auto
c_reduce_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatReduceAcc
>
(
c_reduce_thread_desc_mperblock_nperblock
.
GetElementSpaceSize
());
auto
d0_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatCShuffle
>
(
d_reduce_thread_desc_mperblock
.
GetElementSpaceSize
());
auto
d1_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatCShuffle
>
(
d_reduce_thread_desc_mperblock
.
GetElementSpaceSize
());
// reduce: threadwise copy from LDS to VGPR
constexpr
auto
c_reduce_thread_cluster_desc
=
make_cluster_descriptor
(
CReduceThreadClusterLengths_MPerBlock_NPerBlock
{},
Sequence
<
1
,
0
>
{});
...
...
@@ -763,7 +725,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
auto
c_reduce_thread_copy_lds_to_vgpr
=
ThreadwiseTensorSliceTransfer_v2
<
FloatCShuffle
,
Float
CShuffle
,
Float
ReduceAcc
,
decltype
(
c_reduce_block_desc_mperblock_nperblock
),
decltype
(
c_reduce_thread_desc_mperblock_nperblock
),
decltype
(
c_reduce_thread_lengths_mperblock_nperblock
),
...
...
@@ -773,47 +735,29 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
1
,
true
>
{
c_reduce_block_desc_mperblock_nperblock
,
c_reduce_thread_data_idx_begin
};
// reduce: copy from VGPR to global
auto
d0_reduce_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatCShuffle
,
FloatD
,
auto
dxs_reduce_thread_copy_vgpr_to_global
=
generate_tuple
(
[
&
](
auto
I
)
{
auto
p_d_grid
=
p_ds_grid
[
I
];
auto
d_out_element_op
=
dxs_out_element_op
[
I
];
return
ThreadwiseTensorSliceTransfer_v1r3
<
FloatReduceAcc
,
remove_pointer_t
<
decltype
(
p_d_grid
)
>
,
decltype
(
d_reduce_thread_desc_mblock_mperblock
),
decltype
(
d_grid_desc_mblock_mperblock
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
decltype
(
d_out_element_op
)
,
Sequence
<
1
,
mreduce_per_thread
>
,
Sequence
<
0
,
1
>
,
1
,
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock
,
DGlobalMemoryDataOperation
,
DGlobalMemoryDataOperation
::
At
(
I
)
,
1
,
false
>
{
d_grid_desc_mblock_mperblock
,
make_multi_index
(
block_work_idx
[
I0
],
// mblock
c_reduce_thread_data_idx_begin
[
I0
]),
// mperblock
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
auto
d1_reduce_thread_copy_vgpr_to_global
=
d0_reduce_thread_copy_vgpr_to_global
;
// space filling curve for threadwise C in VGPR
constexpr
auto
sfc_c_vgpr
=
SpaceFillingCurve
<
Sequence
<
MXdlPerWave
,
NXdlPerWave
,
1
,
1
,
M2
,
1
,
M4
,
1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
Sequence
<
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
1
,
1
,
M2
,
1
,
M4
,
1
>>
{};
// space filling curve for shuffled blockwise C in global mem
constexpr
auto
sfc_c_global
=
SpaceFillingCurve
<
Sequence
<
1
,
MPerBlock
,
1
,
NPerBlock
>
,
Sequence
<
0
,
2
,
1
,
3
>
,
Sequence
<
1
,
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
1
,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>>
{};
d_out_element_op
};
},
Number
<
p_ds_grid
.
Size
()
>
{});
constexpr
index_t
num_access
=
sfc_c_vgpr
.
GetNumOfAccess
();
...
...
@@ -840,48 +784,73 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_buf
);
//
reduc
e
//
TODO - extract following into reduction_blockwis
e
{
// copy from LDS to VGPR
c_reduce_thread_copy_lds_to_vgpr
.
Run
(
c_reduce_block_desc_mperblock_nperblock
,
c_shuffle_block_buf
,
c_reduce_thread_desc_mperblock_nperblock
,
make_tuple
(
I0
,
I0
),
c_reduce_thread_buf
);
static_for
<
0
,
p_ds_grid
.
Size
(),
1
>
{}([
&
](
auto
In
)
{
auto
&
p_d_grid
=
p_ds_grid
[
In
];
auto
d_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_d_grid
,
d_grid_desc_mblock_mperblock
.
GetElementSpaceSize
());
auto
d_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatReduceAcc
>
(
d_reduce_thread_desc_mperblock
.
GetElementSpaceSize
());
auto
&
d_in_element_op
=
dxs_in_element_op
[
In
];
auto
&
d_reduce_thread_copy_vgpr_to_global
=
dxs_reduce_thread_copy_vgpr_to_global
(
In
);
using
DReduceOperation
=
remove_cvref_t
<
decltype
(
DxsReduceOperation
{}[
In
])
>
;
using
ThreadwiseReduce
=
ThreadwiseReduction
<
FloatReduceAcc
,
decltype
(
c_reduce_thread_desc_mperblock_nperblock
),
decltype
(
d_reduce_thread_desc_mperblock
),
DReduceOperation
,
false
>
;
// Global write Gemm shuffle + reduction
const
auto
d_identityVal
=
DReduceOperation
::
GetIdentityValue
();
static_for
<
0
,
mreduce_per_thread
,
1
>
{}(
[
&
](
auto
I
)
{
d_thread_buf
(
I
)
=
d_identityVal
;
});
// reduce in VGPR
static_for
<
0
,
mreduce_per_thread
,
1
>
{}([
&
](
auto
im
)
{
FloatReduceAcc
d0_acc
=
d0_reduce_op
.
GetReduceZeroValue
();
FloatReduceAcc
d1_acc
=
d1_reduce_op
.
GetReduceZeroValue
();
static_for
<
0
,
nreduce_per_thread
,
1
>
{}([
&
](
auto
in
)
{
constexpr
auto
offset
=
Number
<
c_reduce_thread_desc_mperblock_nperblock
.
CalculateOffset
(
make_tuple
(
im
,
in
))
>
{};
d0_reduce_op
.
Reduce
(
d0_acc
,
c_reduce_thread_buf
[
offset
]);
d1_reduce_op
.
Reduce
(
d1_acc
,
c_reduce_thread_buf
[
offset
]
);
d_in_element_op
(
c_reduce_thread_buf
(
offset
),
c_reduce_thread_buf
(
offset
)
);
});
constexpr
index_t
out_offset
=
d_reduce_thread_desc_mperblock
.
CalculateOffset
(
make_tuple
(
im
));
d0_thread_buf
(
Number
<
out_offset
>
{})
=
d0_acc
;
d1_thread_buf
(
Number
<
out_offset
>
{})
=
d1_acc
;
});
ThreadwiseReduce
::
Reduce
(
c_reduce_thread_buf
,
d_thread_buf
);
// copy from VGPR to Global
d0_reduce_thread_copy_vgpr_to_global
.
Run
(
d_reduce_thread_desc_mblock_mperblock
,
d_reduce_thread_copy_vgpr_to_global
.
Run
(
d_reduce_thread_desc_mblock_mperblock
,
make_tuple
(
I0
,
I0
),
d
0
_thread_buf
,
d_thread_buf
,
d_grid_desc_mblock_mperblock
,
d
0
_grid_buf
);
d_grid_buf
);
d1_reduce_thread_copy_vgpr_to_global
.
Run
(
d_reduce_thread_desc_mblock_mperblock
,
make_tuple
(
I0
,
I0
),
d1_thread_buf
,
if
constexpr
(
access_id
<
num_access
-
1
)
{
constexpr
auto
c_global_step
=
sfc_c_global
.
GetForwardStep
(
access_id
);
d_reduce_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
d_grid_desc_mblock_mperblock
,
d1_grid_buf
);
make_tuple
(
c_global_step
[
I0
],
c_global_step
[
I1
]));
}
});
}
if
constexpr
(
access_id
<
num_access
-
1
)
...
...
@@ -891,18 +860,10 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// move on C
c_shuffle_block_copy_lds_to_global
.
MoveDstSliceWindow
(
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_global_step
);
// move on D0
d0_reduce_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
d_grid_desc_mblock_mperblock
,
make_tuple
(
c_global_step
[
I0
],
c_global_step
[
I1
]));
// move on D1
d1_reduce_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
d_grid_desc_mblock_mperblock
,
make_tuple
(
c_global_step
[
I0
],
c_global_step
[
I1
]));
}
});
// Reduction
}
}
};
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
View file @
a3b4c5cb
...
...
@@ -3,9 +3,10 @@
#include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "
blockwise
_tensor_slice_transfer_v4r1.hpp"
#include "
blockwise
_tensor_slice_transfer_v6r1.hpp"
#include "
thread_group
_tensor_slice_transfer_v4r1.hpp"
#include "
thread_group
_tensor_slice_transfer_v6r1.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "gridwise_gemm_pipeline_v1.hpp"
...
...
@@ -21,7 +22,7 @@ template <typename GridwiseGemm,
typename
BGridDesc_BK0_N_BK1
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
Block2CTileMap
,
bool
HasMainK
0
BlockLoop
>
bool
HasMainKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
...
...
@@ -41,7 +42,7 @@ __global__ void
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainK
0
BlockLoop
>(
p_a_grid
,
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
p_b_grid
,
p_c_grid
,
p_shared
,
...
...
@@ -107,7 +108,8 @@ template <typename FloatAB,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
>
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
>
struct
GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
@@ -125,6 +127,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
static
constexpr
auto
AK1
=
Number
<
AK1Value
>
{};
static
constexpr
auto
BK1
=
Number
<
BK1Value
>
{};
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
GridwiseGemmPipeline_v1
<
NumGemmKPrefetchStage
>
;
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
{
// A matrix in LDS memory, dst of blockwise copy
...
...
@@ -185,15 +191,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template
<
typename
Block2CTileMap
>
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
const
Block2CTileMap
&
block_2_ctile_map
)
{
// static_assert(is_known_at_compile_time<remove_cv_t<decltype(AK1)>>::value &&
// is_known_at_compile_time<remove_cv_t<decltype(BK1)>>::value,
// "wrong! K1 need to be known at compile-time");
static_assert
((
MPerBlock
%
(
MPerXdl
*
MXdlPerWave
)
==
0
)
&&
(
NPerBlock
%
(
NXdlPerWave
*
NPerXdl
))
==
0
,
"Invalid tuning param!"
);
...
...
@@ -208,21 +212,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
))
return
false
;
// check NumGemmKPrefetchStage
if
constexpr
(
NumGemmKPrefetchStage
==
1
)
{
// 1-stage prefetch always supported
}
else
if
constexpr
(
NumGemmKPrefetchStage
==
2
)
{
// 2-stage prefetch currently only support even number of K0 loop
// TODO: add support for odd number of K0 loop
if
(
!
((
K
/
KPerBlock
)
%
2
==
0
))
// check gridwise gemm pipeline
const
auto
num_k_loop
=
K
/
KPerBlock
;
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_k_loop
))
{
return
false
;
}
}
else
if
(
!
block_2_ctile_map
.
CheckValidity
(
c_grid_desc_m_n
))
{
return
false
;
}
...
...
@@ -231,23 +229,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
return
true
;
}
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
const
index_t
grid_size
=
(
M
/
MPerBlock
)
*
(
N
/
NPerBlock
);
return
grid_size
;
}
// TODO move this function into GEMM-pipeline class
__host__
__device__
static
constexpr
bool
CalculateHasMainK0BlockLoop
(
index_t
K0
)
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
{
const
bool
has_main_k0_block_loop
=
((
K0
*
AK1
)
/
(
NumGemmKPrefetchStage
*
KPerBlock
))
>
1
;
const
index_t
num_loop
=
K
/
KPerBlock
;
return
has_main_k0_block
_loop
;
return
GridwiseGemmPipe
::
CalculateHasMainLoop
(
num
_loop
)
;
}
__host__
__device__
static
constexpr
auto
...
...
@@ -273,40 +259,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2CTileMap
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
constexpr
auto
M1
=
Number
<
MPerBlock
>
{};
constexpr
auto
N1
=
Number
<
NPerBlock
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
N0
=
N
/
N1
;
// FIXME: remove
constexpr
auto
M01
=
I1
;
constexpr
auto
N01
=
I1
;
const
auto
M00
=
M0
/
M01
;
const
auto
N00
=
N0
/
N01
;
const
auto
m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M00
,
M01
)),
make_unmerge_transform
(
make_tuple
(
N00
,
N01
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
,
3
>
{}));
const
auto
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M00
,
N00
,
M01
,
N01
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
cblockid_to_m0_n0_block_cluster_adaptor
=
chain_tensor_adaptors
(
m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
,
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor
);
return
cblockid_to_m0_n0_block_cluster_adaptor
;
return
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
,
CGridDesc_M_N
>
(
c_grid_desc_m_n
);
}
using
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
...
...
@@ -315,7 +269,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using
DefaultBlock2CTileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{}))
>
;
template
<
bool
HasMainK
0
BlockLoop
,
typename
Block2CTileMap
>
template
<
bool
HasMainKBlockLoop
,
typename
Block2CTileMap
>
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
...
...
@@ -340,6 +294,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
const
auto
block_work_idx
=
block_2_ctile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
if
(
!
block_2_ctile_map
.
ValidCTileIndex
(
block_work_idx
,
make_tuple
(
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I0
),
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I2
))))
{
return
;
}
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]
*
MPerBlock
);
...
...
@@ -358,7 +320,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// A matrix blockwise copy
auto
a_blockwise_copy
=
Blockwise
TensorSliceTransfer_v4r1
<
Block
Size
,
ThreadGroup
TensorSliceTransfer_v4r1
<
ThisThread
Block
,
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
...
...
@@ -389,7 +351,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// B matrix blockwise copy
auto
b_blockwise_copy
=
Blockwise
TensorSliceTransfer_v4r1
<
Block
Size
,
ThreadGroup
TensorSliceTransfer_v4r1
<
ThisThread
Block
,
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
...
...
@@ -428,8 +390,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
BlockSize
,
FloatAB
,
FloatGemmAcc
,
decltype
(
a_block_desc_ak0_m_ak1
),
...
...
@@ -438,7 +400,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
NPerXdl
,
MXdlPerWave
,
NXdlPerWave
,
KPack
>
{};
KPack
,
LoopSched
>
();
auto
c_thread_buf
=
blockwise_gemm
.
GetCThreadBuffer
();
...
...
@@ -458,28 +421,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// gridwise GEMM pipeline
const
auto
gridwise_gemm_pipeline
=
GridwiseGemmPipeline_v1
<
remove_cvref_t
<
decltype
(
a_grid_desc_ak0_m_ak1
)
>
,
remove_cvref_t
<
decltype
(
a_block_desc_ak0_m_ak1
)
>
,
remove_cvref_t
<
decltype
(
a_blockwise_copy
)
>
,
remove_cvref_t
<
decltype
(
a_grid_buf
)
>
,
remove_cvref_t
<
decltype
(
a_block_buf
)
>
,
remove_cvref_t
<
decltype
(
a_block_slice_copy_step
)
>
,
remove_cvref_t
<
decltype
(
b_grid_desc_bk0_n_bk1
)
>
,
remove_cvref_t
<
decltype
(
b_block_desc_bk0_n_bk1
)
>
,
remove_cvref_t
<
decltype
(
b_blockwise_copy
)
>
,
remove_cvref_t
<
decltype
(
b_grid_buf
)
>
,
remove_cvref_t
<
decltype
(
b_block_buf
)
>
,
remove_cvref_t
<
decltype
(
b_block_slice_copy_step
)
>
,
remove_cvref_t
<
decltype
(
blockwise_gemm
)
>
,
remove_cvref_t
<
decltype
(
c_thread_buf
)
>
,
NumGemmKPrefetchStage
,
HasMainK0BlockLoop
>
{};
GridwiseGemmPipeline_v1_Selector
<
NumGemmKPrefetchStage
,
LoopSched
>
();
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))
/
KPerBlock
);
gridwise_gemm_pipeline
.
Run
(
a_grid_desc_ak0_m_ak1
,
gridwise_gemm_pipeline
.
template
Run
<
HasMainKBlockLoop
>
(
a_grid_desc_ak0_m_ak1
,
a_block_desc_ak0_m_ak1
,
a_blockwise_copy
,
a_grid_buf
,
...
...
@@ -609,8 +557,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
// shuffle: blockwise copy C from LDS to global
auto
c_shuffle_block_copy_lds_to_global
=
Blockwise
TensorSliceTransfer_v6r1
<
BlockSize
,
// index_t BlockSize,
auto
c_shuffle_block_copy_lds_to_global
=
ThreadGroup
TensorSliceTransfer_v6r1
<
ThisThreadBlock
,
// ThreadGroup
CElementwiseOperation
,
// ElementwiseOperation,
CGlobalMemoryDataOperation
,
// DstInMemOp,
Sequence
<
1
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp
0 → 100644
View file @
a3b4c5cb
#pragma once
#include "common_header.hpp"
#include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "thread_group_tensor_slice_transfer_v4r1.hpp"
#include "thread_group_tensor_slice_transfer_v6r1.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "gridwise_gemm_pipeline_v3.hpp"
namespace
ck
{
// Implementation of "Merge" transformation primitive that uses division and mod. It is supposed to
// be used for low_lengths that are known at compile time and are power of 2, otherwise performance
// will be very bad
template
<
typename
LowLengths
>
struct
Merge_v4_no_carry
{
static
constexpr
index_t
NDimLow
=
LowLengths
::
Size
();
using
LowerIndex
=
MultiIndex
<
NDimLow
>
;
using
UpperIndex
=
MultiIndex
<
1
>
;
using
LowLengthsScan
=
decltype
(
container_reverse_exclusive_scan
(
LowLengths
{},
math
::
multiplies
{},
Number
<
1
>
{}));
using
UpLengths
=
decltype
(
make_tuple
(
container_reduce
(
LowLengths
{},
math
::
multiplies
{},
Number
<
1
>
{})));
LowLengths
low_lengths_
;
LowLengthsScan
low_lengths_scan_
;
UpLengths
up_lengths_
;
__host__
__device__
constexpr
Merge_v4_no_carry
()
=
default
;
__host__
__device__
constexpr
Merge_v4_no_carry
(
const
LowLengths
&
low_lengths
)
:
low_lengths_
{
low_lengths
},
low_lengths_scan_
{
container_reverse_exclusive_scan
(
low_lengths
,
math
::
multiplies
{},
Number
<
1
>
{})},
up_lengths_
{
make_tuple
(
container_reduce
(
low_lengths
,
math
::
multiplies
{},
Number
<
1
>
{}))}
{
static_assert
(
LowerIndex
::
Size
()
==
NDimLow
,
"wrong!"
);
}
__host__
__device__
static
constexpr
index_t
GetNumOfLowerDimension
()
{
return
NDimLow
;
}
__host__
__device__
static
constexpr
index_t
GetNumOfUpperDimension
()
{
return
1
;
}
__host__
__device__
constexpr
const
auto
&
GetUpperLengths
()
const
{
return
up_lengths_
;
}
template
<
typename
LowIdx
,
typename
UpIdx
>
__host__
__device__
constexpr
void
CalculateLowerIndex
(
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up
)
const
{
static_assert
(
LowIdx
::
Size
()
==
NDimLow
&&
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
index_t
tmp
=
idx_up
[
Number
<
0
>
{}];
// division and mod
static_for
<
0
,
NDimLow
-
1
,
1
>
{}([
&
](
auto
i
)
{
idx_low
(
i
)
=
tmp
/
this
->
low_lengths_scan_
[
i
];
tmp
%=
this
->
low_lengths_scan_
[
i
];
});
idx_low
(
Number
<
NDimLow
-
1
>
{})
=
tmp
;
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
,
index_t
Hack
>
__host__
__device__
void
UpdateLowerIndex
(
LowIdxDiff
&
idx_diff_low
,
const
UpIdxDiff
&
idx_up_diff
,
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up_new
,
Number
<
Hack
>
)
const
{
static_assert
(
LowIdxDiff
::
Size
()
==
NDimLow
&&
UpIdxDiff
::
Size
()
==
1
&&
LowIdx
::
Size
()
==
NDimLow
&&
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
INm1
=
Number
<
NDimLow
-
1
>
{};
index_t
tmp
=
idx_up_new
[
I0
];
idx_low
(
INm1
)
=
tmp
;
idx_diff_low
(
INm1
)
=
idx_up_diff
[
I0
];
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
false
;
}
__host__
__device__
static
constexpr
bool
IsValidUpperIndexAlwaysMappedToValidLowerIndex
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsKnownAtCompileTime
()
{
return
is_known_at_compile_time
<
LowLengths
>::
value
&&
is_known_at_compile_time
<
LowLengthsScan
>::
value
&&
is_known_at_compile_time
<
UpLengths
>::
value
;
}
template
<
typename
UpIdx
>
__host__
__device__
static
constexpr
bool
IsValidUpperIndexMappedToValidLowerIndex
(
const
UpIdx
&
/* idx_up */
)
{
return
true
;
}
__host__
__device__
void
Print
()
const
{
printf
(
"{"
);
printf
(
"Merge_v3_direct_division_mod_wrw, "
);
printf
(
"low_lengths_ "
);
print_multi_index
(
low_lengths_
);
printf
(
"low_lengths_scan_ "
);
print_multi_index
(
low_lengths_scan_
);
printf
(
"up_lengths_ "
);
print_multi_index
(
up_lengths_
);
printf
(
"}"
);
}
};
template
<
typename
LowLengths
>
__host__
__device__
constexpr
auto
make_merge_transform_v4_no_carry
(
const
LowLengths
&
low_lengths
)
{
return
Merge_v4_no_carry
<
LowLengths
>
{
low_lengths
};
}
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatC
,
typename
AGridDesc_B_K0_M_K1
,
typename
BGridDesc_B_K0_N_K1
,
typename
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
CBlockClusterAdaptor
,
bool
HasMainKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_gemm_xdlops_bwd_weight
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
AGridDesc_B_K0_M_K1
a_b_k0_m_k1_grid_desc
,
const
BGridDesc_B_K0_N_K1
b_b_k0_n_k1_grid_desc
,
const
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
,
const
CBlockClusterAdaptor
c_block_cluster_adaptor
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
constexpr
index_t
shared_block_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
__shared__
FloatAB
p_shared_block
[
shared_block_size
];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
p_b_grid
,
p_c_grid
,
p_shared_block
,
a_b_k0_m_k1_grid_desc
,
b_b_k0_n_k1_grid_desc
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
a_element_op
,
b_element_op
,
c_element_op
,
c_block_cluster_adaptor
);
#else
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_c_grid
;
ignore
=
a_b_k0_m_k1_grid_desc
;
ignore
=
b_b_k0_n_k1_grid_desc
;
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
c_element_op
;
ignore
=
c_block_cluster_adaptor
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
template
<
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatAcc
,
typename
FloatC
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
AGridDesc_B_K0_M_K1
,
typename
BGridDesc_B_K0_N_K1
,
typename
CMNGridDesc
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
K0PerBlock
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
K1Value
,
index_t
MRepeat
,
index_t
NRepeat
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_K1
,
bool
AThreadTransferSrcResetCoordinateAfterRun
,
bool
ABlockLdsExtraM
,
index_t
ABlockLdsM1PerBlock
,
index_t
ABlockLdsM0PerBlock
,
index_t
ABlockLdsM1Padding
,
typename
BBlockTransferThreadClusterLengths_K0_N_K1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_K1
,
bool
BThreadTransferSrcResetCoordinateAfterRun
,
bool
BBlockLdsExtraN
,
index_t
BBlockLdsN1PerBlock
,
index_t
BBlockLdsN0PerBlock
,
index_t
BBlockLdsN1Padding
,
index_t
CShuffleMRepeatPerShuffle
,
index_t
CShuffleNRepeatPerShuffle
,
index_t
CBlockTransferScalarPerVector_NWaveNPerXDL
,
typename
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
bool
ABlockLdsExtraM1Wrw
=
false
,
bool
BBlockLdsExtraN1Wrw
=
false
,
index_t
NumGemmKPrefetchStage
=
1
>
struct
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
// K1 should be Number<...>
static
constexpr
auto
K1
=
Number
<
K1Value
>
{};
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
// using GridwiseGemmPipe = GridwiseGemmPipeline_v1<NumGemmKPrefetchStage>;
using
GridwiseGemmPipe
=
GridwiseGemmPipeline_v3
;
// M0/M1/M1Padding
static
constexpr
auto
M1PerBlock
=
Number
<
ABlockLdsM1PerBlock
>
{};
static
constexpr
auto
M0PerBlock
=
Number
<
ABlockLdsM0PerBlock
>
{};
static
constexpr
auto
M1Padding
=
Number
<
ABlockLdsM1Padding
>
{};
// N0/N1/N1Padding
static
constexpr
auto
N1PerBlock
=
Number
<
BBlockLdsN1PerBlock
>
{};
static
constexpr
auto
N0PerBlock
=
Number
<
BBlockLdsN0PerBlock
>
{};
static
constexpr
auto
N1Padding
=
Number
<
BBlockLdsN1Padding
>
{};
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
()
{
constexpr
auto
max_lds_align
=
K1
;
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_k0_m_k1
=
[
&
]()
{
if
constexpr
(
ABlockLdsExtraM
)
{
if
constexpr
(
ABlockLdsExtraM1Wrw
)
{
constexpr
auto
a_block_desc_k0_m0_m1_k1
=
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
M0PerBlock
>
{},
Number
<
M1PerBlock
>
{},
K1
),
make_tuple
(
Number
<
M0PerBlock
>
{}
*
(
Number
<
M1PerBlock
>
{}
*
K1
+
M1Padding
),
Number
<
M1PerBlock
>
{}
*
K1
+
M1Padding
,
K1
,
I1
));
constexpr
auto
a_block_desc_k0_m_k1_tmp
=
transform_tensor_descriptor
(
a_block_desc_k0_m0_m1_k1
,
make_tuple
(
make_pass_through_transform
(
Number
<
K0PerBlock
>
{}),
make_merge_transform_v3_division_mod
(
make_tuple
(
Number
<
M0PerBlock
>
{},
Number
<
M1PerBlock
>
{})),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
return
a_block_desc_k0_m_k1_tmp
;
}
else
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
MPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
return
a_block_desc_k0_m_k1
;
}
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_Batch_K0PerBlock_MPerBlock_K1
()
{
constexpr
auto
max_lds_align
=
K1
;
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_b_k0_m_k1
=
[
&
]()
{
if
constexpr
(
ABlockLdsExtraM
)
{
if
constexpr
(
ABlockLdsExtraM1Wrw
)
{
constexpr
auto
a_block_desc_b_k0_m0_m1_k1
=
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
1
>
{},
Number
<
K0PerBlock
>
{},
Number
<
M0PerBlock
>
{},
Number
<
M1PerBlock
>
{},
K1
),
make_tuple
(
Number
<
K0PerBlock
>
{}
*
Number
<
M0PerBlock
>
{}
*
(
Number
<
M1PerBlock
>
{}
*
K1
+
M1Padding
),
Number
<
M0PerBlock
>
{}
*
(
Number
<
M1PerBlock
>
{}
*
K1
+
M1Padding
),
Number
<
M1PerBlock
>
{}
*
K1
+
M1Padding
,
K1
,
I1
));
constexpr
auto
a_block_desc_b_k0_m_k1_tmp
=
transform_tensor_descriptor
(
a_block_desc_b_k0_m0_m1_k1
,
make_tuple
(
make_pass_through_transform
(
Number
<
1
>
{}),
make_pass_through_transform
(
Number
<
K0PerBlock
>
{}),
make_merge_transform_v4_no_carry
(
make_tuple
(
Number
<
M0PerBlock
>
{},
Number
<
M1PerBlock
>
{})),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
return
a_block_desc_b_k0_m_k1_tmp
;
}
else
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
1
>
{},
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
K0PerBlock
>
{}
*
Number
<
MPerBlock
+
1
>
{}
*
K1
,
Number
<
MPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
1
>
{},
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
return
a_block_desc_b_k0_m_k1
;
}
__host__
__device__
static
constexpr
auto
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1
()
{
constexpr
auto
max_lds_align
=
K1
;
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_block_desc_k0_n_k1
=
[
&
]()
{
if
constexpr
(
BBlockLdsExtraN
)
{
if
constexpr
(
BBlockLdsExtraN1Wrw
)
{
constexpr
auto
b_block_desc_k0_n0_n1_k1
=
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
N0PerBlock
>
{},
Number
<
N1PerBlock
>
{},
K1
),
make_tuple
(
Number
<
N0PerBlock
>
{}
*
(
Number
<
N1PerBlock
>
{}
*
K1
+
N1Padding
),
Number
<
N1PerBlock
>
{}
*
K1
+
N1Padding
,
K1
,
I1
));
constexpr
auto
b_block_desc_k0_n_k1_tmp
=
transform_tensor_descriptor
(
b_block_desc_k0_n0_n1_k1
,
make_tuple
(
make_pass_through_transform
(
Number
<
K0PerBlock
>
{}),
make_merge_transform_v3_division_mod
(
make_tuple
(
Number
<
N0PerBlock
>
{},
Number
<
N1PerBlock
>
{})),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
return
b_block_desc_k0_n_k1_tmp
;
}
else
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
NPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
return
b_block_desc_k0_n_k1
;
}
__host__
__device__
static
constexpr
auto
GetBBlockDescriptor_Batch_K0PerBlock_NPerBlock_K1
()
{
constexpr
auto
max_lds_align
=
K1
;
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_block_desc_b_k0_n_k1
=
[
&
]()
{
if
constexpr
(
BBlockLdsExtraN
)
{
if
constexpr
(
BBlockLdsExtraN1Wrw
)
{
constexpr
auto
b_block_desc_b_k0_n0_n1_k1
=
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
1
>
{},
Number
<
K0PerBlock
>
{},
Number
<
N0PerBlock
>
{},
Number
<
N1PerBlock
>
{},
K1
),
make_tuple
(
Number
<
K0PerBlock
>
{}
*
Number
<
N0PerBlock
>
{}
*
(
Number
<
N1PerBlock
>
{}
*
K1
+
N1Padding
),
Number
<
N0PerBlock
>
{}
*
(
Number
<
N1PerBlock
>
{}
*
K1
+
N1Padding
),
Number
<
N1PerBlock
>
{}
*
K1
+
N1Padding
,
K1
,
I1
));
constexpr
auto
b_block_desc_b_k0_n_k1_tmp
=
transform_tensor_descriptor
(
b_block_desc_b_k0_n0_n1_k1
,
make_tuple
(
make_pass_through_transform
(
Number
<
1
>
{}),
make_pass_through_transform
(
Number
<
K0PerBlock
>
{}),
make_merge_transform_v4_no_carry
(
make_tuple
(
Number
<
N0PerBlock
>
{},
Number
<
N1PerBlock
>
{})),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
return
b_block_desc_b_k0_n_k1_tmp
;
}
else
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
1
>
{},
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
K0PerBlock
>
{}
*
Number
<
NPerBlock
+
1
>
{}
*
K1
,
Number
<
NPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
1
>
{},
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
return
b_block_desc_b_k0_n_k1
;
}
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
constexpr
auto
max_lds_align
=
K1
;
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_b_k0_m_k1_block_desc
=
GetABlockDescriptor_Batch_K0PerBlock_MPerBlock_K1
();
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_b_k0_n_k1_block_desc
=
GetBBlockDescriptor_Batch_K0PerBlock_NPerBlock_K1
();
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_space_size
=
math
::
integer_least_multiple
(
a_b_k0_m_k1_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
auto
b_block_space_size
=
math
::
integer_least_multiple
(
b_b_k0_n_k1_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
auto
c_block_size
=
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
().
GetElementSpaceSize
();
return
math
::
max
((
a_block_space_size
+
b_block_space_size
)
*
sizeof
(
FloatAB
),
c_block_size
*
sizeof
(
FloatC
));
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template
<
typename
Block2CTileMap
>
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc_B_K0_M_K1
&
a_b_k0_m_k1_grid_desc
,
const
BGridDesc_B_K0_N_K1
&
b_b_k0_n_k1_grid_desc
,
const
CMNGridDesc
&
c_m_n_grid_desc
,
const
Block2CTileMap
&
block_2_ctile_map
)
{
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
decltype
(
K1
)
>>::
value
,
"wrong! K1 need to be known at compile-time"
);
static_assert
((
MPerBlock
%
(
MPerXDL
*
MRepeat
)
==
0
)
&&
(
NPerBlock
%
(
NRepeat
*
NPerXDL
))
==
0
,
"Invalid tuning param!"
);
const
auto
M
=
a_b_k0_m_k1_grid_desc
.
GetLength
(
I2
);
const
auto
N
=
b_b_k0_n_k1_grid_desc
.
GetLength
(
I2
);
const
auto
K0
=
a_b_k0_m_k1_grid_desc
.
GetLength
(
I1
);
const
auto
KBatch
=
a_b_k0_m_k1_grid_desc
.
GetLength
(
I0
);
// check gridwise gemm pipeline
const
auto
num_k_loop
=
K0
/
K0PerBlock
;
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_k_loop
))
{
return
false
;
}
if
(
!
(
M
==
c_m_n_grid_desc
.
GetLength
(
I0
)
&&
N
==
c_m_n_grid_desc
.
GetLength
(
I1
)
&&
K0
==
b_b_k0_n_k1_grid_desc
.
GetLength
(
I1
)
&&
K1
==
a_b_k0_m_k1_grid_desc
.
GetLength
(
I3
)
&&
K1
==
b_b_k0_n_k1_grid_desc
.
GetLength
(
I3
)
&&
KBatch
==
b_b_k0_n_k1_grid_desc
.
GetLength
(
I0
)))
return
false
;
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K0
%
K0PerBlock
==
0
))
return
false
;
if
(
!
block_2_ctile_map
.
CheckValidity
(
c_m_n_grid_desc
))
{
return
false
;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return
true
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainK0BlockLoop
(
index_t
K0
)
{
// const bool has_main_k0_block_loop = K0 > K0PerBlock;
const
index_t
num_loop
=
K0
/
K0PerBlock
;
return
GridwiseGemmPipe
::
CalculateHasMainLoop
(
num_loop
);
// return has_main_k0_block_loop;
}
__host__
__device__
static
constexpr
auto
MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
(
const
CMNGridDesc
&
c_m_n_grid_desc
)
{
const
auto
M
=
c_m_n_grid_desc
.
GetLength
(
I0
);
const
auto
N
=
c_m_n_grid_desc
.
GetLength
(
I1
);
const
auto
MBlock
=
M
/
MPerBlock
;
const
auto
NBlock
=
N
/
NPerBlock
;
return
transform_tensor_descriptor
(
c_m_n_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
Number
<
MPerBlock
>
{})),
make_unmerge_transform
(
make_tuple
(
NBlock
,
Number
<
NPerBlock
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
}
// return block_id to C matrix tile idx (m0, n0) mapping
__host__
__device__
static
constexpr
auto
MakeCBlockClusterAdaptor
(
const
CMNGridDesc
&
c_m_n_grid_desc
,
index_t
M01
,
index_t
N01
,
index_t
KBatch
)
{
return
BlockToCTileMap_KSplit_M00_N00_M01_N01
<
MPerBlock
,
NPerBlock
,
CMNGridDesc
>
(
c_m_n_grid_desc
,
M01
,
N01
,
KBatch
);
}
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
()
{
constexpr
index_t
MWave
=
MPerBlock
/
(
MRepeat
*
MPerXDL
);
constexpr
index_t
NWave
=
NPerBlock
/
(
NRepeat
*
NPerXDL
);
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
CShuffleMRepeatPerShuffle
*
MWave
*
MPerXDL
>
{},
I1
,
Number
<
CShuffleNRepeatPerShuffle
*
NWave
*
NPerXDL
>
{}));
}
using
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
decltype
(
MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
(
CMNGridDesc
{}));
using
CBlockClusterAdaptor
=
decltype
(
MakeCBlockClusterAdaptor
(
CMNGridDesc
{},
1
,
1
,
1
));
template
<
bool
HasMainKBlockLoop
>
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatAB
*
__restrict__
p_shared_block
,
const
AGridDesc_B_K0_M_K1
&
a_b_k0_m_k1_grid_desc
,
const
BGridDesc_B_K0_N_K1
&
b_b_k0_n_k1_grid_desc
,
const
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
&
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CElementwiseOperation
&
c_element_op
,
const
CBlockClusterAdaptor
&
c_block_cluster_adaptor
)
{
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_b_k0_m_k1_grid_desc
.
GetElementSpaceSize
());
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_grid
,
b_b_k0_n_k1_grid_desc
.
GetElementSpaceSize
());
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
const
auto
K0
=
a_b_k0_m_k1_grid_desc
.
GetLength
(
I1
);
// divide block work by [M, N]
const
auto
block_work_idx
=
c_block_cluster_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
const
index_t
k_batch_id
=
block_work_idx
[
I0
];
if
(
!
c_block_cluster_adaptor
.
ValidCTileIndex
(
make_tuple
(
block_work_idx
[
I1
],
block_work_idx
[
I2
]),
make_tuple
(
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I0
),
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I2
))))
{
return
;
}
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]
*
MPerBlock
);
const
index_t
n_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I2
]
*
NPerBlock
);
// lds max alignment
constexpr
auto
max_lds_align
=
K1
;
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_k0_m_k1_block_desc
=
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
();
constexpr
auto
a_b_k0_m_k1_block_desc
=
GetABlockDescriptor_Batch_K0PerBlock_MPerBlock_K1
();
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_k0_n_k1_block_desc
=
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1
();
constexpr
auto
b_b_k0_n_k1_block_desc
=
GetBBlockDescriptor_Batch_K0PerBlock_NPerBlock_K1
();
// A matrix blockwise copy
auto
a_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
1
,
K0PerBlock
,
MPerBlock
,
K1
>
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
decltype
(
a_b_k0_m_k1_grid_desc
),
decltype
(
a_b_k0_m_k1_block_desc
),
ABlockTransferSrcAccessOrder
,
Sequence
<
0
,
2
,
1
,
3
>
,
ABlockTransferSrcVectorDim
,
3
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
1
,
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
a_b_k0_m_k1_grid_desc
,
make_multi_index
(
k_batch_id
,
0
,
m_block_data_idx_on_grid
,
0
),
a_element_op
,
a_b_k0_m_k1_block_desc
,
make_multi_index
(
0
,
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
// B matrix blockwise copy
auto
b_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
1
,
K0PerBlock
,
NPerBlock
,
K1
>
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
decltype
(
b_b_k0_n_k1_grid_desc
),
decltype
(
b_b_k0_n_k1_block_desc
),
BBlockTransferSrcAccessOrder
,
Sequence
<
0
,
2
,
1
,
3
>
,
BBlockTransferSrcVectorDim
,
3
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
1
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
b_b_k0_n_k1_grid_desc
,
make_multi_index
(
k_batch_id
,
0
,
n_block_data_idx_on_grid
,
0
),
b_element_op
,
b_b_k0_n_k1_block_desc
,
make_multi_index
(
0
,
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[K0PerBlock, MPerBlock] is in LDS
// b_mtx[K0PerBlock, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
constexpr
index_t
KPack
=
math
::
max
(
K1
,
MfmaSelector
<
FloatAB
,
MPerXDL
,
NPerXDL
>::
selected_mfma
.
k_per_blk
);
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
FloatAB
,
FloatAcc
,
decltype
(
a_k0_m_k1_block_desc
),
decltype
(
b_k0_n_k1_block_desc
),
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
,
KPack
/
N1Padding
>
{};
auto
c_thread_buf
=
blockwise_gemm
.
GetCThreadBuffer
();
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_space_size
=
math
::
integer_least_multiple
(
a_k0_m_k1_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
FloatAB
*
p_a_block
=
p_shared_block
;
FloatAB
*
p_b_block
=
p_shared_block
+
a_block_space_size
;
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
0
,
K0PerBlock
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
0
,
K0PerBlock
,
0
,
0
);
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_a_block
,
a_k0_m_k1_block_desc
.
GetElementSpaceSize
());
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_b_block
,
b_k0_n_k1_block_desc
.
GetElementSpaceSize
());
// gridwise GEMM pipeline
const
index_t
K0BlockMainLoop
=
__builtin_amdgcn_readfirstlane
(
K0
/
K0PerBlock
);
GridwiseGemmPipe
::
template
Run
<
HasMainKBlockLoop
>(
a_b_k0_m_k1_grid_desc
,
a_b_k0_m_k1_block_desc
,
a_blockwise_copy
,
a_grid_buf
,
a_block_buf
,
a_block_slice_copy_step
,
b_b_k0_n_k1_grid_desc
,
b_b_k0_n_k1_block_desc
,
b_blockwise_copy
,
b_grid_buf
,
b_block_buf
,
b_block_slice_copy_step
,
blockwise_gemm
,
c_thread_buf
,
K0BlockMainLoop
);
// output: register to global memory
{
constexpr
index_t
MWave
=
MPerBlock
/
(
MRepeat
*
MPerXDL
);
constexpr
index_t
NWave
=
NPerBlock
/
(
NRepeat
*
NPerXDL
);
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
=
blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc
=
blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
constexpr
auto
M0
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I0
);
constexpr
auto
N0
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I1
);
constexpr
auto
M1
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I2
);
constexpr
auto
N1
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I3
);
constexpr
auto
M2
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I4
);
constexpr
auto
M3
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I5
);
constexpr
auto
M4
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I6
);
constexpr
auto
N2
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I7
);
constexpr
auto
c_block_desc_mblock_mperblock_nblock_nperblock
=
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
auto
c_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatC
*>
(
p_shared_block
),
c_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
static_assert
(
M1
==
MWave
,
""
);
static_assert
(
N1
==
NWave
,
""
);
static_assert
(
M2
*
M3
*
M4
==
MPerXDL
,
""
);
static_assert
(
N2
==
NPerXDL
,
""
);
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
transform_tensor_descriptor
(
c_block_desc_mblock_mperblock_nblock_nperblock
,
make_tuple
(
make_freeze_transform
(
I0
),
// freeze mblock
make_unmerge_transform
(
make_tuple
(
CShuffleMRepeatPerShuffle
,
M1
,
M2
,
M3
,
M4
)),
// M1 = MWave, M2 * M3 * M4 = MPerXDL
make_freeze_transform
(
I0
),
// freeze nblock
make_unmerge_transform
(
make_tuple
(
CShuffleNRepeatPerShuffle
,
N1
,
N2
))),
// M1 = MWave, M2 * M3 * M4 = MPerXDL
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<>
{},
Sequence
<
0
,
2
,
4
,
5
,
6
>
{},
Sequence
<>
{},
Sequence
<
1
,
3
,
7
>
{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
CalculateCThreadOriginDataIndex
(
I0
,
I0
,
I0
,
I0
);
const
index_t
m_thread_data_on_block
=
c_thread_mtx_on_block
[
I0
];
const
index_t
n_thread_data_on_block
=
c_thread_mtx_on_block
[
I1
];
const
auto
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M0
,
M1
,
M2
,
M3
,
M4
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
m_thread_data_on_block_idx
=
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
m_thread_data_on_block
));
const
auto
n_thread_data_on_block_to_n0_n1_n2_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
N0
,
N1
,
N2
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
n_thread_data_on_block_idx
=
n_thread_data_on_block_to_n0_n1_n2_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
n_thread_data_on_block
));
// VGPR to LDS
auto
c_thread_copy_vgpr_to_lds
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatAcc
,
FloatC
,
decltype
(
c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc
),
decltype
(
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
CShuffleMRepeatPerShuffle
,
CShuffleNRepeatPerShuffle
,
I1
,
I1
,
M2
,
I1
,
M4
,
I1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
7
,
1
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
make_multi_index
(
0
,
0
,
m_thread_data_on_block_idx
[
I1
],
n_thread_data_on_block_idx
[
I1
],
m_thread_data_on_block_idx
[
I2
],
m_thread_data_on_block_idx
[
I3
],
m_thread_data_on_block_idx
[
I4
],
n_thread_data_on_block_idx
[
I2
]),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
// LDS to global
auto
c_block_copy_lds_to_global
=
ThreadGroupTensorSliceTransfer_v6r1
<
ThisThreadBlock
,
// index_t BlockSize,
CElementwiseOperation
,
// ElementwiseOperation,
CGlobalMemoryDataOperation
,
// DstInMemOp,
Sequence
<
1
,
CShuffleMRepeatPerShuffle
*
MWave
*
MPerXDL
,
1
,
CShuffleNRepeatPerShuffle
*
NWave
*
NPerXDL
>
,
// BlockSliceLengths,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
FloatC
,
// typename SrcData,
FloatC
,
// typename DstData,
decltype
(
c_block_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
c_grid_desc_mblock_mperblock_nblock_nperblock
),
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
3
,
// index_t VectorDim,
CBlockTransferScalarPerVector_NWaveNPerXDL
,
// index_t ScalarPerVector,
true
,
// bool ThreadTransferSrcResetCoordinateAfterRun,
false
>
// bool ThreadTransferDstResetCoordinateAfterRun
{
c_block_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
0
,
0
,
0
,
0
),
c_grid_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
block_work_idx
[
I1
],
0
,
block_work_idx
[
I2
],
0
),
c_element_op
};
constexpr
auto
mxdlperwave_forward_step
=
make_multi_index
(
0
,
CShuffleMRepeatPerShuffle
*
MWave
*
MPerXDL
,
0
,
0
);
constexpr
auto
nxdlperwave_forward_step
=
make_multi_index
(
0
,
0
,
0
,
CShuffleNRepeatPerShuffle
*
NWave
*
NPerXDL
);
constexpr
auto
nxdlperwave_backward_step
=
make_multi_index
(
0
,
0
,
0
,
-
CShuffleNRepeatPerShuffle
*
NWave
*
NPerXDL
);
static_for
<
0
,
MRepeat
,
CShuffleMRepeatPerShuffle
>
{}([
&
](
auto
mxdlperwave_iter
)
{
constexpr
auto
mxdlperwave
=
mxdlperwave_iter
;
static_for
<
0
,
NRepeat
,
CShuffleNRepeatPerShuffle
>
{}([
&
](
auto
nxdlperwave_iter
)
{
constexpr
bool
nxdlperwave_forward_sweep
=
(
mxdlperwave
%
(
2
*
CShuffleMRepeatPerShuffle
)
==
0
);
constexpr
index_t
nxdlperwave_value
=
nxdlperwave_forward_sweep
?
nxdlperwave_iter
:
(
NRepeat
-
nxdlperwave_iter
-
CShuffleNRepeatPerShuffle
);
constexpr
auto
nxdlperwave
=
Number
<
nxdlperwave_value
>
{};
// make sure it's safe to do ds_write
block_sync_lds
();
// VGPR to LDS
c_thread_copy_vgpr_to_lds
.
Run
(
c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc
,
make_tuple
(
mxdlperwave
,
nxdlperwave
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c_block_buf
);
// make sure it's safe to do ds_read
block_sync_lds
();
// LDS to global
c_block_copy_lds_to_global
.
Run
(
c_block_desc_mblock_mperblock_nblock_nperblock
,
c_block_buf
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_buf
);
// move on nxdlperwave dimension
if
constexpr
(
nxdlperwave_forward_sweep
&&
(
nxdlperwave
<
NRepeat
-
CShuffleNRepeatPerShuffle
))
{
c_block_copy_lds_to_global
.
MoveDstSliceWindow
(
c_grid_desc_mblock_mperblock_nblock_nperblock
,
nxdlperwave_forward_step
);
}
else
if
constexpr
((
!
nxdlperwave_forward_sweep
)
&&
(
nxdlperwave
>
0
))
{
c_block_copy_lds_to_global
.
MoveDstSliceWindow
(
c_grid_desc_mblock_mperblock_nblock_nperblock
,
nxdlperwave_backward_step
);
}
});
// move on mxdlperwave dimension
if
constexpr
(
mxdlperwave
<
MRepeat
-
CShuffleMRepeatPerShuffle
)
{
c_block_copy_lds_to_global
.
MoveDstSliceWindow
(
c_grid_desc_mblock_mperblock_nblock_nperblock
,
mxdlperwave_forward_step
);
}
});
}
}
};
// namespace ck
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
View file @
a3b4c5cb
#ifndef CK_GRIDWISE_GEMM_XDLOPS_V2R3_HPP
#define CK_GRIDWISE_GEMM_XDLOPS_V2R3_HPP
#pragma once
#include "common_header.hpp"
#include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "
blockwise
_tensor_slice_transfer_v4r1.hpp"
#include "
thread_group
_tensor_slice_transfer_v4r1.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "gridwise_gemm_pipeline_v1.hpp"
...
...
@@ -22,7 +22,7 @@ template <typename GridwiseGemm,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
Block2CTileMap
,
bool
HasMainK
0
BlockLoop
>
bool
HasMainKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
...
...
@@ -42,7 +42,7 @@ __global__ void
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainK
0
BlockLoop
>(
p_a_grid
,
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
p_b_grid
,
p_c_grid
,
p_shared
,
...
...
@@ -67,88 +67,6 @@ __global__ void
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatC
,
typename
GemmDesc
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
bool
HasMainK0BlockLoop
,
index_t
MaxGroupCount
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_grouped_gemm_xdlops_v2r3
(
const
StaticallyIndexedArray
<
GemmDesc
,
MaxGroupCount
>
gemm_desc_
,
const
index_t
group_count
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
block_id
=
get_block_1d_id
();
#if 1
static_for
<
0
,
MaxGroupCount
,
1
>
{}([
&
](
auto
i
)
{
if
(
block_id
>=
gemm_desc_
[
i
].
BlockStart_
&&
block_id
<
gemm_desc_
[
i
].
BlockEnd_
&&
i
<
group_count
)
{
auto
group_id
=
i
;
GridwiseGemm
::
template
Run
<
HasMainK0BlockLoop
>(
gemm_desc_
[
group_id
].
a_ptr
,
gemm_desc_
[
group_id
].
b_ptr
,
gemm_desc_
[
group_id
].
c_ptr
,
p_shared
,
gemm_desc_
[
group_id
].
a_grid_desc_k0_m_k1_
,
gemm_desc_
[
group_id
].
b_grid_desc_k0_n_k1_
,
gemm_desc_
[
group_id
].
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
a_element_op
,
b_element_op
,
c_element_op
,
gemm_desc_
[
group_id
].
grouped_gemm_block_2_ctile_map_
);
}
});
#else
const
auto
gemm_desc_ptr
=
reinterpret_cast
<
const
GemmDesc
*>
(
&
gemm_desc_
);
index_t
group_id
=
0
;
static_for
<
0
,
MaxGroupCount
,
1
>
{}([
&
](
auto
i
)
{
group_id
=
(
block_id
>=
gemm_desc_
[
i
].
BlockStart
&&
block_id
<
gemm_desc_
[
i
].
BlockEnd
&&
i
<
group_count
)
?
i
:
group_id
;
});
const
index_t
block_id_grp
=
block_id
-
gemm_desc_ptr
[
group_id
].
BlockStart
;
GridwiseGemm
::
template
Run
<
HasMainK0BlockLoop
>(
gemm_desc_ptr
[
group_id
].
a_ptr
,
gemm_desc_ptr
[
group_id
].
b_ptr
,
gemm_desc_ptr
[
group_id
].
c_ptr
,
p_shared
,
gemm_desc_ptr
[
group_id
].
a_grid_desc_k0_m_k1_
,
gemm_desc_ptr
[
group_id
].
b_grid_desc_k0_n_k1_
,
gemm_desc_ptr
[
group_id
].
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
a_element_op
,
b_element_op
,
c_element_op
,
gemm_desc_ptr
[
group_id
].
block_2_ctile_map_
,
block_id_grp
);
#endif
#else
ignore
=
gemm_desc_
;
ignore
=
group_count
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
c_element_op
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
template
<
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatAcc
,
...
...
@@ -187,7 +105,7 @@ template <index_t BlockSize,
typename
CThreadTransferSrcDstAccessOrder
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferDstScalarPerVector
,
index_t
NumPrefetch
=
1
>
index_t
Num
GemmK
Prefetch
Stage
=
1
>
struct
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
@@ -202,6 +120,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// K1 should be Number<...>
static
constexpr
auto
K1
=
Number
<
K1Value
>
{};
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
GridwiseGemmPipeline_v1
<
NumGemmKPrefetchStage
>
;
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
()
{
constexpr
auto
max_lds_align
=
K1
;
...
...
@@ -265,12 +187,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template
<
typename
Block2CTileMap
>
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
,
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
M01
,
index_t
N01
)
const
Block2CTileMap
&
block_2_ctile_map
)
{
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
decltype
(
K1
)
>>::
value
,
"wrong! K1 need to be known at compile-time"
);
...
...
@@ -291,56 +213,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K0
%
K0PerBlock
==
0
))
return
false
;
// check NumPrefetch
if
constexpr
(
NumPrefetch
==
1
)
{
// 1-stage prefetch always supported
}
else
if
constexpr
(
NumPrefetch
==
2
)
{
// 2-stage prefetch currently only support even number of K0 loop
// TODO: add support for odd number of K0 loop
if
(
!
((
K0
/
K0PerBlock
)
%
2
==
0
))
// check gridwise gemm pipeline
const
auto
num_k_loop
=
K0
/
K0PerBlock
;
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_k_loop
))
{
return
false
;
}
}
else
if
(
!
block_2_ctile_map
.
CheckValidity
(
c_grid_desc_m_n
))
{
return
false
;
}
// check M01, N01
constexpr
auto
M1
=
Number
<
MPerBlock
>
{};
constexpr
auto
N1
=
Number
<
NPerBlock
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
N0
=
N
/
N1
;
if
(
!
(
M0
%
M01
==
0
&&
N0
%
N01
==
0
))
return
false
;
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return
true
;
}
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
const
index_t
grid_size
=
(
M
/
MPerBlock
)
*
(
N
/
NPerBlock
);
return
grid_size
;
}
// TODO move this function into GEMM-pipeline class
__host__
__device__
static
constexpr
bool
CalculateHasMainK0BlockLoop
(
index_t
K0
)
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
{
const
bool
has_main_k0_block
_loop
=
(
K0
/
(
NumPrefetch
*
K0PerBlock
))
>
1
;
const
index_t
num
_loop
=
K
/
(
K0PerBlock
*
K1
)
;
return
has_main_k0_block
_loop
;
return
GridwiseGemmPipe
::
CalculateHasMainLoop
(
num
_loop
)
;
}
__host__
__device__
static
constexpr
auto
...
...
@@ -394,46 +288,18 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
}
// return block_id to C matrix tile idx (m0, n0) mapping
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2CTileMap
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
M01
,
index_t
N01
)
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2CTileMap
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
/* M01 */
,
index_t
/* N01 */
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
constexpr
auto
M1
=
Number
<
MPerBlock
>
{};
constexpr
auto
N1
=
Number
<
NPerBlock
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
N0
=
N
/
N1
;
const
auto
M00
=
M0
/
M01
;
const
auto
N00
=
N0
/
N01
;
const
auto
m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M00
,
M01
)),
make_unmerge_transform
(
make_tuple
(
N00
,
N01
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
,
3
>
{}));
const
auto
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M00
,
N00
,
M01
,
N01
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
cblockid_to_m0_n0_block_cluster_adaptor
=
chain_tensor_adaptors
(
m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
,
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor
);
return
cblockid_to_m0_n0_block_cluster_adaptor
;
return
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
,
CGridDesc_M_N
>
(
c_grid_desc_m_n
);
}
using
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
=
decltype
(
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
CGridDesc_M_N
{}));
using
DefaultBlock2CTileMap
=
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{},
1
,
1
));
template
<
bool
HasMainK
0
BlockLoop
,
typename
Block2CTileMap
=
DefaultBlock2CTileMap
>
template
<
bool
HasMainKBlockLoop
,
typename
Block2CTileMap
=
DefaultBlock2CTileMap
>
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
...
...
@@ -460,6 +326,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const
auto
block_work_idx
=
block_2_ctile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
if
(
!
block_2_ctile_map
.
ValidCTileIndex
(
block_work_idx
,
make_tuple
(
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I0
),
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I1
))))
{
return
;
}
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]
*
MPerBlock
);
...
...
@@ -478,7 +352,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// A matrix blockwise copy
auto
a_blockwise_copy
=
Blockwise
TensorSliceTransfer_v4r1
<
Block
Size
,
ThreadGroup
TensorSliceTransfer_v4r1
<
ThisThread
Block
,
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
...
...
@@ -499,7 +373,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
true
,
NumPrefetch
>
(
Num
GemmK
Prefetch
Stage
>
(
a_grid_desc_k0_m_k1
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
a_element_op
,
...
...
@@ -509,7 +383,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// B matrix blockwise copy
auto
b_blockwise_copy
=
Blockwise
TensorSliceTransfer_v4r1
<
Block
Size
,
ThreadGroup
TensorSliceTransfer_v4r1
<
ThisThread
Block
,
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
...
...
@@ -530,7 +404,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
,
NumPrefetch
>
(
Num
GemmK
Prefetch
Stage
>
(
b_grid_desc_k0_n_k1
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
b_element_op
,
...
...
@@ -575,27 +449,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
);
// gridwise GEMM pipeline
const
auto
gridwise_gemm_pipeline
=
GridwiseGemmPipeline_v1
<
remove_cvref_t
<
decltype
(
a_grid_desc_k0_m_k1
)
>
,
remove_cvref_t
<
decltype
(
a_block_desc_k0_m_k1
)
>
,
remove_cvref_t
<
decltype
(
a_blockwise_copy
)
>
,
remove_cvref_t
<
decltype
(
a_grid_buf
)
>
,
remove_cvref_t
<
decltype
(
a_block_buf
)
>
,
remove_cvref_t
<
decltype
(
a_block_slice_copy_step
)
>
,
remove_cvref_t
<
decltype
(
b_grid_desc_k0_n_k1
)
>
,
remove_cvref_t
<
decltype
(
b_block_desc_k0_n_k1
)
>
,
remove_cvref_t
<
decltype
(
b_blockwise_copy
)
>
,
remove_cvref_t
<
decltype
(
b_grid_buf
)
>
,
remove_cvref_t
<
decltype
(
b_block_buf
)
>
,
remove_cvref_t
<
decltype
(
b_block_slice_copy_step
)
>
,
remove_cvref_t
<
decltype
(
blockwise_gemm
)
>
,
remove_cvref_t
<
decltype
(
c_thread_buf
)
>
,
NumPrefetch
,
HasMainK0BlockLoop
>
{};
const
index_t
K0BlockMainLoop
=
__builtin_amdgcn_readfirstlane
(
K0
/
K0PerBlock
);
gridwise_gemm_pipeline
.
Run
(
a_grid_desc_k0_m_k1
,
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
K0
/
K0PerBlock
);
GridwiseGemmPipe
::
template
Run
<
HasMainKBlockLoop
>(
a_grid_desc_k0_m_k1
,
a_block_desc_k0_m_k1
,
a_blockwise_copy
,
a_grid_buf
,
...
...
@@ -609,7 +465,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
b_block_slice_copy_step
,
blockwise_gemm
,
c_thread_buf
,
K0B
lock
M
ain
L
oop
);
num_k_b
lock
_m
ain
_l
oop
);
// output: register to global memory
{
...
...
@@ -692,4 +548,3 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
};
}
// namespace ck
#endif
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp
View file @
a3b4c5cb
...
...
@@ -5,8 +5,9 @@
#include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "
blockwise
_tensor_slice_transfer_v4r1.hpp"
#include "
thread_group
_tensor_slice_transfer_v4r1.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
namespace
ck
{
...
...
@@ -120,6 +121,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
// K1 should be Number<...>
static
constexpr
auto
K1
=
Number
<
K1Value
>
{};
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
constexpr
auto
max_lds_align
=
K1
;
...
...
@@ -165,12 +168,12 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template
<
typename
Block2CTileMap
>
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
ABK0MK1GridDesc
&
a_b_k0_m_k1_grid_desc
,
const
BBK0NK1GridDesc
&
b_b_k0_n_k1_grid_desc
,
const
CMNGridDesc
&
c_m_n_grid_desc
,
index_t
M01
,
index_t
N01
)
const
Block2CTileMap
&
block_2_ctile_map
)
{
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
decltype
(
K1
)
>>::
value
,
"wrong! K1 need to be known at compile-time"
);
...
...
@@ -194,31 +197,15 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K0
%
K0PerBlock
==
0
))
return
false
;
// check M01, N01
constexpr
auto
M1
=
Number
<
MPerBlock
>
{};
constexpr
auto
N1
=
Number
<
NPerBlock
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
N0
=
N
/
N1
;
if
(
!
(
M0
%
M01
==
0
&&
N0
%
N01
==
0
))
if
(
!
block_2_ctile_map
.
CheckValidity
(
c_m_n_grid_desc
))
{
return
false
;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return
true
;
}
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
const
CMNGridDesc
&
c_m_n_grid_desc
,
index_t
KBatch
)
{
const
auto
M
=
c_m_n_grid_desc
.
GetLength
(
I0
);
const
auto
N
=
c_m_n_grid_desc
.
GetLength
(
I1
);
const
index_t
grid_size
=
(
M
/
MPerBlock
)
*
(
N
/
NPerBlock
)
*
KBatch
;
return
grid_size
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainK0BlockLoop
(
index_t
K0
)
{
const
bool
has_main_k0_block_loop
=
K0
>
K0PerBlock
;
...
...
@@ -278,39 +265,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
// return block_id to C matrix tile idx (m0, n0) mapping
__host__
__device__
static
constexpr
auto
MakeCBlockClusterAdaptor
(
const
CMNGridDesc
&
c_m_n_grid_desc
,
index_t
M01
,
index_t
N01
,
index_t
KBatch
)
const
CMNGridDesc
&
c_m_n_grid_desc
,
index_t
/* M01 */
,
index_t
/* N01 */
,
index_t
KBatch
)
{
const
auto
M
=
c_m_n_grid_desc
.
GetLength
(
I0
);
const
auto
N
=
c_m_n_grid_desc
.
GetLength
(
I1
);
constexpr
auto
M1
=
Number
<
MPerBlock
>
{};
constexpr
auto
N1
=
Number
<
NPerBlock
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
N0
=
N
/
N1
;
const
auto
M00
=
M0
/
M01
;
const
auto
N00
=
N0
/
N01
;
const
auto
kbatch_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_pass_through_transform
(
KBatch
),
make_unmerge_transform
(
make_tuple
(
M00
,
M01
)),
make_unmerge_transform
(
make_tuple
(
N00
,
N01
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
3
>
{},
Sequence
<
2
,
4
>
{}));
const
auto
cblockid_to_kbatch_m00_m01_n00_n01_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
KBatch
,
M00
,
N00
,
M01
,
N01
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
cblockid_to_kbatch_m0_n0_block_cluster_adaptor
=
chain_tensor_adaptors
(
kbatch_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
,
cblockid_to_kbatch_m00_m01_n00_n01_block_cluster_adaptor
);
return
cblockid_to_kbatch_m0_n0_block_cluster_adaptor
;
return
BlockToCTileMap_KSplit_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
,
CMNGridDesc
>
(
c_m_n_grid_desc
,
8
,
KBatch
);
}
using
CM0N0M1N1M2M3M4N2GridDesc
=
decltype
(
MakeCM0N0M1N1M2M3M4N2GridDescriptor
(
CMNGridDesc
{}));
...
...
@@ -342,6 +300,14 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
const
auto
block_work_idx
=
c_block_cluster_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
if
(
!
c_block_cluster_adaptor
.
ValidCTileIndex
(
make_tuple
(
block_work_idx
[
I1
],
block_work_idx
[
I2
]),
make_tuple
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
.
GetLength
(
I0
),
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
.
GetLength
(
I1
))))
{
return
;
}
const
index_t
k_batch_id
=
block_work_idx
[
I0
];
// HACK: this force m/n_block_data_idx_on_grid into SGPR
...
...
@@ -420,7 +386,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
}();
// A matrix blockwise copy
auto
a_blockwise_copy
=
Blockwise
TensorSliceTransfer_v4r1
<
Block
Size
,
ThreadGroup
TensorSliceTransfer_v4r1
<
ThisThread
Block
,
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
...
...
@@ -450,7 +416,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
// B matrix blockwise copy
auto
b_blockwise_copy
=
Blockwise
TensorSliceTransfer_v4r1
<
Block
Size
,
ThreadGroup
TensorSliceTransfer_v4r1
<
ThisThread
Block
,
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
View file @
a3b4c5cb
...
...
@@ -6,9 +6,10 @@
#include "merge_transform_for_wrw.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "
blockwise
_tensor_slice_transfer_v4r1.hpp"
#include "
blockwise
_tensor_slice_transfer_v6r1.hpp"
#include "
thread_group
_tensor_slice_transfer_v4r1.hpp"
#include "
thread_group
_tensor_slice_transfer_v6r1.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
namespace
ck
{
...
...
@@ -130,17 +131,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
// K1 should be Number<...>
static
constexpr
auto
K1
=
Number
<
K1Value
>
{};
// M1 & M0
static
constexpr
auto
M1PerBlock
=
Number
<
ElePerBank
/
K1Value
>
{};
static
constexpr
auto
M0PerBlock
=
Number
<
MPerBlock
/
M1PerBlock
>
{};
static
constexpr
auto
M1Padding
=
I4
;
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
// N1 & N0
static
constexpr
auto
N1PerBlock
=
Number
<
ElePerBank
/
K1Value
>
{};
static
constexpr
auto
N0PerBlock
=
Number
<
NPerBlock
/
M1PerBlock
>
{};
static
constexpr
auto
N1Padding
=
I4
;
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
()
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
constexpr
auto
max_lds_align
=
K1
;
...
...
@@ -369,12 +362,12 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template
<
typename
Block2CTileMap
>
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc_B_K0_M_K1
&
a_b_k0_m_k1_grid_desc
,
const
BGridDesc_B_K0_N_K1
&
b_b_k0_n_k1_grid_desc
,
const
CMNGridDesc
&
c_m_n_grid_desc
,
index_t
M01
,
index_t
N01
)
const
Block2CTileMap
&
block_2_ctile_map
)
{
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
decltype
(
K1
)
>>::
value
,
"wrong! K1 need to be known at compile-time"
);
...
...
@@ -398,31 +391,15 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K0
%
K0PerBlock
==
0
))
return
false
;
// check M01, N01
constexpr
auto
M1
=
Number
<
MPerBlock
>
{};
constexpr
auto
N1
=
Number
<
NPerBlock
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
N0
=
N
/
N1
;
if
(
!
(
M0
%
M01
==
0
&&
N0
%
N01
==
0
))
if
(
!
block_2_ctile_map
.
CheckValidity
(
c_m_n_grid_desc
))
{
return
false
;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return
true
;
}
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
const
CMNGridDesc
&
c_m_n_grid_desc
,
index_t
KBatch
)
{
const
auto
M
=
c_m_n_grid_desc
.
GetLength
(
I0
);
const
auto
N
=
c_m_n_grid_desc
.
GetLength
(
I1
);
const
index_t
grid_size
=
(
M
/
MPerBlock
)
*
(
N
/
NPerBlock
)
*
KBatch
;
return
grid_size
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainK0BlockLoop
(
index_t
K0
)
{
const
bool
has_main_k0_block_loop
=
K0
>
K0PerBlock
;
...
...
@@ -449,39 +426,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
// return block_id to C matrix tile idx (m0, n0) mapping
__host__
__device__
static
constexpr
auto
MakeCBlockClusterAdaptor
(
const
CMNGridDesc
&
c_m_n_grid_desc
,
index_t
M01
,
index_t
N01
,
index_t
KBatch
)
const
CMNGridDesc
&
c_m_n_grid_desc
,
index_t
/* M01 */
,
index_t
/* N01 */
,
index_t
KBatch
)
{
const
auto
M
=
c_m_n_grid_desc
.
GetLength
(
I0
);
const
auto
N
=
c_m_n_grid_desc
.
GetLength
(
I1
);
constexpr
auto
M1
=
Number
<
MPerBlock
>
{};
constexpr
auto
N1
=
Number
<
NPerBlock
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
N0
=
N
/
N1
;
const
auto
M00
=
M0
/
M01
;
const
auto
N00
=
N0
/
N01
;
const
auto
kbatch_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_pass_through_transform
(
KBatch
),
make_unmerge_transform
(
make_tuple
(
M00
,
M01
)),
make_unmerge_transform
(
make_tuple
(
N00
,
N01
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
3
>
{},
Sequence
<
2
,
4
>
{}));
const
auto
c_blockid_to_kbatch_m00_m01_n00_n01_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
KBatch
,
M00
,
N00
,
M01
,
N01
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
c_blockid_to_kbatch_m0_n0_block_cluster_adaptor
=
chain_tensor_adaptors
(
kbatch_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
,
c_blockid_to_kbatch_m00_m01_n00_n01_block_cluster_adaptor
);
return
c_blockid_to_kbatch_m0_n0_block_cluster_adaptor
;
return
BlockToCTileMap_KSplit_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
,
CMNGridDesc
>
(
c_m_n_grid_desc
,
8
,
KBatch
);
}
__host__
__device__
static
constexpr
auto
...
...
@@ -528,6 +476,14 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
const
auto
block_work_idx
=
c_block_cluster_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
if
(
!
c_block_cluster_adaptor
.
ValidCTileIndex
(
make_tuple
(
block_work_idx
[
I1
],
block_work_idx
[
I2
]),
make_tuple
(
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I0
),
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I2
))))
{
return
;
}
const
index_t
k_batch_id
=
block_work_idx
[
I0
];
// HACK: this force m/n_block_data_idx_on_grid into SGPR
...
...
@@ -550,7 +506,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
constexpr
auto
b_b_k0_n_k1_block_desc
=
GetBBlockDescriptor_Batch_K0PerBlock_NPerBlock_K1
();
// A matrix blockwise copy
auto
a_blockwise_copy
=
Blockwise
TensorSliceTransfer_v4r1
<
Block
Size
,
ThreadGroup
TensorSliceTransfer_v4r1
<
ThisThread
Block
,
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
...
...
@@ -580,7 +536,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
// B matrix blockwise copy
auto
b_blockwise_copy
=
Blockwise
TensorSliceTransfer_v4r1
<
Block
Size
,
ThreadGroup
TensorSliceTransfer_v4r1
<
ThisThread
Block
,
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
...
...
@@ -829,8 +785,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
// LDS to global
auto
c_block_copy_lds_to_global
=
Blockwise
TensorSliceTransfer_v6r1
<
BlockSize
,
// index_t BlockSize,
auto
c_block_copy_lds_to_global
=
ThreadGroup
TensorSliceTransfer_v6r1
<
ThisThreadBlock
,
// index_t BlockSize,
CElementwiseOperation
,
// ElementwiseOperation,
CGlobalMemoryDataOperation
,
// DstInMemOp,
Sequence
<
1
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp
View file @
a3b4c5cb
#ifndef CK_GRIDWISE_GEMM_XDLOPS_V3R1_HPP
#define CK_GRIDWISE_GEMM_XDLOPS_V3R1_HPP
#pragma once
#include "common_header.hpp"
#include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "
blockwise
_tensor_slice_transfer_v4r1.hpp"
#include "
blockwise
_tensor_slice_transfer_v6r1.hpp"
#include "
thread_group
_tensor_slice_transfer_v4r1.hpp"
#include "
thread_group
_tensor_slice_transfer_v6r1.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "gridwise_gemm_pipeline_v1.hpp"
#include "tensor_space_filling_curve.hpp"
...
...
@@ -113,7 +112,7 @@ template <
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
,
index_t
CBlockTransferScalarPerVector_NWaveNPerXdl
,
index_t
NumPrefetch
=
1
>
index_t
Num
GemmK
Prefetch
Stage
=
1
>
struct
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
@@ -131,6 +130,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
static
constexpr
auto
AK1
=
Number
<
AK1Value
>
{};
static
constexpr
auto
BK1
=
Number
<
BK1Value
>
{};
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
GridwiseGemmPipeline_v1
<
NumGemmKPrefetchStage
>
;
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
{
constexpr
auto
max_lds_align
=
AK1
;
...
...
@@ -221,12 +224,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template
<
typename
Block2CTileMap
>
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
M01
,
index_t
N01
)
const
Block2CTileMap
&
block_2_ctile_map
)
{
// static_assert(is_known_at_compile_time<remove_cv_t<decltype(AK1)>>::value &&
// is_known_at_compile_time<remove_cv_t<decltype(BK1)>>::value,
...
...
@@ -246,56 +249,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
))
return
false
;
// check NumPrefetch
if
constexpr
(
NumPrefetch
==
1
)
{
// 1-stage prefetch always supported
}
else
if
constexpr
(
NumPrefetch
==
2
)
{
// 2-stage prefetch currently only support even number of K0 loop
// TODO: add support for odd number of K0 loop
if
(
!
((
K
/
KPerBlock
)
%
2
==
0
))
// check gridwise gemm pipeline
const
auto
num_k_loop
=
K
/
KPerBlock
;
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_k_loop
))
{
return
false
;
}
}
else
if
(
!
block_2_ctile_map
.
CheckValidity
(
c_grid_desc_m_n
))
{
return
false
;
}
// check M01, N01
constexpr
auto
M1
=
Number
<
MPerBlock
>
{};
constexpr
auto
N1
=
Number
<
NPerBlock
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
N0
=
N
/
N1
;
if
(
!
(
M0
%
M01
==
0
&&
N0
%
N01
==
0
))
return
false
;
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return
true
;
}
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
const
index_t
grid_size
=
(
M
/
MPerBlock
)
*
(
N
/
NPerBlock
);
return
grid_size
;
}
// TODO move this function into GEMM-pipeline class
__host__
__device__
static
constexpr
bool
CalculateHasMainK0BlockLoop
(
index_t
K0
)
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
{
const
bool
has_main_k0_block_loop
=
((
K0
*
AK1
)
/
(
NumPrefetch
*
KPerBlock
))
>
1
;
const
index_t
num_loop
=
K
/
KPerBlock
;
return
has_main_k0_block
_loop
;
return
GridwiseGemmPipe
::
CalculateHasMainLoop
(
num
_loop
)
;
}
__host__
__device__
static
constexpr
auto
...
...
@@ -325,39 +300,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
}
// return block_id to C matrix tile idx (m0, n0) mapping
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2CTileMap
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
M01
,
index_t
N01
)
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2CTileMap
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
/* M01 */
,
index_t
/* N01 */
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
constexpr
auto
M1
=
Number
<
MPerBlock
>
{};
constexpr
auto
N1
=
Number
<
NPerBlock
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
N0
=
N
/
N1
;
const
auto
M00
=
M0
/
M01
;
const
auto
N00
=
N0
/
N01
;
const
auto
m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M00
,
M01
)),
make_unmerge_transform
(
make_tuple
(
N00
,
N01
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
,
3
>
{}));
const
auto
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M00
,
N00
,
M01
,
N01
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
cblockid_to_m0_n0_block_cluster_adaptor
=
chain_tensor_adaptors
(
m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
,
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor
);
return
cblockid_to_m0_n0_block_cluster_adaptor
;
return
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
,
CGridDesc_M_N
>
(
c_grid_desc_m_n
);
}
using
CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
=
remove_cvref_t
<
decltype
(
...
...
@@ -367,7 +314,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
using
DefaultBlock2CTileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{},
1
,
1
))
>
;
template
<
bool
HasMainK0BlockLoop
,
typename
Block2CTileMap
=
DefaultBlock2CTileMap
>
template
<
bool
HasMainK0BlockLoop
,
typename
Block2CTileMap
>
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
...
...
@@ -395,6 +342,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
const
auto
block_work_idx
=
block_2_ctile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
if
(
!
block_2_ctile_map
.
ValidCTileIndex
(
block_work_idx
,
make_tuple
(
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
.
GetLength
(
I0
),
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
.
GetLength
(
I3
))))
{
return
;
}
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]
*
MPerBlock
);
...
...
@@ -413,7 +371,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
// A matrix blockwise copy
auto
a_blockwise_copy
=
Blockwise
TensorSliceTransfer_v4r1
<
Block
Size
,
ThreadGroup
TensorSliceTransfer_v4r1
<
ThisThread
Block
,
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
...
...
@@ -434,7 +392,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
true
,
NumPrefetch
>
(
Num
GemmK
Prefetch
Stage
>
(
a_grid_desc_ak0_m_ak1
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
a_element_op
,
...
...
@@ -444,7 +402,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
// B matrix blockwise copy
auto
b_blockwise_copy
=
Blockwise
TensorSliceTransfer_v4r1
<
Block
Size
,
ThreadGroup
TensorSliceTransfer_v4r1
<
ThisThread
Block
,
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
...
...
@@ -465,7 +423,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
,
NumPrefetch
>
(
Num
GemmK
Prefetch
Stage
>
(
b_grid_desc_bk0_n_bk1
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
b_element_op
,
...
...
@@ -512,29 +470,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
BK1
,
0
,
0
);
// gridwise GEMM pipeline
const
auto
gridwise_gemm_pipeline
=
GridwiseGemmPipeline_v1
<
remove_cvref_t
<
decltype
(
a_grid_desc_ak0_m_ak1
)
>
,
remove_cvref_t
<
decltype
(
a_block_desc_ak0_m_ak1
)
>
,
remove_cvref_t
<
decltype
(
a_blockwise_copy
)
>
,
remove_cvref_t
<
decltype
(
a_grid_buf
)
>
,
remove_cvref_t
<
decltype
(
a_block_buf
)
>
,
remove_cvref_t
<
decltype
(
a_block_slice_copy_step
)
>
,
remove_cvref_t
<
decltype
(
b_grid_desc_bk0_n_bk1
)
>
,
remove_cvref_t
<
decltype
(
b_block_desc_bk0_n_bk1
)
>
,
remove_cvref_t
<
decltype
(
b_blockwise_copy
)
>
,
remove_cvref_t
<
decltype
(
b_grid_buf
)
>
,
remove_cvref_t
<
decltype
(
b_block_buf
)
>
,
remove_cvref_t
<
decltype
(
b_block_slice_copy_step
)
>
,
remove_cvref_t
<
decltype
(
blockwise_gemm
)
>
,
remove_cvref_t
<
decltype
(
c_thread_buf
)
>
,
NumPrefetch
,
HasMainK0BlockLoop
>
{};
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))
/
KPerBlock
);
g
ridwise
_g
emm
_p
ipe
line
.
Run
(
a_grid_desc_ak0_m_ak1
,
G
ridwise
G
emm
P
ipe
::
template
Run
<
HasMainK0BlockLoop
>
(
a_grid_desc_ak0_m_ak1
,
a_block_desc_ak0_m_ak1
,
a_blockwise_copy
,
a_grid_buf
,
...
...
@@ -672,8 +612,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
// LDS to global
auto
c_block_copy_lds_to_global
=
Blockwise
TensorSliceTransfer_v6r1
<
BlockSize
,
// index_t BlockSize,
auto
c_block_copy_lds_to_global
=
ThreadGroup
TensorSliceTransfer_v6r1
<
ThisThreadBlock
,
// ThreadGroup
CElementwiseOperation
,
// ElementwiseOperation,
CGlobalMemoryDataOperation
,
// DstInMemOp,
Sequence
<
1
,
...
...
@@ -774,4 +714,3 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
};
}
// namespace ck
#endif
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp
View file @
a3b4c5cb
...
...
@@ -5,9 +5,10 @@
#include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "
blockwise
_tensor_slice_transfer_v4r1.hpp"
#include "
blockwise
_tensor_slice_transfer_v6r2.hpp"
#include "
thread_group
_tensor_slice_transfer_v4r1.hpp"
#include "
thread_group
_tensor_slice_transfer_v6r2.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "gridwise_gemm_pipeline_v1.hpp"
...
...
@@ -24,7 +25,7 @@ template <typename GridwiseGemm,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
Block2CTileMap
,
bool
HasMainK
0
BlockLoop
>
bool
HasMainKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
...
...
@@ -48,7 +49,7 @@ __global__ void
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainK
0
BlockLoop
>(
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
p_b_grid
,
p_c_grid
,
...
...
@@ -119,7 +120,7 @@ template <
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
,
index_t
CBlockTransferScalarPerVector_NWaveNPerXdl
,
index_t
NumPrefetch
=
1
>
index_t
Num
GemmK
Prefetch
Stage
=
1
>
struct
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
@@ -134,6 +135,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
// K1 should be Number<...>
static
constexpr
auto
K1
=
Number
<
K1Value
>
{};
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
GridwiseGemmPipeline_v1
<
NumGemmKPrefetchStage
>
;
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
()
{
constexpr
auto
max_lds_align
=
K1
;
...
...
@@ -226,12 +231,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template
<
typename
Block2CTileMap
>
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
,
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
M01
,
index_t
N01
)
const
Block2CTileMap
&
block_2_ctile_map
)
{
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
decltype
(
K1
)
>>::
value
,
"wrong! K1 need to be known at compile-time"
);
...
...
@@ -252,56 +257,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K0
%
K0PerBlock
==
0
))
return
false
;
// check NumPrefetch
if
constexpr
(
NumPrefetch
==
1
)
{
// 1-stage prefetch always supported
}
else
if
constexpr
(
NumPrefetch
==
2
)
{
// 2-stage prefetch currently only support even number of K0 loop
// TODO: add support for odd number of K0 loop
if
(
!
((
K0
/
K0PerBlock
)
%
2
==
0
))
// check gridwise gemm pipeline
const
auto
num_k_loop
=
K0
/
K0PerBlock
;
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_k_loop
))
{
return
false
;
}
}
else
if
(
!
block_2_ctile_map
.
CheckValidity
(
c_grid_desc_m_n
))
{
return
false
;
}
// check M01, N01
constexpr
auto
M1
=
Number
<
MPerBlock
>
{};
constexpr
auto
N1
=
Number
<
NPerBlock
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
N0
=
N
/
N1
;
if
(
!
(
M0
%
M01
==
0
&&
N0
%
N01
==
0
))
return
false
;
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return
true
;
}
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
const
index_t
grid_size
=
(
M
/
MPerBlock
)
*
(
N
/
NPerBlock
);
const
index_t
num_loop
=
K
/
(
K0PerBlock
*
K1
);
return
grid_size
;
}
// TODO move this function into GEMM-pipeline class
__host__
__device__
static
constexpr
bool
CalculateHasMainK0BlockLoop
(
index_t
K0
)
{
const
bool
has_main_k0_block_loop
=
(
K0
/
(
NumPrefetch
*
K0PerBlock
))
>
1
;
return
has_main_k0_block_loop
;
return
GridwiseGemmPipe
::
CalculateHasMainLoop
(
num_loop
);
}
template
<
typename
CGridDesc_M_N_
>
...
...
@@ -332,40 +309,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
}
// return block_id to C matrix tile idx (m0, n0) mapping
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2CTileMap
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
M01
,
index_t
N01
)
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2CTileMap
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
/* M01 */
,
index_t
/* N01 */
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
constexpr
auto
M1
=
Number
<
MPerBlock
>
{};
constexpr
auto
N1
=
Number
<
NPerBlock
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
N0
=
N
/
N1
;
const
auto
M00
=
M0
/
M01
;
const
auto
N00
=
N0
/
N01
;
const
auto
m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M00
,
M01
)),
make_unmerge_transform
(
make_tuple
(
N00
,
N01
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
,
3
>
{}));
const
auto
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M00
,
N00
,
M01
,
N01
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
cblockid_to_m0_n0_block_cluster_adaptor
=
chain_tensor_adaptors
(
m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
,
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor
);
return
cblockid_to_m0_n0_block_cluster_adaptor
;
return
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
,
CGridDesc_M_N
>
(
c_grid_desc_m_n
);
}
using
CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
(
...
...
@@ -379,7 +329,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
using
DefaultBlock2CTileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{},
1
,
1
))
>
;
template
<
bool
HasMainK
0
BlockLoop
,
typename
Block2CTileMap
=
DefaultBlock2CTileMap
>
template
<
bool
HasMainKBlockLoop
,
typename
Block2CTileMap
=
DefaultBlock2CTileMap
>
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
...
...
@@ -416,6 +366,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
const
auto
block_work_idx
=
block_2_ctile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
if
(
!
block_2_ctile_map
.
ValidCTileIndex
(
block_work_idx
,
make_tuple
(
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
.
GetLength
(
I0
),
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
.
GetLength
(
I3
))))
{
return
;
}
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]
*
MPerBlock
);
...
...
@@ -434,7 +395,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
// A matrix blockwise copy
auto
a_blockwise_copy
=
Blockwise
TensorSliceTransfer_v4r1
<
Block
Size
,
ThreadGroup
TensorSliceTransfer_v4r1
<
ThisThread
Block
,
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
...
...
@@ -455,7 +416,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
true
,
NumPrefetch
>
(
Num
GemmK
Prefetch
Stage
>
(
a_grid_desc_k0_m_k1
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
a_element_op
,
...
...
@@ -465,7 +426,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
// B matrix blockwise copy
auto
b_blockwise_copy
=
Blockwise
TensorSliceTransfer_v4r1
<
Block
Size
,
ThreadGroup
TensorSliceTransfer_v4r1
<
ThisThread
Block
,
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
...
...
@@ -486,7 +447,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
,
NumPrefetch
>
(
Num
GemmK
Prefetch
Stage
>
(
b_grid_desc_k0_n_k1
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
b_element_op
,
...
...
@@ -531,27 +492,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
);
// gridwise GEMM pipeline
const
auto
gridwise_gemm_pipeline
=
GridwiseGemmPipeline_v1
<
remove_cvref_t
<
decltype
(
a_grid_desc_k0_m_k1
)
>
,
remove_cvref_t
<
decltype
(
a_block_desc_k0_m_k1
)
>
,
remove_cvref_t
<
decltype
(
a_blockwise_copy
)
>
,
remove_cvref_t
<
decltype
(
a_grid_buf
)
>
,
remove_cvref_t
<
decltype
(
a_block_buf
)
>
,
remove_cvref_t
<
decltype
(
a_block_slice_copy_step
)
>
,
remove_cvref_t
<
decltype
(
b_grid_desc_k0_n_k1
)
>
,
remove_cvref_t
<
decltype
(
b_block_desc_k0_n_k1
)
>
,
remove_cvref_t
<
decltype
(
b_blockwise_copy
)
>
,
remove_cvref_t
<
decltype
(
b_grid_buf
)
>
,
remove_cvref_t
<
decltype
(
b_block_buf
)
>
,
remove_cvref_t
<
decltype
(
b_block_slice_copy_step
)
>
,
remove_cvref_t
<
decltype
(
blockwise_gemm
)
>
,
remove_cvref_t
<
decltype
(
c_thread_buf
)
>
,
NumPrefetch
,
HasMainK0BlockLoop
>
{};
const
index_t
K0BlockMainLoop
=
__builtin_amdgcn_readfirstlane
(
K0
/
K0PerBlock
);
g
ridwise
_g
emm
_p
ipe
line
.
Run
(
a_grid_desc_k0_m_k1
,
G
ridwise
G
emm
P
ipe
::
template
Run
<
HasMainKBlockLoop
>
(
a_grid_desc_k0_m_k1
,
a_block_desc_k0_m_k1
,
a_blockwise_copy
,
a_grid_buf
,
...
...
@@ -690,8 +633,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
n_thread_data_on_block_idx
[
I2
]),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
auto
c_block_copy_lds_to_global
=
Blockwise
TensorSliceTransfer_v6r2
<
BlockSize
,
// index_t BlockSize,
auto
c_block_copy_lds_to_global
=
ThreadGroup
TensorSliceTransfer_v6r2
<
ThisThreadBlock
,
// index_t BlockSize,
CElementwiseOperation
,
// ElementwiseOperation,
CGlobalMemoryDataOperation
,
// DstInMemOp,
Sequence
<
1
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp
View file @
a3b4c5cb
#ifndef CK_GRIDWISE_GEMM_XDLOPS_V3R3_HPP
#define CK_GRIDWISE_GEMM_XDLOPS_V3R3_HPP
#pragma once
#include "common_header.hpp"
#include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "
blockwise
_tensor_slice_transfer_v4r1.hpp"
#include "
blockwise
_tensor_slice_transfer_v6r3.hpp"
#include "
thread_group
_tensor_slice_transfer_v4r1.hpp"
#include "
thread_group
_tensor_slice_transfer_v6r3.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "gridwise_gemm_pipeline_v1.hpp"
...
...
@@ -25,7 +24,7 @@ template <typename GridwiseGemm,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
Block2CTileMap
,
bool
HasMainK
0
BlockLoop
>
bool
HasMainKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
...
...
@@ -52,7 +51,7 @@ __global__ void
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainK
0
BlockLoop
>(
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
p_b_grid
,
p_c_grid
,
...
...
@@ -128,7 +127,7 @@ template <
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
,
index_t
CBlockTransferScalarPerVector_NWaveNPerXdl
,
index_t
NumPrefetch
=
1
>
index_t
Num
GemmK
Prefetch
Stage
=
1
>
struct
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
@@ -143,6 +142,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
// K1 should be Number<...>
static
constexpr
auto
K1
=
Number
<
K1Value
>
{};
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
GridwiseGemmPipeline_v1
<
NumGemmKPrefetchStage
>
;
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
()
{
constexpr
auto
max_lds_align
=
K1
;
...
...
@@ -235,12 +238,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template
<
typename
Block2CTileMap
>
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
,
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
M01
,
index_t
N01
)
const
Block2CTileMap
&
block_2_ctile_map
)
{
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
decltype
(
K1
)
>>::
value
,
"wrong! K1 need to be known at compile-time"
);
...
...
@@ -261,56 +264,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K0
%
K0PerBlock
==
0
))
return
false
;
// check NumPrefetch
if
constexpr
(
NumPrefetch
==
1
)
{
// 1-stage prefetch always supported
}
else
if
constexpr
(
NumPrefetch
==
2
)
{
// 2-stage prefetch currently only support even number of K0 loop
// TODO: add support for odd number of K0 loop
if
(
!
((
K0
/
K0PerBlock
)
%
2
==
0
))
// check gridwise gemm pipeline
const
auto
num_k_loop
=
K0
/
K0PerBlock
;
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_k_loop
))
{
return
false
;
}
}
else
if
(
!
block_2_ctile_map
.
CheckValidity
(
c_grid_desc_m_n
))
{
return
false
;
}
// check M01, N01
constexpr
auto
M1
=
Number
<
MPerBlock
>
{};
constexpr
auto
N1
=
Number
<
NPerBlock
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
N0
=
N
/
N1
;
if
(
!
(
M0
%
M01
==
0
&&
N0
%
N01
==
0
))
return
false
;
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return
true
;
}
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
const
index_t
grid_size
=
(
M
/
MPerBlock
)
*
(
N
/
NPerBlock
);
const
index_t
num_loop
=
K
/
(
K0PerBlock
*
K1
);
return
grid_size
;
}
// TODO move this function into GEMM-pipeline class
__host__
__device__
static
constexpr
bool
CalculateHasMainK0BlockLoop
(
index_t
K0
)
{
const
bool
has_main_k0_block_loop
=
(
K0
/
(
NumPrefetch
*
K0PerBlock
))
>
1
;
return
has_main_k0_block_loop
;
return
GridwiseGemmPipe
::
CalculateHasMainLoop
(
num_loop
);
}
template
<
typename
CGridDesc_M_N_
>
...
...
@@ -341,39 +316,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
}
// return block_id to C matrix tile idx (m0, n0) mapping
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2CTileMap
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
M01
,
index_t
N01
)
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2CTileMap
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
/* M01 */
,
index_t
/* N01 */
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
constexpr
auto
M1
=
Number
<
MPerBlock
>
{};
constexpr
auto
N1
=
Number
<
NPerBlock
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
N0
=
N
/
N1
;
const
auto
M00
=
M0
/
M01
;
const
auto
N00
=
N0
/
N01
;
const
auto
m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M00
,
M01
)),
make_unmerge_transform
(
make_tuple
(
N00
,
N01
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
,
3
>
{}));
const
auto
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M00
,
N00
,
M01
,
N01
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
cblockid_to_m0_n0_block_cluster_adaptor
=
chain_tensor_adaptors
(
m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
,
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor
);
return
cblockid_to_m0_n0_block_cluster_adaptor
;
return
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
,
CGridDesc_M_N
>
(
c_grid_desc_m_n
);
}
using
CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
=
remove_cvref_t
<
decltype
(
...
...
@@ -393,7 +340,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
using
DefaultBlock2CTileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{},
1
,
1
))
>
;
template
<
bool
HasMainK
0
BlockLoop
,
typename
Block2CTileMap
=
DefaultBlock2CTileMap
>
template
<
bool
HasMainKBlockLoop
,
typename
Block2CTileMap
>
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
...
...
@@ -437,6 +384,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
const
auto
block_work_idx
=
block_2_ctile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
if
(
!
block_2_ctile_map
.
ValidCTileIndex
(
block_work_idx
,
make_tuple
(
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
.
GetLength
(
I0
),
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
.
GetLength
(
I3
))))
{
return
;
}
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]
*
MPerBlock
);
...
...
@@ -455,7 +413,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
// A matrix blockwise copy
auto
a_blockwise_copy
=
Blockwise
TensorSliceTransfer_v4r1
<
Block
Size
,
ThreadGroup
TensorSliceTransfer_v4r1
<
ThisThread
Block
,
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
...
...
@@ -485,7 +443,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
// B matrix blockwise copy
auto
b_blockwise_copy
=
Blockwise
TensorSliceTransfer_v4r1
<
Block
Size
,
ThreadGroup
TensorSliceTransfer_v4r1
<
ThisThread
Block
,
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
...
...
@@ -550,27 +508,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
);
// gridwise GEMM pipeline
const
auto
gridwise_gemm_pipeline
=
GridwiseGemmPipeline_v1
<
remove_cvref_t
<
decltype
(
a_grid_desc_k0_m_k1
)
>
,
remove_cvref_t
<
decltype
(
a_block_desc_k0_m_k1
)
>
,
remove_cvref_t
<
decltype
(
a_blockwise_copy
)
>
,
remove_cvref_t
<
decltype
(
a_grid_buf
)
>
,
remove_cvref_t
<
decltype
(
a_block_buf
)
>
,
remove_cvref_t
<
decltype
(
a_block_slice_copy_step
)
>
,
remove_cvref_t
<
decltype
(
b_grid_desc_k0_n_k1
)
>
,
remove_cvref_t
<
decltype
(
b_block_desc_k0_n_k1
)
>
,
remove_cvref_t
<
decltype
(
b_blockwise_copy
)
>
,
remove_cvref_t
<
decltype
(
b_grid_buf
)
>
,
remove_cvref_t
<
decltype
(
b_block_buf
)
>
,
remove_cvref_t
<
decltype
(
b_block_slice_copy_step
)
>
,
remove_cvref_t
<
decltype
(
blockwise_gemm
)
>
,
remove_cvref_t
<
decltype
(
c_thread_buf
)
>
,
NumPrefetch
,
HasMainK0BlockLoop
>
{};
const
index_t
K0BlockMainLoop
=
__builtin_amdgcn_readfirstlane
(
K0
/
K0PerBlock
);
g
ridwise
_g
emm
_p
ipe
line
.
Run
(
a_grid_desc_k0_m_k1
,
G
ridwise
G
emm
P
ipe
::
template
Run
<
HasMainKBlockLoop
>
(
a_grid_desc_k0_m_k1
,
a_block_desc_k0_m_k1
,
a_blockwise_copy
,
a_grid_buf
,
...
...
@@ -623,15 +563,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
transform_tensor_descriptor
(
c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
,
make_tuple
(
make_freeze_transform
(
I0
),
// freeze mblock
make_tuple
(
make_freeze_transform
(
I0
),
// freeze mblock
make_pass_through_transform
(
Number
<
CShuffleMXdlPerWavePerShuffle
>
{}),
// M0 (MXdlPerWave) per shuffle
Number
<
CShuffleMXdlPerWavePerShuffle
>
{}),
// M0 (MXdlPerWave) per
// shuffle
make_unmerge_transform
(
make_tuple
(
M1
,
M2
,
M3
,
M4
)),
// M1 = MWave, M2 * M3 * M4 = MPerXdl
make_freeze_transform
(
I0
),
// freeze nblock
make_pass_through_transform
(
Number
<
CShuffleNXdlPerWavePerShuffle
>
{}),
// N0 (NXdlPerWave) per shuffle
Number
<
CShuffleNXdlPerWavePerShuffle
>
{}),
// N0 (NXdlPerWave) per
// shuffle
make_unmerge_transform
(
make_tuple
(
N1
,
N2
))),
// M1 = MWave, M2 * M3 * M4 = MPerXdl
make_tuple
(
Sequence
<
0
>
{},
...
...
@@ -709,8 +650,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
n_thread_data_on_block_idx
[
I2
]),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
auto
c_block_copy_lds_to_global
=
Blockwise
TensorSliceTransfer_v6r3
<
BlockSize
,
// index_t BlockSize,
auto
c_block_copy_lds_to_global
=
ThreadGroup
TensorSliceTransfer_v6r3
<
ThisThreadBlock
,
// ThreadGroup
CElementwiseOperation
,
// ElementwiseOperation,
CGlobalMemoryDataOperation
,
// DstInMemOp,
Sequence
<
1
,
...
...
@@ -851,4 +792,3 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
};
}
// namespace ck
#endif
include/ck/tensor_operation/gpu/thread/threadwise_contraction_dl
ops
.hpp
→
include/ck/tensor_operation/gpu/thread/threadwise_contraction_dl.hpp
View file @
a3b4c5cb
#ifndef CK_THREADWISE_CONTRACTION_DLOPS_HPP
#define CK_THREADWISE_CONTRACTION_DLOPS_HPP
#pragma once
#include "common_header.hpp"
#include "math.hpp"
...
...
@@ -25,9 +23,9 @@ template <typename FloatA,
BThreadDesc_TK0_TN0_TN1_TK1
::
IsKnownAtCompileTime
()
&&
CThreadDesc_TM0_TM1_TN0_TN1
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
struct
ThreadwiseGemmDl
ops
_km0m1_kn0n1_m0m1n0n1
struct
ThreadwiseGemmDl_km0m1_kn0n1_m0m1n0n1
{
__device__
constexpr
ThreadwiseGemmDl
ops
_km0m1_kn0n1_m0m1n0n1
()
__device__
constexpr
ThreadwiseGemmDl_km0m1_kn0n1_m0m1n0n1
()
{
static_assert
(
AThreadDesc_TK0_TM0_TM1_TK1
::
IsKnownAtCompileTime
()
&&
BThreadDesc_TK0_TN0_TN1_TK1
::
IsKnownAtCompileTime
()
&&
...
...
@@ -124,9 +122,9 @@ template <typename FloatA,
BThreadDesc_TK0_TN0_TN1_TK1
::
IsKnownAtCompileTime
()
&&
CThreadDesc_TM0_TM1_TN0_TN1
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
struct
ThreadwiseContractionDl
ops
_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1
struct
ThreadwiseContractionDl_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1
{
__device__
constexpr
ThreadwiseContractionDl
ops
_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1
()
__device__
constexpr
ThreadwiseContractionDl_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1
()
{
static_assert
(
AThreadDesc_TK0_TM0_TM1_TK1
::
IsKnownAtCompileTime
()
&&
BThreadDesc_TK0_TN0_TN1_TK1
::
IsKnownAtCompileTime
()
&&
...
...
@@ -220,4 +218,3 @@ struct ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_
};
}
// namespace ck
#endif
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
View file @
a3b4c5cb
...
...
@@ -51,7 +51,7 @@ template <typename SrcData,
typename
DstData
,
typename
SrcDesc
,
typename
DstDesc
,
typename
Dst
ElementwiseOperation
,
typename
ElementwiseOperation
,
typename
SliceLengths
,
typename
DimAccessOrder
,
index_t
DstVectorDim
,
...
...
@@ -70,12 +70,11 @@ struct ThreadwiseTensorSliceTransfer_v1r3
using
DstCoordStep
=
decltype
(
make_tensor_coordinate_step
(
DstDesc
{},
Index
{}));
__device__
constexpr
ThreadwiseTensorSliceTransfer_v1r3
(
const
DstDesc
&
dst_desc
,
__device__
constexpr
ThreadwiseTensorSliceTransfer_v1r3
(
const
DstDesc
&
dst_desc
,
const
Index
&
dst_slice_origin_idx
,
const
Dst
ElementwiseOperation
&
dst_
element_op
)
const
ElementwiseOperation
&
element_op
)
:
dst_coord_
(
make_tensor_coordinate
(
dst_desc
,
dst_slice_origin_idx
)),
dst_
element_op_
{
dst_
element_op
}
element_op_
{
element_op
}
{
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
(),
"wrong! SrcDesc need to known at compile-time"
);
...
...
@@ -136,13 +135,13 @@ struct ThreadwiseTensorSliceTransfer_v1r3
constexpr
index_t
src_offset
=
src_desc
.
CalculateOffset
(
src_slice_origin_idx
+
idx_md
+
i
*
dst_scalar_step_in_vector
);
SrcData
dst_
v
;
SrcData
v
;
// apply element-wise operation
dst_
element_op_
(
dst_
v
,
src_buf
[
Number
<
src_offset
>
{}]);
element_op_
(
v
,
src_buf
[
Number
<
src_offset
>
{}]);
// apply type convert
dst_vector
.
template
AsType
<
DstData
>()(
i
)
=
type_convert
<
DstData
>
(
dst_
v
);
dst_vector
.
template
AsType
<
DstData
>()(
i
)
=
type_convert
<
DstData
>
(
v
);
});
const
bool
is_dst_valid
=
...
...
@@ -213,7 +212,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3
private:
DstCoord
dst_coord_
;
const
Dst
ElementwiseOperation
dst_
element_op_
;
const
ElementwiseOperation
element_op_
;
};
// namespace ThreadwiseTensorSliceTransfer_v1r3
// Assume:
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp
View file @
a3b4c5cb
#ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_V5R1_HPP
#define CK_THREADWISE_TENSOR_SLICE_TRANSFER_V5R1_HPP
#pragma once
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
...
...
@@ -609,4 +608,3 @@ struct ThreadwiseTensorSliceTransfer_v5r1
};
}
// namespace ck
#endif
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp
View file @
a3b4c5cb
...
...
@@ -102,8 +102,13 @@ struct ThreadwiseTensorSliceTransfer_v6r1
// apply pointwise operation
static_for
<
0
,
ScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
element_op_
(
dst_vector_container
.
template
AsType
<
DstData
>()(
i
),
src_vector_container
.
template
AsType
<
SrcData
>()[
i
]);
SrcData
v
;
// apply element-wise operation
element_op_
(
v
,
src_vector_container
.
template
AsType
<
SrcData
>()[
i
]);
// apply type convert
dst_vector_container
.
template
AsType
<
DstData
>()(
i
)
=
type_convert
<
DstData
>
(
v
);
});
const
bool
is_dst_valid
=
...
...
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
View file @
a3b4c5cb
...
...
@@ -25,6 +25,7 @@ enum struct MfmaInstr
mfma_f32_16x16x8bf16
,
mfma_i32_32x32x8i8
,
mfma_i32_16x16x16i8
,
mfma_f64_16x16x4f64
};
template
<
MfmaInstr
instr
>
...
...
@@ -383,12 +384,40 @@ struct mfma_type<MfmaInstr::mfma_i32_16x16x16i8>
}
};
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f64_16x16x4f64
>
{
static
constexpr
index_t
group_size
=
1
;
static
constexpr
index_t
num_groups_per_blk
=
4
;
static
constexpr
index_t
num_regs_per_blk
=
4
;
// group_size * num_groups_per_blk;
static
constexpr
index_t
num_threads_per_blk
=
16
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
4
;
// wave_size / num_threads_per_blk;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
m_per_blk
=
16
;
static
constexpr
index_t
n_per_blk
=
16
;
static
constexpr
index_t
k_per_blk
=
1
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
intrin_mfma_f64_16x16x4f64
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
};
template
<
typename
base_type
,
index_t
MPerXdlops
,
index_t
NPerXdlops
>
struct
MfmaSelector
{
template
<
typename
base_type_
,
index_t
MPerXdlops_
,
index_t
NPerXdlops_
>
static
constexpr
auto
GetMfma
();
template
<
>
static
constexpr
auto
GetMfma
<
double
,
16
,
16
>
()
{
return
MfmaInstr
::
mfma_f64_16x16x4f64
;
}
template
<
>
static
constexpr
auto
GetMfma
<
float
,
64
,
64
>
()
{
...
...
@@ -661,9 +690,10 @@ struct XdlopsGemm
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
Run
(
const
FloatA
&
p_a_wave
,
const
FloatB
&
p_b_wave
,
FloatC
&
p_c_thread
)
const
{
static_assert
(
is_same
<
base_type
,
float
>::
value
||
is_same
<
base_type
,
half_t
>::
value
||
is_same
<
base_type
,
bhalf_t
>::
value
||
is_same
<
base_type
,
int8_t
>::
value
,
"base base_type must be float, half, bfloat16, and int8_t!"
);
static_assert
(
is_same
<
base_type
,
double
>::
value
||
is_same
<
base_type
,
float
>::
value
||
is_same
<
base_type
,
half_t
>::
value
||
is_same
<
base_type
,
bhalf_t
>::
value
||
is_same
<
base_type
,
int8_t
>::
value
,
"base base_type must be double, float, half, bfloat16, and int8_t!"
);
static_for
<
0
,
KPack
/
mfma_instr
.
k_per_blk
,
1
>
{}([
&
](
auto
k
)
{
mfma_instr
.
template
run
<
MPerXdlops
,
NPerXdlops
>(
p_a_wave
[
k
],
p_b_wave
[
k
],
p_c_thread
);
...
...
include/ck/utility/amd_buffer_addressing.hpp
View file @
a3b4c5cb
...
...
@@ -258,6 +258,14 @@ __device__ float llvm_amdgcn_raw_buffer_atomic_add_fp32(
index_t
soffset
,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.atomic.fadd.f32"
);
// buffer atomic-add fp32
__device__
double
llvm_amdgcn_raw_buffer_atomic_max_fp64
(
double
vdata
,
int32x4_t
rsrc
,
// dst_wave_buffer_resource
int
voffset
,
// dst_thread_addr_offset
int
soffset
,
// dst_wave_addr_offset
int
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.atomic.fmax.f64"
);
template
<
typename
T
,
index_t
N
>
__device__
typename
vector_type
<
T
,
N
>::
type
amd_buffer_load_impl
(
int32x4_t
src_wave_buffer_resource
,
index_t
src_thread_addr_offset
,
...
...
@@ -915,6 +923,71 @@ __device__ void amd_buffer_atomic_add_impl(const typename vector_type<T, N>::typ
}
}
template
<
typename
T
,
index_t
N
>
__device__
void
amd_buffer_atomic_max_impl
(
const
typename
vector_type
<
T
,
N
>::
type
src_thread_data
,
int32x4_t
dst_wave_buffer_resource
,
index_t
dst_thread_addr_offset
,
index_t
dst_wave_addr_offset
)
{
static_assert
((
is_same
<
T
,
double
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
)),
"wrong! not implemented"
);
if
constexpr
(
is_same
<
T
,
double
>::
value
)
{
if
constexpr
(
N
==
1
)
{
llvm_amdgcn_raw_buffer_atomic_max_fp64
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
}
else
if
constexpr
(
N
==
2
)
{
vector_type
<
double
,
2
>
tmp
{
src_thread_data
};
llvm_amdgcn_raw_buffer_atomic_max_fp64
(
tmp
.
AsType
<
double
>
()[
Number
<
0
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
llvm_amdgcn_raw_buffer_atomic_max_fp64
(
tmp
.
AsType
<
double
>
()[
Number
<
1
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
+
sizeof
(
double
),
0
);
}
else
if
constexpr
(
N
==
4
)
{
vector_type
<
double
,
4
>
tmp
{
src_thread_data
};
llvm_amdgcn_raw_buffer_atomic_max_fp64
(
tmp
.
AsType
<
double
>
()[
Number
<
0
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
llvm_amdgcn_raw_buffer_atomic_max_fp64
(
tmp
.
AsType
<
double
>
()[
Number
<
1
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
+
sizeof
(
double
),
0
);
llvm_amdgcn_raw_buffer_atomic_max_fp64
(
tmp
.
AsType
<
double
>
()[
Number
<
2
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
+
2
*
sizeof
(
double
),
0
);
llvm_amdgcn_raw_buffer_atomic_max_fp64
(
tmp
.
AsType
<
double
>
()[
Number
<
3
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
+
3
*
sizeof
(
double
),
0
);
}
}
}
// buffer_load requires:
// 1) p_src_wave must point to global memory space
// 2) p_src_wave must be a wavewise pointer.
...
...
@@ -1046,4 +1119,39 @@ amd_buffer_atomic_add(const typename vector_type_maker<T, N>::type::type src_thr
#endif
}
// buffer_atomic_max requires:
// 1) p_dst_wave must point to global memory
// 2) p_dst_wave must be a wavewise pointer.
// It is user's responsibility to make sure that is true.
template
<
typename
T
,
index_t
N
>
__device__
void
amd_buffer_atomic_max
(
const
typename
vector_type_maker
<
T
,
N
>::
type
::
type
src_thread_data
,
T
*
p_dst_wave
,
const
index_t
dst_thread_element_offset
,
const
bool
dst_thread_element_valid
,
const
index_t
dst_element_space_size
)
{
const
int32x4_t
dst_wave_buffer_resource
=
make_wave_buffer_resource
(
p_dst_wave
,
dst_element_space_size
);
index_t
dst_thread_addr_offset
=
dst_thread_element_offset
*
sizeof
(
T
);
using
vector_t
=
typename
vector_type_maker
<
T
,
N
>::
type
::
type
;
using
scalar_t
=
typename
scalar_type
<
vector_t
>::
type
;
constexpr
index_t
vector_size
=
scalar_type
<
vector_t
>::
vector_size
;
#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK
uint32_t
dst_addr_shift
=
dst_thread_element_valid
?
0
:
0x7fffffff
;
amd_buffer_atomic_max_impl
<
scalar_t
,
vector_size
>
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_addr_shift
+
dst_thread_addr_offset
,
0
);
#else
if
(
dst_thread_element_valid
)
{
amd_buffer_atomic_max_impl
<
scalar_t
,
vector_size
>
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
0
);
}
#endif
}
}
// namespace ck
include/ck/utility/amd_xdlops.hpp
View file @
a3b4c5cb
...
...
@@ -266,8 +266,8 @@ struct intrin_mfma_i32_32x32x8i8<32, 32>
__device__
static
void
Run
(
const
int8x4_t
&
reg_a
,
const
int8x4_t
&
reg_b
,
FloatC
&
reg_c
)
{
reg_c
.
template
AsType
<
int32x16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_i32_32x32x8i8
(
bit_cast
<
int
>
(
reg_a
),
bit_cast
<
int
>
(
reg_b
),
__builtin_amdgcn_mfma_i32_32x32x8i8
(
bit_cast
<
int
32_t
>
(
reg_a
),
bit_cast
<
int
32_t
>
(
reg_b
),
reg_c
.
template
AsType
<
int32x16_t
>()[
Number
<
0
>
{}],
0
,
0
,
...
...
@@ -285,8 +285,8 @@ struct intrin_mfma_i32_16x16x16i8<16, 16>
__device__
static
void
Run
(
const
int8x4_t
&
reg_a
,
const
int8x4_t
&
reg_b
,
FloatC
&
reg_c
)
{
reg_c
.
template
AsType
<
int32x4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_i32_16x16x16i8
(
bit_cast
<
int
>
(
reg_a
),
bit_cast
<
int
>
(
reg_b
),
__builtin_amdgcn_mfma_i32_16x16x16i8
(
bit_cast
<
int
32_t
>
(
reg_a
),
bit_cast
<
int
32_t
>
(
reg_b
),
reg_c
.
template
AsType
<
int32x4_t
>()[
Number
<
0
>
{}],
0
,
0
,
...
...
@@ -294,5 +294,24 @@ struct intrin_mfma_i32_16x16x16i8<16, 16>
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f64_16x16x4f64
;
template
<
>
struct
intrin_mfma_f64_16x16x4f64
<
16
,
16
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
double
&
reg_a
,
const
double
&
reg_b
,
FloatC
&
reg_c
)
{
#ifdef __gfx90a__
reg_c
.
template
AsType
<
double4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_f64_16x16x4f64
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
double4_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
#endif
}
};
}
// namespace ck
#endif
Prev
1
…
3
4
5
6
7
8
9
10
11
…
19
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment