Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
c44818e7
Commit
c44818e7
authored
Jun 13, 2022
by
rocking
Browse files
Rename DxsReduceAccElementwiseOperation to DxsReduceAccElementwiseOperation
parent
46eca0a1
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
103 additions
and
99 deletions
+103
-99
include/ck/tensor_operation/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp
...on/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp
+33
-31
include/ck/tensor_operation/gpu/device/device_gemm_bias_add_reduce_xdl_cshuffle.hpp
...n/gpu/device/device_gemm_bias_add_reduce_xdl_cshuffle.hpp
+29
-28
include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp
...ude/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp
+8
-8
include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp
..._operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp
+25
-24
include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp
...pu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp
+4
-4
include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp
...eration/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp
+4
-4
No files found.
include/ck/tensor_operation/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp
View file @
c44818e7
...
@@ -22,7 +22,7 @@ template <typename GridwiseGemm,
...
@@ -22,7 +22,7 @@ template <typename GridwiseGemm,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
typename
DxsInElementwiseOperation
,
typename
DxsInElementwiseOperation
,
typename
DxsAccElementwiseOperation
,
typename
Dxs
Reduce
AccElementwiseOperation
,
typename
AGridDesc_AK0_M_AK1
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
...
@@ -44,7 +44,7 @@ __global__ void
...
@@ -44,7 +44,7 @@ __global__ void
const
BElementwiseOperation
b_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
,
const
CElementwiseOperation
c_element_op
,
const
DxsInElementwiseOperation
dxs_in_element_op
,
const
DxsInElementwiseOperation
dxs_in_element_op
,
const
DxsAccElementwiseOperation
dxs_out_element_op
,
const
Dxs
Reduce
AccElementwiseOperation
dxs_out_element_op
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
...
@@ -126,7 +126,7 @@ template <typename ALayout,
...
@@ -126,7 +126,7 @@ template <typename ALayout,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
typename
DxsReduceOperation
,
typename
DxsReduceOperation
,
typename
DxsInElementwiseOperation
,
typename
DxsInElementwiseOperation
,
typename
DxsAccElementwiseOperation
,
typename
Dxs
Reduce
AccElementwiseOperation
,
typename
DGlobalMemoryDataOperation
,
typename
DGlobalMemoryDataOperation
,
GemmSpecialization
GemmSpec
,
GemmSpecialization
GemmSpec
,
index_t
NumGemmKPrefetchStage
,
index_t
NumGemmKPrefetchStage
,
...
@@ -162,12 +162,13 @@ template <typename ALayout,
...
@@ -162,12 +162,13 @@ template <typename ALayout,
index_t
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock
,
index_t
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock
,
index_t
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock
,
index_t
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()>
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()>
struct
DeviceBatchedGemmReduce_Xdl_CShuffle
:
public
DeviceGemmReduce
<
DPtrsGlobal
,
struct
DeviceBatchedGemmReduce_Xdl_CShuffle
:
public
DeviceGemmReduce
<
DPtrsGlobal
,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
CElementwiseOperation
,
DxsInElementwiseOperation
,
DxsInElementwiseOperation
,
DxsAccElementwiseOperation
>
Dxs
Reduce
AccElementwiseOperation
>
{
{
using
DeviceOp
=
DeviceBatchedGemmReduce_Xdl_CShuffle
;
using
DeviceOp
=
DeviceBatchedGemmReduce_Xdl_CShuffle
;
...
@@ -527,7 +528,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
...
@@ -527,7 +528,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
CElementwiseOperation
,
CElementwiseOperation
,
DxsReduceOperation
,
DxsReduceOperation
,
DxsInElementwiseOperation
,
DxsInElementwiseOperation
,
DxsAccElementwiseOperation
,
Dxs
Reduce
AccElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
DGlobalMemoryDataOperation
,
DGlobalMemoryDataOperation
,
AGridDesc_AK0_M_AK1
,
AGridDesc_AK0_M_AK1
,
...
@@ -587,7 +588,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
...
@@ -587,7 +588,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
CElementwiseOperation
c_element_op
,
DxsInElementwiseOperation
dxs_in_element_op
,
DxsInElementwiseOperation
dxs_in_element_op
,
DxsAccElementwiseOperation
dxs_out_element_op
,
Dxs
Reduce
AccElementwiseOperation
dxs_out_element_op
,
index_t
BatchCount
)
index_t
BatchCount
)
:
p_a_grid_
{
p_a_grid
},
:
p_a_grid_
{
p_a_grid
},
p_b_grid_
{
p_b_grid
},
p_b_grid_
{
p_b_grid
},
...
@@ -645,7 +646,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
...
@@ -645,7 +646,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
BElementwiseOperation
b_element_op_
;
BElementwiseOperation
b_element_op_
;
CElementwiseOperation
c_element_op_
;
CElementwiseOperation
c_element_op_
;
DxsInElementwiseOperation
dxs_in_element_op_
;
DxsInElementwiseOperation
dxs_in_element_op_
;
DxsAccElementwiseOperation
dxs_out_element_op_
;
Dxs
Reduce
AccElementwiseOperation
dxs_out_element_op_
;
};
};
// Invoker
// Invoker
...
@@ -703,7 +704,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
...
@@ -703,7 +704,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
BElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
CElementwiseOperation
,
DxsInElementwiseOperation
,
DxsInElementwiseOperation
,
DxsAccElementwiseOperation
,
Dxs
Reduce
AccElementwiseOperation
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
...
@@ -746,7 +747,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
...
@@ -746,7 +747,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
BElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
CElementwiseOperation
,
DxsInElementwiseOperation
,
DxsInElementwiseOperation
,
DxsAccElementwiseOperation
,
Dxs
Reduce
AccElementwiseOperation
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
...
@@ -832,7 +833,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
...
@@ -832,7 +833,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
CElementwiseOperation
c_element_op
,
DxsInElementwiseOperation
dxs_in_element_op
,
DxsInElementwiseOperation
dxs_in_element_op
,
DxsAccElementwiseOperation
dxs_out_element_op
,
Dxs
Reduce
AccElementwiseOperation
dxs_out_element_op
,
index_t
BatchCount
)
index_t
BatchCount
)
{
{
return
Argument
{
p_a
,
return
Argument
{
p_a
,
...
@@ -856,7 +857,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
...
@@ -856,7 +857,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
const
void
*
p_b
,
void
*
p_c
,
void
*
p_c
,
DPtrsGlobal
p_dxs
,
DPtrsGlobal
p_dxs
,
...
@@ -870,7 +872,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
...
@@ -870,7 +872,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
CElementwiseOperation
c_element_op
,
DxsInElementwiseOperation
dxs_in_element_op
,
DxsInElementwiseOperation
dxs_in_element_op
,
DxsAccElementwiseOperation
dxs_out_element_op
,
Dxs
Reduce
AccElementwiseOperation
dxs_out_element_op
,
index_t
BatchCount
)
override
index_t
BatchCount
)
override
{
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
...
...
include/ck/tensor_operation/gpu/device/device_gemm_bias_add_reduce_xdl_cshuffle.hpp
View file @
c44818e7
...
@@ -35,7 +35,7 @@ template <typename ALayout,
...
@@ -35,7 +35,7 @@ template <typename ALayout,
typename
C1ElementwiseOperation
,
typename
C1ElementwiseOperation
,
typename
DxsReduceOperation
,
typename
DxsReduceOperation
,
typename
DxsInElementwiseOperation
,
typename
DxsInElementwiseOperation
,
typename
DxsAccElementwiseOperation
,
typename
Dxs
Reduce
AccElementwiseOperation
,
typename
DGlobalMemoryDataOperation
,
typename
DGlobalMemoryDataOperation
,
GemmSpecialization
GemmSpec
,
GemmSpecialization
GemmSpec
,
index_t
NumGemmKPrefetchStage
,
index_t
NumGemmKPrefetchStage
,
...
@@ -78,7 +78,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
...
@@ -78,7 +78,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
CElementwiseOperation
,
CElementwiseOperation
,
C1ElementwiseOperation
,
C1ElementwiseOperation
,
DxsInElementwiseOperation
,
DxsInElementwiseOperation
,
DxsAccElementwiseOperation
>
Dxs
Reduce
AccElementwiseOperation
>
{
{
using
DeviceOp
=
DeviceGemmBiasAddReduce_Xdl_CShuffle
;
using
DeviceOp
=
DeviceGemmBiasAddReduce_Xdl_CShuffle
;
...
@@ -399,7 +399,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
...
@@ -399,7 +399,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
C1ElementwiseOperation
,
C1ElementwiseOperation
,
DxsReduceOperation
,
DxsReduceOperation
,
DxsInElementwiseOperation
,
DxsInElementwiseOperation
,
DxsAccElementwiseOperation
,
Dxs
Reduce
AccElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
DGlobalMemoryDataOperation
,
DGlobalMemoryDataOperation
,
AGridDesc_AK0_M_AK1
,
AGridDesc_AK0_M_AK1
,
...
@@ -465,7 +465,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
...
@@ -465,7 +465,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
CElementwiseOperation
c_element_op
,
CElementwiseOperation
c_element_op
,
C1ElementwiseOperation
c1_element_op
,
C1ElementwiseOperation
c1_element_op
,
DxsInElementwiseOperation
dxs_in_element_op
,
DxsInElementwiseOperation
dxs_in_element_op
,
DxsAccElementwiseOperation
dxs_out_element_op
)
Dxs
Reduce
AccElementwiseOperation
dxs_out_element_op
)
:
p_a_grid_
{
p_a_grid
},
:
p_a_grid_
{
p_a_grid
},
p_b_grid_
{
p_b_grid
},
p_b_grid_
{
p_b_grid
},
p_c_grid_
{
p_c_grid
},
p_c_grid_
{
p_c_grid
},
...
@@ -538,7 +538,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
...
@@ -538,7 +538,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
CElementwiseOperation
c_element_op_
;
CElementwiseOperation
c_element_op_
;
C1ElementwiseOperation
c1_element_op_
;
C1ElementwiseOperation
c1_element_op_
;
DxsInElementwiseOperation
dxs_in_element_op_
;
DxsInElementwiseOperation
dxs_in_element_op_
;
DxsAccElementwiseOperation
dxs_out_element_op_
;
Dxs
Reduce
AccElementwiseOperation
dxs_out_element_op_
;
};
};
// Invoker
// Invoker
...
@@ -577,7 +577,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
...
@@ -577,7 +577,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
CElementwiseOperation
,
CElementwiseOperation
,
C1ElementwiseOperation
,
C1ElementwiseOperation
,
DxsInElementwiseOperation
,
DxsInElementwiseOperation
,
DxsAccElementwiseOperation
,
Dxs
Reduce
AccElementwiseOperation
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
...
@@ -627,7 +627,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
...
@@ -627,7 +627,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
CElementwiseOperation
,
CElementwiseOperation
,
C1ElementwiseOperation
,
C1ElementwiseOperation
,
DxsInElementwiseOperation
,
DxsInElementwiseOperation
,
DxsAccElementwiseOperation
,
Dxs
Reduce
AccElementwiseOperation
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
...
@@ -713,7 +713,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
...
@@ -713,7 +713,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
CElementwiseOperation
c_element_op
,
CElementwiseOperation
c_element_op
,
C1ElementwiseOperation
c1_element_op
,
C1ElementwiseOperation
c1_element_op
,
DxsInElementwiseOperation
dxs_in_element_op
,
DxsInElementwiseOperation
dxs_in_element_op
,
DxsAccElementwiseOperation
dxs_out_element_op
)
Dxs
Reduce
AccElementwiseOperation
dxs_out_element_op
)
{
{
return
Argument
{
p_a
,
return
Argument
{
p_a
,
p_b
,
p_b
,
...
@@ -739,7 +739,8 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
...
@@ -739,7 +739,8 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
const
void
*
p_b
,
void
*
p_c
,
void
*
p_c
,
const
void
*
p_c0
,
const
void
*
p_c0
,
...
@@ -757,7 +758,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
...
@@ -757,7 +758,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
CElementwiseOperation
c_element_op
,
CElementwiseOperation
c_element_op
,
C1ElementwiseOperation
c1_element_op
,
C1ElementwiseOperation
c1_element_op
,
DxsInElementwiseOperation
dxs_in_element_op
,
DxsInElementwiseOperation
dxs_in_element_op
,
DxsAccElementwiseOperation
dxs_out_element_op
,
Dxs
Reduce
AccElementwiseOperation
dxs_out_element_op
,
index_t
/* KBatch */
=
1
)
override
index_t
/* KBatch */
=
1
)
override
{
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
...
...
include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp
View file @
c44818e7
...
@@ -11,7 +11,7 @@ template <typename DPtrsGlobal,
...
@@ -11,7 +11,7 @@ template <typename DPtrsGlobal,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
typename
DxsInElementwiseOperation
,
typename
DxsInElementwiseOperation
,
typename
DxsAccElementwiseOperation
>
typename
Dxs
Reduce
AccElementwiseOperation
>
struct
DeviceGemmReduce
:
public
BaseOperator
struct
DeviceGemmReduce
:
public
BaseOperator
{
{
virtual
std
::
unique_ptr
<
BaseArgument
>
virtual
std
::
unique_ptr
<
BaseArgument
>
...
@@ -29,7 +29,7 @@ struct DeviceGemmReduce : public BaseOperator
...
@@ -29,7 +29,7 @@ struct DeviceGemmReduce : public BaseOperator
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
CElementwiseOperation
c_element_op
,
DxsInElementwiseOperation
dxs_in_element_op
,
DxsInElementwiseOperation
dxs_in_element_op
,
DxsAccElementwiseOperation
dxs_out_element_op
,
Dxs
Reduce
AccElementwiseOperation
dxs_out_element_op
,
ck
::
index_t
BatchCount
=
1
)
=
0
;
ck
::
index_t
BatchCount
=
1
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
...
@@ -40,13 +40,13 @@ template <typename DPtrsGlobal,
...
@@ -40,13 +40,13 @@ template <typename DPtrsGlobal,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
typename
DxsInElementwiseOperation
,
typename
DxsInElementwiseOperation
,
typename
DxsAccElementwiseOperation
>
typename
Dxs
Reduce
AccElementwiseOperation
>
using
DeviceGemmReducePtr
=
std
::
unique_ptr
<
DeviceGemmReduce
<
DPtrsGlobal
,
using
DeviceGemmReducePtr
=
std
::
unique_ptr
<
DeviceGemmReduce
<
DPtrsGlobal
,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
CElementwiseOperation
,
DxsInElementwiseOperation
,
DxsInElementwiseOperation
,
DxsAccElementwiseOperation
>>
;
Dxs
Reduce
AccElementwiseOperation
>>
;
template
<
typename
DPtrsGlobal
,
template
<
typename
DPtrsGlobal
,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
...
@@ -54,7 +54,7 @@ template <typename DPtrsGlobal,
...
@@ -54,7 +54,7 @@ template <typename DPtrsGlobal,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
typename
C1ElementwiseOperation
,
typename
C1ElementwiseOperation
,
typename
DxsInElementwiseOperation
,
typename
DxsInElementwiseOperation
,
typename
DxsAccElementwiseOperation
>
typename
Dxs
Reduce
AccElementwiseOperation
>
struct
DeviceGemmBiasAddReduce
:
public
BaseOperator
struct
DeviceGemmBiasAddReduce
:
public
BaseOperator
{
{
virtual
std
::
unique_ptr
<
BaseArgument
>
virtual
std
::
unique_ptr
<
BaseArgument
>
...
@@ -76,7 +76,7 @@ struct DeviceGemmBiasAddReduce : public BaseOperator
...
@@ -76,7 +76,7 @@ struct DeviceGemmBiasAddReduce : public BaseOperator
CElementwiseOperation
c_element_op
,
CElementwiseOperation
c_element_op
,
C1ElementwiseOperation
c1_element_op
,
C1ElementwiseOperation
c1_element_op
,
DxsInElementwiseOperation
dxs_in_element_op
,
DxsInElementwiseOperation
dxs_in_element_op
,
DxsAccElementwiseOperation
dxs_out_element_op
,
Dxs
Reduce
AccElementwiseOperation
dxs_out_element_op
,
ck
::
index_t
BatchCount
=
1
)
=
0
;
ck
::
index_t
BatchCount
=
1
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
...
@@ -88,7 +88,7 @@ template <typename DPtrsGlobal,
...
@@ -88,7 +88,7 @@ template <typename DPtrsGlobal,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
typename
C1ElementwiseOperation
,
typename
C1ElementwiseOperation
,
typename
DxsInElementwiseOperation
,
typename
DxsInElementwiseOperation
,
typename
DxsAccElementwiseOperation
>
typename
Dxs
Reduce
AccElementwiseOperation
>
using
DeviceGemmBiasAddReducePtr
=
using
DeviceGemmBiasAddReducePtr
=
std
::
unique_ptr
<
DeviceGemmBiasAddReduce
<
DPtrsGlobal
,
std
::
unique_ptr
<
DeviceGemmBiasAddReduce
<
DPtrsGlobal
,
AElementwiseOperation
,
AElementwiseOperation
,
...
@@ -96,7 +96,7 @@ using DeviceGemmBiasAddReducePtr =
...
@@ -96,7 +96,7 @@ using DeviceGemmBiasAddReducePtr =
CElementwiseOperation
,
CElementwiseOperation
,
C1ElementwiseOperation
,
C1ElementwiseOperation
,
DxsInElementwiseOperation
,
DxsInElementwiseOperation
,
DxsAccElementwiseOperation
>>
;
Dxs
Reduce
AccElementwiseOperation
>>
;
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
...
...
include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp
View file @
c44818e7
...
@@ -32,7 +32,7 @@ template <typename ALayout,
...
@@ -32,7 +32,7 @@ template <typename ALayout,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
typename
DxsReduceOperation
,
typename
DxsReduceOperation
,
typename
DxsInElementwiseOperation
,
typename
DxsInElementwiseOperation
,
typename
DxsAccElementwiseOperation
,
typename
Dxs
Reduce
AccElementwiseOperation
,
typename
DGlobalMemoryDataOperation
,
typename
DGlobalMemoryDataOperation
,
GemmSpecialization
GemmSpec
,
GemmSpecialization
GemmSpec
,
index_t
NumGemmKPrefetchStage
,
index_t
NumGemmKPrefetchStage
,
...
@@ -73,7 +73,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
...
@@ -73,7 +73,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
BElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
CElementwiseOperation
,
DxsInElementwiseOperation
,
DxsInElementwiseOperation
,
DxsAccElementwiseOperation
>
Dxs
Reduce
AccElementwiseOperation
>
{
{
using
DeviceOp
=
DeviceGemmReduce_Xdl_CShuffle
;
using
DeviceOp
=
DeviceGemmReduce_Xdl_CShuffle
;
...
@@ -389,7 +389,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
...
@@ -389,7 +389,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
CElementwiseOperation
,
CElementwiseOperation
,
DxsReduceOperation
,
DxsReduceOperation
,
DxsInElementwiseOperation
,
DxsInElementwiseOperation
,
DxsAccElementwiseOperation
,
Dxs
Reduce
AccElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
DGlobalMemoryDataOperation
,
DGlobalMemoryDataOperation
,
AGridDesc_AK0_M_AK1
,
AGridDesc_AK0_M_AK1
,
...
@@ -449,7 +449,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
...
@@ -449,7 +449,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
CElementwiseOperation
c_element_op
,
DxsInElementwiseOperation
dxs_in_element_op
,
DxsInElementwiseOperation
dxs_in_element_op
,
DxsAccElementwiseOperation
dxs_out_element_op
)
Dxs
Reduce
AccElementwiseOperation
dxs_out_element_op
)
:
p_a_grid_
{
p_a_grid
},
:
p_a_grid_
{
p_a_grid
},
p_b_grid_
{
p_b_grid
},
p_b_grid_
{
p_b_grid
},
p_c_grid_
{
p_c_grid
},
p_c_grid_
{
p_c_grid
},
...
@@ -498,7 +498,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
...
@@ -498,7 +498,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
BElementwiseOperation
b_element_op_
;
BElementwiseOperation
b_element_op_
;
CElementwiseOperation
c_element_op_
;
CElementwiseOperation
c_element_op_
;
DxsInElementwiseOperation
dxs_in_element_op_
;
DxsInElementwiseOperation
dxs_in_element_op_
;
DxsAccElementwiseOperation
dxs_out_element_op_
;
Dxs
Reduce
AccElementwiseOperation
dxs_out_element_op_
;
};
};
// Invoker
// Invoker
...
@@ -554,7 +554,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
...
@@ -554,7 +554,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
BElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
CElementwiseOperation
,
DxsInElementwiseOperation
,
DxsInElementwiseOperation
,
DxsAccElementwiseOperation
,
Dxs
Reduce
AccElementwiseOperation
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
...
@@ -594,7 +594,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
...
@@ -594,7 +594,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
BElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
CElementwiseOperation
,
DxsInElementwiseOperation
,
DxsInElementwiseOperation
,
DxsAccElementwiseOperation
,
Dxs
Reduce
AccElementwiseOperation
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
...
@@ -669,7 +669,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
...
@@ -669,7 +669,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
CElementwiseOperation
c_element_op
,
DxsInElementwiseOperation
dxs_in_element_op
,
DxsInElementwiseOperation
dxs_in_element_op
,
DxsAccElementwiseOperation
dxs_out_element_op
)
Dxs
Reduce
AccElementwiseOperation
dxs_out_element_op
)
{
{
return
Argument
{
p_a
,
return
Argument
{
p_a
,
p_b
,
p_b
,
...
@@ -691,7 +691,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
...
@@ -691,7 +691,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
const
void
*
p_b
,
void
*
p_c
,
void
*
p_c
,
DPtrsGlobal
p_dxs
,
DPtrsGlobal
p_dxs
,
...
@@ -705,7 +706,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
...
@@ -705,7 +706,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
CElementwiseOperation
c_element_op
,
DxsInElementwiseOperation
dxs_in_element_op
,
DxsInElementwiseOperation
dxs_in_element_op
,
DxsAccElementwiseOperation
dxs_out_element_op
,
Dxs
Reduce
AccElementwiseOperation
dxs_out_element_op
,
index_t
/* KBatch */
=
1
)
override
index_t
/* KBatch */
=
1
)
override
{
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp
View file @
c44818e7
...
@@ -24,7 +24,7 @@ template <typename GridwiseGemm,
...
@@ -24,7 +24,7 @@ template <typename GridwiseGemm,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
typename
C1ElementwiseOperation
,
typename
C1ElementwiseOperation
,
typename
DxsInElementwiseOperation
,
typename
DxsInElementwiseOperation
,
typename
DxsAccElementwiseOperation
,
typename
Dxs
Reduce
AccElementwiseOperation
,
typename
AGridDesc_AK0_M_AK1
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
...
@@ -49,7 +49,7 @@ __global__ void
...
@@ -49,7 +49,7 @@ __global__ void
const
CElementwiseOperation
c_element_op
,
const
CElementwiseOperation
c_element_op
,
const
C1ElementwiseOperation
c1_element_op
,
const
C1ElementwiseOperation
c1_element_op
,
const
DxsInElementwiseOperation
dxs_in_element_op
,
const
DxsInElementwiseOperation
dxs_in_element_op
,
const
DxsAccElementwiseOperation
dxs_out_element_op
,
const
Dxs
Reduce
AccElementwiseOperation
dxs_out_element_op
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
...
@@ -121,7 +121,7 @@ template <typename FloatAB,
...
@@ -121,7 +121,7 @@ template <typename FloatAB,
typename
C1ElementwiseOperation
,
typename
C1ElementwiseOperation
,
typename
DxsReduceOperation
,
typename
DxsReduceOperation
,
typename
DxsInElementwiseOperation
,
typename
DxsInElementwiseOperation
,
typename
DxsAccElementwiseOperation
,
typename
Dxs
Reduce
AccElementwiseOperation
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
DGlobalMemoryDataOperation
,
typename
DGlobalMemoryDataOperation
,
typename
AGridDesc_AK0_M_AK1
,
typename
AGridDesc_AK0_M_AK1
,
...
@@ -366,7 +366,7 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -366,7 +366,7 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
const
CElementwiseOperation
&
c_element_op
,
const
CElementwiseOperation
&
c_element_op
,
const
C1ElementwiseOperation
&
c1_element_op
,
const
C1ElementwiseOperation
&
c1_element_op
,
const
DxsInElementwiseOperation
&
dxs_in_element_op
,
const
DxsInElementwiseOperation
&
dxs_in_element_op
,
const
DxsAccElementwiseOperation
&
dxs_out_element_op
,
const
Dxs
Reduce
AccElementwiseOperation
&
dxs_out_element_op
,
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp
View file @
c44818e7
...
@@ -21,7 +21,7 @@ template <typename GridwiseGemm,
...
@@ -21,7 +21,7 @@ template <typename GridwiseGemm,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
typename
DxsInElementwiseOperation
,
typename
DxsInElementwiseOperation
,
typename
DxsAccElementwiseOperation
,
typename
Dxs
Reduce
AccElementwiseOperation
,
typename
AGridDesc_AK0_M_AK1
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
...
@@ -41,7 +41,7 @@ __global__ void
...
@@ -41,7 +41,7 @@ __global__ void
const
BElementwiseOperation
b_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
,
const
CElementwiseOperation
c_element_op
,
const
DxsInElementwiseOperation
dxs_in_element_op
,
const
DxsInElementwiseOperation
dxs_in_element_op
,
const
DxsAccElementwiseOperation
dxs_out_element_op
,
const
Dxs
Reduce
AccElementwiseOperation
dxs_out_element_op
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
...
@@ -96,7 +96,7 @@ template <typename FloatAB,
...
@@ -96,7 +96,7 @@ template <typename FloatAB,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
typename
DxsReduceOperation
,
typename
DxsReduceOperation
,
typename
DxsInElementwiseOperation
,
typename
DxsInElementwiseOperation
,
typename
DxsAccElementwiseOperation
,
typename
Dxs
Reduce
AccElementwiseOperation
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
DGlobalMemoryDataOperation
,
typename
DGlobalMemoryDataOperation
,
typename
AGridDesc_AK0_M_AK1
,
typename
AGridDesc_AK0_M_AK1
,
...
@@ -329,7 +329,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -329,7 +329,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
const
BElementwiseOperation
&
b_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CElementwiseOperation
&
c_element_op
,
const
CElementwiseOperation
&
c_element_op
,
const
DxsInElementwiseOperation
&
dxs_in_element_op
,
const
DxsInElementwiseOperation
&
dxs_in_element_op
,
const
DxsAccElementwiseOperation
&
dxs_out_element_op
,
const
Dxs
Reduce
AccElementwiseOperation
&
dxs_out_element_op
,
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment