Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
5dce9c9d
Commit
5dce9c9d
authored
Sep 08, 2022
by
wangshaojie6
Browse files
make device/grid level code
parent
f0d63f25
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
107 additions
and
129 deletions
+107
-129
example/42_splitK_gemm_bias/run_splitK_gemm_bias_example.inc
example/42_splitK_gemm_bias/run_splitK_gemm_bias_example.inc
+9
-23
example/42_splitK_gemm_bias/splitK_gemm_bias_xdl_fp16.cpp
example/42_splitK_gemm_bias/splitK_gemm_bias_xdl_fp16.cpp
+2
-2
include/ck/tensor_operation/gpu/device/device_contraction_splitK_multiple_d_xdl_cshuffle.hpp
...ice/device_contraction_splitK_multiple_d_xdl_cshuffle.hpp
+26
-73
include/ck/tensor_operation/gpu/grid/gridwise_gemm_splitk_multiple_d_xdl_cshuffle.hpp
...gpu/grid/gridwise_gemm_splitk_multiple_d_xdl_cshuffle.hpp
+70
-31
No files found.
example/42_splitK_gemm_bias/run_splitK_gemm_bias_example.inc
View file @
5dce9c9d
...
...
@@ -56,28 +56,14 @@ bool run_splitK_gemm_bias(const ProblemSize& problem_size, const ExecutionConfig
std
::
cout
<<
"b_k_n: "
<<
b_k_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"e_m_n: "
<<
e_m_n_device_result
.
mDesc
<<
std
::
endl
;
auto
f_tensor_length_stride_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
){
if
(
std
::
is_same
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
return
{
std
::
vector
<
std
::
size_t
>
({
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
stride
,
1
})};
}
else
{
return
{
std
::
vector
<
std
::
size_t
>
({
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
1
,
stride
})};
}
};
std
::
vector
<
ck
::
index_t
>
a_ms_ks_lengths
=
f_tensor_length_stride_descriptor
(
M
,
K
,
StrideA
,
ALayout
{})[
0
];
std
::
vector
<
ck
::
index_t
>
a_ms_ks_strides
=
f_tensor_length_stride_descriptor
(
M
,
K
,
StrideA
,
ALayout
{})[
1
];
std
::
vector
<
ck
::
index_t
>
b_ns_ks_lengths
=
f_tensor_length_stride_descriptor
(
N
,
K
,
StrideB
,
Row
{})[
0
];
std
::
vector
<
ck
::
index_t
>
b_ns_ks_strides
=
f_tensor_length_stride_descriptor
(
N
,
K
,
StrideB
,
Row
{})[
1
];
std
::
vector
<
ck
::
index_t
>
d_ms_ns_lengths
=
f_tensor_length_stride_descriptor
(
M
,
N
,
0
,
Row
{})[
0
];
std
::
vector
<
ck
::
index_t
>
d_ms_ns_strides
=
f_tensor_length_stride_descriptor
(
M
,
N
,
0
,
Row
{})[
1
];
std
::
vector
<
ck
::
index_t
>
e_ms_ns_lengths
=
f_tensor_length_stride_descriptor
(
M
,
N
,
StrideE
,
ELayout
{})[
0
];
std
::
vector
<
ck
::
index_t
>
e_ms_ns_strides
=
f_tensor_length_stride_descriptor
(
M
,
N
,
StrideE
,
ELayout
{})[
1
];
std
::
vector
<
ck
::
index_t
>
a_ms_ks_lengths
=
{
M
,
K
};
std
::
vector
<
ck
::
index_t
>
a_ms_ks_strides
=
{
StrideA
,
1
};
std
::
vector
<
ck
::
index_t
>
b_ns_ks_lengths
=
{
N
,
K
};
std
::
vector
<
ck
::
index_t
>
b_ns_ks_strides
=
{
StrideB
,
1
};
std
::
vector
<
ck
::
index_t
>
d_ms_ns_lengths
=
{
M
,
N
};
std
::
vector
<
ck
::
index_t
>
d_ms_ns_strides
=
{
0
,
1
};
std
::
vector
<
ck
::
index_t
>
e_ms_ns_lengths
=
{
M
,
N
}
std
::
vector
<
ck
::
index_t
>
e_ms_ns_strides
=
{
StrideE
,
1
};
switch
(
config
.
init_method
)
{
...
...
@@ -176,7 +162,7 @@ bool run_splitK_gemm_bias(const ProblemSize& problem_size, const ExecutionConfig
Tensor
<
CDataType
>
e_m_n_host_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a_m_k
,
b_k_n
,
e_m_n_host_result
,
d_m_n
,
a_element_op
,
b_element_op
,
c_element_op
);
a_m_k
,
b_k_n
,
e_m_n_host_result
,
d_m_n
,
a_element_op
,
b_element_op
,
c
de
_element_op
);
ref_invoker
.
Run
(
ref_argument
);
...
...
example/42_splitK_gemm_bias/splitK_gemm_bias_xdl_fp16.cpp
View file @
5dce9c9d
...
...
@@ -8,7 +8,7 @@
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_
gemm_xdl_splitk_c_
shuffle.hpp"
#include "ck/tensor_operation/gpu/device/device_
contraction_splitK_multiple_d_xdl_c
shuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
...
...
@@ -51,7 +51,7 @@ static constexpr ck::index_t NumDimM = 1;
static
constexpr
ck
::
index_t
NumDimN
=
1
;
static
constexpr
ck
::
index_t
NumDimK
=
1
;
static
constexpr
auto
Gemm
Default
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
static
constexpr
auto
Gemm
Spec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
// clang-format off
using
DeviceOpInstanceKKN
=
ck
::
tensor_operation
::
device
::
...
...
include/ck/tensor_operation/gpu/device/device_contraction_splitK_multiple_d_xdl_cshuffle.hpp
View file @
5dce9c9d
...
...
@@ -15,12 +15,13 @@
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_splitk_multiple_d_xdl_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
AtomicAdd
,
typename
FloatAB
,
typename
FloatDsPointer
,
typename
FloatE
,
...
...
@@ -57,7 +58,7 @@ __global__ void
const
Block2ETileMap
block_2_etile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
AtomicAdd
::
GetSharedMemoryNumberOfByte
()];
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
...
...
@@ -80,7 +81,7 @@ __global__ void
static_for
<
0
,
NumDTensor
,
1
>
{}(
[
&
](
auto
i
)
{
p_ds_grid_grp
(
i
)
=
p_ds_grid
[
i
]
+
ds_batch_offset
[
i
];
});
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
+
a_batch_offset
,
GridwiseGemm
AtomicAdd
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
+
a_batch_offset
,
p_b_grid
+
b_batch_offset
,
p_ds_grid_grp
,
p_e_grid
+
e_batch_offset
,
...
...
@@ -538,56 +539,8 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle
EGridDesc_G_M_N
e_grid_desc_g_m_n_
;
};
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemmMultipleD_xdl_cshuffle
<
ADataType
,
// TODO: distinguish A/B datatype
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_M_K
,
BGridDesc_N_K
,
DsGridDesc_M_N
,
EGridDesc_M_N
,
NumGemmKPrefetchStage
,
BlockSize
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
AK1
,
BK1
,
MPerXDL
,
NPerXDL
,
MXdlPerWave
,
NXdlPerWave
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
false
,
ABlockLdsExtraM
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
false
,
BBlockLdsExtraN
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CDEBlockTransferScalarPerVector_NPerBlock
,
LoopSched
>
;
// GridwiseGemm
using
GridwiseGemmAtomicAdd
=
GridwiseGemmMultipleD_xdl_cshuffle
<
// GridwiseGemmAtomicAdd atomicadd
using
GridwiseGemmAtomicAdd
=
GridwiseGemmSplitKMultipleD_xdl_cshuffle
<
ADataType
,
// TODO: distinguish A/B datatype
AccDataType
,
CShuffleDataType
,
...
...
@@ -635,11 +588,11 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle
LoopSched
>
;
using
AGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
AGridDesc_M_K
{}))
>
;
GridwiseGemm
AtomicAdd
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
AGridDesc_M_K
{}))
>
;
using
BGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultBGridDescriptor_BK0_N_BK1
(
BGridDesc_N_K
{}))
>
;
GridwiseGemm
AtomicAdd
::
MakeDefaultBGridDescriptor_BK0_N_BK1
(
BGridDesc_N_K
{}))
>
;
using
Block2ETileMap
=
typename
GridwiseGemm
::
DefaultBlock2ETileMap
;
using
Block2ETileMap
=
typename
GridwiseGemm
AtomicAdd
::
DefaultBlock2ETileMap
;
// Argument
struct
Argument
:
public
BaseArgument
...
...
@@ -676,12 +629,12 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle
e_grid_desc_g_m_n_
{
DeviceOp
::
MakeEGridDescriptor_G_M_N
(
e_gs_ms_ns_lengths
,
e_gs_ms_ns_strides
)},
a_grid_desc_ak0_m_ak1_
{
GridwiseGemm
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
a_grid_desc_m_k_
)},
GridwiseGemm
AtomicAdd
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
a_grid_desc_m_k_
)},
b_grid_desc_bk0_n_bk1_
{
GridwiseGemm
::
MakeDefaultBGridDescriptor_BK0_N_BK1
(
b_grid_desc_n_k_
)},
GridwiseGemm
AtomicAdd
::
MakeDefaultBGridDescriptor_BK0_N_BK1
(
b_grid_desc_n_k_
)},
ds_grid_desc_mblock_mperblock_nblock_nperblock_
{},
e_grid_desc_mblock_mperblock_nblock_nperblock_
{},
block_2_etile_map_
{
GridwiseGemm
::
MakeDefaultBlock2ETileMap
(
e_grid_desc_m_n_
,
KBatch
)},
block_2_etile_map_
{
GridwiseGemm
AtomicAdd
::
MakeDefaultBlock2ETileMap
(
e_grid_desc_m_n_
,
1
,
1
,
KBatch
)},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
cde_element_op_
{
cde_element_op
},
...
...
@@ -711,18 +664,18 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle
});
// populate desc for Ds/E
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_m_k_
,
if
(
GridwiseGemm
AtomicAdd
::
CheckValidity
(
a_grid_desc_m_k_
,
b_grid_desc_n_k_
,
ds_grid_desc_m_n_
,
e_grid_desc_m_n_
,
block_2_etile_map_
))
{
e_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
GridwiseGemm
AtomicAdd
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
e_grid_desc_m_n_
);
ds_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
GridwiseGemm
AtomicAdd
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
ds_grid_desc_m_n_
);
}
...
...
@@ -753,7 +706,7 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle
// pointers
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
typename
GridwiseGemm
::
DsGridPointer
p_ds_grid_
;
typename
GridwiseGemm
AtomicAdd
::
DsGridPointer
p_ds_grid_
;
EDataType
*
p_e_grid_
;
// tensor descriptors for problem definiton
...
...
@@ -768,9 +721,9 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle
// tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
typename
GridwiseGemm
::
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
typename
GridwiseGemm
AtomicAdd
::
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock_
;
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
typename
GridwiseGemm
AtomicAdd
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_
;
// block-to-e-tile map
...
...
@@ -804,7 +757,7 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_m_k_
,
if
(
!
GridwiseGemm
AtomicAdd
::
CheckValidity
(
arg
.
a_grid_desc_m_k_
,
arg
.
b_grid_desc_n_k_
,
arg
.
ds_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
...
...
@@ -826,19 +779,19 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
const
auto
kernel
=
kernel_contraction_multiple_d_xdl_cshuffle
<
GridwiseGemm
,
GridwiseGemm
AtomicAdd
,
ADataType
,
// TODO: distiguish A/B datatype
typename
GridwiseGemm
::
DsGridPointer
,
typename
GridwiseGemm
AtomicAdd
::
DsGridPointer
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
AtomicAdd
::
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
AtomicAdd
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
ComputePtrOffsetOfStridedBatch
,
typename
GridwiseGemm
::
DefaultBlock2ETileMap
,
typename
GridwiseGemm
AtomicAdd
::
DefaultBlock2ETileMap
,
has_main_loop
>
;
return
launch_and_time_kernel
(
stream_config
,
...
...
@@ -862,7 +815,7 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle
arg
.
block_2_etile_map_
);
};
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
if
(
GridwiseGemm
AtomicAdd
::
CalculateHasMainKBlockLoop
(
K
))
{
return
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
}
...
...
@@ -887,7 +840,7 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle
return
false
;
}
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_m_k_
,
if
(
!
GridwiseGemm
AtomicAdd
::
CheckValidity
(
arg
.
a_grid_desc_m_k_
,
arg
.
b_grid_desc_n_k_
,
arg
.
ds_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_splitk_multiple_d_xdl_cshuffle.hpp
View file @
5dce9c9d
...
...
@@ -11,6 +11,7 @@
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
...
...
@@ -231,17 +232,12 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle
Number
<
NumDTensor
>
{});
}
// return block_id to E matrix tile idx (m0, n0) mapping
// __host__ __device__ static constexpr auto
// MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n)
// {
// return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, EGridDesc_M_N>(
// e_grid_desc_m_n);
// }
// return block_id to C matrix tile idx (m0, n0) mapping
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2ETileMap
(
const
EGridDesc_M_N
&
c_m_n_grid_desc
,
index_t
/* M01 */
,
index_t
/* N01 */
,
index_t
KBatch
=
1
)
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2ETileMap
(
const
EGridDesc_M_N
&
c_m_n_grid_desc
,
index_t
/* M01 */
,
index_t
/* N01 */
,
index_t
KBatch
=
1
)
{
return
BlockToCTileMap_KSplit_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
,
EGridDesc_M_N
>
(
c_m_n_grid_desc
,
8
,
KBatch
);
...
...
@@ -263,6 +259,11 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle
const
auto
N
=
b_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
K
=
a_grid_desc_m_k
.
GetLength
(
I1
);
if
(
EGlobalMemoryDataOperation
!=
InMemoryDataOperationEnum
::
AtomicAdd
)
{
return
false
;
}
// check consistency of desc
if
(
!
(
M
==
e_grid_desc_m_n
.
GetLength
(
I0
)
&&
N
==
e_grid_desc_m_n
.
GetLength
(
I1
)))
{
...
...
@@ -332,7 +333,7 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
DsGridDesc_M_N
{}))
>
;
using
DefaultBlock2ETileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2ETileMap
(
EGridDesc_M_N
{},
KBatch
))
>
;
remove_cvref_t
<
decltype
(
MakeDefaultBlock2ETileMap
(
EGridDesc_M_N
{},
1
,
1
,
1
))
>
;
using
DsGridPointer
=
decltype
(
MakeDsGridPointer
());
...
...
@@ -378,19 +379,22 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle
block_2_etile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
if
(
!
block_2_etile_map
.
ValidCTileIndex
(
block_work_idx
,
make_tuple
(
block_work_idx
[
I1
],
block_work_idx
[
I2
])
,
make_tuple
(
e_grid_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I0
),
e_grid_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I2
))))
{
return
;
}
// k batch id
const
index_t
k_batch_id
=
block_work_idx
[
I0
];
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I
0
]
*
MPerBlock
);
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I
1
]
*
MPerBlock
);
const
index_t
n_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I
1
]
*
NPerBlock
);
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I
2
]
*
NPerBlock
);
// lds max alignment
constexpr
auto
max_lds_align
=
math
::
lcm
(
AK1
,
BK1
);
...
...
@@ -426,7 +430,7 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle
true
,
NumGemmKPrefetchStage
>
(
a_grid_desc_ak0_m_ak1
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
make_multi_index
(
k_batch_id
,
m_block_data_idx_on_grid
,
0
),
a_element_op
,
a_block_desc_ak0_m_ak1
,
make_multi_index
(
0
,
0
,
0
),
...
...
@@ -457,7 +461,7 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle
true
,
NumGemmKPrefetchStage
>
(
b_grid_desc_bk0_n_bk1
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
make_multi_index
(
k_batch_id
,
n_block_data_idx_on_grid
,
0
),
b_element_op
,
b_block_desc_bk0_n_bk1
,
make_multi_index
(
0
,
0
,
0
),
...
...
@@ -640,6 +644,9 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle
n_thread_data_on_block_idx
[
I2
]),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
// add multiple d at the fisrt atomic option position
// if(k_batch_id == 0)
// tuple of reference to C/Ds tensor descriptors
const
auto
c_ds_desc_refs
=
concat_tuple_of_reference
(
tie
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
),
...
...
@@ -665,12 +672,6 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle
},
Number
<
NumDTensor
>
{}));
// only do bias at the 1st atomic add position
if
constexpr
(
EGlobalMemoryDataOperation
==
InMemoryDataOperationEnum
::
AtomicAdd
)
{
}
// blockwise copy C/D/E between LDS and global
auto
cde_block_copy_lds_and_global
=
ThreadGroupTensorSliceTransfer_v7
<
ThisThreadBlock
,
...
...
@@ -679,8 +680,9 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle
decltype
(
c_ds_desc_refs
),
decltype
(
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
)),
CDEElementwiseOperation
,
Sequence
<
static_cast
<
index_t
>
(
EGlobalMemoryDataOperation
)
>
,
// FIXME: make Sequence
// support arbitray type
Sequence
<
static_cast
<
index_t
>
(
EGlobalMemoryDataOperation
)
>
,
// FIXME: make
// Sequence support
// arbitray type
Sequence
<
1
,
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
1
,
...
...
@@ -698,9 +700,35 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle
{
c_ds_desc_refs
,
idx_c_ds_block_begin
,
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
),
make_tuple
(
make_multi_index
(
block_work_idx
[
I
0
],
0
,
block_work_idx
[
I
1
],
0
)),
make_tuple
(
make_multi_index
(
block_work_idx
[
I
1
],
0
,
block_work_idx
[
I
2
],
0
)),
cde_element_op
};
// block wise copy E between lds and global
auto
e_block_copy_lds_to_global
=
ThreadGroupTensorSliceTransfer_v6r1
<
ThisThreadBlock
,
// index_t BlockSize,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
// ElementwiseOperation,
EGlobalMemoryDataOperation
,
// DstInMemOp,
Sequence
<
1
,
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
1
,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>
,
// BlockSliceLengths,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
EDataType
,
// typename SrcData,
EDataType
,
// typename DstData,
decltype
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
e_grid_desc_mblock_mperblock_nblock_nperblock
),
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
3
,
// index_t VectorDim,
CDEShuffleBlockTransferScalarPerVector_NPerBlock
,
// index_t ScalarPerVector,
true
,
// bool ThreadTransferSrcResetCoordinateAfterRun,
false
>
// bool ThreadTransferDstResetCoordinateAfterRun
{
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
0
,
0
,
0
,
0
),
e_grid_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
block_work_idx
[
I1
],
0
,
block_work_idx
[
I2
],
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
// space filling curve for threadwise C in VGPR before shuffle
constexpr
auto
sfc_c_vgpr
=
SpaceFillingCurve
<
Sequence
<
MXdlPerWave
,
NXdlPerWave
,
1
,
1
,
M2
,
1
,
M4
,
1
>
,
...
...
@@ -741,12 +769,23 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle
// make sure it's safe to read from LDS
block_sync_lds
();
// each block copy its data from LDS to global
cde_block_copy_lds_and_global
.
Run
(
c_ds_desc_refs
,
c_ds_buf_refs
,
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
),
tie
(
e_grid_buf
));
if
(
k_batch_id
==
0
)
{
// each block copy its data from LDS to global
cde_block_copy_lds_and_global
.
Run
(
c_ds_desc_refs
,
c_ds_buf_refs
,
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
),
tie
(
e_grid_buf
));
}
else
{
e_block_copy_lds_to_global
.
Run
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
c_shuffle_block_buf
,
e_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_buf
);
}
if
constexpr
(
access_id
<
num_access
-
1
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment