Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
18707866
Commit
18707866
authored
Apr 10, 2022
by
Chao Liu
Browse files
adding thread group
parent
ee33b1fa
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
206 additions
and
81 deletions
+206
-81
include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle_v2.hpp
...nsor_operation/gpu/device/device_gemm_xdl_cshuffle_v2.hpp
+40
-37
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp
+71
-21
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp
+95
-23
No files found.
include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle_v2.hpp
View file @
18707866
...
@@ -27,7 +27,8 @@ template <typename ALayout,
...
@@ -27,7 +27,8 @@ template <typename ALayout,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
GemmSpecialization
GemmSpec
,
GemmSpecialization
GemmSpec
,
index_t
NumGemmKPrefetchStage
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
ABBlockTransferThreadGroupSize
,
index_t
BlockGemmThreadGroupSize
,
index_t
MPerBlock
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
KPerBlock
,
...
@@ -346,7 +347,8 @@ struct DeviceGemm_Xdl_CShuffle_v2
...
@@ -346,7 +347,8 @@ struct DeviceGemm_Xdl_CShuffle_v2
BGridDesc_BK0_N_BK1
,
BGridDesc_BK0_N_BK1
,
CGridDesc_M_N
,
CGridDesc_M_N
,
NumGemmKPrefetchStage
,
NumGemmKPrefetchStage
,
BlockSize
,
ABBlockTransferThreadGroupSize
,
BlockGemmThreadGroupSize
,
MPerBlock
,
MPerBlock
,
NPerBlock
,
NPerBlock
,
KPerBlock
,
KPerBlock
,
...
@@ -487,7 +489,7 @@ struct DeviceGemm_Xdl_CShuffle_v2
...
@@ -487,7 +489,7 @@ struct DeviceGemm_Xdl_CShuffle_v2
{
{
launch_kernel
(
kernel
,
launch_kernel
(
kernel
,
dim3
(
grid_size
),
dim3
(
grid_size
),
dim3
(
BlockSize
),
dim3
(
AB
Block
TransferThreadGroupSize
+
BlockGemmThreadGroup
Size
),
0
,
0
,
arg
.
p_a_grid_
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_b_grid_
,
...
@@ -502,22 +504,22 @@ struct DeviceGemm_Xdl_CShuffle_v2
...
@@ -502,22 +504,22 @@ struct DeviceGemm_Xdl_CShuffle_v2
}
}
else
else
{
{
ave_time
=
ave_time
=
launch_and_time_kernel
(
launch_and_time_kernel
(
kernel
,
kernel
,
nrepeat
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
grid_size
),
dim3
(
BlockSize
),
dim3
(
AB
Block
TransferThreadGroupSize
+
BlockGemmThreadGroup
Size
),
0
,
0
,
arg
.
p_a_grid_
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
p_c_grid_
,
arg
.
a_element_op_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
c_element_op_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
block_2_ctile_map_
);
arg
.
block_2_ctile_map_
);
}
}
}
}
else
else
...
@@ -539,7 +541,7 @@ struct DeviceGemm_Xdl_CShuffle_v2
...
@@ -539,7 +541,7 @@ struct DeviceGemm_Xdl_CShuffle_v2
{
{
launch_kernel
(
kernel
,
launch_kernel
(
kernel
,
dim3
(
grid_size
),
dim3
(
grid_size
),
dim3
(
BlockSize
),
dim3
(
AB
Block
TransferThreadGroupSize
+
BlockGemmThreadGroup
Size
),
0
,
0
,
arg
.
p_a_grid_
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_b_grid_
,
...
@@ -554,22 +556,22 @@ struct DeviceGemm_Xdl_CShuffle_v2
...
@@ -554,22 +556,22 @@ struct DeviceGemm_Xdl_CShuffle_v2
}
}
else
else
{
{
ave_time
=
ave_time
=
launch_and_time_kernel
(
launch_and_time_kernel
(
kernel
,
kernel
,
nrepeat
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
grid_size
),
dim3
(
BlockSize
),
dim3
(
AB
Block
TransferThreadGroupSize
+
BlockGemmThreadGroup
Size
),
0
,
0
,
arg
.
p_a_grid_
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
p_c_grid_
,
arg
.
a_element_op_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
c_element_op_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
block_2_ctile_map_
);
arg
.
block_2_ctile_map_
);
}
}
}
}
...
@@ -673,7 +675,8 @@ struct DeviceGemm_Xdl_CShuffle_v2
...
@@ -673,7 +675,8 @@ struct DeviceGemm_Xdl_CShuffle_v2
// clang-format off
// clang-format off
str
<<
"DeviceGemm_Xdl_CShuffle_v2"
str
<<
"DeviceGemm_Xdl_CShuffle_v2"
<<
"<"
<<
"<"
<<
BlockSize
<<
", "
<<
ABBlockTransferThreadGroupSize
<<
", "
<<
BlockGemmThreadGroupSize
<<
", "
<<
MPerBlock
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
KPerBlock
<<
", "
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp
View file @
18707866
...
@@ -4,7 +4,9 @@
...
@@ -4,7 +4,9 @@
namespace
ck
{
namespace
ck
{
template
<
typename
AGridDesc
,
template
<
typename
ABBlockTransferThreadGroup
,
typename
BlockGemmThreadGroup
,
typename
AGridDesc
,
typename
ABlockDesc
,
typename
ABlockDesc
,
typename
ABlockTransfer
,
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
AGridBuffer
,
...
@@ -23,7 +25,9 @@ template <typename AGridDesc,
...
@@ -23,7 +25,9 @@ template <typename AGridDesc,
struct
GridwiseGemmPipeline_v2
;
struct
GridwiseGemmPipeline_v2
;
// 1-stage prefetch
// 1-stage prefetch
template
<
typename
AGridDesc
,
template
<
typename
ABBlockTransferThreadGroup
,
typename
BlockGemmThreadGroup
,
typename
AGridDesc
,
typename
ABlockDesc
,
typename
ABlockDesc
,
typename
ABlockTransfer
,
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
AGridBuffer
,
...
@@ -38,7 +42,9 @@ template <typename AGridDesc,
...
@@ -38,7 +42,9 @@ template <typename AGridDesc,
typename
BlockwiseGemm
,
typename
BlockwiseGemm
,
typename
CThreadBuffer
,
typename
CThreadBuffer
,
bool
HasMainLoop
>
bool
HasMainLoop
>
struct
GridwiseGemmPipeline_v2
<
AGridDesc
,
struct
GridwiseGemmPipeline_v2
<
ABBlockTransferThreadGroup
,
BlockGemmThreadGroup
,
AGridDesc
,
ABlockDesc
,
ABlockDesc
,
ABlockTransfer
,
ABlockTransfer
,
AGridBuffer
,
AGridBuffer
,
...
@@ -58,19 +64,24 @@ struct GridwiseGemmPipeline_v2<AGridDesc,
...
@@ -58,19 +64,24 @@ struct GridwiseGemmPipeline_v2<AGridDesc,
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
__device__
void
RunProducer
(
const
AGridDesc
&
a_grid_desc
,
__device__
constexpr
GridwiseGemmPipeline_v2
()
const
ABlockDesc
&
a_block_desc
,
{
ABlockTransfer
&
a_blockwise_copy
,
// TODO static assert
const
AGridBuffer
&
a_grid_buf
,
}
ABlockBuffer
&
a_block_buf
,
const
ABlockTransferStep
&
a_block_copy_step
,
static
__device__
void
RunABBlockTransferPipeline
(
const
AGridDesc
&
a_grid_desc
,
const
BGridDesc
&
b_grid_desc
,
const
ABlockDesc
&
a_block_desc
,
const
BBlockDesc
&
b_block_desc
,
ABlockTransfer
&
a_blockwise_copy
,
BBlockTransfer
&
b_blockwise_copy
,
const
AGridBuffer
&
a_grid_buf
,
const
BGridBuffer
&
b_grid_buf
,
ABlockBuffer
&
a_block_buf
,
BBlockBuffer
&
b_block_buf
,
const
ABlockTransferStep
&
a_block_copy_step
,
const
BBlockTransferStep
&
b_block_copy_step
,
const
BGridDesc
&
b_grid_desc
,
index_t
num_loop
)
const
BBlockDesc
&
b_block_desc
,
BBlockTransfer
&
b_blockwise_copy
,
const
BGridBuffer
&
b_grid_buf
,
BBlockBuffer
&
b_block_buf
,
const
BBlockTransferStep
&
b_block_copy_step
,
index_t
num_loop
)
{
{
// global read 0
// global read 0
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
...
@@ -140,11 +151,11 @@ struct GridwiseGemmPipeline_v2<AGridDesc,
...
@@ -140,11 +151,11 @@ struct GridwiseGemmPipeline_v2<AGridDesc,
}
}
}
}
static
__device__
void
Run
Consumer
(
ABlockBuffer
&
a_block_buf
,
static
__device__
void
Run
BlockGemmPipeline
(
ABlockBuffer
&
a_block_buf
,
BBlockBuffer
&
b_block_buf
,
BBlockBuffer
&
b_block_buf
,
const
BlockwiseGemm
&
blockwise_gemm
,
const
BlockwiseGemm
&
blockwise_gemm
,
CThreadBuffer
&
c_thread_buf
,
CThreadBuffer
&
c_thread_buf
,
index_t
num_loop
)
index_t
num_loop
)
{
{
// Initialize C
// Initialize C
c_thread_buf
.
Clear
();
c_thread_buf
.
Clear
();
...
@@ -193,6 +204,45 @@ struct GridwiseGemmPipeline_v2<AGridDesc,
...
@@ -193,6 +204,45 @@ struct GridwiseGemmPipeline_v2<AGridDesc,
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
}
}
}
}
static
__device__
void
Run
(
const
AGridDesc
&
a_grid_desc
,
const
ABlockDesc
&
a_block_desc
,
ABlockTransfer
&
a_blockwise_copy
,
const
AGridBuffer
&
a_grid_buf
,
ABlockBuffer
&
a_block_buf
,
const
ABlockTransferStep
&
a_block_copy_step
,
const
BGridDesc
&
b_grid_desc
,
const
BBlockDesc
&
b_block_desc
,
BBlockTransfer
&
b_blockwise_copy
,
const
BGridBuffer
&
b_grid_buf
,
BBlockBuffer
&
b_block_buf
,
const
BBlockTransferStep
&
b_block_copy_step
,
const
BlockwiseGemm
&
blockwise_gemm
,
CThreadBuffer
&
c_thread_buf
,
index_t
num_loop
)
{
if
(
ABBlockTransferThreadGroup
::
IsBelong
())
{
gridwise_gemm_pipeline
.
RunABBlockTransferPipeline
(
a_grid_desc_ak0_m_ak1
,
a_block_desc_ak0_m_ak1
,
a_blockwise_copy
,
a_grid_buf
,
a_block_buf
,
a_block_slice_copy_step
,
b_grid_desc_bk0_n_bk1
,
b_block_desc_bk0_n_bk1
,
b_blockwise_copy
,
b_grid_buf
,
b_block_buf
,
b_block_slice_copy_step
,
num_loop
);
}
else
if
(
BlockGemmThreadGroup
::
IsBelong
())
{
gridwise_gemm_pipeline
.
RunBlockGemmPipeline
(
a_block_buf
,
b_block_buf
,
blockwise_gemm
,
c_thread_buf
,
num_loop
);
}
}
};
};
}
// namespace ck
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp
View file @
18707866
...
@@ -67,7 +67,8 @@ template <typename FloatAB,
...
@@ -67,7 +67,8 @@ template <typename FloatAB,
typename
BGridDesc_BK0_N_BK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
CGridDesc_M_N
,
typename
CGridDesc_M_N
,
index_t
NumGemmKPrefetchStage
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
ABBlockTransferThreadGroupSize
,
index_t
BlockGemmThreadGroupSize
,
index_t
MPerBlock
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
KPerBlock
,
...
@@ -114,6 +115,50 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
...
@@ -114,6 +115,50 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
static
constexpr
auto
AK1
=
Number
<
AK1Value
>
{};
static
constexpr
auto
AK1
=
Number
<
AK1Value
>
{};
static
constexpr
auto
BK1
=
Number
<
BK1Value
>
{};
static
constexpr
auto
BK1
=
Number
<
BK1Value
>
{};
using
ThisThreadBlock
=
AnyThreadBlock
<
ABBlockTransferThreadGroupSize
+
BlockGemmThreadGroupSize
>
;
#if 1
using
ABBlockTransferThreadGroup
=
ThisThreadBlock
;
using
BlockGemmThreadGroup
=
ThisThreadBlock
;
using
CShuffleBlockTransferThreadGroup
=
ThisThreadBlock
;
#else
struct
ABBlockTransferThreadGroup
{
__device__
static
constexpr
index_t
GetNumOfThread
()
{
return
ABBlockTransferThreadGroupSize
;
}
__device__
static
constexpr
bool
IsBelong
()
{
return
get_thread_local_1d_id
()
<
ABBlockTransferThreadGroupSize
;
}
__device__
static
index_t
GetThreadId
()
{
return
get_thread_local_1d_id
();
}
};
struct
BlockGemmThreadGroup
{
__device__
static
constexpr
index_t
GetNumOfThread
()
{
return
ABBlockTransferThreadGroupSize
;
}
__device__
static
constexpr
bool
IsBelong
()
{
return
get_thread_local_1d_id
()
>=
ABBlockTransferThreadGroupSize
;
}
__device__
static
index_t
GetThreadId
()
{
return
get_thread_local_1d_id
()
-
ABBlockTransferThreadGroupSize
;
}
};
using
CShuffleBlockTransferThreadGroup
=
ThisThreadBlock
;
#endif
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
{
{
// A matrix in LDS memory, dst of blockwise copy
// A matrix in LDS memory, dst of blockwise copy
...
@@ -345,11 +390,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
...
@@ -345,11 +390,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
// B matrix in LDS memory, dst of blockwise copy
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
using
ThisThreadBlock
=
AnyThreadBlock
<
BlockSize
>
;
// A matrix blockwise copy
// A matrix blockwise copy
auto
a_blockwise_copy
=
auto
a_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
This
Thread
Block
,
ThreadGroupTensorSliceTransfer_v4r1
<
ABBlockTransfer
Thread
Group
,
AElementwiseOperation
,
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
...
@@ -380,7 +423,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
...
@@ -380,7 +423,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
// B matrix blockwise copy
// B matrix blockwise copy
auto
b_blockwise_copy
=
auto
b_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
This
Thread
Block
,
ThreadGroupTensorSliceTransfer_v4r1
<
ABBlockTransfer
Thread
Group
,
BElementwiseOperation
,
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
...
@@ -420,7 +463,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
...
@@ -420,7 +463,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
auto
blockwise_gemm
=
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
This
Thread
Block
,
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockGemm
Thread
Group
,
FloatAB
,
FloatAB
,
FloatGemmAcc
,
FloatGemmAcc
,
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
a_block_desc_ak0_m_ak1
),
...
@@ -447,6 +490,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
...
@@ -447,6 +490,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
AK1
,
0
,
0
);
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
AK1
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
BK1
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
BK1
,
0
,
0
);
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))
/
KPerBlock
);
#if 1
// gridwise GEMM pipeline
// gridwise GEMM pipeline
const
auto
gridwise_gemm_pipeline
=
const
auto
gridwise_gemm_pipeline
=
GridwiseGemmPipeline_v1
<
remove_cvref_t
<
decltype
(
a_grid_desc_ak0_m_ak1
)
>
,
GridwiseGemmPipeline_v1
<
remove_cvref_t
<
decltype
(
a_grid_desc_ak0_m_ak1
)
>
,
...
@@ -465,10 +513,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
...
@@ -465,10 +513,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
remove_cvref_t
<
decltype
(
c_thread_buf
)
>
,
remove_cvref_t
<
decltype
(
c_thread_buf
)
>
,
NumGemmKPrefetchStage
,
NumGemmKPrefetchStage
,
HasMainK0BlockLoop
>
{};
HasMainK0BlockLoop
>
{};
#else
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
// gridwise GEMM pipeline
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))
/
const
auto
gridwise_gemm_pipeline
=
KPerBlock
);
GridwiseGemmPipeline_v2
<
ABBlockTransferThreadGroup
,
BlockGemmThreadGroup
,
remove_cvref_t
<
decltype
(
a_grid_desc_ak0_m_ak1
)
>
,
remove_cvref_t
<
decltype
(
a_block_desc_ak0_m_ak1
)
>
,
remove_cvref_t
<
decltype
(
a_blockwise_copy
)
>
,
remove_cvref_t
<
decltype
(
a_grid_buf
)
>
,
remove_cvref_t
<
decltype
(
a_block_buf
)
>
,
remove_cvref_t
<
decltype
(
a_block_slice_copy_step
)
>
,
remove_cvref_t
<
decltype
(
b_grid_desc_bk0_n_bk1
)
>
,
remove_cvref_t
<
decltype
(
b_block_desc_bk0_n_bk1
)
>
,
remove_cvref_t
<
decltype
(
b_blockwise_copy
)
>
,
remove_cvref_t
<
decltype
(
b_grid_buf
)
>
,
remove_cvref_t
<
decltype
(
b_block_buf
)
>
,
remove_cvref_t
<
decltype
(
b_block_slice_copy_step
)
>
,
remove_cvref_t
<
decltype
(
blockwise_gemm
)
>
,
remove_cvref_t
<
decltype
(
c_thread_buf
)
>
,
NumGemmKPrefetchStage
,
HasMainK0BlockLoop
>
{};
#endif
gridwise_gemm_pipeline
.
Run
(
a_grid_desc_ak0_m_ak1
,
gridwise_gemm_pipeline
.
Run
(
a_grid_desc_ak0_m_ak1
,
a_block_desc_ak0_m_ak1
,
a_block_desc_ak0_m_ak1
,
...
@@ -601,7 +667,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
...
@@ -601,7 +667,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
// shuffle: blockwise copy C from LDS to global
// shuffle: blockwise copy C from LDS to global
auto
c_shuffle_block_copy_lds_to_global
=
ThreadGroupTensorSliceTransfer_v6r1
<
auto
c_shuffle_block_copy_lds_to_global
=
ThreadGroupTensorSliceTransfer_v6r1
<
ThisThreadBlock
,
//
index_t BlockSize,
ThisThreadBlock
,
//
ThreadGroup
CElementwiseOperation
,
// ElementwiseOperation,
CElementwiseOperation
,
// ElementwiseOperation,
CGlobalMemoryDataOperation
,
// DstInMemOp,
CGlobalMemoryDataOperation
,
// DstInMemOp,
Sequence
<
1
,
Sequence
<
1
,
...
@@ -655,22 +721,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
...
@@ -655,22 +721,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
// make sure it's safe to write to LDS
// make sure it's safe to write to LDS
block_sync_lds
();
block_sync_lds
();
// each thread write its data from VGPR to LDS
if
(
BlockGemmThreadGroup
::
IsBelong
())
c_thread_copy_vgpr_to_lds
.
Run
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
{
sfc_c_vgpr
.
GetIndexTupleOfNumber
(
access_id
),
// thread write its data from VGPR to LDS
c_thread_buf
,
c_thread_copy_vgpr_to_lds
.
Run
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
sfc_c_vgpr
.
GetIndexTupleOfNumber
(
access_id
),
c_shuffle_block_buf
);
c_thread_buf
,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c_shuffle_block_buf
);
}
// make sure it's safe to read from LDS
// make sure it's safe to read from LDS
block_sync_lds
();
block_sync_lds
();
// each block copy its data from LDS to global
if
(
CShuffleBlockTransferThreadGroup
::
IsBelong
())
c_shuffle_block_copy_lds_to_global
.
Run
(
{
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
// block copy its data from LDS to global
c_shuffle_block_buf
,
c_shuffle_block_copy_lds_to_global
.
Run
(
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
c_grid_buf
);
c_shuffle_block_buf
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_buf
);
}
if
constexpr
(
access_id
<
num_access
-
1
)
if
constexpr
(
access_id
<
num_access
-
1
)
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment