Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
523a4045
Commit
523a4045
authored
Jan 21, 2022
by
Chao Liu
Browse files
clang-format
parent
2d31e921
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
73 additions
and
72 deletions
+73
-72
device_operation/include/device_gemm_shuffle_xdl.hpp
device_operation/include/device_gemm_shuffle_xdl.hpp
+73
-72
No files found.
device_operation/include/device_gemm_shuffle_xdl.hpp
View file @
523a4045
...
@@ -17,43 +17,44 @@ namespace ck {
...
@@ -17,43 +17,44 @@ namespace ck {
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
template
<
typename
ADataType
,
template
<
typename
BDataType
,
typename
ADataType
,
typename
CDataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CDataType
,
typename
ALayout
,
typename
AccDataType
,
typename
BLayout
,
typename
ALayout
,
typename
CLayout
,
typename
BLayout
,
typename
AElementwiseOperation
,
typename
CLayout
,
typename
BElementwiseOperation
,
typename
AElementwiseOperation
,
typename
CElementwiseOperation
,
typename
BElementwiseOperation
,
ck
::
index_t
BlockSize
,
typename
CElementwiseOperation
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
BlockSize
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
K0PerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
K1
,
ck
::
index_t
K0PerBlock
,
ck
::
index_t
MPerXDL
,
ck
::
index_t
K1
,
ck
::
index_t
NPerXDL
,
ck
::
index_t
MPerXDL
,
ck
::
index_t
MXdlPerWave
,
ck
::
index_t
NPerXDL
,
ck
::
index_t
NXdlPerWave
,
ck
::
index_t
MXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
ck
::
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
typename
ABlockTransferSrcAccessOrder
,
typename
ABlockTransferThreadClusterArrangeOrder
,
ck
::
index_t
ABlockTransferSrcVectorDim
,
typename
ABlockTransferSrcAccessOrder
,
ck
::
index_t
ABlockTransferSrcScalarPerVector
,
ck
::
index_t
ABlockTransferSrcVectorDim
,
ck
::
index_t
ABlockTransferDstScalarPerVector_K1
,
ck
::
index_t
ABlockTransferSrcScalarPerVector
,
bool
ABlockLdsAddExtraM
,
ck
::
index_t
ABlockTransferDstScalarPerVector_K1
,
typename
BBlockTransferThreadClusterLengths_K0_N_K1
,
bool
ABlockLdsAddExtraM
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferThreadClusterLengths_K0_N_K1
,
typename
BBlockTransferSrcAccessOrder
,
typename
BBlockTransferThreadClusterArrangeOrder
,
ck
::
index_t
BBlockTransferSrcVectorDim
,
typename
BBlockTransferSrcAccessOrder
,
ck
::
index_t
BBlockTransferSrcScalarPerVector
,
ck
::
index_t
BBlockTransferSrcVectorDim
,
ck
::
index_t
BBlockTransferDstScalarPerVector_K1
,
ck
::
index_t
BBlockTransferSrcScalarPerVector
,
bool
BBlockLdsAddExtraN
,
ck
::
index_t
BBlockTransferDstScalarPerVector_K1
,
index_t
CShuffleMXdlPerWavePerShuffle
,
bool
BBlockLdsAddExtraN
,
index_t
CShuffleNXdlPerWavePerShuffle
,
index_t
CShuffleMXdlPerWavePerShuffle
,
typename
CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
,
index_t
CShuffleNXdlPerWavePerShuffle
,
index_t
CBlockTransferScalarPerVector_NWaveNPerXdl
>
typename
CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
,
index_t
CBlockTransferScalarPerVector_NWaveNPerXdl
>
struct
DeviceGemmShuffleXdl
struct
DeviceGemmShuffleXdl
:
public
DeviceGemm
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
:
public
DeviceGemm
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
{
{
...
@@ -175,7 +176,6 @@ struct DeviceGemmShuffleXdl
...
@@ -175,7 +176,6 @@ struct DeviceGemmShuffleXdl
CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
,
CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
,
CBlockTransferScalarPerVector_NWaveNPerXdl
>
;
CBlockTransferScalarPerVector_NWaveNPerXdl
>
;
// Argument
// Argument
struct
Argument
:
public
BaseArgument
struct
Argument
:
public
BaseArgument
{
{
...
@@ -215,8 +215,9 @@ struct DeviceGemmShuffleXdl
...
@@ -215,8 +215,9 @@ struct DeviceGemmShuffleXdl
a_grid_desc_k0_m_k1_
,
b_grid_desc_k0_n_k1_
,
c_grid_desc_m_n_
,
M01_
,
N01_
))
a_grid_desc_k0_m_k1_
,
b_grid_desc_k0_n_k1_
,
c_grid_desc_m_n_
,
M01_
,
N01_
))
{
{
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_
=
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_
=
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
(
GridwiseGemm
::
c_grid_desc_m_n_
);
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
(
c_grid_desc_m_n_
);
block_2_ctile_map_
=
GridwiseGemm
::
MakeBlock2CTileMap
(
c_grid_desc_m_n_
,
M01
,
N01
);
block_2_ctile_map_
=
GridwiseGemm
::
MakeBlock2CTileMap
(
c_grid_desc_m_n_
,
M01
,
N01
);
}
}
...
@@ -295,21 +296,22 @@ struct DeviceGemmShuffleXdl
...
@@ -295,21 +296,22 @@ struct DeviceGemmShuffleXdl
remove_reference_t
<
typename
GridwiseGemm
::
Block2CTileMap
>
,
remove_reference_t
<
typename
GridwiseGemm
::
Block2CTileMap
>
,
true
>
;
true
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
ave_time
=
launch_and_time_kernel
(
nrepeat
,
kernel
,
dim3
(
grid_size
),
nrepeat
,
dim3
(
BlockSize
),
dim3
(
grid_size
),
0
,
dim3
(
BlockSize
),
arg
.
p_a_grid_
,
0
,
arg
.
p_b_grid_
,
arg
.
p_a_grid_
,
arg
.
p_c_grid_
,
arg
.
p_b_grid_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
p_c_grid_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
a_element_op_
,
arg
.
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_
,
arg
.
b_element_op_
,
arg
.
a_element_op_
,
arg
.
c_element_op_
,
arg
.
b_element_op_
,
arg
.
block_2_ctile_map_
);
arg
.
c_element_op_
,
arg
.
block_2_ctile_map_
);
}
}
else
else
{
{
...
@@ -328,21 +330,22 @@ struct DeviceGemmShuffleXdl
...
@@ -328,21 +330,22 @@ struct DeviceGemmShuffleXdl
remove_reference_t
<
typename
GridwiseGemm
::
Block2CTileMap
>
,
remove_reference_t
<
typename
GridwiseGemm
::
Block2CTileMap
>
,
false
>
;
false
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
ave_time
=
launch_and_time_kernel
(
nrepeat
,
kernel
,
dim3
(
grid_size
),
nrepeat
,
dim3
(
BlockSize
),
dim3
(
grid_size
),
0
,
dim3
(
BlockSize
),
arg
.
p_a_grid_
,
0
,
arg
.
p_b_grid_
,
arg
.
p_a_grid_
,
arg
.
p_c_grid_
,
arg
.
p_b_grid_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
p_c_grid_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
a_element_op_
,
arg
.
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_
,
arg
.
b_element_op_
,
arg
.
a_element_op_
,
arg
.
c_element_op_
,
arg
.
b_element_op_
,
arg
.
block_2_ctile_map_
);
arg
.
c_element_op_
,
arg
.
block_2_ctile_map_
);
}
}
return
ave_time
;
return
ave_time
;
...
@@ -355,7 +358,6 @@ struct DeviceGemmShuffleXdl
...
@@ -355,7 +358,6 @@ struct DeviceGemmShuffleXdl
}
}
};
};
static
constexpr
bool
IsValidCompilationParameter
()
static
constexpr
bool
IsValidCompilationParameter
()
{
{
// TODO: properly implement this check
// TODO: properly implement this check
...
@@ -438,7 +440,6 @@ struct DeviceGemmShuffleXdl
...
@@ -438,7 +440,6 @@ struct DeviceGemmShuffleXdl
c_element_op
);
c_element_op
);
}
}
// polymorphic
// polymorphic
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment