Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
66206c23
"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "0537226c75d542225553ca2a7ae3b88fb75b0d7a"
Commit
66206c23
authored
Jan 21, 2022
by
Chao Liu
Browse files
rename
parent
ad8c418d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
53 additions
and
51 deletions
+53
-51
device_operation/include/device_gemm_xdl_c_shuffle.hpp
device_operation/include/device_gemm_xdl_c_shuffle.hpp
+13
-11
example/1_gemm_xdl/gemm_xdl.cpp
example/1_gemm_xdl/gemm_xdl.cpp
+40
-40
No files found.
device_operation/include/device_gemm_shuffle
_xdl
.hpp
→
device_operation/include/device_gemm_
xdl_c_
shuffle.hpp
View file @
66206c23
#ifndef DEVICE_GEMM_SHUFFLE_
XDL_
HPP
#define DEVICE_GEMM_SHUFFLE_
XDL_
HPP
#ifndef DEVICE_GEMM_
XDL_C_
SHUFFLE_HPP
#define DEVICE_GEMM_
XDL_C_
SHUFFLE_HPP
#include <iostream>
#include <sstream>
...
...
@@ -55,7 +55,7 @@ template <
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
,
index_t
CBlockTransferScalarPerVector_NWaveNPerXdl
>
struct
DeviceGemmShuffle
Xdl
struct
DeviceGemm
Xdl_C_
Shuffle
:
public
DeviceGemm
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
@@ -207,9 +207,11 @@ struct DeviceGemmShuffleXdl
b_element_op_
{
b_element_op
},
c_element_op_
{
c_element_op
}
{
a_grid_desc_k0_m_k1_
=
DeviceGemmShuffleXdl
::
MakeAGridDescriptor_K0_M_K1
(
M
,
K
,
StrideA
);
b_grid_desc_k0_n_k1_
=
DeviceGemmShuffleXdl
::
MakeBGridDescriptor_K0_N_K1
(
K
,
N
,
StrideB
);
c_grid_desc_m_n_
=
DeviceGemmShuffleXdl
::
MakeCGridDescriptor_M_N
(
M
,
N
,
StrideC
);
a_grid_desc_k0_m_k1_
=
DeviceGemmXdl_C_Shuffle
::
MakeAGridDescriptor_K0_M_K1
(
M
,
K
,
StrideA
);
b_grid_desc_k0_n_k1_
=
DeviceGemmXdl_C_Shuffle
::
MakeBGridDescriptor_K0_N_K1
(
K
,
N
,
StrideB
);
c_grid_desc_m_n_
=
DeviceGemmXdl_C_Shuffle
::
MakeCGridDescriptor_M_N
(
M
,
N
,
StrideC
);
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_k0_m_k1_
,
b_grid_desc_k0_n_k1_
,
c_grid_desc_m_n_
,
M01_
,
N01_
))
...
...
@@ -244,7 +246,7 @@ struct DeviceGemmShuffleXdl
// Invoker
struct
Invoker
:
public
BaseInvoker
{
using
Argument
=
DeviceGemmShuffle
Xdl
::
Argument
;
using
Argument
=
DeviceGemm
Xdl_C_
Shuffle
::
Argument
;
float
Run
(
const
Argument
&
arg
,
int
nrepeat
=
1
)
{
...
...
@@ -285,8 +287,8 @@ struct DeviceGemmShuffleXdl
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
remove_reference_t
<
DeviceGemmShuffle
Xdl
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceGemmShuffle
Xdl
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceGemm
Xdl_C_
Shuffle
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceGemm
Xdl_C_
Shuffle
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
>
,
...
...
@@ -319,8 +321,8 @@ struct DeviceGemmShuffleXdl
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
remove_reference_t
<
DeviceGemmShuffle
Xdl
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceGemmShuffle
Xdl
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceGemm
Xdl_C_
Shuffle
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceGemm
Xdl_C_
Shuffle
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
>
,
...
...
example/1_gemm_xdl/gemm_xdl.cpp
View file @
66206c23
...
...
@@ -12,7 +12,7 @@
#include "host_gemm.hpp"
#include "device_tensor.hpp"
#include "device_base.hpp"
#include "device_gemm_shuffle
_xdl
.hpp"
#include "device_gemm_
xdl_c_
shuffle.hpp"
#include "element_wise_operation.hpp"
template
<
ck
::
index_t
...
Is
>
...
...
@@ -32,44 +32,44 @@ using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using
CElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
// clang-format off
using
DeviceGemm
Shuffle
Instance
=
ck
::
tensor_operation
::
device
::
DeviceGemmShuffleXdl
<
ADataType
,
// ADataType
BDataType
,
// BDataType
CDataType
,
// CDataType
AccDataType
,
// AccDataType
ALayout
,
// ALayout
BLayout
,
// BLayout
CLayout
,
// CLayout
AElementOp
,
// AElementwiseOperation
BElementOp
,
// BElementwiseOperation
CElementOp
,
// CElementwiseOperation
256
,
// BlockSize
256
,
// MPerBlock
128
,
// NPerBlock
4
,
// K0PerBlock
8
,
// K1
32
,
// MPerXDL
32
,
// NPerXDL
4
,
// MXdlPerWave
2
,
// NXdlPerWave
S
<
4
,
64
,
1
>
,
// ABlockTransferThreadClusterLengths_K0_M_K1
S
<
1
,
0
,
2
>
,
// ABlockTransferThreadClusterArrangeOrder
S
<
1
,
0
,
2
>
,
// ABlockTransferSrcAccessOrder
2
,
// ABlockTransferSrcVectorDim
8
,
// ABlockTransferSrcScalarPerVector
8
,
// ABlockTransferDstScalarPerVector_K1
true
,
// ABlockLdsAddExtraM
S
<
4
,
64
,
1
>
,
// BBlockTransferThreadClusterLengths_K0_N_K1
S
<
1
,
0
,
2
>
,
// BBlockTransferThreadClusterArrangeOrder
S
<
1
,
0
,
2
>
,
// BBlockTransferSrcAccessOrder
2
,
// BBlockTransferSrcVectorDim
8
,
// BBlockTransferSrcScalarPerVector
8
,
// BBlockTransferDstScalarPerVector_K1
true
,
// BBlockLdsAddExtraN
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
1
,
32
,
1
,
1
,
8
>
,
// CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
8
>
;
// CBlockTransferScalarPerVector_NWaveNPerXdl
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmXdl_C_Shuffle
<
ADataType
,
// ADataType
BDataType
,
// BDataType
CDataType
,
// CDataType
AccDataType
,
// AccDataType
ALayout
,
// ALayout
BLayout
,
// BLayout
CLayout
,
// CLayout
AElementOp
,
// AElementwiseOperation
BElementOp
,
// BElementwiseOperation
CElementOp
,
// CElementwiseOperation
256
,
// BlockSize
256
,
// MPerBlock
128
,
// NPerBlock
4
,
// K0PerBlock
8
,
// K1
32
,
// MPerXDL
32
,
// NPerXDL
4
,
// MXdlPerWave
2
,
// NXdlPerWave
S
<
4
,
64
,
1
>
,
// ABlockTransferThreadClusterLengths_K0_M_K1
S
<
1
,
0
,
2
>
,
// ABlockTransferThreadClusterArrangeOrder
S
<
1
,
0
,
2
>
,
// ABlockTransferSrcAccessOrder
2
,
// ABlockTransferSrcVectorDim
8
,
// ABlockTransferSrcScalarPerVector
8
,
// ABlockTransferDstScalarPerVector_K1
true
,
// ABlockLdsAddExtraM
S
<
4
,
64
,
1
>
,
// BBlockTransferThreadClusterLengths_K0_N_K1
S
<
1
,
0
,
2
>
,
// BBlockTransferThreadClusterArrangeOrder
S
<
1
,
0
,
2
>
,
// BBlockTransferSrcAccessOrder
2
,
// BBlockTransferSrcVectorDim
8
,
// BBlockTransferSrcScalarPerVector
8
,
// BBlockTransferDstScalarPerVector_K1
true
,
// BBlockLdsAddExtraN
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
1
,
32
,
1
,
1
,
8
>
,
// CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
8
>
;
// CBlockTransferScalarPerVector_NWaveNPerXdl
// clang-format on
template
<
typename
AType
,
...
...
@@ -192,7 +192,7 @@ int main(int argc, char* argv[])
c_m_n_device_buf
.
ToDevice
(
c_m_n_device_result
.
mData
.
data
());
// do GEMM
auto
gemm
=
DeviceGemm
Shuffle
Instance
{};
auto
gemm
=
DeviceGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
auto
argument
=
gemm
.
MakeArgument
(
static_cast
<
ADataType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
b_k_n_device_buf
.
GetDeviceBuffer
()),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment