Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
300ac4e4
"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "296b01e1a15a4feecac267050543d62e921d5875"
Commit
300ac4e4
authored
Jun 02, 2022
by
rocking
Browse files
Implement gemm bias add reduction
parent
09a2b547
Changes
6
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
703 additions
and
77 deletions
+703
-77
example/21_gemm_layernorm/CMakeLists.txt
example/21_gemm_layernorm/CMakeLists.txt
+1
-0
example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_fp16.cpp
..._gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_fp16.cpp
+415
-0
include/ck/tensor_operation/gpu/device/device_gemm_bias_add_reduce_xdl_cshuffle.hpp
...n/gpu/device/device_gemm_bias_add_reduce_xdl_cshuffle.hpp
+73
-31
include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp
...ude/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp
+46
-0
include/ck/tensor_operation/gpu/element/element_wise_operation.hpp
...k/tensor_operation/gpu/element/element_wise_operation.hpp
+15
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp
...pu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp
+153
-46
No files found.
example/21_gemm_layernorm/CMakeLists.txt
View file @
300ac4e4
add_example_executable
(
example_gemm_bias_relu_add_layernorm_xdl_fp16 gemm_bias_relu_add_layernorm_xdl_fp16.cpp
)
add_example_executable
(
example_gemm_layernorm_xdl_fp16 gemm_layernorm_xdl_fp16.cpp
)
add_example_executable
(
example_gemm_layernorm_xdl_fp16 gemm_layernorm_xdl_fp16.cpp
)
example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_fp16.cpp
0 → 100644
View file @
300ac4e4
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/device_gemm_bias_add_reduce_xdl_cshuffle.hpp
View file @
300ac4e4
...
@@ -7,7 +7,7 @@
...
@@ -7,7 +7,7 @@
#include "tensor_layout.hpp"
#include "tensor_layout.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_reduce_xdl_cshuffle_v1.hpp"
#include "gridwise_gemm_
bias_add_
reduce_xdl_cshuffle_v1.hpp"
#include "gemm_specialization.hpp"
#include "gemm_specialization.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -23,6 +23,8 @@ template <typename ALayout,
...
@@ -23,6 +23,8 @@ template <typename ALayout,
typename
ADataType
,
typename
ADataType
,
typename
BDataType
,
typename
BDataType
,
typename
CDataType
,
typename
CDataType
,
typename
C0DataType
,
typename
C1DataType
,
typename
GemmAccDataType
,
typename
GemmAccDataType
,
typename
CShuffleDataType
,
typename
CShuffleDataType
,
typename
ReduceAccDataType
,
typename
ReduceAccDataType
,
...
@@ -68,14 +70,15 @@ template <typename ALayout,
...
@@ -68,14 +70,15 @@ template <typename ALayout,
index_t
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock
,
index_t
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock
,
index_t
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock
,
index_t
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()>
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()>
struct
DeviceGemmReduce_Xdl_CShuffle
:
public
DeviceGemmReduce
<
DPtrsGlobal
,
struct
DeviceGemmBiasAddReduce_Xdl_CShuffle
AElementwiseOperation
,
:
public
DeviceGemmBiasAddReduce
<
DPtrsGlobal
,
BElementwiseOperation
,
AElementwiseOperation
,
CElementwiseOperation
,
BElementwiseOperation
,
DxsInElementwiseOperation
,
CElementwiseOperation
,
DxsAccElementwiseOperation
>
DxsInElementwiseOperation
,
DxsAccElementwiseOperation
>
{
{
using
DeviceOp
=
DeviceGemmReduce_Xdl_CShuffle
;
using
DeviceOp
=
DeviceGemm
BiasAdd
Reduce_Xdl_CShuffle
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -374,14 +377,18 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
...
@@ -374,14 +377,18 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
(
1
,
1
,
1
));
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
(
1
,
1
,
1
));
using
BGridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
(
1
,
1
,
1
));
using
BGridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
(
1
,
1
,
1
));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
using
C0GridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
0
));
using
C1GridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
using
DGridDesc_M
=
decltype
(
MakeDGridDescriptor_M
(
1
));
using
DGridDesc_M
=
decltype
(
MakeDGridDescriptor_M
(
1
));
// GridwiseGemm
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
<
using
GridwiseGemm
=
GridwiseGemm
BiasAdd
Reduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
<
ADataType
,
// TODO: distinguish A/B datatype
ADataType
,
// TODO: distinguish A/B datatype
GemmAccDataType
,
GemmAccDataType
,
CShuffleDataType
,
CShuffleDataType
,
CDataType
,
CDataType
,
C0DataType
,
C1DataType
,
ReduceAccDataType
,
ReduceAccDataType
,
DPtrsGlobal
,
DPtrsGlobal
,
AElementwiseOperation
,
AElementwiseOperation
,
...
@@ -395,6 +402,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
...
@@ -395,6 +402,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
AGridDesc_AK0_M_AK1
,
AGridDesc_AK0_M_AK1
,
BGridDesc_BK0_N_BK1
,
BGridDesc_BK0_N_BK1
,
CGridDesc_M_N
,
CGridDesc_M_N
,
C0GridDesc_M_N
,
C1GridDesc_M_N
,
DGridDesc_M
,
DGridDesc_M
,
NumGemmKPrefetchStage
,
NumGemmKPrefetchStage
,
BlockSize
,
BlockSize
,
...
@@ -438,6 +447,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
...
@@ -438,6 +447,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
Argument
(
const
ADataType
*
p_a_grid
,
Argument
(
const
ADataType
*
p_a_grid
,
const
BDataType
*
p_b_grid
,
const
BDataType
*
p_b_grid
,
CDataType
*
p_c_grid
,
CDataType
*
p_c_grid
,
const
C0DataType
*
p_c0_grid
,
const
C1DataType
*
p_c1_grid
,
DPtrsGlobal
p_ds_grid
,
DPtrsGlobal
p_ds_grid
,
index_t
MRaw
,
index_t
MRaw
,
index_t
NRaw
,
index_t
NRaw
,
...
@@ -445,6 +456,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
...
@@ -445,6 +456,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
index_t
StrideA
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideB
,
index_t
StrideC
,
index_t
StrideC
,
index_t
StrideC1
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
CElementwiseOperation
c_element_op
,
...
@@ -453,12 +465,18 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
...
@@ -453,12 +465,18 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
:
p_a_grid_
{
p_a_grid
},
:
p_a_grid_
{
p_a_grid
},
p_b_grid_
{
p_b_grid
},
p_b_grid_
{
p_b_grid
},
p_c_grid_
{
p_c_grid
},
p_c_grid_
{
p_c_grid
},
p_c0_grid_
{
p_c0_grid
},
p_c1_grid_
{
p_c1_grid
},
p_ds_grid_
{
p_ds_grid
},
p_ds_grid_
{
p_ds_grid
},
a_grid_desc_ak0_m_ak1_
{
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
MRaw
,
KRaw
,
StrideA
)},
a_grid_desc_ak0_m_ak1_
{
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
MRaw
,
KRaw
,
StrideA
)},
b_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
KRaw
,
NRaw
,
StrideB
)},
b_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
KRaw
,
NRaw
,
StrideB
)},
c_grid_desc_m_n_
{
DeviceOp
::
MakeCGridDescriptor_M_N
(
MRaw
,
NRaw
,
StrideC
)},
c_grid_desc_m_n_
{
DeviceOp
::
MakeCGridDescriptor_M_N
(
MRaw
,
NRaw
,
StrideC
)},
c0_grid_desc_m_n_
{
DeviceOp
::
MakeCGridDescriptor_M_N
(
MRaw
,
NRaw
,
0
)},
c1_grid_desc_m_n_
{
DeviceOp
::
MakeCGridDescriptor_M_N
(
MRaw
,
NRaw
,
StrideC1
)},
d_grid_desc_m_
{
DeviceOp
::
MakeDGridDescriptor_M
(
MRaw
)},
d_grid_desc_m_
{
DeviceOp
::
MakeDGridDescriptor_M
(
MRaw
)},
c_grid_desc_mblock_mperblock_nblock_nperblock_
{},
c_grid_desc_mblock_mperblock_nblock_nperblock_
{},
c0_grid_desc_mblock_mperblock_nblock_nperblock_
{},
c1_grid_desc_mblock_mperblock_nblock_nperblock_
{},
d_grid_desc_mblock_mperblock_
{},
d_grid_desc_mblock_mperblock_
{},
block_2_ctile_map_
{
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
)},
block_2_ctile_map_
{
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
)},
a_element_op_
{
a_element_op
},
a_element_op_
{
a_element_op
},
...
@@ -476,6 +494,14 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
...
@@ -476,6 +494,14 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n_
);
c_grid_desc_m_n_
);
c0_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c0_grid_desc_m_n_
);
c1_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c1_grid_desc_m_n_
);
d_grid_desc_mblock_mperblock_
=
d_grid_desc_mblock_mperblock_
=
GridwiseGemm
::
MakeDGridDescriptor_MBlock_MPerBlock
(
d_grid_desc_m_
);
GridwiseGemm
::
MakeDGridDescriptor_MBlock_MPerBlock
(
d_grid_desc_m_
);
}
}
...
@@ -485,13 +511,21 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
...
@@ -485,13 +511,21 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
const
ADataType
*
p_a_grid_
;
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
const
BDataType
*
p_b_grid_
;
CDataType
*
p_c_grid_
;
CDataType
*
p_c_grid_
;
const
C0DataType
*
p_c0_grid_
;
const
C1DataType
*
p_c1_grid_
;
DPtrsGlobal
p_ds_grid_
;
DPtrsGlobal
p_ds_grid_
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
C0GridDesc_M_N
c0_grid_desc_m_n_
;
C1GridDesc_M_N
c1_grid_desc_m_n_
;
DGridDesc_M
d_grid_desc_m_
;
DGridDesc_M
d_grid_desc_m_
;
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_
;
c_grid_desc_mblock_mperblock_nblock_nperblock_
;
typename
GridwiseGemm
::
C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c0_grid_desc_mblock_mperblock_nblock_nperblock_
;
typename
GridwiseGemm
::
C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c1_grid_desc_mblock_mperblock_nblock_nperblock_
;
typename
GridwiseGemm
::
DGridDescriptor_MBlock_MPerBlock
d_grid_desc_mblock_mperblock_
;
typename
GridwiseGemm
::
DGridDescriptor_MBlock_MPerBlock
d_grid_desc_mblock_mperblock_
;
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
AElementwiseOperation
a_element_op_
;
AElementwiseOperation
a_element_op_
;
...
@@ -508,26 +542,6 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
...
@@ -508,26 +542,6 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
{
#if 0
{
std::cout << "arg.a_grid_desc_ak0_m_ak1_{"
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", "
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", "
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_bk0_n_bk1_{"
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", "
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", "
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
std::cout << "arg.d_grid_desc_m_{ " << arg.d_grid_desc_m_.GetLength(I0) << "}"
<< std::endl;
}
#endif
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_ak0_m_ak1_
,
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_m_n_
,
arg
.
c_grid_desc_m_n_
,
...
@@ -545,10 +559,12 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
...
@@ -545,10 +559,12 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
float
elapsed_time
=
0.0
f
;
float
elapsed_time
=
0.0
f
;
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
{
{
const
auto
kernel
=
kernel_gemm_reduce_xdl_cshuffle_v1
<
const
auto
kernel
=
kernel_gemm_
bias_add_
reduce_xdl_cshuffle_v1
<
GridwiseGemm
,
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
CDataType
,
C0DataType
,
C1DataType
,
DPtrsGlobal
,
DPtrsGlobal
,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
...
@@ -558,6 +574,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
...
@@ -558,6 +574,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
DGridDescriptor_MBlock_MPerBlock
,
typename
GridwiseGemm
::
DGridDescriptor_MBlock_MPerBlock
,
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
true
>
;
true
>
;
...
@@ -571,6 +589,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
...
@@ -571,6 +589,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
arg
.
p_a_grid_
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
p_c_grid_
,
arg
.
p_c0_grid_
,
arg
.
p_c1_grid_
,
arg
.
p_ds_grid_
,
arg
.
p_ds_grid_
,
arg
.
a_element_op_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
b_element_op_
,
...
@@ -580,15 +600,19 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
...
@@ -580,15 +600,19 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
c0_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
c1_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
d_grid_desc_mblock_mperblock_
,
arg
.
d_grid_desc_mblock_mperblock_
,
arg
.
block_2_ctile_map_
);
arg
.
block_2_ctile_map_
);
}
}
else
else
{
{
const
auto
kernel
=
kernel_gemm_reduce_xdl_cshuffle_v1
<
const
auto
kernel
=
kernel_gemm_
bias_add_
reduce_xdl_cshuffle_v1
<
GridwiseGemm
,
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
CDataType
,
C0DataType
,
C1DataType
,
DPtrsGlobal
,
DPtrsGlobal
,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
...
@@ -598,6 +622,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
...
@@ -598,6 +622,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
DGridDescriptor_MBlock_MPerBlock
,
typename
GridwiseGemm
::
DGridDescriptor_MBlock_MPerBlock
,
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
false
>
;
false
>
;
...
@@ -611,6 +637,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
...
@@ -611,6 +637,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
arg
.
p_a_grid_
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
p_c_grid_
,
arg
.
p_c0_grid_
,
arg
.
p_c1_grid_
,
arg
.
p_ds_grid_
,
arg
.
p_ds_grid_
,
arg
.
a_element_op_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
b_element_op_
,
...
@@ -620,6 +648,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
...
@@ -620,6 +648,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
c0_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
c1_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
d_grid_desc_mblock_mperblock_
,
arg
.
d_grid_desc_mblock_mperblock_
,
arg
.
block_2_ctile_map_
);
arg
.
block_2_ctile_map_
);
}
}
...
@@ -658,6 +688,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
...
@@ -658,6 +688,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
static
auto
MakeArgument
(
const
ADataType
*
p_a
,
static
auto
MakeArgument
(
const
ADataType
*
p_a
,
const
BDataType
*
p_b
,
const
BDataType
*
p_b
,
CDataType
*
p_c
,
CDataType
*
p_c
,
const
C0DataType
*
p_c0
,
const
C1DataType
*
p_c1
,
DPtrsGlobal
p_dxs
,
DPtrsGlobal
p_dxs
,
index_t
MRaw
,
index_t
MRaw
,
index_t
NRaw
,
index_t
NRaw
,
...
@@ -665,6 +697,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
...
@@ -665,6 +697,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
index_t
StrideA
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideB
,
index_t
StrideC
,
index_t
StrideC
,
index_t
StrideC1
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
CElementwiseOperation
c_element_op
,
...
@@ -674,6 +707,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
...
@@ -674,6 +707,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
return
Argument
{
p_a
,
return
Argument
{
p_a
,
p_b
,
p_b
,
p_c
,
p_c
,
p_c0
,
p_c1
,
p_dxs
,
p_dxs
,
MRaw
,
MRaw
,
NRaw
,
NRaw
,
...
@@ -681,6 +716,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
...
@@ -681,6 +716,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
StrideA
,
StrideA
,
StrideB
,
StrideB
,
StrideC
,
StrideC
,
StrideC1
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
c_element_op
,
c_element_op
,
...
@@ -694,6 +730,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
...
@@ -694,6 +730,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
const
void
*
p_b
,
void
*
p_c
,
void
*
p_c
,
const
void
*
p_c0
,
const
void
*
p_c1
,
DPtrsGlobal
p_dxs
,
DPtrsGlobal
p_dxs
,
index_t
MRaw
,
index_t
MRaw
,
index_t
NRaw
,
index_t
NRaw
,
...
@@ -701,6 +739,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
...
@@ -701,6 +739,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
index_t
StrideA
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideB
,
index_t
StrideC
,
index_t
StrideC
,
index_t
StrideC1
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
CElementwiseOperation
c_element_op
,
...
@@ -711,6 +750,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
...
@@ -711,6 +750,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
CDataType
*>
(
p_c
),
static_cast
<
CDataType
*>
(
p_c
),
static_cast
<
const
C0DataType
*>
(
p_c0
),
static_cast
<
const
C1DataType
*>
(
p_c1
),
p_dxs
,
p_dxs
,
MRaw
,
MRaw
,
NRaw
,
NRaw
,
...
@@ -718,6 +759,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
...
@@ -718,6 +759,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
StrideA
,
StrideA
,
StrideB
,
StrideB
,
StrideC
,
StrideC
,
StrideC1
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
c_element_op
,
c_element_op
,
...
...
include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp
View file @
300ac4e4
...
@@ -48,6 +48,52 @@ using DeviceGemmReducePtr = std::unique_ptr<DeviceGemmReduce<DPtrsGlobal,
...
@@ -48,6 +48,52 @@ using DeviceGemmReducePtr = std::unique_ptr<DeviceGemmReduce<DPtrsGlobal,
DxsInElementwiseOperation
,
DxsInElementwiseOperation
,
DxsAccElementwiseOperation
>>
;
DxsAccElementwiseOperation
>>
;
template
<
typename
DPtrsGlobal
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
DxsInElementwiseOperation
,
typename
DxsAccElementwiseOperation
>
struct
DeviceGemmBiasAddReduce
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_c
,
const
void
*
p_c0
,
const
void
*
p_c1
,
DPtrsGlobal
p_dxs
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
ck
::
index_t
StrideC
,
ck
::
index_t
StrideC1
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
DxsInElementwiseOperation
dxs_in_element_op
,
DxsAccElementwiseOperation
dxs_out_element_op
,
ck
::
index_t
BatchCount
=
1
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
typename
DPtrsGlobal
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
DxsInElementwiseOperation
,
typename
DxsAccElementwiseOperation
>
using
DeviceGemmBiasAddReducePtr
=
std
::
unique_ptr
<
DeviceGemmBiasAddReduce
<
DPtrsGlobal
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
DxsInElementwiseOperation
,
DxsAccElementwiseOperation
>>
;
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
include/ck/tensor_operation/gpu/element/element_wise_operation.hpp
View file @
300ac4e4
...
@@ -143,6 +143,21 @@ struct AddHardswishAdd
...
@@ -143,6 +143,21 @@ struct AddHardswishAdd
}
}
};
};
struct
Relu
{
__host__
__device__
void
operator
()(
float
&
y
,
const
float
&
x
)
const
{
y
=
x
>
0
?
x
:
0
;
}
__host__
__device__
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
{
y
=
x
>
0
?
x
:
0
;
}
__host__
__device__
void
operator
()(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
{
y
=
x
>
0
?
x
:
0
;
}
__host__
__device__
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
{
y
=
x
>
0
?
x
:
0
;
}
__host__
__device__
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
{
y
=
x
>
0
?
x
:
0
;
}
__host__
__device__
void
operator
()(
double
&
y
,
const
double
&
x
)
const
{
y
=
x
>
0
?
x
:
0
;
}
};
struct
Normalize
struct
Normalize
{
{
Normalize
(
float
epsilon
=
1e-4
)
:
epsilon_
(
epsilon
)
{}
Normalize
(
float
epsilon
=
1e-4
)
:
epsilon_
(
epsilon
)
{}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp
View file @
300ac4e4
...
@@ -16,6 +16,8 @@ namespace ck {
...
@@ -16,6 +16,8 @@ namespace ck {
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatAB
,
typename
FloatC
,
typename
FloatC
,
typename
FloatC0
,
typename
FloatC1
,
typename
DPtrsGlobal
,
typename
DPtrsGlobal
,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
...
@@ -25,6 +27,8 @@ template <typename GridwiseGemm,
...
@@ -25,6 +27,8 @@ template <typename GridwiseGemm,
typename
AGridDesc_AK0_M_AK1
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
DGridDescriptor_MBlock_MPerBlock
,
typename
DGridDescriptor_MBlock_MPerBlock
,
typename
Block2CTileMap
,
typename
Block2CTileMap
,
bool
HasMainKBlockLoop
>
bool
HasMainKBlockLoop
>
...
@@ -32,10 +36,12 @@ __global__ void
...
@@ -32,10 +36,12 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_gemm_reduce_xdl_cshuffle_v1
(
kernel_gemm_
bias_add_
reduce_xdl_cshuffle_v1
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
FloatC0
*
__restrict__
p_c0_grid
,
const
FloatC1
*
__restrict__
p_c1_grid
,
DPtrsGlobal
p_ds_grid
,
DPtrsGlobal
p_ds_grid
,
const
AElementwiseOperation
a_element_op
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
BElementwiseOperation
b_element_op
,
...
@@ -46,6 +52,10 @@ __global__ void
...
@@ -46,6 +52,10 @@ __global__ void
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c0_grid_desc_mblock_mperblock_nblock_nperblock
,
const
C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c1_grid_desc_mblock_mperblock_nblock_nperblock
,
const
DGridDescriptor_MBlock_MPerBlock
d_grid_desc_mblock_mperblock
,
const
DGridDescriptor_MBlock_MPerBlock
d_grid_desc_mblock_mperblock
,
const
Block2CTileMap
block_2_ctile_map
)
const
Block2CTileMap
block_2_ctile_map
)
{
{
...
@@ -55,6 +65,8 @@ __global__ void
...
@@ -55,6 +65,8 @@ __global__ void
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
p_b_grid
,
p_b_grid
,
p_c_grid
,
p_c_grid
,
p_c0_grid
,
p_c1_grid
,
p_ds_grid
,
p_ds_grid
,
p_shared
,
p_shared
,
a_element_op
,
a_element_op
,
...
@@ -65,12 +77,16 @@ __global__ void
...
@@ -65,12 +77,16 @@ __global__ void
a_grid_desc_ak0_m_ak1
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c0_grid_desc_mblock_mperblock_nblock_nperblock
,
c1_grid_desc_mblock_mperblock_nblock_nperblock
,
d_grid_desc_mblock_mperblock
,
d_grid_desc_mblock_mperblock
,
block_2_ctile_map
);
block_2_ctile_map
);
#else
#else
ignore
=
p_a_grid
;
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_b_grid
;
ignore
=
p_c_grid
;
ignore
=
p_c_grid
;
ignore
=
p_c0_grid
;
ignore
=
p_c1_grid
;
ignore
=
p_ds_grid
;
ignore
=
p_ds_grid
;
ignore
=
a_element_op
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
b_element_op
;
...
@@ -80,6 +96,8 @@ __global__ void
...
@@ -80,6 +96,8 @@ __global__ void
ignore
=
a_grid_desc_ak0_m_ak1
;
ignore
=
a_grid_desc_ak0_m_ak1
;
ignore
=
b_grid_desc_bk0_n_bk1
;
ignore
=
b_grid_desc_bk0_n_bk1
;
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
c0_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
c1_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
d_grid_desc_mblock_mperblock
;
ignore
=
d_grid_desc_mblock_mperblock
;
ignore
=
block_2_ctile_map
;
ignore
=
block_2_ctile_map
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
...
@@ -89,6 +107,8 @@ template <typename FloatAB,
...
@@ -89,6 +107,8 @@ template <typename FloatAB,
typename
FloatGemmAcc
,
typename
FloatGemmAcc
,
typename
FloatCShuffle
,
typename
FloatCShuffle
,
typename
FloatC
,
typename
FloatC
,
typename
FloatC0
,
typename
FloatC1
,
typename
FloatReduceAcc
,
typename
FloatReduceAcc
,
typename
DPtrsGlobal
,
typename
DPtrsGlobal
,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
...
@@ -102,6 +122,8 @@ template <typename FloatAB,
...
@@ -102,6 +122,8 @@ template <typename FloatAB,
typename
AGridDesc_AK0_M_AK1
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
CGridDesc_M_N
,
typename
CGridDesc_M_N
,
typename
C0GridDesc_M_N
,
typename
C1GridDesc_M_N
,
typename
DGridDesc_M
,
typename
DGridDesc_M
,
index_t
NumGemmKPrefetchStage
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
BlockSize
,
...
@@ -138,7 +160,7 @@ template <typename FloatAB,
...
@@ -138,7 +160,7 @@ template <typename FloatAB,
index_t
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock
,
index_t
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock
,
index_t
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock
,
index_t
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock
,
LoopScheduler
LoopSched
>
LoopScheduler
LoopSched
>
struct
GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
struct
GridwiseGemm
BiasAdd
Reduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -268,8 +290,9 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -268,8 +290,9 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
return
GridwiseGemmPipe
::
CalculateHasMainLoop
(
num_loop
);
return
GridwiseGemmPipe
::
CalculateHasMainLoop
(
num_loop
);
}
}
template
<
typename
CGridDesc_M_N_
>
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
const
CGridDesc_M_N
_
&
c_grid_desc_m_n
)
{
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
...
@@ -313,6 +336,12 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -313,6 +336,12 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
using
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
CGridDesc_M_N
{}))
>
;
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
CGridDesc_M_N
{}))
>
;
using
C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
C0GridDesc_M_N
{}))
>
;
using
C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
C1GridDesc_M_N
{}))
>
;
using
DGridDescriptor_MBlock_MPerBlock
=
using
DGridDescriptor_MBlock_MPerBlock
=
remove_cvref_t
<
decltype
(
MakeDGridDescriptor_MBlock_MPerBlock
(
DGridDesc_M
{}))
>
;
remove_cvref_t
<
decltype
(
MakeDGridDescriptor_MBlock_MPerBlock
(
DGridDesc_M
{}))
>
;
...
@@ -323,6 +352,8 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -323,6 +352,8 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
FloatC0
*
__restrict__
p_c0_grid
,
const
FloatC1
*
__restrict__
p_c1_grid
,
DPtrsGlobal
p_ds_grid
,
DPtrsGlobal
p_ds_grid
,
void
*
__restrict__
p_shared
,
void
*
__restrict__
p_shared
,
const
AElementwiseOperation
&
a_element_op
,
const
AElementwiseOperation
&
a_element_op
,
...
@@ -334,6 +365,10 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -334,6 +365,10 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
c0_grid_desc_mblock_mperblock_nblock_nperblock
,
const
C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
c1_grid_desc_mblock_mperblock_nblock_nperblock
,
const
DGridDescriptor_MBlock_MPerBlock
&
d_grid_desc_mblock_mperblock
,
const
DGridDescriptor_MBlock_MPerBlock
&
d_grid_desc_mblock_mperblock
,
const
Block2CTileMap
&
block_2_ctile_map
)
const
Block2CTileMap
&
block_2_ctile_map
)
{
{
...
@@ -343,6 +378,10 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -343,6 +378,10 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
p_b_grid
,
b_grid_desc_bk0_n_bk1
.
GetElementSpaceSize
());
p_b_grid
,
b_grid_desc_bk0_n_bk1
.
GetElementSpaceSize
());
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
auto
c0_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c0_grid
,
c0_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
auto
c1_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c1_grid
,
c1_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
// divide block work by [M, N]
// divide block work by [M, N]
const
auto
block_work_idx
=
const
auto
block_work_idx
=
...
@@ -610,32 +649,6 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -610,32 +649,6 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
n_thread_data_on_block_idx
[
I2
]),
n_thread_data_on_block_idx
[
I2
]),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
// shuffle: blockwise copy C from LDS to global
auto
c_shuffle_block_copy_lds_to_global
=
ThreadGroupTensorSliceTransfer_v6r1
<
ThisThreadBlock
,
// ThreadGroup
CElementwiseOperation
,
// ElementwiseOperation,
CGlobalMemoryDataOperation
,
// DstInMemOp,
Sequence
<
1
,
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
1
,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>
,
// BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
FloatCShuffle
,
// typename SrcData,
FloatC
,
// typename DstData,
decltype
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
c_grid_desc_mblock_mperblock_nblock_nperblock
),
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
3
,
// index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// index_t ScalarPerVector,
true
,
// bool ThreadTransferSrcResetCoordinateAfterRun,
false
>
// bool ThreadTransferDstResetCoordinateAfterRun>
{
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
0
,
0
,
0
,
0
),
c_grid_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
block_work_idx
[
I0
],
0
,
block_work_idx
[
I1
],
0
),
c_element_op
};
// space filling curve for threadwise C in VGPR
// space filling curve for threadwise C in VGPR
constexpr
auto
sfc_c_vgpr
=
constexpr
auto
sfc_c_vgpr
=
SpaceFillingCurve
<
Sequence
<
MXdlPerWave
,
NXdlPerWave
,
1
,
1
,
M2
,
1
,
M4
,
1
>
,
SpaceFillingCurve
<
Sequence
<
MXdlPerWave
,
NXdlPerWave
,
1
,
1
,
M2
,
1
,
M4
,
1
>
,
...
@@ -759,14 +772,80 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -759,14 +772,80 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
},
},
Number
<
p_ds_grid
.
Size
()
>
{});
Number
<
p_ds_grid
.
Size
()
>
{});
// c0 and c1
constexpr
auto
c0_reduce_thread_desc_mblock_mperblock_nblock_nperblock
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
mreduce_per_thread
>
{},
I1
,
Number
<
nreduce_per_thread
>
{}));
constexpr
auto
c1_reduce_thread_desc_mblock_mperblock_nblock_nperblock
=
c0_reduce_thread_desc_mblock_mperblock_nblock_nperblock
;
auto
c01_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatReduceAcc
>
(
c0_reduce_thread_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
auto
c0_thread_copy_global_to_vgpr
=
ThreadwiseTensorSliceTransfer_v2
<
FloatC0
,
FloatReduceAcc
,
decltype
(
c0_grid_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
c0_reduce_thread_desc_mblock_mperblock_nblock_nperblock
),
Sequence
<
I1
,
mreduce_per_thread
,
I1
,
nreduce_per_thread
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock
,
1
,
true
>
(
c0_grid_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
I0
,
m_block_data_idx_on_grid
+
c_reduce_thread_data_idx_begin
[
I0
],
I0
,
n_block_data_idx_on_grid
+
c_reduce_thread_data_idx_begin
[
I1
]));
auto
c1_thread_copy_global_to_vgpr
=
ThreadwiseTensorSliceTransfer_v2
<
FloatC1
,
FloatReduceAcc
,
decltype
(
c1_grid_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
c1_reduce_thread_desc_mblock_mperblock_nblock_nperblock
),
Sequence
<
I1
,
mreduce_per_thread
,
I1
,
nreduce_per_thread
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock
,
1
,
true
>
(
c1_grid_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
I0
,
m_block_data_idx_on_grid
+
c_reduce_thread_data_idx_begin
[
I0
],
I0
,
n_block_data_idx_on_grid
+
c_reduce_thread_data_idx_begin
[
I1
]));
constexpr
auto
c_reduce_thread_desc_mblock_mperblock_nblock_nperblock
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
mreduce_per_thread
>
{},
I1
,
Number
<
nreduce_per_thread
>
{}));
auto
c_reduce_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatReduceAcc
,
FloatC
,
decltype
(
c_reduce_thread_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
c_grid_desc_mblock_mperblock_nblock_nperblock
),
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
I1
,
mreduce_per_thread
,
I1
,
nreduce_per_thread
>
,
// SliceLengths
Sequence
<
0
,
1
,
2
,
3
>
,
// DimAccessOrder
3
,
// DstVectorDim
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
{
c_grid_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
I0
,
m_block_data_idx_on_grid
+
c_reduce_thread_data_idx_begin
[
I0
],
I0
,
n_block_data_idx_on_grid
+
c_reduce_thread_data_idx_begin
[
I1
]),
tensor_operation
::
element_wise
::
PassThrough
{}};
constexpr
index_t
num_access
=
sfc_c_vgpr
.
GetNumOfAccess
();
constexpr
index_t
num_access
=
sfc_c_vgpr
.
GetNumOfAccess
();
static_assert
(
num_access
==
sfc_c_global
.
GetNumOfAccess
(),
"wrong!"
);
static_assert
(
num_access
==
sfc_c_global
.
GetNumOfAccess
(),
"wrong!"
);
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
access_id
)
{
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
access_id
)
{
// make sure it's safe to write to LDS
block_sync_lds
();
// each thread write its data from VGPR to LDS
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds
.
Run
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c_thread_copy_vgpr_to_lds
.
Run
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
sfc_c_vgpr
.
GetIndexTupleOfNumber
(
access_id
),
sfc_c_vgpr
.
GetIndexTupleOfNumber
(
access_id
),
...
@@ -774,17 +853,8 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -774,17 +853,8 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c_shuffle_block_buf
);
c_shuffle_block_buf
);
// make sure it's safe to
read from
LDS
// make sure it's safe to
write to
LDS
block_sync_lds
();
block_sync_lds
();
// each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global
.
Run
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
c_shuffle_block_buf
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_buf
);
// TODO - extract following into reduction_blockwise
{
{
c_reduce_thread_copy_lds_to_vgpr
.
Run
(
c_reduce_block_desc_mperblock_nperblock
,
c_reduce_thread_copy_lds_to_vgpr
.
Run
(
c_reduce_block_desc_mperblock_nperblock
,
c_shuffle_block_buf
,
c_shuffle_block_buf
,
...
@@ -792,6 +862,37 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -792,6 +862,37 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
make_tuple
(
I0
,
I0
),
make_tuple
(
I0
,
I0
),
c_reduce_thread_buf
);
c_reduce_thread_buf
);
c0_thread_copy_global_to_vgpr
.
Run
(
c0_grid_desc_mblock_mperblock_nblock_nperblock
,
c0_grid_buf
,
c0_reduce_thread_desc_mblock_mperblock_nblock_nperblock
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
c01_thread_buf
);
static_for
<
0
,
c_reduce_thread_desc_mperblock_nperblock
.
GetElementSize
(),
1
>
{}(
[
&
](
auto
i
)
{
FloatReduceAcc
out
;
c_element_op
(
out
,
c_reduce_thread_buf
(
i
)
+
c01_thread_buf
(
i
));
c_reduce_thread_buf
(
i
)
=
out
;
});
c1_thread_copy_global_to_vgpr
.
Run
(
c1_grid_desc_mblock_mperblock_nblock_nperblock
,
c1_grid_buf
,
c1_reduce_thread_desc_mblock_mperblock_nblock_nperblock
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
c01_thread_buf
);
static_for
<
0
,
c_reduce_thread_desc_mperblock_nperblock
.
GetElementSize
(),
1
>
{}(
[
&
](
auto
i
)
{
c_reduce_thread_buf
(
i
)
+=
c01_thread_buf
(
i
);
});
c_reduce_thread_copy_vgpr_to_global
.
Run
(
c_reduce_thread_desc_mblock_mperblock_nblock_nperblock
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
c_reduce_thread_buf
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_buf
);
static_for
<
0
,
p_ds_grid
.
Size
(),
1
>
{}([
&
](
auto
In
)
{
static_for
<
0
,
p_ds_grid
.
Size
(),
1
>
{}([
&
](
auto
In
)
{
auto
&
p_d_grid
=
p_ds_grid
[
In
];
auto
&
p_d_grid
=
p_ds_grid
[
In
];
...
@@ -858,13 +959,19 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -858,13 +959,19 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
constexpr
auto
c_global_step
=
sfc_c_global
.
GetForwardStep
(
access_id
);
constexpr
auto
c_global_step
=
sfc_c_global
.
GetForwardStep
(
access_id
);
// move on C
// move on C
c_
shuffle_block
_copy_
lds
_to_global
.
MoveDstSliceWindow
(
c_
reduce_thread
_copy_
vgpr
_to_global
.
MoveDstSliceWindow
(
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_global_step
);
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_global_step
);
// move on C0
c0_thread_copy_global_to_vgpr
.
MoveSrcSliceWindow
(
c0_grid_desc_mblock_mperblock_nblock_nperblock
,
c_global_step
);
// move on C1
c1_thread_copy_global_to_vgpr
.
MoveSrcSliceWindow
(
c1_grid_desc_mblock_mperblock_nblock_nperblock
,
c_global_step
);
}
}
});
});
}
// Reduction
// Reduction
}
}
}
};
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment