"vscode:/vscode.git/clone" did not exist on "90c07a3dfd99f14bfbc5b43f59b96ce48fc4d0ec"
Commit a76eac28 authored by rocking's avatar rocking
Browse files

[What] Remove tuple type in the base class

[Why] External api depend on base class. if base class has relationship with type, we will need many class for different type
parent f9d22b02
...@@ -163,8 +163,7 @@ template <typename ALayout, ...@@ -163,8 +163,7 @@ template <typename ALayout,
index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()> LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceBatchedGemmReduce_Xdl_CShuffle struct DeviceBatchedGemmReduce_Xdl_CShuffle
: public DeviceGemmReduce<DPtrsGlobal, : public DeviceGemmReduce<AElementwiseOperation,
AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
DxsInElementwiseOperation, DxsInElementwiseOperation,
...@@ -861,7 +860,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle ...@@ -861,7 +860,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
MakeArgumentPointer(const void* p_a, MakeArgumentPointer(const void* p_a,
const void* p_b, const void* p_b,
void* p_c, void* p_c,
DPtrsGlobal p_dxs, void* p_dxs,
index_t MRaw, index_t MRaw,
index_t NRaw, index_t NRaw,
index_t KRaw, index_t KRaw,
...@@ -875,10 +874,11 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle ...@@ -875,10 +874,11 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
DxsReduceAccElementwiseOperation dxs_out_element_op, DxsReduceAccElementwiseOperation dxs_out_element_op,
index_t BatchCount) override index_t BatchCount) override
{ {
DPtrsGlobal dxs_tuple = *(static_cast<DPtrsGlobal*>(p_dxs));
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c), static_cast<CDataType*>(p_c),
p_dxs, dxs_tuple,
MRaw, MRaw,
NRaw, NRaw,
KRaw, KRaw,
......
...@@ -71,8 +71,7 @@ template <typename ALayout, ...@@ -71,8 +71,7 @@ template <typename ALayout,
index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()> LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceGemmBiasAddReduce_Xdl_CShuffle struct DeviceGemmBiasAddReduce_Xdl_CShuffle
: public DeviceGemmBiasAddReduce<DPtrsGlobal, : public DeviceGemmBiasAddReduce<AElementwiseOperation,
AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
C1ElementwiseOperation, C1ElementwiseOperation,
...@@ -744,7 +743,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle ...@@ -744,7 +743,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
void* p_c, void* p_c,
const void* p_c0, const void* p_c0,
const void* p_c1, const void* p_c1,
DPtrsGlobal p_dxs, void* p_dxs,
index_t MRaw, index_t MRaw,
index_t NRaw, index_t NRaw,
index_t KRaw, index_t KRaw,
...@@ -760,12 +759,13 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle ...@@ -760,12 +759,13 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
DxsReduceAccElementwiseOperation dxs_out_element_op, DxsReduceAccElementwiseOperation dxs_out_element_op,
index_t /* KBatch */ = 1) override index_t /* KBatch */ = 1) override
{ {
DPtrsGlobal dxs_tuple = *(static_cast<DPtrsGlobal*>(p_dxs));
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c), static_cast<CDataType*>(p_c),
static_cast<const C0DataType*>(p_c0), static_cast<const C0DataType*>(p_c0),
static_cast<const C1DataType*>(p_c1), static_cast<const C1DataType*>(p_c1),
p_dxs, dxs_tuple,
MRaw, MRaw,
NRaw, NRaw,
KRaw, KRaw,
......
...@@ -6,8 +6,7 @@ namespace ck { ...@@ -6,8 +6,7 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
template <typename DPtrsGlobal, template <typename AElementwiseOperation,
typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
typename DxsInElementwiseOperation, typename DxsInElementwiseOperation,
...@@ -18,7 +17,7 @@ struct DeviceGemmReduce : public BaseOperator ...@@ -18,7 +17,7 @@ struct DeviceGemmReduce : public BaseOperator
MakeArgumentPointer(const void* p_a, MakeArgumentPointer(const void* p_a,
const void* p_b, const void* p_b,
void* p_c, void* p_c,
DPtrsGlobal p_dxs, void* p_dxs,
ck::index_t M, ck::index_t M,
ck::index_t N, ck::index_t N,
ck::index_t K, ck::index_t K,
...@@ -35,21 +34,18 @@ struct DeviceGemmReduce : public BaseOperator ...@@ -35,21 +34,18 @@ struct DeviceGemmReduce : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
template <typename DPtrsGlobal, template <typename AElementwiseOperation,
typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
typename DxsInElementwiseOperation, typename DxsInElementwiseOperation,
typename DxsReduceAccElementwiseOperation> typename DxsReduceAccElementwiseOperation>
using DeviceGemmReducePtr = std::unique_ptr<DeviceGemmReduce<DPtrsGlobal, using DeviceGemmReducePtr = std::unique_ptr<DeviceGemmReduce<AElementwiseOperation,
AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
DxsInElementwiseOperation, DxsInElementwiseOperation,
DxsReduceAccElementwiseOperation>>; DxsReduceAccElementwiseOperation>>;
template <typename DPtrsGlobal, template <typename AElementwiseOperation,
typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
typename C1ElementwiseOperation, typename C1ElementwiseOperation,
...@@ -63,7 +59,7 @@ struct DeviceGemmBiasAddReduce : public BaseOperator ...@@ -63,7 +59,7 @@ struct DeviceGemmBiasAddReduce : public BaseOperator
void* p_c, void* p_c,
const void* p_c0, const void* p_c0,
const void* p_c1, const void* p_c1,
DPtrsGlobal p_dxs, void* p_dxs,
ck::index_t M, ck::index_t M,
ck::index_t N, ck::index_t N,
ck::index_t K, ck::index_t K,
...@@ -82,16 +78,14 @@ struct DeviceGemmBiasAddReduce : public BaseOperator ...@@ -82,16 +78,14 @@ struct DeviceGemmBiasAddReduce : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
template <typename DPtrsGlobal, template <typename AElementwiseOperation,
typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
typename C1ElementwiseOperation, typename C1ElementwiseOperation,
typename DxsInElementwiseOperation, typename DxsInElementwiseOperation,
typename DxsReduceAccElementwiseOperation> typename DxsReduceAccElementwiseOperation>
using DeviceGemmBiasAddReducePtr = using DeviceGemmBiasAddReducePtr =
std::unique_ptr<DeviceGemmBiasAddReduce<DPtrsGlobal, std::unique_ptr<DeviceGemmBiasAddReduce<AElementwiseOperation,
AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
C1ElementwiseOperation, C1ElementwiseOperation,
......
...@@ -68,8 +68,7 @@ template <typename ALayout, ...@@ -68,8 +68,7 @@ template <typename ALayout,
index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()> LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal, struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOperation,
AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
DxsInElementwiseOperation, DxsInElementwiseOperation,
...@@ -695,7 +694,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal, ...@@ -695,7 +694,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
MakeArgumentPointer(const void* p_a, MakeArgumentPointer(const void* p_a,
const void* p_b, const void* p_b,
void* p_c, void* p_c,
DPtrsGlobal p_dxs, void* p_dxs,
index_t MRaw, index_t MRaw,
index_t NRaw, index_t NRaw,
index_t KRaw, index_t KRaw,
...@@ -709,10 +708,11 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal, ...@@ -709,10 +708,11 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
DxsReduceAccElementwiseOperation dxs_out_element_op, DxsReduceAccElementwiseOperation dxs_out_element_op,
index_t /* KBatch */ = 1) override index_t /* KBatch */ = 1) override
{ {
DPtrsGlobal dxs_tuple = *(static_cast<DPtrsGlobal*>(p_dxs));
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c), static_cast<CDataType*>(p_c),
p_dxs, dxs_tuple,
MRaw, MRaw,
NRaw, NRaw,
KRaw, KRaw,
......
...@@ -62,12 +62,9 @@ using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_in ...@@ -62,12 +62,9 @@ using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_in
>; >;
void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances( void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances(
std::vector<DeviceGemmReducePtr<DPtrsGlobal, std::vector<
PassThrough, DeviceGemmReducePtr<PassThrough, PassThrough, PassThrough, DInElementOps, DOutElementOps>>&
PassThrough, instances)
PassThrough,
DInElementOps,
DOutElementOps>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, instances,
......
...@@ -62,12 +62,9 @@ using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_in ...@@ -62,12 +62,9 @@ using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_in
>; >;
void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances( void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances(
std::vector<DeviceGemmReducePtr<DPtrsGlobal, std::vector<
PassThrough, DeviceGemmReducePtr<PassThrough, PassThrough, PassThrough, DInElementOps, DOutElementOps>>&
PassThrough, instances)
PassThrough,
DInElementOps,
DOutElementOps>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, instances,
......
...@@ -62,12 +62,9 @@ using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_in ...@@ -62,12 +62,9 @@ using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_in
>; >;
void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances( void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances(
std::vector<DeviceGemmReducePtr<DPtrsGlobal, std::vector<
PassThrough, DeviceGemmReducePtr<PassThrough, PassThrough, PassThrough, DInElementOps, DOutElementOps>>&
PassThrough, instances)
PassThrough,
DInElementOps,
DOutElementOps>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, instances,
......
...@@ -59,12 +59,9 @@ using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_in ...@@ -59,12 +59,9 @@ using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_in
>; >;
void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances( void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances(
std::vector<DeviceGemmReducePtr<DPtrsGlobal, std::vector<
PassThrough, DeviceGemmReducePtr<PassThrough, PassThrough, PassThrough, DInElementOps, DOutElementOps>>&
PassThrough, instances)
PassThrough,
DInElementOps,
DOutElementOps>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, instances,
......
...@@ -63,8 +63,7 @@ using device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn ...@@ -63,8 +63,7 @@ using device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn
>; >;
void add_device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances( void add_device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances(
std::vector<DeviceGemmBiasAddReducePtr<DPtrsGlobal, std::vector<DeviceGemmBiasAddReducePtr<PassThrough,
PassThrough,
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough, PassThrough,
......
...@@ -63,8 +63,7 @@ using device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk ...@@ -63,8 +63,7 @@ using device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk
>; >;
void add_device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances( void add_device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances(
std::vector<DeviceGemmBiasAddReducePtr<DPtrsGlobal, std::vector<DeviceGemmBiasAddReducePtr<PassThrough,
PassThrough,
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough, PassThrough,
......
...@@ -63,8 +63,7 @@ using device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn ...@@ -63,8 +63,7 @@ using device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn
>; >;
void add_device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances( void add_device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances(
std::vector<DeviceGemmBiasAddReducePtr<DPtrsGlobal, std::vector<DeviceGemmBiasAddReducePtr<PassThrough,
PassThrough,
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough, PassThrough,
......
...@@ -60,8 +60,7 @@ using device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk ...@@ -60,8 +60,7 @@ using device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk
>; >;
void add_device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances( void add_device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances(
std::vector<DeviceGemmBiasAddReducePtr<DPtrsGlobal, std::vector<DeviceGemmBiasAddReducePtr<PassThrough,
PassThrough,
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough, PassThrough,
......
...@@ -62,12 +62,9 @@ using device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances = s ...@@ -62,12 +62,9 @@ using device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances = s
>; >;
void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances( void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances(
std::vector<DeviceGemmReducePtr<DPtrsGlobal, std::vector<
PassThrough, DeviceGemmReducePtr<PassThrough, PassThrough, PassThrough, DInElementOps, DOutElementOps>>&
PassThrough, instances)
PassThrough,
DInElementOps,
DOutElementOps>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances{}); instances, device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances{});
......
...@@ -62,12 +62,9 @@ using device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances = s ...@@ -62,12 +62,9 @@ using device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances = s
>; >;
void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances( void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances(
std::vector<DeviceGemmReducePtr<DPtrsGlobal, std::vector<
PassThrough, DeviceGemmReducePtr<PassThrough, PassThrough, PassThrough, DInElementOps, DOutElementOps>>&
PassThrough, instances)
PassThrough,
DInElementOps,
DOutElementOps>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances{}); instances, device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances{});
......
...@@ -62,12 +62,9 @@ using device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances = s ...@@ -62,12 +62,9 @@ using device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances = s
>; >;
void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances( void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances(
std::vector<DeviceGemmReducePtr<DPtrsGlobal, std::vector<
PassThrough, DeviceGemmReducePtr<PassThrough, PassThrough, PassThrough, DInElementOps, DOutElementOps>>&
PassThrough, instances)
PassThrough,
DInElementOps,
DOutElementOps>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances{}); instances, device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances{});
......
...@@ -59,12 +59,9 @@ using device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instances = s ...@@ -59,12 +59,9 @@ using device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instances = s
>; >;
void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instances( void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instances(
std::vector<DeviceGemmReducePtr<DPtrsGlobal, std::vector<
PassThrough, DeviceGemmReducePtr<PassThrough, PassThrough, PassThrough, DInElementOps, DOutElementOps>>&
PassThrough, instances)
PassThrough,
DInElementOps,
DOutElementOps>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instances{}); instances, device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instances{});
......
...@@ -26,7 +26,6 @@ using DInElementOps = ck::Tuple<Identity, Square>; ...@@ -26,7 +26,6 @@ using DInElementOps = ck::Tuple<Identity, Square>;
using DOutElementOps = ck::Tuple<Identity, Identity>; using DOutElementOps = ck::Tuple<Identity, Identity>;
using DeviceGemmReduceNoOpPtr = ck::tensor_operation::device::DeviceGemmReducePtr< using DeviceGemmReduceNoOpPtr = ck::tensor_operation::device::DeviceGemmReducePtr<
DPtrsGlobal,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
...@@ -260,7 +259,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification, ...@@ -260,7 +259,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
gemm_ptr->MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()), gemm_ptr->MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()), static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()), static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
dxs_global, &dxs_global,
M, M,
N, N,
K, K,
......
...@@ -27,7 +27,6 @@ using DInElementOps = ck::Tuple<Identity, Square>; ...@@ -27,7 +27,6 @@ using DInElementOps = ck::Tuple<Identity, Square>;
using DOutElementOps = ck::Tuple<Div, Div>; using DOutElementOps = ck::Tuple<Div, Div>;
using DeviceGemmBiasAddReduceNoOpPtr = ck::tensor_operation::device::DeviceGemmBiasAddReducePtr< using DeviceGemmBiasAddReduceNoOpPtr = ck::tensor_operation::device::DeviceGemmBiasAddReducePtr<
DPtrsGlobal,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
...@@ -296,7 +295,7 @@ void profile_gemm_bias_add_reduce_impl(int do_verification, ...@@ -296,7 +295,7 @@ void profile_gemm_bias_add_reduce_impl(int do_verification,
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()), static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
static_cast<C0DataType*>(bias_device_buf.GetDeviceBuffer()), static_cast<C0DataType*>(bias_device_buf.GetDeviceBuffer()),
static_cast<C1DataType*>(c1_device_buf.GetDeviceBuffer()), static_cast<C1DataType*>(c1_device_buf.GetDeviceBuffer()),
dxs_global, &dxs_global,
M, M,
N, N,
K, K,
......
...@@ -27,7 +27,6 @@ using DInElementOps = ck::Tuple<Identity, Square>; ...@@ -27,7 +27,6 @@ using DInElementOps = ck::Tuple<Identity, Square>;
using DOutElementOps = ck::Tuple<Div, Div>; using DOutElementOps = ck::Tuple<Div, Div>;
using DeviceGemmReduceNoOpPtr = ck::tensor_operation::device::DeviceGemmReducePtr< using DeviceGemmReduceNoOpPtr = ck::tensor_operation::device::DeviceGemmReducePtr<
DPtrsGlobal,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
...@@ -261,7 +260,7 @@ bool profile_gemm_reduce_impl(int do_verification, ...@@ -261,7 +260,7 @@ bool profile_gemm_reduce_impl(int do_verification,
gemm_ptr->MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()), gemm_ptr->MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()), static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()), static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
dxs_global, &dxs_global,
M, M,
N, N,
K, K,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment