"vscode:/vscode.git/clone" did not exist on "adb6134398cb1623a2439de5cead128d1ebdba3f"
Unverified Commit cd167e49 authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

Compile for gfx908 and gfx90a (#130)

* adding compilation for multiple targets

* fix build

* clean

* update Jekinsfile

* update readme

* update Jenkins

* use ck::half_t instead of ushort for bf16

* rename enum classes

* clean

* rename

* clean
parent ecf337ba
...@@ -15,7 +15,7 @@ namespace ck { ...@@ -15,7 +15,7 @@ namespace ck {
// 3. Run() does not construct new tensor coordinate // 3. Run() does not construct new tensor coordinate
template <index_t BlockSize, template <index_t BlockSize,
typename ElementwiseOperation, typename ElementwiseOperation,
InMemoryDataOperationEnum_t DstInMemOp, InMemoryDataOperationEnum DstInMemOp,
typename BlockSliceLengths, typename BlockSliceLengths,
typename ThreadClusterLengths, typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder, typename ThreadClusterArrangeOrder,
......
...@@ -15,7 +15,7 @@ namespace ck { ...@@ -15,7 +15,7 @@ namespace ck {
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate // 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template <index_t BlockSize, template <index_t BlockSize,
typename ElementwiseOperation, typename ElementwiseOperation,
InMemoryDataOperationEnum_t DstInMemOp, InMemoryDataOperationEnum DstInMemOp,
typename BlockSliceLengths, typename BlockSliceLengths,
typename ThreadClusterLengths, typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder, typename ThreadClusterArrangeOrder,
......
...@@ -5,7 +5,7 @@ namespace ck { ...@@ -5,7 +5,7 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
enum struct ConvolutionBackwardDataSpecialization_t enum struct ConvolutionBackwardDataSpecialization
{ {
Default, Default,
Filter1x1Stride1Pad0, Filter1x1Stride1Pad0,
......
...@@ -7,7 +7,7 @@ namespace ck { ...@@ -7,7 +7,7 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
enum struct ConvolutionForwardSpecialization_t enum struct ConvolutionForwardSpecialization
{ {
Default, Default,
Filter1x1Pad0, Filter1x1Pad0,
...@@ -15,14 +15,14 @@ enum struct ConvolutionForwardSpecialization_t ...@@ -15,14 +15,14 @@ enum struct ConvolutionForwardSpecialization_t
OddC, OddC,
}; };
inline std::string getConvFwdSpecializationStr(const ConvolutionForwardSpecialization_t& s) inline std::string getConvFwdSpecializationStr(const ConvolutionForwardSpecialization& s)
{ {
switch(s) switch(s)
{ {
case ConvolutionForwardSpecialization_t::Default: return "Default"; case ConvolutionForwardSpecialization::Default: return "Default";
case ConvolutionForwardSpecialization_t::Filter1x1Pad0: return "Filter1x1Pad0"; case ConvolutionForwardSpecialization::Filter1x1Pad0: return "Filter1x1Pad0";
case ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0: return "Filter1x1Stride1Pad0"; case ConvolutionForwardSpecialization::Filter1x1Stride1Pad0: return "Filter1x1Stride1Pad0";
case ConvolutionForwardSpecialization_t::OddC: return "OddC"; case ConvolutionForwardSpecialization::OddC: return "OddC";
default: return "Unrecognized specialization!"; default: return "Unrecognized specialization!";
} }
} }
......
...@@ -105,7 +105,7 @@ template <typename ALayout, ...@@ -105,7 +105,7 @@ template <typename ALayout,
typename CElementwiseOperation, typename CElementwiseOperation,
typename D0ReduceOperation, typename D0ReduceOperation,
typename D1ReduceOperation, typename D1ReduceOperation,
GemmSpecialization_t GemmSpecialization, GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage, index_t NumGemmKPrefetchStage,
index_t BlockSize, index_t BlockSize,
index_t MPerBlock, index_t MPerBlock,
...@@ -171,8 +171,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi ...@@ -171,8 +171,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
const auto MPad = M - MRaw; const auto MPad = M - MRaw;
const auto KPad = K - KRaw; const auto KPad = K - KRaw;
if constexpr(GemmSpecialization == GemmSpecialization_t::MKPadding || if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpecialization == GemmSpecialization_t::MNKPadding) GemmSpec == GemmSpecialization::MNKPadding)
{ {
// pad both M and K // pad both M and K
assert(K % AK1 == 0); assert(K % AK1 == 0);
...@@ -195,8 +195,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi ...@@ -195,8 +195,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
return a_grid_desc_ak0_m_ak1; return a_grid_desc_ak0_m_ak1;
} }
else if constexpr(GemmSpecialization == GemmSpecialization_t::MPadding || else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpecialization == GemmSpecialization_t::MNPadding) GemmSpec == GemmSpecialization::MNPadding)
{ {
// pad M, but not K // pad M, but not K
assert(KRaw % AK1 == 0); assert(KRaw % AK1 == 0);
...@@ -212,8 +212,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi ...@@ -212,8 +212,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
return a_grid_desc_ak0_m_ak1; return a_grid_desc_ak0_m_ak1;
} }
else if constexpr(GemmSpecialization == GemmSpecialization_t::KPadding || else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpecialization == GemmSpecialization_t::NKPadding) GemmSpec == GemmSpecialization::NKPadding)
{ {
// pad K, but not M // pad K, but not M
assert(K % AK1 == 0); assert(K % AK1 == 0);
...@@ -274,8 +274,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi ...@@ -274,8 +274,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
const auto NPad = N - NRaw; const auto NPad = N - NRaw;
const auto KPad = K - KRaw; const auto KPad = K - KRaw;
if constexpr(GemmSpecialization == GemmSpecialization_t::NKPadding || if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
GemmSpecialization == GemmSpecialization_t::MNKPadding) GemmSpec == GemmSpecialization::MNKPadding)
{ {
// pad both N and K // pad both N and K
assert(K % BK1 == 0); assert(K % BK1 == 0);
...@@ -298,8 +298,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi ...@@ -298,8 +298,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
return b_grid_desc_bk0_n_bk1; return b_grid_desc_bk0_n_bk1;
} }
else if constexpr(GemmSpecialization == GemmSpecialization_t::NPadding || else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpecialization == GemmSpecialization_t::MNPadding) GemmSpec == GemmSpecialization::MNPadding)
{ {
// pad N, but not K // pad N, but not K
assert(KRaw % BK1 == 0); assert(KRaw % BK1 == 0);
...@@ -315,8 +315,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi ...@@ -315,8 +315,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
return b_grid_desc_bk0_n_bk1; return b_grid_desc_bk0_n_bk1;
} }
else if constexpr(GemmSpecialization == GemmSpecialization_t::KPadding || else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpecialization == GemmSpecialization_t::MKPadding) GemmSpec == GemmSpecialization::MKPadding)
{ {
// pad K, but not N // pad K, but not N
assert(K % BK1 == 0); assert(K % BK1 == 0);
...@@ -377,8 +377,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi ...@@ -377,8 +377,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
const auto MPad = M - MRaw; const auto MPad = M - MRaw;
const auto NPad = N - NRaw; const auto NPad = N - NRaw;
if constexpr(GemmSpecialization == GemmSpecialization_t::MNPadding || if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
GemmSpecialization == GemmSpecialization_t::MNKPadding) GemmSpec == GemmSpecialization::MNKPadding)
{ {
// pad M and N // pad M and N
return transform_tensor_descriptor(c_grid_desc_mraw_nraw, return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
...@@ -387,8 +387,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi ...@@ -387,8 +387,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
} }
else if constexpr(GemmSpecialization == GemmSpecialization_t::MPadding || else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpecialization == GemmSpecialization_t::MKPadding) GemmSpec == GemmSpecialization::MKPadding)
{ {
// pad M, but not N // pad M, but not N
return transform_tensor_descriptor( return transform_tensor_descriptor(
...@@ -397,8 +397,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi ...@@ -397,8 +397,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
} }
else if constexpr(GemmSpecialization == GemmSpecialization_t::NPadding || else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpecialization == GemmSpecialization_t::NKPadding) GemmSpec == GemmSpecialization::NKPadding)
{ {
// pad N, but not M // pad N, but not M
return transform_tensor_descriptor( return transform_tensor_descriptor(
...@@ -422,10 +422,10 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi ...@@ -422,10 +422,10 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto MPad = M - MRaw; const auto MPad = M - MRaw;
if constexpr(GemmSpecialization == GemmSpecialization_t::MPadding || if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpecialization == GemmSpecialization_t::MNPadding || GemmSpec == GemmSpecialization::MNPadding ||
GemmSpecialization == GemmSpecialization_t::MKPadding || GemmSpec == GemmSpecialization::MKPadding ||
GemmSpecialization == GemmSpecialization_t::MNKPadding) GemmSpec == GemmSpecialization::MNKPadding)
{ {
// pad M // pad M
return transform_tensor_descriptor(d_grid_desc_mraw, return transform_tensor_descriptor(d_grid_desc_mraw,
...@@ -544,8 +544,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi ...@@ -544,8 +544,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
CElementwiseOperation, CElementwiseOperation,
D0ReduceOperation, D0ReduceOperation,
D1ReduceOperation, D1ReduceOperation,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
InMemoryDataOperationEnum_t::AtomicAdd, InMemoryDataOperationEnum::AtomicAdd,
AGridDesc_AK0_M_AK1, AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1, BGridDesc_BK0_N_BK1,
CGridDesc_M_N, CGridDesc_M_N,
......
...@@ -277,7 +277,7 @@ struct DeviceBatchedGemmXdl ...@@ -277,7 +277,7 @@ struct DeviceBatchedGemmXdl
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
AccDataType, AccDataType,
CDataType, CDataType,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1, AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1, BGridDesc_K0_N_K1,
CGridDesc_M_N, CGridDesc_M_N,
......
...@@ -209,7 +209,7 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -209,7 +209,7 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
AccDataType, AccDataType,
CDataType, CDataType,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1, AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1, BGridDesc_K0_N_K1,
CGridDesc_M_N, CGridDesc_M_N,
...@@ -250,7 +250,7 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -250,7 +250,7 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
AccDataType, AccDataType,
CDataType, CDataType,
InMemoryDataOperationEnum_t::AtomicAdd, InMemoryDataOperationEnum::AtomicAdd,
AGridDesc_K0_M_K1, AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1, BGridDesc_K0_N_K1,
CGridDesc_M_N, CGridDesc_M_N,
......
...@@ -25,7 +25,7 @@ template <typename InDataType, ...@@ -25,7 +25,7 @@ template <typename InDataType,
typename InElementwiseOperation, typename InElementwiseOperation,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation, typename OutElementwiseOperation,
ConvolutionBackwardDataSpecialization_t ConvBackwardDataSpecialization, ConvolutionBackwardDataSpecialization ConvBackwardDataSpecialization,
ck::index_t BlockSize, ck::index_t BlockSize,
ck::index_t MPerBlock, ck::index_t MPerBlock,
ck::index_t NPerBlock, ck::index_t NPerBlock,
...@@ -131,7 +131,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -131,7 +131,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
if constexpr(ConvBackwardDataSpecialization == if constexpr(ConvBackwardDataSpecialization ==
ConvolutionBackwardDataSpecialization_t::Filter1x1Stride1Pad0) ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0)
{ {
// A: output tensor // A: output tensor
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
...@@ -368,7 +368,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -368,7 +368,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
ABDataType, // TODO: distinguish A/B datatype ABDataType, // TODO: distinguish A/B datatype
AccDataType, AccDataType,
CDataType, CDataType,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1, AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1, BGridDesc_K0_N_K1,
CGridDesc_M_N, CGridDesc_M_N,
...@@ -671,7 +671,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -671,7 +671,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if constexpr(ConvBackwardDataSpecialization == if constexpr(ConvBackwardDataSpecialization ==
ConvolutionBackwardDataSpecialization_t::Filter1x1Stride1Pad0) ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0)
{ {
// check if it's 1x1, stride=1 pad = 0 conv // check if it's 1x1, stride=1 pad = 0 conv
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
......
...@@ -27,7 +27,7 @@ template < ...@@ -27,7 +27,7 @@ template <
typename InElementwiseOperation, typename InElementwiseOperation,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation, typename OutElementwiseOperation,
ConvolutionForwardSpecialization_t ConvForwardSpecialization, ConvolutionForwardSpecialization ConvForwardSpecialization,
ck::index_t BlockSize, ck::index_t BlockSize,
ck::index_t MPerBlock, ck::index_t MPerBlock,
ck::index_t NPerBlock, ck::index_t NPerBlock,
...@@ -125,7 +125,7 @@ struct ...@@ -125,7 +125,7 @@ struct
const auto GemmMPad = GemmM - GemmMRaw; const auto GemmMPad = GemmM - GemmMRaw;
if constexpr(ConvForwardSpecialization == if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{ // 1x1, stride=1, pad=0 { // 1x1, stride=1, pad=0
const index_t GemmK = Y * X * C; const index_t GemmK = Y * X * C;
assert(GemmK % GemmK1Number == 0); assert(GemmK % GemmK1Number == 0);
...@@ -179,7 +179,7 @@ struct ...@@ -179,7 +179,7 @@ struct
resi_grid_desc_gemmm_gemmn); resi_grid_desc_gemmm_gemmn);
} }
else if constexpr(ConvForwardSpecialization == else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0) ConvolutionForwardSpecialization::Filter1x1Pad0)
{ // 1x1, pad=0 { // 1x1, pad=0
const index_t GemmK = Y * X * C; const index_t GemmK = Y * X * C;
assert(GemmK % GemmK1Number == 0); assert(GemmK % GemmK1Number == 0);
...@@ -249,7 +249,7 @@ struct ...@@ -249,7 +249,7 @@ struct
bias_grid_desc_gemmm_gemmn, bias_grid_desc_gemmm_gemmn,
resi_grid_desc_gemmm_gemmn); resi_grid_desc_gemmm_gemmn);
} }
else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization_t::OddC) else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::OddC)
{ // C = odd value { // C = odd value
const index_t GemmKRaw = Y * X * C; const index_t GemmKRaw = Y * X * C;
const index_t GemmK = math::integer_least_multiple(GemmKRaw, K0PerBlock * GemmK1Number); const index_t GemmK = math::integer_least_multiple(GemmKRaw, K0PerBlock * GemmK1Number);
...@@ -466,7 +466,7 @@ struct ...@@ -466,7 +466,7 @@ struct
ABDataType, // TODO: distinguish A/B datatype ABDataType, // TODO: distinguish A/B datatype
AccDataType, AccDataType,
CDataType, CDataType,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1, AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1, BGridDesc_K0_N_K1,
CGridDesc_M_N, CGridDesc_M_N,
...@@ -811,7 +811,7 @@ struct ...@@ -811,7 +811,7 @@ struct
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if constexpr(ConvForwardSpecialization == if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{ {
// check if it's 1x1, stride=1 conv // check if it's 1x1, stride=1 conv
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
...@@ -823,7 +823,7 @@ struct ...@@ -823,7 +823,7 @@ struct
} }
} }
else if constexpr(ConvForwardSpecialization == else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0) ConvolutionForwardSpecialization::Filter1x1Pad0)
{ {
// check if it's 1x1 conv // check if it's 1x1 conv
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
......
...@@ -27,8 +27,8 @@ template < ...@@ -27,8 +27,8 @@ template <
typename InElementwiseOperation, typename InElementwiseOperation,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation, typename OutElementwiseOperation,
InMemoryDataOperationEnum_t OutGlobalMemoryDataOperation, InMemoryDataOperationEnum OutGlobalMemoryDataOperation,
ConvolutionForwardSpecialization_t ConvForwardSpecialization, ConvolutionForwardSpecialization ConvForwardSpecialization,
ck::index_t BlockSize, ck::index_t BlockSize,
ck::index_t MPerBlock, ck::index_t MPerBlock,
ck::index_t NPerBlock, ck::index_t NPerBlock,
...@@ -124,7 +124,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X ...@@ -124,7 +124,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
const auto GemmMPad = GemmM - GemmMRaw; const auto GemmMPad = GemmM - GemmMRaw;
if constexpr(ConvForwardSpecialization == if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{ // 1x1, stride=1, pad=0 { // 1x1, stride=1, pad=0
const index_t GemmK = Y * X * C; const index_t GemmK = Y * X * C;
assert(GemmK % GemmK1Number == 0); assert(GemmK % GemmK1Number == 0);
...@@ -174,7 +174,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X ...@@ -174,7 +174,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
bias_grid_desc_gemmm_gemmn); bias_grid_desc_gemmm_gemmn);
} }
else if constexpr(ConvForwardSpecialization == else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0) ConvolutionForwardSpecialization::Filter1x1Pad0)
{ // 1x1, pad=0 { // 1x1, pad=0
const index_t GemmK = Y * X * C; const index_t GemmK = Y * X * C;
assert(GemmK % GemmK1Number == 0); assert(GemmK % GemmK1Number == 0);
...@@ -240,7 +240,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X ...@@ -240,7 +240,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
out_gemmm_gemmn_grid_desc, out_gemmm_gemmn_grid_desc,
bias_grid_desc_gemmm_gemmn); bias_grid_desc_gemmm_gemmn);
} }
else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization_t::OddC) else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::OddC)
{ // C = odd value { // C = odd value
const index_t GemmKRaw = Y * X * C; const index_t GemmKRaw = Y * X * C;
const index_t GemmK = math::integer_least_multiple(GemmKRaw, K0PerBlock * GemmK1Number); const index_t GemmK = math::integer_least_multiple(GemmKRaw, K0PerBlock * GemmK1Number);
...@@ -763,7 +763,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X ...@@ -763,7 +763,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if constexpr(ConvForwardSpecialization == if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{ {
// check if it's 1x1, stride=1 conv // check if it's 1x1, stride=1 conv
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
...@@ -775,7 +775,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X ...@@ -775,7 +775,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
} }
} }
else if constexpr(ConvForwardSpecialization == else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0) ConvolutionForwardSpecialization::Filter1x1Pad0)
{ {
// check if it's 1x1 conv // check if it's 1x1 conv
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
......
...@@ -26,7 +26,7 @@ template < ...@@ -26,7 +26,7 @@ template <
typename InElementwiseOperation, typename InElementwiseOperation,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation, typename OutElementwiseOperation,
ConvolutionForwardSpecialization_t ConvForwardSpecialization, ConvolutionForwardSpecialization ConvForwardSpecialization,
ck::index_t BlockSize, ck::index_t BlockSize,
ck::index_t MPerBlock, ck::index_t MPerBlock,
ck::index_t NPerBlock, ck::index_t NPerBlock,
...@@ -120,7 +120,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -120,7 +120,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
const auto GemmMPad = GemmM - GemmMRaw; const auto GemmMPad = GemmM - GemmMRaw;
if constexpr(ConvForwardSpecialization == if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{ // 1x1, stride=1, pad=0 { // 1x1, stride=1, pad=0
const index_t GemmK = Y * X * C; const index_t GemmK = Y * X * C;
assert(GemmK % GemmK1Number == 0); assert(GemmK % GemmK1Number == 0);
...@@ -165,7 +165,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -165,7 +165,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
out_gemmm_gemmn_grid_desc); out_gemmm_gemmn_grid_desc);
} }
else if constexpr(ConvForwardSpecialization == else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0) ConvolutionForwardSpecialization::Filter1x1Pad0)
{ // 1x1, pad=0 { // 1x1, pad=0
const index_t GemmK = Y * X * C; const index_t GemmK = Y * X * C;
assert(GemmK % GemmK1Number == 0); assert(GemmK % GemmK1Number == 0);
...@@ -226,7 +226,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -226,7 +226,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
wei_gemmk0_gemmn_gemmk1_grid_desc, wei_gemmk0_gemmn_gemmk1_grid_desc,
out_gemmm_gemmn_grid_desc); out_gemmm_gemmn_grid_desc);
} }
else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization_t::OddC) else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::OddC)
{ // C = odd value { // C = odd value
const index_t GemmKRaw = Y * X * C; const index_t GemmKRaw = Y * X * C;
const index_t GemmK = math::integer_least_multiple(GemmKRaw, K0PerBlock * GemmK1Number); const index_t GemmK = math::integer_least_multiple(GemmKRaw, K0PerBlock * GemmK1Number);
...@@ -424,7 +424,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -424,7 +424,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
AccDataType, AccDataType,
CDataType, // TODO: Add ShuffleType for DeviceConv2d CDataType, // TODO: Add ShuffleType for DeviceConv2d
CDataType, CDataType,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1, AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1, BGridDesc_K0_N_K1,
CGridDesc_M_N, CGridDesc_M_N,
...@@ -733,7 +733,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -733,7 +733,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if constexpr(ConvForwardSpecialization == if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{ {
// check if it's 1x1, stride=1 conv // check if it's 1x1, stride=1 conv
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
...@@ -745,7 +745,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -745,7 +745,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
} }
} }
else if constexpr(ConvForwardSpecialization == else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0) ConvolutionForwardSpecialization::Filter1x1Pad0)
{ {
// check if it's 1x1 conv // check if it's 1x1 conv
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
......
...@@ -25,7 +25,7 @@ template <typename InDataType, ...@@ -25,7 +25,7 @@ template <typename InDataType,
typename InElementwiseOperation, typename InElementwiseOperation,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation, typename OutElementwiseOperation,
ConvolutionForwardSpecialization_t ConvForwardSpecialization, ConvolutionForwardSpecialization ConvForwardSpecialization,
ck::index_t BlockSize, ck::index_t BlockSize,
ck::index_t MPerBlock, ck::index_t MPerBlock,
ck::index_t NPerBlock, ck::index_t NPerBlock,
...@@ -119,7 +119,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -119,7 +119,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
const index_t GemmK0 = GemmK / GemmK1Number; const index_t GemmK0 = GemmK / GemmK1Number;
if constexpr(ConvForwardSpecialization == if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{ {
// A: input tensor // A: input tensor
const auto in_gemmmraw_gemmk_grid_desc = const auto in_gemmmraw_gemmk_grid_desc =
...@@ -159,7 +159,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -159,7 +159,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
out_gemmm_gemmn_grid_desc); out_gemmm_gemmn_grid_desc);
} }
else if constexpr(ConvForwardSpecialization == else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0) ConvolutionForwardSpecialization::Filter1x1Pad0)
{ {
// A: input tensor // A: input tensor
const auto in_n_hi_wi_c_grid_desc = const auto in_n_hi_wi_c_grid_desc =
...@@ -316,7 +316,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -316,7 +316,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
ABDataType, // TODO: distinguish A/B datatype ABDataType, // TODO: distinguish A/B datatype
AccDataType, AccDataType,
CDataType, CDataType,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1, AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1, BGridDesc_K0_N_K1,
CGridDesc_M_N, CGridDesc_M_N,
...@@ -565,7 +565,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -565,7 +565,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if constexpr(ConvForwardSpecialization == if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{ {
// check if it's 1x1, stride=1 conv // check if it's 1x1, stride=1 conv
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
...@@ -577,7 +577,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -577,7 +577,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
} }
} }
else if constexpr(ConvForwardSpecialization == else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0) ConvolutionForwardSpecialization::Filter1x1Pad0)
{ {
// check if it's 1x1 conv // check if it's 1x1 conv
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
......
...@@ -83,7 +83,7 @@ template <typename InDataType, ...@@ -83,7 +83,7 @@ template <typename InDataType,
typename InElementwiseOperation, typename InElementwiseOperation,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation, typename OutElementwiseOperation,
ConvolutionForwardSpecialization_t ConvForwardSpecialization, ConvolutionForwardSpecialization ConvForwardSpecialization,
ck::index_t BlockSize, ck::index_t BlockSize,
ck::index_t MPerBlock, ck::index_t MPerBlock,
ck::index_t NPerBlock, ck::index_t NPerBlock,
...@@ -207,7 +207,7 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_ ...@@ -207,7 +207,7 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_
const index_t Ho = output_spatial_lengths[1]; const index_t Ho = output_spatial_lengths[1];
const index_t Wo = output_spatial_lengths[2]; const index_t Wo = output_spatial_lengths[2];
static_assert(ConvForwardSpecialization == ConvolutionForwardSpecialization_t::Default, static_assert(ConvForwardSpecialization == ConvolutionForwardSpecialization::Default,
"Wrong! This specialization not implemented!"); "Wrong! This specialization not implemented!");
const auto in_desc_n_di_hi_wi_c = const auto in_desc_n_di_hi_wi_c =
...@@ -287,7 +287,7 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_ ...@@ -287,7 +287,7 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_
InDataType, InDataType,
AccDataType, AccDataType,
OutDataType, OutDataType,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1, AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1, BGridDesc_K0_N_K1,
CGridDesc_M_N, CGridDesc_M_N,
......
...@@ -25,7 +25,7 @@ template <typename InDataType, ...@@ -25,7 +25,7 @@ template <typename InDataType,
typename InElementwiseOperation, typename InElementwiseOperation,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation, typename OutElementwiseOperation,
ConvolutionBackwardDataSpecialization_t ConvBackwardDataSpecialization, ConvolutionBackwardDataSpecialization ConvBackwardDataSpecialization,
ck::index_t NumDimSpatial, ck::index_t NumDimSpatial,
ck::index_t BlockSize, ck::index_t BlockSize,
ck::index_t MPerBlock, ck::index_t MPerBlock,
...@@ -116,7 +116,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho ...@@ -116,7 +116,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
const auto in_n_wi_c_grid_desc = make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C)); const auto in_n_wi_c_grid_desc = make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C));
if constexpr(ConvBackwardDataSpecialization == if constexpr(ConvBackwardDataSpecialization ==
ConvolutionBackwardDataSpecialization_t::Filter1x1Stride1Pad0) ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0)
{ {
// A: output tensor // A: output tensor
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
...@@ -336,7 +336,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho ...@@ -336,7 +336,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
if constexpr(ConvBackwardDataSpecialization == if constexpr(ConvBackwardDataSpecialization ==
ConvolutionBackwardDataSpecialization_t::Filter1x1Stride1Pad0) ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0)
{ {
// A: output tensor // A: output tensor
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
...@@ -618,7 +618,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho ...@@ -618,7 +618,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C)); make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C));
if constexpr(ConvBackwardDataSpecialization == if constexpr(ConvBackwardDataSpecialization ==
ConvolutionBackwardDataSpecialization_t::Filter1x1Stride1Pad0) ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0)
{ {
// A: output tensor // A: output tensor
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
...@@ -959,7 +959,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho ...@@ -959,7 +959,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
ABDataType, // TODO: distinguish A/B datatype ABDataType, // TODO: distinguish A/B datatype
AccDataType, AccDataType,
CDataType, CDataType,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1, AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1, BGridDesc_K0_N_K1,
CGridDesc_M_N, CGridDesc_M_N,
...@@ -1385,7 +1385,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho ...@@ -1385,7 +1385,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if constexpr(ConvBackwardDataSpecialization == if constexpr(ConvBackwardDataSpecialization ==
ConvolutionBackwardDataSpecialization_t::Filter1x1Stride1Pad0) ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0)
{ {
// check if it's 1x1, stride=1 pad = 0 conv // check if it's 1x1, stride=1 pad = 0 conv
for(int i = 0; i < NumDimSpatial; i++) for(int i = 0; i < NumDimSpatial; i++)
...@@ -1527,7 +1527,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho ...@@ -1527,7 +1527,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
<< K0PerBlock << K0PerBlock
<< ">"; << ">";
if constexpr(ConvBackwardDataSpecialization == if constexpr(ConvBackwardDataSpecialization ==
ConvolutionBackwardDataSpecialization_t::Filter1x1Stride1Pad0){ ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0){
str<< " Filter1x1Stride1Pad0"; str<< " Filter1x1Stride1Pad0";
} }
......
...@@ -44,7 +44,7 @@ template <typename InDataType, ...@@ -44,7 +44,7 @@ template <typename InDataType,
typename InElementwiseOperation, typename InElementwiseOperation,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation, typename OutElementwiseOperation,
ConvolutionForwardSpecialization_t ConvForwardSpecialization, ConvolutionForwardSpecialization ConvForwardSpecialization,
ck::index_t NumDimSpatial, ck::index_t NumDimSpatial,
ck::index_t BlockSize, ck::index_t BlockSize,
ck::index_t MPerBlock, ck::index_t MPerBlock,
...@@ -142,7 +142,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -142,7 +142,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
const index_t ConvStrideW = conv_filter_strides[0]; const index_t ConvStrideW = conv_filter_strides[0];
if constexpr(ConvForwardSpecialization == if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{ {
const auto in_gemmmraw_gemmk_grid_desc = const auto in_gemmmraw_gemmk_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k)); make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k));
...@@ -156,7 +156,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -156,7 +156,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
} }
else if constexpr(ConvForwardSpecialization == else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0) ConvolutionForwardSpecialization::Filter1x1Pad0)
{ {
const auto in_n_wi_c_grid_desc = const auto in_n_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C)); make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C));
...@@ -262,7 +262,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -262,7 +262,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
const index_t ConvStrideW = conv_filter_strides[1]; const index_t ConvStrideW = conv_filter_strides[1];
if constexpr(ConvForwardSpecialization == if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{ {
const auto in_gemmmraw_gemmk_grid_desc = const auto in_gemmmraw_gemmk_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k)); make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k));
...@@ -276,7 +276,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -276,7 +276,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
} }
else if constexpr(ConvForwardSpecialization == else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0) ConvolutionForwardSpecialization::Filter1x1Pad0)
{ {
const auto in_n_hi_wi_c_grid_desc = const auto in_n_hi_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
...@@ -395,7 +395,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -395,7 +395,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
const index_t ConvStrideW = conv_filter_strides[2]; const index_t ConvStrideW = conv_filter_strides[2];
if constexpr(ConvForwardSpecialization == if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{ {
const auto in_gemmmraw_gemmk_grid_desc = const auto in_gemmmraw_gemmk_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k)); make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k));
...@@ -409,7 +409,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -409,7 +409,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
} }
else if constexpr(ConvForwardSpecialization == else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0) ConvolutionForwardSpecialization::Filter1x1Pad0)
{ {
const auto in_n_di_hi_wi_c_grid_desc = const auto in_n_di_hi_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C)); make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C));
...@@ -613,7 +613,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -613,7 +613,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
ABDataType, // TODO: distinguish A/B datatype ABDataType, // TODO: distinguish A/B datatype
AccDataType, AccDataType,
CDataType, CDataType,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1, AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1, BGridDesc_K0_N_K1,
CGridDesc_M_N, CGridDesc_M_N,
...@@ -878,7 +878,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -878,7 +878,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
} }
if constexpr(ConvForwardSpecialization == if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{ {
// check if it's 1x1, stride=1 conv // check if it's 1x1, stride=1 conv
for(ck::index_t i = 0; i < NumDimSpatial; ++i) for(ck::index_t i = 0; i < NumDimSpatial; ++i)
...@@ -891,7 +891,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -891,7 +891,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
} }
} }
else if constexpr(ConvForwardSpecialization == else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0) ConvolutionForwardSpecialization::Filter1x1Pad0)
{ {
// check if it's 1x1 conv // check if it's 1x1 conv
for(ck::index_t i = 0; i < NumDimSpatial; ++i) for(ck::index_t i = 0; i < NumDimSpatial; ++i)
......
...@@ -29,7 +29,7 @@ template <typename ALayout, ...@@ -29,7 +29,7 @@ template <typename ALayout,
typename CElementwiseOperation, typename CElementwiseOperation,
typename D0ReduceOperation, typename D0ReduceOperation,
typename D1ReduceOperation, typename D1ReduceOperation,
GemmSpecialization_t GemmSpecialization, GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage, index_t NumGemmKPrefetchStage,
index_t BlockSize, index_t BlockSize,
index_t MPerBlock, index_t MPerBlock,
...@@ -95,8 +95,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera ...@@ -95,8 +95,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
const auto MPad = M - MRaw; const auto MPad = M - MRaw;
const auto KPad = K - KRaw; const auto KPad = K - KRaw;
if constexpr(GemmSpecialization == GemmSpecialization_t::MKPadding || if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpecialization == GemmSpecialization_t::MNKPadding) GemmSpec == GemmSpecialization::MNKPadding)
{ {
// pad both M and K // pad both M and K
assert(K % AK1 == 0); assert(K % AK1 == 0);
...@@ -119,8 +119,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera ...@@ -119,8 +119,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
return a_grid_desc_ak0_m_ak1; return a_grid_desc_ak0_m_ak1;
} }
else if constexpr(GemmSpecialization == GemmSpecialization_t::MPadding || else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpecialization == GemmSpecialization_t::MNPadding) GemmSpec == GemmSpecialization::MNPadding)
{ {
// pad M, but not K // pad M, but not K
assert(KRaw % AK1 == 0); assert(KRaw % AK1 == 0);
...@@ -136,8 +136,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera ...@@ -136,8 +136,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
return a_grid_desc_ak0_m_ak1; return a_grid_desc_ak0_m_ak1;
} }
else if constexpr(GemmSpecialization == GemmSpecialization_t::KPadding || else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpecialization == GemmSpecialization_t::NKPadding) GemmSpec == GemmSpecialization::NKPadding)
{ {
// pad K, but not M // pad K, but not M
assert(K % AK1 == 0); assert(K % AK1 == 0);
...@@ -198,8 +198,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera ...@@ -198,8 +198,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
const auto NPad = N - NRaw; const auto NPad = N - NRaw;
const auto KPad = K - KRaw; const auto KPad = K - KRaw;
if constexpr(GemmSpecialization == GemmSpecialization_t::NKPadding || if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
GemmSpecialization == GemmSpecialization_t::MNKPadding) GemmSpec == GemmSpecialization::MNKPadding)
{ {
// pad both N and K // pad both N and K
assert(K % BK1 == 0); assert(K % BK1 == 0);
...@@ -222,8 +222,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera ...@@ -222,8 +222,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
return b_grid_desc_bk0_n_bk1; return b_grid_desc_bk0_n_bk1;
} }
else if constexpr(GemmSpecialization == GemmSpecialization_t::NPadding || else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpecialization == GemmSpecialization_t::MNPadding) GemmSpec == GemmSpecialization::MNPadding)
{ {
// pad N, but not K // pad N, but not K
assert(KRaw % BK1 == 0); assert(KRaw % BK1 == 0);
...@@ -239,8 +239,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera ...@@ -239,8 +239,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
return b_grid_desc_bk0_n_bk1; return b_grid_desc_bk0_n_bk1;
} }
else if constexpr(GemmSpecialization == GemmSpecialization_t::KPadding || else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpecialization == GemmSpecialization_t::MKPadding) GemmSpec == GemmSpecialization::MKPadding)
{ {
// pad K, but not N // pad K, but not N
assert(K % BK1 == 0); assert(K % BK1 == 0);
...@@ -301,8 +301,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera ...@@ -301,8 +301,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
const auto MPad = M - MRaw; const auto MPad = M - MRaw;
const auto NPad = N - NRaw; const auto NPad = N - NRaw;
if constexpr(GemmSpecialization == GemmSpecialization_t::MNPadding || if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
GemmSpecialization == GemmSpecialization_t::MNKPadding) GemmSpec == GemmSpecialization::MNKPadding)
{ {
// pad M and N // pad M and N
return transform_tensor_descriptor(c_grid_desc_mraw_nraw, return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
...@@ -311,8 +311,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera ...@@ -311,8 +311,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
} }
else if constexpr(GemmSpecialization == GemmSpecialization_t::MPadding || else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpecialization == GemmSpecialization_t::MKPadding) GemmSpec == GemmSpecialization::MKPadding)
{ {
// pad M, but not N // pad M, but not N
return transform_tensor_descriptor( return transform_tensor_descriptor(
...@@ -321,8 +321,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera ...@@ -321,8 +321,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
} }
else if constexpr(GemmSpecialization == GemmSpecialization_t::NPadding || else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpecialization == GemmSpecialization_t::NKPadding) GemmSpec == GemmSpecialization::NKPadding)
{ {
// pad N, but not M // pad N, but not M
return transform_tensor_descriptor( return transform_tensor_descriptor(
...@@ -346,10 +346,10 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera ...@@ -346,10 +346,10 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto MPad = M - MRaw; const auto MPad = M - MRaw;
if constexpr(GemmSpecialization == GemmSpecialization_t::MPadding || if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpecialization == GemmSpecialization_t::MNPadding || GemmSpec == GemmSpecialization::MNPadding ||
GemmSpecialization == GemmSpecialization_t::MKPadding || GemmSpec == GemmSpecialization::MKPadding ||
GemmSpecialization == GemmSpecialization_t::MNKPadding) GemmSpec == GemmSpecialization::MNKPadding)
{ {
// pad M // pad M
return transform_tensor_descriptor(d_grid_desc_mraw, return transform_tensor_descriptor(d_grid_desc_mraw,
...@@ -382,8 +382,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera ...@@ -382,8 +382,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
CElementwiseOperation, CElementwiseOperation,
D0ReduceOperation, D0ReduceOperation,
D1ReduceOperation, D1ReduceOperation,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
InMemoryDataOperationEnum_t::AtomicAdd, InMemoryDataOperationEnum::AtomicAdd,
AGridDesc_AK0_M_AK1, AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1, BGridDesc_BK0_N_BK1,
CGridDesc_M_N, CGridDesc_M_N,
......
...@@ -27,7 +27,7 @@ template <typename ADataType, ...@@ -27,7 +27,7 @@ template <typename ADataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
GemmSpecialization_t GemmSpecialization, GemmSpecialization GemmSpec,
ck::index_t BlockSize, ck::index_t BlockSize,
ck::index_t MPerBlock, ck::index_t MPerBlock,
ck::index_t NPerBlock, ck::index_t NPerBlock,
...@@ -80,7 +80,7 @@ struct DeviceGemmXdl ...@@ -80,7 +80,7 @@ struct DeviceGemmXdl
} }
}(); }();
if constexpr(GemmSpecialization == GemmSpecialization_t::MNPadding) if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{ {
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
...@@ -119,7 +119,7 @@ struct DeviceGemmXdl ...@@ -119,7 +119,7 @@ struct DeviceGemmXdl
} }
}(); }();
if constexpr(GemmSpecialization == GemmSpecialization_t::MNPadding) if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{ {
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
...@@ -154,7 +154,7 @@ struct DeviceGemmXdl ...@@ -154,7 +154,7 @@ struct DeviceGemmXdl
} }
}(); }();
if constexpr(GemmSpecialization == GemmSpecialization_t::MNPadding) if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{ {
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
...@@ -186,7 +186,7 @@ struct DeviceGemmXdl ...@@ -186,7 +186,7 @@ struct DeviceGemmXdl
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
AccDataType, AccDataType,
CDataType, CDataType,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1, AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1, BGridDesc_K0_N_K1,
CGridDesc_M_N, CGridDesc_M_N,
......
...@@ -138,7 +138,7 @@ struct DeviceGemmXdl_C_Shuffle ...@@ -138,7 +138,7 @@ struct DeviceGemmXdl_C_Shuffle
AccDataType, AccDataType,
CShuffleDataType, CShuffleDataType,
CDataType, CDataType,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1, AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1, BGridDesc_K0_N_K1,
CGridDesc_M_N, CGridDesc_M_N,
......
...@@ -139,7 +139,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_2d ...@@ -139,7 +139,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_2d
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
AccDataType, AccDataType,
CDataType, CDataType,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1, AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1, BGridDesc_K0_N_K1,
CGridDesc_M_N, CGridDesc_M_N,
......
...@@ -147,7 +147,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation ...@@ -147,7 +147,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
AccDataType, AccDataType,
CDataType, CDataType,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1, AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1, BGridDesc_K0_N_K1,
CGridDesc_M_N, CGridDesc_M_N,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment