Unverified Commit cd167e49 authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

Compile for gfx908 and gfx90a (#130)

* adding compilation for multiple targets

* fix build

* clean

* update Jekinsfile

* update readme

* update Jenkins

* use ck::half_t instead of ushort for bf16

* rename enum classes

* clean

* rename

* clean
parent ecf337ba
...@@ -169,7 +169,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation_Add ...@@ -169,7 +169,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation_Add
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
AccDataType, AccDataType,
CDataType, CDataType,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1, AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1, BGridDesc_K0_N_K1,
CGridDesc_M_N, CGridDesc_M_N,
......
...@@ -24,7 +24,7 @@ template <typename ALayout, ...@@ -24,7 +24,7 @@ template <typename ALayout,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
GemmSpecialization_t GemmSpecialization, GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage, index_t NumGemmKPrefetchStage,
index_t BlockSize, index_t BlockSize,
index_t MPerBlock, index_t MPerBlock,
...@@ -84,8 +84,8 @@ struct DeviceGemm_Xdl_CShuffle ...@@ -84,8 +84,8 @@ struct DeviceGemm_Xdl_CShuffle
const auto MPad = M - MRaw; const auto MPad = M - MRaw;
const auto KPad = K - KRaw; const auto KPad = K - KRaw;
if constexpr(GemmSpecialization == GemmSpecialization_t::MKPadding || if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpecialization == GemmSpecialization_t::MNKPadding) GemmSpec == GemmSpecialization::MNKPadding)
{ {
// pad both M and K // pad both M and K
assert(K % AK1 == 0); assert(K % AK1 == 0);
...@@ -108,8 +108,8 @@ struct DeviceGemm_Xdl_CShuffle ...@@ -108,8 +108,8 @@ struct DeviceGemm_Xdl_CShuffle
return a_grid_desc_ak0_m_ak1; return a_grid_desc_ak0_m_ak1;
} }
else if constexpr(GemmSpecialization == GemmSpecialization_t::MPadding || else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpecialization == GemmSpecialization_t::MNPadding) GemmSpec == GemmSpecialization::MNPadding)
{ {
// pad M, but not K // pad M, but not K
assert(KRaw % AK1 == 0); assert(KRaw % AK1 == 0);
...@@ -125,8 +125,8 @@ struct DeviceGemm_Xdl_CShuffle ...@@ -125,8 +125,8 @@ struct DeviceGemm_Xdl_CShuffle
return a_grid_desc_ak0_m_ak1; return a_grid_desc_ak0_m_ak1;
} }
else if constexpr(GemmSpecialization == GemmSpecialization_t::KPadding || else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpecialization == GemmSpecialization_t::NKPadding) GemmSpec == GemmSpecialization::NKPadding)
{ {
// pad K, but not M // pad K, but not M
assert(K % AK1 == 0); assert(K % AK1 == 0);
...@@ -187,8 +187,8 @@ struct DeviceGemm_Xdl_CShuffle ...@@ -187,8 +187,8 @@ struct DeviceGemm_Xdl_CShuffle
const auto NPad = N - NRaw; const auto NPad = N - NRaw;
const auto KPad = K - KRaw; const auto KPad = K - KRaw;
if constexpr(GemmSpecialization == GemmSpecialization_t::NKPadding || if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
GemmSpecialization == GemmSpecialization_t::MNKPadding) GemmSpec == GemmSpecialization::MNKPadding)
{ {
// pad both N and K // pad both N and K
assert(K % BK1 == 0); assert(K % BK1 == 0);
...@@ -211,8 +211,8 @@ struct DeviceGemm_Xdl_CShuffle ...@@ -211,8 +211,8 @@ struct DeviceGemm_Xdl_CShuffle
return b_grid_desc_bk0_n_bk1; return b_grid_desc_bk0_n_bk1;
} }
else if constexpr(GemmSpecialization == GemmSpecialization_t::NPadding || else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpecialization == GemmSpecialization_t::MNPadding) GemmSpec == GemmSpecialization::MNPadding)
{ {
// pad N, but not K // pad N, but not K
assert(KRaw % BK1 == 0); assert(KRaw % BK1 == 0);
...@@ -228,8 +228,8 @@ struct DeviceGemm_Xdl_CShuffle ...@@ -228,8 +228,8 @@ struct DeviceGemm_Xdl_CShuffle
return b_grid_desc_bk0_n_bk1; return b_grid_desc_bk0_n_bk1;
} }
else if constexpr(GemmSpecialization == GemmSpecialization_t::KPadding || else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpecialization == GemmSpecialization_t::MKPadding) GemmSpec == GemmSpecialization::MKPadding)
{ {
// pad K, but not N // pad K, but not N
assert(K % BK1 == 0); assert(K % BK1 == 0);
...@@ -290,8 +290,8 @@ struct DeviceGemm_Xdl_CShuffle ...@@ -290,8 +290,8 @@ struct DeviceGemm_Xdl_CShuffle
const auto MPad = M - MRaw; const auto MPad = M - MRaw;
const auto NPad = N - NRaw; const auto NPad = N - NRaw;
if constexpr(GemmSpecialization == GemmSpecialization_t::MNPadding || if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
GemmSpecialization == GemmSpecialization_t::MNKPadding) GemmSpec == GemmSpecialization::MNKPadding)
{ {
// pad M and N // pad M and N
return transform_tensor_descriptor(c_grid_desc_mraw_nraw, return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
...@@ -300,8 +300,8 @@ struct DeviceGemm_Xdl_CShuffle ...@@ -300,8 +300,8 @@ struct DeviceGemm_Xdl_CShuffle
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
} }
else if constexpr(GemmSpecialization == GemmSpecialization_t::MPadding || else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpecialization == GemmSpecialization_t::MKPadding) GemmSpec == GemmSpecialization::MKPadding)
{ {
// pad M, but not N // pad M, but not N
return transform_tensor_descriptor( return transform_tensor_descriptor(
...@@ -310,8 +310,8 @@ struct DeviceGemm_Xdl_CShuffle ...@@ -310,8 +310,8 @@ struct DeviceGemm_Xdl_CShuffle
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
} }
else if constexpr(GemmSpecialization == GemmSpecialization_t::NPadding || else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpecialization == GemmSpecialization_t::NKPadding) GemmSpec == GemmSpecialization::NKPadding)
{ {
// pad N, but not M // pad N, but not M
return transform_tensor_descriptor( return transform_tensor_descriptor(
...@@ -340,7 +340,7 @@ struct DeviceGemm_Xdl_CShuffle ...@@ -340,7 +340,7 @@ struct DeviceGemm_Xdl_CShuffle
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
AGridDesc_AK0_M_AK1, AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1, BGridDesc_BK0_N_BK1,
CGridDesc_M_N, CGridDesc_M_N,
......
...@@ -31,7 +31,7 @@ template <typename ADataType, ...@@ -31,7 +31,7 @@ template <typename ADataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
GemmSpecialization_t GemmSpecialization, GemmSpecialization GemmSpec,
ck::index_t BlockSize, ck::index_t BlockSize,
ck::index_t MPerBlock, ck::index_t MPerBlock,
ck::index_t NPerBlock, ck::index_t NPerBlock,
...@@ -91,7 +91,7 @@ struct DeviceGemmXdlSplitK ...@@ -91,7 +91,7 @@ struct DeviceGemmXdlSplitK
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
if constexpr(GemmSpecialization == GemmSpecialization_t::MNPadding) if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{ {
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
return transform_tensor_descriptor( return transform_tensor_descriptor(
...@@ -136,7 +136,7 @@ struct DeviceGemmXdlSplitK ...@@ -136,7 +136,7 @@ struct DeviceGemmXdlSplitK
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
if constexpr(GemmSpecialization == GemmSpecialization_t::MNPadding) if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{ {
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor( return transform_tensor_descriptor(
...@@ -170,7 +170,7 @@ struct DeviceGemmXdlSplitK ...@@ -170,7 +170,7 @@ struct DeviceGemmXdlSplitK
} }
}(); }();
if constexpr(GemmSpecialization == GemmSpecialization_t::MNPadding) if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{ {
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
...@@ -209,7 +209,7 @@ struct DeviceGemmXdlSplitK ...@@ -209,7 +209,7 @@ struct DeviceGemmXdlSplitK
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
AccDataType, AccDataType,
CDataType, CDataType,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1, AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1, BGridDesc_K0_N_K1,
CGridDesc_M_N, CGridDesc_M_N,
...@@ -250,7 +250,7 @@ struct DeviceGemmXdlSplitK ...@@ -250,7 +250,7 @@ struct DeviceGemmXdlSplitK
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
AccDataType, AccDataType,
CDataType, CDataType,
InMemoryDataOperationEnum_t::AtomicAdd, InMemoryDataOperationEnum::AtomicAdd,
AGridDesc_K0_M_K1, AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1, BGridDesc_K0_N_K1,
CGridDesc_M_N, CGridDesc_M_N,
......
...@@ -31,7 +31,7 @@ template <typename ADataType, ...@@ -31,7 +31,7 @@ template <typename ADataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
GemmSpecialization_t GemmSpecialization, GemmSpecialization GemmSpec,
ck::index_t BlockSize, ck::index_t BlockSize,
ck::index_t MPerBlock, ck::index_t MPerBlock,
ck::index_t NPerBlock, ck::index_t NPerBlock,
...@@ -93,7 +93,7 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -93,7 +93,7 @@ struct DeviceGemmXdlSplitKCShuffle
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
if constexpr(GemmSpecialization == GemmSpecialization_t::MNPadding) if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{ {
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
return transform_tensor_descriptor( return transform_tensor_descriptor(
...@@ -138,7 +138,7 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -138,7 +138,7 @@ struct DeviceGemmXdlSplitKCShuffle
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
if constexpr(GemmSpecialization == GemmSpecialization_t::MNPadding) if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{ {
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor( return transform_tensor_descriptor(
...@@ -172,7 +172,7 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -172,7 +172,7 @@ struct DeviceGemmXdlSplitKCShuffle
} }
}(); }();
if constexpr(GemmSpecialization == GemmSpecialization_t::MNPadding) if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{ {
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
...@@ -211,7 +211,7 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -211,7 +211,7 @@ struct DeviceGemmXdlSplitKCShuffle
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
AccDataType, AccDataType,
CDataType, CDataType,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1, AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1, BGridDesc_K0_N_K1,
CGridDesc_M_N, CGridDesc_M_N,
...@@ -253,7 +253,7 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -253,7 +253,7 @@ struct DeviceGemmXdlSplitKCShuffle
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
AccDataType, AccDataType,
CDataType, CDataType,
InMemoryDataOperationEnum_t::AtomicAdd, InMemoryDataOperationEnum::AtomicAdd,
AGridDesc_K0_M_K1, AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1, BGridDesc_K0_N_K1,
CGridDesc_M_N, CGridDesc_M_N,
......
...@@ -27,7 +27,7 @@ template <typename ADataType, ...@@ -27,7 +27,7 @@ template <typename ADataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
GemmSpecialization_t GemmSpecialization, GemmSpecialization GemmSpec,
ck::index_t BlockSize, ck::index_t BlockSize,
ck::index_t MPerBlock, ck::index_t MPerBlock,
ck::index_t NPerBlock, ck::index_t NPerBlock,
...@@ -81,7 +81,7 @@ struct DeviceGroupedGemmXdl ...@@ -81,7 +81,7 @@ struct DeviceGroupedGemmXdl
} }
}(); }();
if constexpr(GemmSpecialization == GemmSpecialization_t::MNPadding) if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{ {
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
...@@ -120,7 +120,7 @@ struct DeviceGroupedGemmXdl ...@@ -120,7 +120,7 @@ struct DeviceGroupedGemmXdl
} }
}(); }();
if constexpr(GemmSpecialization == GemmSpecialization_t::MNPadding) if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{ {
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
...@@ -155,7 +155,7 @@ struct DeviceGroupedGemmXdl ...@@ -155,7 +155,7 @@ struct DeviceGroupedGemmXdl
} }
}(); }();
if constexpr(GemmSpecialization == GemmSpecialization_t::MNPadding) if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{ {
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
...@@ -187,7 +187,7 @@ struct DeviceGroupedGemmXdl ...@@ -187,7 +187,7 @@ struct DeviceGroupedGemmXdl
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
AccDataType, AccDataType,
CDataType, CDataType,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1, AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1, BGridDesc_K0_N_K1,
CGridDesc_M_N, CGridDesc_M_N,
......
...@@ -10,7 +10,7 @@ namespace ck { ...@@ -10,7 +10,7 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
template <ck::ReduceTensorOp_t ReduceOpId> template <ck::ReduceTensorOp ReduceOpId>
struct DevicePool2dFwd : public BaseOperator struct DevicePool2dFwd : public BaseOperator
{ {
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
...@@ -29,7 +29,7 @@ struct DevicePool2dFwd : public BaseOperator ...@@ -29,7 +29,7 @@ struct DevicePool2dFwd : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
template <ck::ReduceTensorOp_t ReduceOpId> template <ck::ReduceTensorOp ReduceOpId>
using DevicePool2dFwdPtr = std::unique_ptr<DevicePool2dFwd<ReduceOpId>>; using DevicePool2dFwdPtr = std::unique_ptr<DevicePool2dFwd<ReduceOpId>>;
} // namespace device } // namespace device
......
...@@ -16,7 +16,7 @@ namespace device { ...@@ -16,7 +16,7 @@ namespace device {
template <typename InDataType, template <typename InDataType,
typename OutDataType, typename OutDataType,
typename AccDataType, typename AccDataType,
ck::ReduceTensorOp_t ReduceOpId, ck::ReduceTensorOp ReduceOpId,
bool NeedIndices, bool NeedIndices,
ck::index_t BlockSize, ck::index_t BlockSize,
ck::index_t ReduceMThreadClusterSize, ck::index_t ReduceMThreadClusterSize,
...@@ -181,7 +181,7 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd ...@@ -181,7 +181,7 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd
reduce_lowest_length_ = window_spatial_lengths[1]; reduce_lowest_length_ = window_spatial_lengths[1];
// TODO: is this correct? // TODO: is this correct?
if constexpr(ReduceOpId == ck::ReduceTensorOp_t::AVG) if constexpr(ReduceOpId == ck::ReduceTensorOp::AVG)
{ {
ck::index_t divider = window_spatial_lengths[0] * window_spatial_lengths[1]; ck::index_t divider = window_spatial_lengths[0] * window_spatial_lengths[1];
in_element_op_ = InElementwiseOperation{divider}; in_element_op_ = InElementwiseOperation{divider};
......
...@@ -5,7 +5,7 @@ namespace ck { ...@@ -5,7 +5,7 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
enum struct GemmSpecialization_t enum struct GemmSpecialization
{ {
Default, Default,
MPadding, MPadding,
......
...@@ -37,11 +37,11 @@ namespace ck { ...@@ -37,11 +37,11 @@ namespace ck {
// The boolean member "indexable" are also provided in reduce_binary_operactor for // The boolean member "indexable" are also provided in reduce_binary_operactor for
// easier checking by the upper-layer codes in the kernels. // easier checking by the upper-layer codes in the kernels.
template <typename T, ReduceTensorOp_t Op> template <typename T, ReduceTensorOp Op>
struct reduce_binary_operator; struct reduce_binary_operator;
template <typename T> template <typename T>
struct reduce_binary_operator<T, ReduceTensorOp_t::ADD> struct reduce_binary_operator<T, ReduceTensorOp::ADD>
{ {
using opType = reduce::Add<T>; using opType = reduce::Add<T>;
using dataType = T; using dataType = T;
...@@ -50,7 +50,7 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::ADD> ...@@ -50,7 +50,7 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::ADD>
}; };
template <typename T> template <typename T>
struct reduce_binary_operator<T, ReduceTensorOp_t::MUL> struct reduce_binary_operator<T, ReduceTensorOp::MUL>
{ {
using opType = reduce::Mul<T>; using opType = reduce::Mul<T>;
using dataType = T; using dataType = T;
...@@ -59,7 +59,7 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::MUL> ...@@ -59,7 +59,7 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::MUL>
}; };
template <typename T> template <typename T>
struct reduce_binary_operator<T, ReduceTensorOp_t::MIN> struct reduce_binary_operator<T, ReduceTensorOp::MIN>
{ {
using opType = reduce::Min<T>; using opType = reduce::Min<T>;
using dataType = T; using dataType = T;
...@@ -68,7 +68,7 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::MIN> ...@@ -68,7 +68,7 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::MIN>
}; };
template <typename T> template <typename T>
struct reduce_binary_operator<T, ReduceTensorOp_t::MAX> struct reduce_binary_operator<T, ReduceTensorOp::MAX>
{ {
using opType = reduce::Max<T>; using opType = reduce::Max<T>;
using dataType = T; using dataType = T;
...@@ -77,7 +77,7 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::MAX> ...@@ -77,7 +77,7 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::MAX>
}; };
template <typename T> template <typename T>
struct reduce_binary_operator<T, ReduceTensorOp_t::AMAX> struct reduce_binary_operator<T, ReduceTensorOp::AMAX>
{ {
using opType = reduce::AMax<T>; using opType = reduce::AMax<T>;
using dataType = T; using dataType = T;
...@@ -86,7 +86,7 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::AMAX> ...@@ -86,7 +86,7 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::AMAX>
}; };
template <typename T> template <typename T>
struct reduce_binary_operator<T, ReduceTensorOp_t::AVG> struct reduce_binary_operator<T, ReduceTensorOp::AVG>
{ {
using opType = reduce::Add<T>; using opType = reduce::Add<T>;
using dataType = T; using dataType = T;
...@@ -95,7 +95,7 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::AVG> ...@@ -95,7 +95,7 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::AVG>
}; };
template <typename T> template <typename T>
struct reduce_binary_operator<T, ReduceTensorOp_t::NORM1> struct reduce_binary_operator<T, ReduceTensorOp::NORM1>
{ {
using opType = reduce::Add<T>; using opType = reduce::Add<T>;
using dataType = T; using dataType = T;
...@@ -104,7 +104,7 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::NORM1> ...@@ -104,7 +104,7 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::NORM1>
}; };
template <typename T> template <typename T>
struct reduce_binary_operator<T, ReduceTensorOp_t::NORM2> struct reduce_binary_operator<T, ReduceTensorOp::NORM2>
{ {
using opType = reduce::Add<T>; using opType = reduce::Add<T>;
using dataType = T; using dataType = T;
...@@ -115,7 +115,7 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::NORM2> ...@@ -115,7 +115,7 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::NORM2>
// The templated struct reduce_unary_operator maps the enum Ids of Reduce operators to two unary // The templated struct reduce_unary_operator maps the enum Ids of Reduce operators to two unary
// functor classes. // functor classes.
// The two unary functors are called before and afer the Reduction is executed respectively // The two unary functors are called before and afer the Reduction is executed respectively
template <typename T, ReduceTensorOp_t Op, bool IsFirstReduce, bool IsLastReduce> template <typename T, ReduceTensorOp Op, bool IsFirstReduce, bool IsLastReduce>
struct reduce_unary_operator struct reduce_unary_operator
{ {
using InElementwiseOperation = tensor_operation::element_wise::UnaryIdentic<T, T>; using InElementwiseOperation = tensor_operation::element_wise::UnaryIdentic<T, T>;
...@@ -123,42 +123,42 @@ struct reduce_unary_operator ...@@ -123,42 +123,42 @@ struct reduce_unary_operator
}; };
template <typename T, bool IsFirstReduce> template <typename T, bool IsFirstReduce>
struct reduce_unary_operator<T, ReduceTensorOp_t::AVG, IsFirstReduce, true> struct reduce_unary_operator<T, ReduceTensorOp::AVG, IsFirstReduce, true>
{ {
using InElementwiseOperation = tensor_operation::element_wise::UnaryIdentic<T, T>; using InElementwiseOperation = tensor_operation::element_wise::UnaryIdentic<T, T>;
using AccElementwiseOperation = tensor_operation::element_wise::UnaryIdentic<T, T, true>; using AccElementwiseOperation = tensor_operation::element_wise::UnaryIdentic<T, T, true>;
}; };
template <typename T, bool IsLastReduce> template <typename T, bool IsLastReduce>
struct reduce_unary_operator<T, ReduceTensorOp_t::NORM1, true, IsLastReduce> struct reduce_unary_operator<T, ReduceTensorOp::NORM1, true, IsLastReduce>
{ {
using InElementwiseOperation = tensor_operation::element_wise::UnaryAbs<T, T>; using InElementwiseOperation = tensor_operation::element_wise::UnaryAbs<T, T>;
using AccElementwiseOperation = tensor_operation::element_wise::UnaryIdentic<T, T>; using AccElementwiseOperation = tensor_operation::element_wise::UnaryIdentic<T, T>;
}; };
template <typename T, bool IsLastReduce> template <typename T, bool IsLastReduce>
struct reduce_unary_operator<T, ReduceTensorOp_t::AMAX, true, IsLastReduce> struct reduce_unary_operator<T, ReduceTensorOp::AMAX, true, IsLastReduce>
{ {
using InElementwiseOperation = tensor_operation::element_wise::UnaryAbs<T, T>; using InElementwiseOperation = tensor_operation::element_wise::UnaryAbs<T, T>;
using AccElementwiseOperation = tensor_operation::element_wise::UnaryIdentic<T, T>; using AccElementwiseOperation = tensor_operation::element_wise::UnaryIdentic<T, T>;
}; };
template <typename T> template <typename T>
struct reduce_unary_operator<T, ReduceTensorOp_t::NORM2, true, false> struct reduce_unary_operator<T, ReduceTensorOp::NORM2, true, false>
{ {
using InElementwiseOperation = tensor_operation::element_wise::UnarySquare<T, T>; using InElementwiseOperation = tensor_operation::element_wise::UnarySquare<T, T>;
using AccElementwiseOperation = tensor_operation::element_wise::UnaryIdentic<T, T>; using AccElementwiseOperation = tensor_operation::element_wise::UnaryIdentic<T, T>;
}; };
template <typename T> template <typename T>
struct reduce_unary_operator<T, ReduceTensorOp_t::NORM2, true, true> struct reduce_unary_operator<T, ReduceTensorOp::NORM2, true, true>
{ {
using InElementwiseOperation = tensor_operation::element_wise::UnarySquare<T, T>; using InElementwiseOperation = tensor_operation::element_wise::UnarySquare<T, T>;
using AccElementwiseOperation = tensor_operation::element_wise::UnarySqrt<T, T>; using AccElementwiseOperation = tensor_operation::element_wise::UnarySqrt<T, T>;
}; };
template <typename T> template <typename T>
struct reduce_unary_operator<T, ReduceTensorOp_t::NORM2, false, true> struct reduce_unary_operator<T, ReduceTensorOp::NORM2, false, true>
{ {
using InElementwiseOperation = tensor_operation::element_wise::UnaryIdentic<T, T>; using InElementwiseOperation = tensor_operation::element_wise::UnaryIdentic<T, T>;
using AccElementwiseOperation = tensor_operation::element_wise::UnarySqrt<T, T>; using AccElementwiseOperation = tensor_operation::element_wise::UnarySqrt<T, T>;
......
...@@ -227,21 +227,18 @@ struct GridwiseReduction_mk_to_m_blockwise ...@@ -227,21 +227,18 @@ struct GridwiseReduction_mk_to_m_blockwise
const auto zeroVal = ReduceOperation::GetReductionZeroVal(); const auto zeroVal = ReduceOperation::GetReductionZeroVal();
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert<InDataType>(zeroVal)); p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert<InDataType>(zeroVal));
auto out_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto out_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global, out_grid_desc_m.GetElementSpaceSize()); p_out_global, out_grid_desc_m.GetElementSpaceSize());
auto block_reduce_buf = auto block_reduce_buf =
make_dynamic_buffer<AddressSpaceEnum_t::Lds>(p_block_reduce_buffer, BlockSize); make_dynamic_buffer<AddressSpaceEnum::Lds>(p_block_reduce_buffer, BlockSize);
StaticBuffer<AddressSpaceEnum_t::Vgpr, StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
AccDataType,
MThreadSliceSize * KThreadSliceSize,
true>
in_thread_buf; in_thread_buf;
StaticBuffer<AddressSpaceEnum_t::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf; StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = zeroVal; }); static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = zeroVal; });
...@@ -336,7 +333,7 @@ struct GridwiseReduction_mk_to_m_blockwise ...@@ -336,7 +333,7 @@ struct GridwiseReduction_mk_to_m_blockwise
{ {
if(!float_equal_zero{}(beta)) if(!float_equal_zero{}(beta))
{ {
StaticBuffer<AddressSpaceEnum_t::Vgpr, OutDataType, MThreadSliceSize, true> StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true>
priorDstValueBuf; priorDstValueBuf;
auto threadwise_dst_load = auto threadwise_dst_load =
...@@ -376,7 +373,7 @@ struct GridwiseReduction_mk_to_m_blockwise ...@@ -376,7 +373,7 @@ struct GridwiseReduction_mk_to_m_blockwise
Sequence<0>, Sequence<0>,
0, 0,
OutDstVectorSize, OutDstVectorSize,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
1, 1,
true>( true>(
out_grid_desc_m, out_grid_desc_m,
...@@ -422,30 +419,26 @@ struct GridwiseReduction_mk_to_m_blockwise ...@@ -422,30 +419,26 @@ struct GridwiseReduction_mk_to_m_blockwise
const auto zeroVal = ReduceOperation::GetReductionZeroVal(); const auto zeroVal = ReduceOperation::GetReductionZeroVal();
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert<InDataType>(zeroVal)); p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert<InDataType>(zeroVal));
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global, out_grid_desc_m.GetElementSpaceSize()); p_out_global, out_grid_desc_m.GetElementSpaceSize());
auto out_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto out_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_indices_global, out_grid_desc_m.GetElementSpaceSize()); p_indices_global, out_grid_desc_m.GetElementSpaceSize());
auto block_reduce_val_buf = auto block_reduce_val_buf =
make_dynamic_buffer<AddressSpaceEnum_t::Lds>(p_block_reduce_val_buffer, BlockSize); make_dynamic_buffer<AddressSpaceEnum::Lds>(p_block_reduce_val_buffer, BlockSize);
auto block_reduce_idx_buf = auto block_reduce_idx_buf =
make_dynamic_buffer<AddressSpaceEnum_t::Lds>(p_block_reduce_idx_buffer, BlockSize); make_dynamic_buffer<AddressSpaceEnum::Lds>(p_block_reduce_idx_buffer, BlockSize);
StaticBuffer<AddressSpaceEnum_t::Vgpr, StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
AccDataType,
MThreadSliceSize * KThreadSliceSize,
true>
in_thread_val_buf; in_thread_val_buf;
StaticBuffer<AddressSpaceEnum_t::Vgpr, index_t, MThreadSliceSize * KThreadSliceSize, true> StaticBuffer<AddressSpaceEnum::Vgpr, index_t, MThreadSliceSize * KThreadSliceSize, true>
in_thread_idx_buf; in_thread_idx_buf;
StaticBuffer<AddressSpaceEnum_t::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf; StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
StaticBuffer<AddressSpaceEnum_t::Vgpr, IndexDataType, MThreadSliceSize, true> StaticBuffer<AddressSpaceEnum::Vgpr, IndexDataType, MThreadSliceSize, true> accu_index_buf;
accu_index_buf;
const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{}); const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
...@@ -561,7 +554,7 @@ struct GridwiseReduction_mk_to_m_blockwise ...@@ -561,7 +554,7 @@ struct GridwiseReduction_mk_to_m_blockwise
{ {
if(!float_equal_zero{}(beta)) if(!float_equal_zero{}(beta))
{ {
StaticBuffer<AddressSpaceEnum_t::Vgpr, OutDataType, MThreadSliceSize, true> StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true>
priorDstValueBuf; priorDstValueBuf;
auto threadwise_dst_load = auto threadwise_dst_load =
...@@ -601,7 +594,7 @@ struct GridwiseReduction_mk_to_m_blockwise ...@@ -601,7 +594,7 @@ struct GridwiseReduction_mk_to_m_blockwise
Sequence<0>, Sequence<0>,
0, 0,
OutDstVectorSize, OutDstVectorSize,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
1, 1,
false>( false>(
out_grid_desc_m, out_grid_desc_m,
...@@ -619,7 +612,7 @@ struct GridwiseReduction_mk_to_m_blockwise ...@@ -619,7 +612,7 @@ struct GridwiseReduction_mk_to_m_blockwise
Sequence<0>, Sequence<0>,
0, 0,
OutDstVectorSize, OutDstVectorSize,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
1, 1,
false>( false>(
out_grid_desc_m, out_grid_desc_m,
...@@ -678,36 +671,32 @@ struct GridwiseReduction_mk_to_m_blockwise ...@@ -678,36 +671,32 @@ struct GridwiseReduction_mk_to_m_blockwise
const auto zeroVal = ReduceOperation::GetReductionZeroVal(); const auto zeroVal = ReduceOperation::GetReductionZeroVal();
const auto src_global_val_buf = const auto src_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum_t::Global>(p_ws_values_global, make_dynamic_buffer<AddressSpaceEnum::Global>(p_ws_values_global,
in_grid_desc_m_k.GetElementSpaceSize(), in_grid_desc_m_k.GetElementSpaceSize(),
type_convert<InDataType>(zeroVal)); type_convert<InDataType>(zeroVal));
const auto src_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto src_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ws_indices_global, in_grid_desc_m_k.GetElementSpaceSize()); p_ws_indices_global, in_grid_desc_m_k.GetElementSpaceSize());
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global, out_grid_desc_m.GetElementSpaceSize()); p_out_global, out_grid_desc_m.GetElementSpaceSize());
auto out_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto out_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_indices_global, out_grid_desc_m.GetElementSpaceSize()); p_indices_global, out_grid_desc_m.GetElementSpaceSize());
auto block_reduce_val_buf = auto block_reduce_val_buf =
make_dynamic_buffer<AddressSpaceEnum_t::Lds>(p_block_reduce_val_buffer, BlockSize); make_dynamic_buffer<AddressSpaceEnum::Lds>(p_block_reduce_val_buffer, BlockSize);
auto block_reduce_idx_buf = auto block_reduce_idx_buf =
make_dynamic_buffer<AddressSpaceEnum_t::Lds>(p_block_reduce_idx_buffer, BlockSize); make_dynamic_buffer<AddressSpaceEnum::Lds>(p_block_reduce_idx_buffer, BlockSize);
StaticBuffer<AddressSpaceEnum_t::Vgpr, StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
AccDataType,
MThreadSliceSize * KThreadSliceSize,
true>
in_thread_val_buf; in_thread_val_buf;
StaticBuffer<AddressSpaceEnum_t::Vgpr, StaticBuffer<AddressSpaceEnum::Vgpr,
IndexDataType, IndexDataType,
MThreadSliceSize * KThreadSliceSize, MThreadSliceSize * KThreadSliceSize,
true> true>
in_thread_idx_buf; in_thread_idx_buf;
StaticBuffer<AddressSpaceEnum_t::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf; StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
StaticBuffer<AddressSpaceEnum_t::Vgpr, IndexDataType, MThreadSliceSize, true> StaticBuffer<AddressSpaceEnum::Vgpr, IndexDataType, MThreadSliceSize, true> accu_index_buf;
accu_index_buf;
const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{}); const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
...@@ -835,7 +824,7 @@ struct GridwiseReduction_mk_to_m_blockwise ...@@ -835,7 +824,7 @@ struct GridwiseReduction_mk_to_m_blockwise
{ {
if(!float_equal_zero{}(beta)) if(!float_equal_zero{}(beta))
{ {
StaticBuffer<AddressSpaceEnum_t::Vgpr, OutDataType, MThreadSliceSize, true> StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true>
priorDstValueBuf; priorDstValueBuf;
auto threadwise_dst_load = auto threadwise_dst_load =
...@@ -875,7 +864,7 @@ struct GridwiseReduction_mk_to_m_blockwise ...@@ -875,7 +864,7 @@ struct GridwiseReduction_mk_to_m_blockwise
Sequence<0>, Sequence<0>,
0, 0,
OutDstVectorSize, OutDstVectorSize,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
1, 1,
true>( true>(
out_grid_desc_m, out_grid_desc_m,
...@@ -893,7 +882,7 @@ struct GridwiseReduction_mk_to_m_blockwise ...@@ -893,7 +882,7 @@ struct GridwiseReduction_mk_to_m_blockwise
Sequence<0>, Sequence<0>,
0, 0,
OutDstVectorSize, OutDstVectorSize,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
1, 1,
true>( true>(
out_grid_desc_m, out_grid_desc_m,
......
...@@ -140,21 +140,18 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add ...@@ -140,21 +140,18 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add
// LDS // LDS
__shared__ AccDataType p_block_reduce_buffer[BlockSize]; __shared__ AccDataType p_block_reduce_buffer[BlockSize];
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert<InDataType>(zeroVal)); p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert<InDataType>(zeroVal));
auto out_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto out_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global, out_grid_desc_m.GetElementSpaceSize()); p_out_global, out_grid_desc_m.GetElementSpaceSize());
auto block_reduce_buf = auto block_reduce_buf =
make_dynamic_buffer<AddressSpaceEnum_t::Lds>(p_block_reduce_buffer, BlockSize); make_dynamic_buffer<AddressSpaceEnum::Lds>(p_block_reduce_buffer, BlockSize);
StaticBuffer<AddressSpaceEnum_t::Vgpr, StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
AccDataType,
MThreadSliceSize * KThreadSliceSize,
true>
in_thread_buf; in_thread_buf;
StaticBuffer<AddressSpaceEnum_t::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf; StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = zeroVal; }); static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = zeroVal; });
...@@ -259,7 +256,7 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add ...@@ -259,7 +256,7 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add
Sequence<0>, Sequence<0>,
0, 0,
OutDstVectorSize, OutDstVectorSize,
InMemoryDataOperationEnum_t::AtomicAdd, InMemoryDataOperationEnum::AtomicAdd,
1, 1,
true>( true>(
out_grid_desc_m, out_grid_desc_m,
......
...@@ -163,22 +163,19 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce ...@@ -163,22 +163,19 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
__shared__ AccDataType p_block_reduce_buffer[BlockSize]; __shared__ AccDataType p_block_reduce_buffer[BlockSize];
const auto in_global_buf = const auto in_global_buf =
make_dynamic_buffer<AddressSpaceEnum_t::Global>(p_src_global, make_dynamic_buffer<AddressSpaceEnum::Global>(p_src_global,
in_grid_desc_m_k.GetElementSpaceSize(), in_grid_desc_m_k.GetElementSpaceSize(),
type_convert<InDataType>(zeroVal)); type_convert<InDataType>(zeroVal));
auto workspace_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto workspace_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ws_values_global, workspace_desc_m_k.GetElementSpaceSize()); p_ws_values_global, workspace_desc_m_k.GetElementSpaceSize());
auto block_reduce_buf = auto block_reduce_buf =
make_dynamic_buffer<AddressSpaceEnum_t::Lds>(p_block_reduce_buffer, BlockSize); make_dynamic_buffer<AddressSpaceEnum::Lds>(p_block_reduce_buffer, BlockSize);
StaticBuffer<AddressSpaceEnum_t::Vgpr, StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
AccDataType,
MThreadSliceSize * KThreadSliceSize,
true>
in_thread_buf; in_thread_buf;
StaticBuffer<AddressSpaceEnum_t::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf; StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = zeroVal; }); static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = zeroVal; });
...@@ -272,7 +269,7 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce ...@@ -272,7 +269,7 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
Sequence<0, 1>, Sequence<0, 1>,
1, 1,
1, 1,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
1, 1,
true>( true>(
workspace_desc_m_k, workspace_desc_m_k,
...@@ -322,33 +319,29 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce ...@@ -322,33 +319,29 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
__shared__ index_t p_block_reduce_idx_buffer[BlockSize]; __shared__ index_t p_block_reduce_idx_buffer[BlockSize];
const auto in_global_buf = const auto in_global_buf =
make_dynamic_buffer<AddressSpaceEnum_t::Global>(p_src_global, make_dynamic_buffer<AddressSpaceEnum::Global>(p_src_global,
in_grid_desc_m_k.GetElementSpaceSize(), in_grid_desc_m_k.GetElementSpaceSize(),
type_convert<InDataType>(zeroVal)); type_convert<InDataType>(zeroVal));
auto workspace_global_val_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto workspace_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ws_values_global, workspace_desc_m_k.GetElementSpaceSize()); p_ws_values_global, workspace_desc_m_k.GetElementSpaceSize());
auto workspace_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto workspace_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ws_indices_global, workspace_desc_m_k.GetElementSpaceSize()); p_ws_indices_global, workspace_desc_m_k.GetElementSpaceSize());
auto block_reduce_val_buf = auto block_reduce_val_buf =
make_dynamic_buffer<AddressSpaceEnum_t::Lds>(p_block_reduce_val_buffer, BlockSize); make_dynamic_buffer<AddressSpaceEnum::Lds>(p_block_reduce_val_buffer, BlockSize);
auto block_reduce_idx_buf = auto block_reduce_idx_buf =
make_dynamic_buffer<AddressSpaceEnum_t::Lds>(p_block_reduce_idx_buffer, BlockSize); make_dynamic_buffer<AddressSpaceEnum::Lds>(p_block_reduce_idx_buffer, BlockSize);
StaticBuffer<AddressSpaceEnum_t::Vgpr, StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
AccDataType,
MThreadSliceSize * KThreadSliceSize,
true>
in_thread_val_buf; in_thread_val_buf;
StaticBuffer<AddressSpaceEnum_t::Vgpr, StaticBuffer<AddressSpaceEnum::Vgpr,
IndexDataType, IndexDataType,
MThreadSliceSize * KThreadSliceSize, MThreadSliceSize * KThreadSliceSize,
true> true>
in_thread_idx_buf; in_thread_idx_buf;
StaticBuffer<AddressSpaceEnum_t::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf; StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
StaticBuffer<AddressSpaceEnum_t::Vgpr, IndexDataType, MThreadSliceSize, true> StaticBuffer<AddressSpaceEnum::Vgpr, IndexDataType, MThreadSliceSize, true> accu_index_buf;
accu_index_buf;
const index_t thread_local_id = get_thread_local_1d_id(); const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id(); const index_t block_global_id = get_block_1d_id();
...@@ -461,7 +454,7 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce ...@@ -461,7 +454,7 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
Sequence<0, 1>, Sequence<0, 1>,
1, 1,
1, 1,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
1, 1,
true>( true>(
workspace_desc_m_k, workspace_desc_m_k,
...@@ -480,7 +473,7 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce ...@@ -480,7 +473,7 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
Sequence<0, 1>, Sequence<0, 1>,
1, 1,
1, 1,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
1, 1,
true>( true>(
workspace_desc_m_k, workspace_desc_m_k,
......
...@@ -132,18 +132,15 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -132,18 +132,15 @@ struct GridwiseReduction_mk_to_m_threadwise
const auto zeroVal = ReduceOperation::GetReductionZeroVal(); const auto zeroVal = ReduceOperation::GetReductionZeroVal();
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert<InDataType>(zeroVal)); p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert<InDataType>(zeroVal));
auto dst_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto dst_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global, out_grid_desc_m.GetElementSpaceSize()); p_out_global, out_grid_desc_m.GetElementSpaceSize());
StaticBuffer<AddressSpaceEnum_t::Vgpr, StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
AccDataType,
MThreadSliceSize * KThreadSliceSize,
true>
in_thread_buf; in_thread_buf;
StaticBuffer<AddressSpaceEnum_t::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf; StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = zeroVal; }); static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = zeroVal; });
...@@ -223,7 +220,7 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -223,7 +220,7 @@ struct GridwiseReduction_mk_to_m_threadwise
true>( true>(
out_grid_desc_m, make_multi_index(thread_global_1d_id * MThreadSliceSize)); out_grid_desc_m, make_multi_index(thread_global_1d_id * MThreadSliceSize));
StaticBuffer<AddressSpaceEnum_t::Vgpr, OutDataType, MThreadSliceSize, true> StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true>
priorDstValue_buf; priorDstValue_buf;
threadwise_dst_load.Run(out_grid_desc_m, threadwise_dst_load.Run(out_grid_desc_m,
...@@ -248,7 +245,7 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -248,7 +245,7 @@ struct GridwiseReduction_mk_to_m_threadwise
Sequence<0>, Sequence<0>,
0, 0,
OutDstVectorSize, OutDstVectorSize,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
1, 1,
false>( false>(
out_grid_desc_m, out_grid_desc_m,
...@@ -277,22 +274,18 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -277,22 +274,18 @@ struct GridwiseReduction_mk_to_m_threadwise
const auto zeroVal = ReduceOperation::GetReductionZeroVal(); const auto zeroVal = ReduceOperation::GetReductionZeroVal();
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert<InDataType>(zeroVal)); p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert<InDataType>(zeroVal));
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global, out_grid_desc_m.GetElementSpaceSize()); p_out_global, out_grid_desc_m.GetElementSpaceSize());
auto out_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto out_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_indices_global, out_grid_desc_m.GetElementSpaceSize()); p_indices_global, out_grid_desc_m.GetElementSpaceSize());
StaticBuffer<AddressSpaceEnum_t::Vgpr, StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
AccDataType,
MThreadSliceSize * KThreadSliceSize,
true>
in_thread_buf; in_thread_buf;
StaticBuffer<AddressSpaceEnum_t::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf; StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
StaticBuffer<AddressSpaceEnum_t::Vgpr, IndexDataType, MThreadSliceSize, true> StaticBuffer<AddressSpaceEnum::Vgpr, IndexDataType, MThreadSliceSize, true> accu_index_buf;
accu_index_buf;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) = zeroVal; accu_value_buf(I) = zeroVal;
...@@ -382,7 +375,7 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -382,7 +375,7 @@ struct GridwiseReduction_mk_to_m_threadwise
false>( false>(
out_grid_desc_m, make_multi_index(thread_global_1d_id * MThreadSliceSize)); out_grid_desc_m, make_multi_index(thread_global_1d_id * MThreadSliceSize));
StaticBuffer<AddressSpaceEnum_t::Vgpr, OutDataType, MThreadSliceSize, true> StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true>
priorDstValue_buf; priorDstValue_buf;
threadwise_dst_load.Run(out_grid_desc_m, threadwise_dst_load.Run(out_grid_desc_m,
...@@ -407,7 +400,7 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -407,7 +400,7 @@ struct GridwiseReduction_mk_to_m_threadwise
Sequence<0>, Sequence<0>,
0, 0,
OutDstVectorSize, OutDstVectorSize,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
1, 1,
false>( false>(
out_grid_desc_m, out_grid_desc_m,
...@@ -424,7 +417,7 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -424,7 +417,7 @@ struct GridwiseReduction_mk_to_m_threadwise
Sequence<0>, Sequence<0>,
0, 0,
OutDstVectorSize, OutDstVectorSize,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
1, 1,
false>( false>(
out_grid_desc_m, out_grid_desc_m,
......
...@@ -55,7 +55,7 @@ template <index_t BlockSize, ...@@ -55,7 +55,7 @@ template <index_t BlockSize,
typename FloatAB, typename FloatAB,
typename FloatAcc, typename FloatAcc,
typename FloatC, typename FloatC,
InMemoryDataOperationEnum_t CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AGridDesc_GK0_GM0_GM1_GK1, typename AGridDesc_GK0_GM0_GM1_GK1,
typename BGridDesc_GK0_GN0_GN1_GK1, typename BGridDesc_GK0_GN0_GN1_GK1,
typename CGridDesc_GM0_GM1_GN0_GN1, typename CGridDesc_GM0_GM1_GN0_GN1,
...@@ -329,11 +329,11 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN ...@@ -329,11 +329,11 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) integral_constant<bool, HasDoubleTailKBlockLoop>)
{ {
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize()); p_a_grid, a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize()); p_b_grid, b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetElementSpaceSize()); p_c_grid, c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetElementSpaceSize());
const auto GK0 = a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I0); const auto GK0 = a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I0);
...@@ -383,7 +383,7 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN ...@@ -383,7 +383,7 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1< auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
BlockSize, BlockSize,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
Sequence<GK0PerBlock, GM0, 1, GM1PerBlockGM11, GK1.value>, Sequence<GK0PerBlock, GM0, 1, GM1PerBlockGM11, GK1.value>,
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1, ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1, ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
...@@ -407,7 +407,7 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN ...@@ -407,7 +407,7 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN
// B matrix blockwise copy // B matrix blockwise copy
auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1< auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
BlockSize, BlockSize,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
Sequence<GK0PerBlock, GN0, 1, GN1PerBlockGN11, GK1.value>, Sequence<GK0PerBlock, GN0, 1, GN1PerBlockGN11, GK1.value>,
BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1, BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1, BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1,
...@@ -467,7 +467,7 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN ...@@ -467,7 +467,7 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size; FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size;
// register allocation for output // register allocation for output
auto c_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAcc>( auto c_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAcc>(
c_thread_desc_bm0_bm1_bn0_bn1.GetElementSpaceSize()); c_thread_desc_bm0_bm1_bn0_bn1.GetElementSpaceSize());
ThreadwiseTensorSliceSet_v1<FloatAcc, ThreadwiseTensorSliceSet_v1<FloatAcc,
...@@ -481,15 +481,15 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN ...@@ -481,15 +481,15 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN
constexpr auto a_block_slice_copy_step = make_multi_index(GK0PerBlock, 0, 0, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(GK0PerBlock, 0, 0, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(GK0PerBlock, 0, 0, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(GK0PerBlock, 0, 0, 0, 0);
auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>( auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_a_block_double, a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize()); p_a_block_double, a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize());
auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>( auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_b_block_double, b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize()); p_b_block_double, b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize());
auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>( auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_a_block_double + a_block_aligned_space_size, p_a_block_double + a_block_aligned_space_size,
a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize()); a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize());
auto b_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>( auto b_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_b_block_double + b_block_aligned_space_size, p_b_block_double + b_block_aligned_space_size,
b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize()); b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize());
......
...@@ -55,7 +55,7 @@ template <index_t BlockSize, ...@@ -55,7 +55,7 @@ template <index_t BlockSize,
typename FloatAB, typename FloatAB,
typename FloatAcc, typename FloatAcc,
typename FloatC, typename FloatC,
InMemoryDataOperationEnum_t CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AKMGridDesc, typename AKMGridDesc,
typename BKNGridDesc, typename BKNGridDesc,
typename CMNGridDesc, typename CMNGridDesc,
...@@ -268,11 +268,11 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2 ...@@ -268,11 +268,11 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) integral_constant<bool, HasDoubleTailKBlockLoop>)
{ {
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_k_m0_m1_grid_desc.GetElementSpaceSize()); p_a_grid, a_k_m0_m1_grid_desc.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_k_n0_n1_grid_desc.GetElementSpaceSize()); p_b_grid, b_k_n0_n1_grid_desc.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_m0_m10_m11_n0_n10_n11_grid_desc.GetElementSpaceSize()); p_c_grid, c_m0_m10_m11_n0_n10_n11_grid_desc.GetElementSpaceSize());
const auto K = a_k_m0_m1_grid_desc.GetLength(I0); const auto K = a_k_m0_m1_grid_desc.GetLength(I0);
...@@ -315,7 +315,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2 ...@@ -315,7 +315,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = auto a_blockwise_copy =
BlockwiseTensorSliceTransfer_v4<BlockSize, BlockwiseTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
Sequence<KPerBlock, 1, MPerBlockM1>, Sequence<KPerBlock, 1, MPerBlockM1>,
ABlockTransferThreadSliceLengths_K_M0_M1, ABlockTransferThreadSliceLengths_K_M0_M1,
ABlockTransferThreadClusterLengths_K_M0_M1, ABlockTransferThreadClusterLengths_K_M0_M1,
...@@ -341,7 +341,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2 ...@@ -341,7 +341,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2
// B matrix blockwise copy // B matrix blockwise copy
auto b_blockwise_copy = auto b_blockwise_copy =
BlockwiseTensorSliceTransfer_v4<BlockSize, BlockwiseTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
Sequence<KPerBlock, 1, NPerBlockN1>, Sequence<KPerBlock, 1, NPerBlockN1>,
BBlockTransferThreadSliceLengths_K_N0_N1, BBlockTransferThreadSliceLengths_K_N0_N1,
BBlockTransferThreadClusterLengths_K_N0_N1, BBlockTransferThreadClusterLengths_K_N0_N1,
...@@ -403,7 +403,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2 ...@@ -403,7 +403,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size; FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size;
// register allocation for output // register allocation for output
auto c_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAcc>( auto c_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAcc>(
c_m10_m11_n10_n11_thread_desc.GetElementSpaceSize()); c_m10_m11_n10_n11_thread_desc.GetElementSpaceSize());
ThreadwiseTensorSliceSet_v1<FloatAcc, ThreadwiseTensorSliceSet_v1<FloatAcc,
...@@ -428,15 +428,15 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2 ...@@ -428,15 +428,15 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2
constexpr auto b_k_n0_n1_global_move_slice_window_step_hack = constexpr auto b_k_n0_n1_global_move_slice_window_step_hack =
BGridMoveSliceWindowStepHacks{}; BGridMoveSliceWindowStepHacks{};
auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>( auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_a_block_double, a_k_m0_m1_block_desc.GetElementSpaceSize()); p_a_block_double, a_k_m0_m1_block_desc.GetElementSpaceSize());
auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>( auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_b_block_double, b_k_n0_n1_block_desc.GetElementSpaceSize()); p_b_block_double, b_k_n0_n1_block_desc.GetElementSpaceSize());
auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>( auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_a_block_double + a_block_aligned_space_size, p_a_block_double + a_block_aligned_space_size,
a_k_m0_m1_block_desc.GetElementSpaceSize()); a_k_m0_m1_block_desc.GetElementSpaceSize());
auto b_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>( auto b_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_b_block_double + b_block_aligned_space_size, p_b_block_double + b_block_aligned_space_size,
b_k_n0_n1_block_desc.GetElementSpaceSize()); b_k_n0_n1_block_desc.GetElementSpaceSize());
......
...@@ -55,7 +55,7 @@ template <index_t BlockSize, ...@@ -55,7 +55,7 @@ template <index_t BlockSize,
typename FloatAB, typename FloatAB,
typename FloatAcc, typename FloatAcc,
typename FloatC, typename FloatC,
InMemoryDataOperationEnum_t CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AK0MK1GridDesc, typename AK0MK1GridDesc,
typename BK0NK1GridDesc, typename BK0NK1GridDesc,
typename CMNGridDesc, typename CMNGridDesc,
...@@ -275,11 +275,11 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 ...@@ -275,11 +275,11 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) integral_constant<bool, HasDoubleTailKBlockLoop>)
{ {
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_k0_m0_m1_k1_grid_desc.GetElementSpaceSize()); p_a_grid, a_k0_m0_m1_k1_grid_desc.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_k0_n0_n1_k1_grid_desc.GetElementSpaceSize()); p_b_grid, b_k0_n0_n1_k1_grid_desc.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_m0_m10_m11_n0_n10_n11_grid_desc.GetElementSpaceSize()); p_c_grid, c_m0_m10_m11_n0_n10_n11_grid_desc.GetElementSpaceSize());
// divide block work by [M, N] // divide block work by [M, N]
...@@ -325,7 +325,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 ...@@ -325,7 +325,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1< auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
BlockSize, BlockSize,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
Sequence<KPerBlock, 1, MPerBlockM1, K1.value>, Sequence<KPerBlock, 1, MPerBlockM1, K1.value>,
ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
...@@ -349,7 +349,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 ...@@ -349,7 +349,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
// B matrix blockwise copy // B matrix blockwise copy
auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1< auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
BlockSize, BlockSize,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
Sequence<KPerBlock, 1, NPerBlockN1, K1.value>, Sequence<KPerBlock, 1, NPerBlockN1, K1.value>,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
...@@ -409,7 +409,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 ...@@ -409,7 +409,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size; FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size;
// register allocation for output // register allocation for output
auto c_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAcc>( auto c_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAcc>(
c_m10_m11_n10_n11_thread_desc.GetElementSpaceSize()); c_m10_m11_n10_n11_thread_desc.GetElementSpaceSize());
ThreadwiseTensorSliceSet_v1<FloatAcc, ThreadwiseTensorSliceSet_v1<FloatAcc,
...@@ -423,15 +423,15 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 ...@@ -423,15 +423,15 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0);
auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>( auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_a_block_double, a_k0_m0_m1_k1_block_desc.GetElementSpaceSize()); p_a_block_double, a_k0_m0_m1_k1_block_desc.GetElementSpaceSize());
auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>( auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_b_block_double, b_k0_n0_n1_k1_block_desc.GetElementSpaceSize()); p_b_block_double, b_k0_n0_n1_k1_block_desc.GetElementSpaceSize());
auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>( auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_a_block_double + a_block_aligned_space_size, p_a_block_double + a_block_aligned_space_size,
a_k0_m0_m1_k1_block_desc.GetElementSpaceSize()); a_k0_m0_m1_k1_block_desc.GetElementSpaceSize());
auto b_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>( auto b_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_b_block_double + b_block_aligned_space_size, p_b_block_double + b_block_aligned_space_size,
b_k0_n0_n1_k1_block_desc.GetElementSpaceSize()); b_k0_n0_n1_k1_block_desc.GetElementSpaceSize());
......
...@@ -15,7 +15,7 @@ template <index_t BlockSize, ...@@ -15,7 +15,7 @@ template <index_t BlockSize,
typename FloatAB, typename FloatAB,
typename FloatAcc, typename FloatAcc,
typename FloatC, typename FloatC,
InMemoryDataOperationEnum_t CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AGlobalDesc, typename AGlobalDesc,
typename BGlobalDesc, typename BGlobalDesc,
typename CGlobalDesc, typename CGlobalDesc,
...@@ -84,11 +84,11 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -84,11 +84,11 @@ struct GridwiseGemmDlops_km_kn_mn_v3
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_global, a_e_k_global_desc.GetElementSpaceSize()); p_a_global, a_e_k_global_desc.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_global, b_e_n_ho_wo_global_desc.GetElementSpaceSize()); p_b_global, b_e_n_ho_wo_global_desc.GetElementSpaceSize());
auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_global, c_k_n_ho_wo_global_desc.GetElementSpaceSize()); p_c_global, c_k_n_ho_wo_global_desc.GetElementSpaceSize());
constexpr auto E = EPerBlock * 3 * 3; constexpr auto E = EPerBlock * 3 * 3;
...@@ -181,7 +181,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -181,7 +181,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = auto a_blockwise_copy =
BlockwiseTensorSliceTransfer_v4<BlockSize, BlockwiseTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
Sequence<E, KPerBlock>, Sequence<E, KPerBlock>,
ABlockTransferThreadSliceLengths_E_K, ABlockTransferThreadSliceLengths_E_K,
ABlockTransferThreadClusterLengths_E_K, ABlockTransferThreadClusterLengths_E_K,
...@@ -221,11 +221,11 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -221,11 +221,11 @@ struct GridwiseGemmDlops_km_kn_mn_v3
b_e_n_ho_wo_global_desc, b_e_n_ho_wo_global_desc,
make_multi_index(0, 0, ho_thread_data_on_global, wo_thread_data_on_global)); make_multi_index(0, 0, ho_thread_data_on_global, wo_thread_data_on_global));
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>( auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_shared_block, a_e_k_desc.GetElementSpaceSize()); p_shared_block, a_e_k_desc.GetElementSpaceSize());
// register allocation for output // register allocation for output
StaticBuffer<AddressSpaceEnum_t::Vgpr, StaticBuffer<AddressSpaceEnum::Vgpr,
FloatAcc, FloatAcc,
c_k_n_ho_wo_thread_desc.GetElementSpaceSize(), c_k_n_ho_wo_thread_desc.GetElementSpaceSize(),
true> true>
...@@ -250,7 +250,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -250,7 +250,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
BGlobalMoveSliceWindowStepHacks{}; BGlobalMoveSliceWindowStepHacks{};
// double regsiter buffer for b // double regsiter buffer for b
StaticBuffer<AddressSpaceEnum_t::Vgpr, StaticBuffer<AddressSpaceEnum::Vgpr,
FloatAB, FloatAB,
b_e_n_ho_wo_thread_desc.GetElementSpaceSize(), b_e_n_ho_wo_thread_desc.GetElementSpaceSize(),
true> true>
......
...@@ -20,7 +20,7 @@ template <typename GridwiseGemm, ...@@ -20,7 +20,7 @@ template <typename GridwiseGemm,
typename CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2, typename CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2,
typename CBlockIdToBlockClusterAdaptor_K_N_H_W, typename CBlockIdToBlockClusterAdaptor_K_N_H_W,
bool HasMainE0BlockLoop, bool HasMainE0BlockLoop,
ActivTypeEnum_t ActivType> ActivTypeEnum ActivType>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
...@@ -50,7 +50,7 @@ __global__ void ...@@ -50,7 +50,7 @@ __global__ void
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
cblockid_to_k_n_h_w_block_cluster_adaptor, cblockid_to_k_n_h_w_block_cluster_adaptor,
integral_constant<bool, HasMainE0BlockLoop>{}, integral_constant<bool, HasMainE0BlockLoop>{},
integral_constant<ActivTypeEnum_t, ActivType>{}); integral_constant<ActivTypeEnum, ActivType>{});
} }
template <typename GridwiseGemm, template <typename GridwiseGemm,
...@@ -62,7 +62,7 @@ template <typename GridwiseGemm, ...@@ -62,7 +62,7 @@ template <typename GridwiseGemm,
typename DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx, typename DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx,
typename CBlockIdToBlockClusterAdaptor_K_N_H_W, typename CBlockIdToBlockClusterAdaptor_K_N_H_W,
bool HasMainE0BlockLoop, bool HasMainE0BlockLoop,
ActivTypeEnum_t ActivType> ActivTypeEnum ActivType>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
...@@ -94,7 +94,7 @@ __global__ void ...@@ -94,7 +94,7 @@ __global__ void
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc, d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
cblockid_to_k_n_h_w_block_cluster_adaptor, cblockid_to_k_n_h_w_block_cluster_adaptor,
integral_constant<bool, HasMainE0BlockLoop>{}, integral_constant<bool, HasMainE0BlockLoop>{},
integral_constant<ActivTypeEnum_t, ActivType>{}); integral_constant<ActivTypeEnum, ActivType>{});
} }
template <typename GridwiseGemm, template <typename GridwiseGemm,
...@@ -106,7 +106,7 @@ template <typename GridwiseGemm, ...@@ -106,7 +106,7 @@ template <typename GridwiseGemm,
typename DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx, typename DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx,
typename CBlockIdToBlockClusterAdaptor_K_N_H_W, typename CBlockIdToBlockClusterAdaptor_K_N_H_W,
bool HasMainE0BlockLoop, bool HasMainE0BlockLoop,
ActivTypeEnum_t ActivType> ActivTypeEnum ActivType>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
...@@ -140,14 +140,14 @@ __global__ void ...@@ -140,14 +140,14 @@ __global__ void
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc, d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
cblockid_to_k_n_h_w_block_cluster_adaptor, cblockid_to_k_n_h_w_block_cluster_adaptor,
integral_constant<bool, HasMainE0BlockLoop>{}, integral_constant<bool, HasMainE0BlockLoop>{},
integral_constant<ActivTypeEnum_t, ActivType>{}); integral_constant<ActivTypeEnum, ActivType>{});
} }
template <index_t BlockSize, template <index_t BlockSize,
typename FloatAB, typename FloatAB,
typename FloatAcc, typename FloatAcc,
typename FloatC, typename FloatC,
InMemoryDataOperationEnum_t CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AGridDesc_E0_E1_K_E2, typename AGridDesc_E0_E1_K_E2,
typename BGridDesc_E0_E1_N_Ho_Wo_E2, typename BGridDesc_E0_E1_N_Ho_Wo_E2,
typename CGridDesc_K_N_Ho_Wo, typename CGridDesc_K_N_Ho_Wo,
...@@ -559,7 +559,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -559,7 +559,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
constexpr auto bias_k0_k1_thread_desc = constexpr auto bias_k0_k1_thread_desc =
make_naive_tensor_descriptor_packed(make_tuple(I1, Number<KPerThread>{})); make_naive_tensor_descriptor_packed(make_tuple(I1, Number<KPerThread>{}));
StaticBuffer<AddressSpaceEnum_t::Vgpr, StaticBuffer<AddressSpaceEnum::Vgpr,
FloatC, FloatC,
bias_k0_k1_thread_desc.GetElementSpaceSize(), bias_k0_k1_thread_desc.GetElementSpaceSize(),
true> true>
...@@ -602,10 +602,10 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -602,10 +602,10 @@ struct GridwiseGemmDlops_km_kn_mn_v3
}); });
} }
template <typename CThreadBuff, typename CThreadDesc_K1_N_H2_W2, ActivTypeEnum_t activ_type_> template <typename CThreadBuff, typename CThreadDesc_K1_N_H2_W2, ActivTypeEnum activ_type_>
__device__ static void Activation(CThreadBuff& c_thread_buf, __device__ static void Activation(CThreadBuff& c_thread_buf,
const CThreadDesc_K1_N_H2_W2&, const CThreadDesc_K1_N_H2_W2&,
integral_constant<ActivTypeEnum_t, activ_type_>) integral_constant<ActivTypeEnum, activ_type_>)
{ {
constexpr auto c_k1_n_h2_w2_thread_gemm_desc = CThreadDesc_K1_N_H2_W2{}; constexpr auto c_k1_n_h2_w2_thread_gemm_desc = CThreadDesc_K1_N_H2_W2{};
...@@ -737,7 +737,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -737,7 +737,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
I1, I1,
Number<WoPerThread_2>{})); Number<WoPerThread_2>{}));
StaticBuffer<AddressSpaceEnum_t::Vgpr, StaticBuffer<AddressSpaceEnum::Vgpr,
FloatC, FloatC,
d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc.GetElementSpaceSize(), d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc.GetElementSpaceSize(),
true> true>
...@@ -783,7 +783,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -783,7 +783,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector, CThreadTransferDstScalarPerVector,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
1, 1,
true>(d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc, true>(d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
make_multi_index(k_block_work_id, make_multi_index(k_block_work_id,
...@@ -843,7 +843,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -843,7 +843,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
I1, I1,
Number<WoPerThreadx2>{})); Number<WoPerThreadx2>{}));
StaticBuffer<AddressSpaceEnum_t::Vgpr, StaticBuffer<AddressSpaceEnum::Vgpr,
FloatC, FloatC,
d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc.GetElementSpaceSize(), d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc.GetElementSpaceSize(),
true> true>
...@@ -874,7 +874,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -874,7 +874,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector, CThreadTransferDstScalarPerVector,
InMemoryDataOperationEnum_t::Add, InMemoryDataOperationEnum::Add,
1, 1,
true>(d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc, true>(d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
make_multi_index(k_block_work_id, make_multi_index(k_block_work_id,
...@@ -964,7 +964,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -964,7 +964,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = auto a_blockwise_copy =
BlockwiseTensorSliceTransfer_v4<BlockSize, BlockwiseTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
Sequence<I1, E1, I1, KPerBlock, E2>, Sequence<I1, E1, I1, KPerBlock, E2>,
ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2, ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2,
ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2, ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2,
...@@ -1023,11 +1023,11 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -1023,11 +1023,11 @@ struct GridwiseGemmDlops_km_kn_mn_v3
0, 0,
0)); 0));
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>( auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_shared_block, a_e0_e1_k0_k1_e2_block_copy_desc.GetElementSpaceSize()); p_shared_block, a_e0_e1_k0_k1_e2_block_copy_desc.GetElementSpaceSize());
//// register allocation for output //// register allocation for output
// StaticBuffer<AddressSpaceEnum_t::Vgpr, // StaticBuffer<AddressSpaceEnum::Vgpr,
// FloatAcc, // FloatAcc,
// c_k1_n_h2_w2_thread_gemm_desc.GetElementSpaceSize(), // c_k1_n_h2_w2_thread_gemm_desc.GetElementSpaceSize(),
// true> // true>
...@@ -1050,7 +1050,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -1050,7 +1050,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
constexpr auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks = BGlobalStepHacks{}; constexpr auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks = BGlobalStepHacks{};
// double regsiter buffer for b // double regsiter buffer for b
StaticBuffer<AddressSpaceEnum_t::Vgpr, StaticBuffer<AddressSpaceEnum::Vgpr,
FloatAB, FloatAB,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc.GetElementSpaceSize(), b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc.GetElementSpaceSize(),
true> true>
...@@ -1294,21 +1294,21 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -1294,21 +1294,21 @@ struct GridwiseGemmDlops_km_kn_mn_v3
const auto bias_k0_k1_grid_desc = const auto bias_k0_k1_grid_desc =
MakeBiasK0K1GridDescriptor(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc); MakeBiasK0K1GridDescriptor(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc);
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_global, a_e0_e1_k0_k1_e2_grid_desc.GetElementSpaceSize()); p_a_global, a_e0_e1_k0_k1_e2_grid_desc.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_global, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetElementSpaceSize()); p_b_global, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetElementSpaceSize());
auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_global, c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc.GetElementSpaceSize()); p_c_global, c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc.GetElementSpaceSize());
auto d_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto d_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d_global, d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc.GetElementSpaceSize()); p_d_global, d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc.GetElementSpaceSize());
auto bias_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto bias_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_bias_global, bias_k0_k1_grid_desc.GetElementSpaceSize()); p_bias_global, bias_k0_k1_grid_desc.GetElementSpaceSize());
constexpr auto c_k1_n_h2_w2_thread_gemm_desc = MakeCK1NH2W2ThreadDescriptor(); constexpr auto c_k1_n_h2_w2_thread_gemm_desc = MakeCK1NH2W2ThreadDescriptor();
// register allocation for output // register allocation for output
StaticBuffer<AddressSpaceEnum_t::Vgpr, StaticBuffer<AddressSpaceEnum::Vgpr,
FloatAcc, FloatAcc,
c_k1_n_h2_w2_thread_gemm_desc.GetElementSpaceSize(), c_k1_n_h2_w2_thread_gemm_desc.GetElementSpaceSize(),
true> true>
...@@ -1344,7 +1344,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -1344,7 +1344,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
typename CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2, typename CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2,
typename CBlockIdToBlockClusterAdaptor_K_N_H_W, typename CBlockIdToBlockClusterAdaptor_K_N_H_W,
bool HasMainE0BlockLoop, bool HasMainE0BlockLoop,
ActivTypeEnum_t ActivType> ActivTypeEnum ActivType>
__device__ static void ConvBiasActiv( __device__ static void ConvBiasActiv(
const FloatAB* __restrict__ p_a_global, const FloatAB* __restrict__ p_a_global,
const FloatAB* __restrict__ p_b_global, const FloatAB* __restrict__ p_b_global,
...@@ -1356,26 +1356,26 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -1356,26 +1356,26 @@ struct GridwiseGemmDlops_km_kn_mn_v3
const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2& c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2& c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
const CBlockIdToBlockClusterAdaptor_K_N_H_W& cblockid_to_k_n_h_w_block_cluster_adaptor, const CBlockIdToBlockClusterAdaptor_K_N_H_W& cblockid_to_k_n_h_w_block_cluster_adaptor,
integral_constant<bool, HasMainE0BlockLoop>, integral_constant<bool, HasMainE0BlockLoop>,
integral_constant<ActivTypeEnum_t, ActivType>) integral_constant<ActivTypeEnum, ActivType>)
{ {
static constexpr auto activ_type = integral_constant<ActivTypeEnum_t, ActivType>{}; static constexpr auto activ_type = integral_constant<ActivTypeEnum, ActivType>{};
const auto bias_k0_k1_grid_desc = const auto bias_k0_k1_grid_desc =
MakeBiasK0K1GridDescriptor(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc); MakeBiasK0K1GridDescriptor(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc);
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_global, a_e0_e1_k0_k1_e2_grid_desc.GetElementSpaceSize()); p_a_global, a_e0_e1_k0_k1_e2_grid_desc.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_global, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetElementSpaceSize()); p_b_global, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetElementSpaceSize());
auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_global, c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc.GetElementSpaceSize()); p_c_global, c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc.GetElementSpaceSize());
auto bias_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto bias_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_bias_global, bias_k0_k1_grid_desc.GetElementSpaceSize()); p_bias_global, bias_k0_k1_grid_desc.GetElementSpaceSize());
constexpr auto c_k1_n_h2_w2_thread_gemm_desc = MakeCK1NH2W2ThreadDescriptor(); constexpr auto c_k1_n_h2_w2_thread_gemm_desc = MakeCK1NH2W2ThreadDescriptor();
// register allocation for output // register allocation for output
StaticBuffer<AddressSpaceEnum_t::Vgpr, StaticBuffer<AddressSpaceEnum::Vgpr,
FloatAcc, FloatAcc,
c_k1_n_h2_w2_thread_gemm_desc.GetElementSpaceSize(), c_k1_n_h2_w2_thread_gemm_desc.GetElementSpaceSize(),
true> true>
...@@ -1423,7 +1423,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -1423,7 +1423,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
typename DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx, typename DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx,
typename CBlockIdToBlockClusterAdaptor_K_N_H_W, typename CBlockIdToBlockClusterAdaptor_K_N_H_W,
bool HasMainE0BlockLoop, bool HasMainE0BlockLoop,
ActivTypeEnum_t ActivType> ActivTypeEnum ActivType>
__device__ static void ConvBiasActivMaxpool( __device__ static void ConvBiasActivMaxpool(
const FloatAB* __restrict__ p_a_global, const FloatAB* __restrict__ p_a_global,
const FloatAB* __restrict__ p_b_global, const FloatAB* __restrict__ p_b_global,
...@@ -1437,28 +1437,28 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -1437,28 +1437,28 @@ struct GridwiseGemmDlops_km_kn_mn_v3
const DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx& d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc, const DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx& d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
const CBlockIdToBlockClusterAdaptor_K_N_H_W& cblockid_to_k_n_h_w_block_cluster_adaptor, const CBlockIdToBlockClusterAdaptor_K_N_H_W& cblockid_to_k_n_h_w_block_cluster_adaptor,
integral_constant<bool, HasMainE0BlockLoop>, integral_constant<bool, HasMainE0BlockLoop>,
integral_constant<ActivTypeEnum_t, ActivType>) integral_constant<ActivTypeEnum, ActivType>)
{ {
static constexpr auto activ_type = integral_constant<ActivTypeEnum_t, ActivType>{}; static constexpr auto activ_type = integral_constant<ActivTypeEnum, ActivType>{};
const auto bias_k0_k1_grid_desc = const auto bias_k0_k1_grid_desc =
MakeBiasK0K1GridDescriptor(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc); MakeBiasK0K1GridDescriptor(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc);
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_global, a_e0_e1_k0_k1_e2_grid_desc.GetElementSpaceSize()); p_a_global, a_e0_e1_k0_k1_e2_grid_desc.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_global, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetElementSpaceSize()); p_b_global, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetElementSpaceSize());
auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_global, c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc.GetElementSpaceSize()); p_c_global, c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc.GetElementSpaceSize());
auto d_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto d_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d_global, d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc.GetElementSpaceSize()); p_d_global, d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc.GetElementSpaceSize());
auto bias_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto bias_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_bias_global, bias_k0_k1_grid_desc.GetElementSpaceSize()); p_bias_global, bias_k0_k1_grid_desc.GetElementSpaceSize());
constexpr auto c_k1_n_h2_w2_thread_gemm_desc = MakeCK1NH2W2ThreadDescriptor(); constexpr auto c_k1_n_h2_w2_thread_gemm_desc = MakeCK1NH2W2ThreadDescriptor();
// register allocation for output // register allocation for output
StaticBuffer<AddressSpaceEnum_t::Vgpr, StaticBuffer<AddressSpaceEnum::Vgpr,
FloatAcc, FloatAcc,
c_k1_n_h2_w2_thread_gemm_desc.GetElementSpaceSize(), c_k1_n_h2_w2_thread_gemm_desc.GetElementSpaceSize(),
true> true>
...@@ -1514,7 +1514,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -1514,7 +1514,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
typename DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx, typename DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx,
typename CBlockIdToBlockClusterAdaptor_K_N_H_W, typename CBlockIdToBlockClusterAdaptor_K_N_H_W,
bool HasMainE0BlockLoop, bool HasMainE0BlockLoop,
ActivTypeEnum_t ActivType> ActivTypeEnum ActivType>
__device__ static void ConvBiasActivResizeAdd( __device__ static void ConvBiasActivResizeAdd(
const FloatAB* __restrict__ p_a_global, const FloatAB* __restrict__ p_a_global,
const FloatAB* __restrict__ p_b_global, const FloatAB* __restrict__ p_b_global,
...@@ -1527,26 +1527,26 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -1527,26 +1527,26 @@ struct GridwiseGemmDlops_km_kn_mn_v3
const DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx& d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc, const DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx& d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
const CBlockIdToBlockClusterAdaptor_K_N_H_W& cblockid_to_k_n_h_w_block_cluster_adaptor, const CBlockIdToBlockClusterAdaptor_K_N_H_W& cblockid_to_k_n_h_w_block_cluster_adaptor,
integral_constant<bool, HasMainE0BlockLoop>, integral_constant<bool, HasMainE0BlockLoop>,
integral_constant<ActivTypeEnum_t, ActivType>) integral_constant<ActivTypeEnum, ActivType>)
{ {
static constexpr auto activ_type = integral_constant<ActivTypeEnum_t, ActivType>{}; static constexpr auto activ_type = integral_constant<ActivTypeEnum, ActivType>{};
const auto bias_k0_k1_grid_desc = const auto bias_k0_k1_grid_desc =
MakeBiasK0K1GridDescriptor(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc); MakeBiasK0K1GridDescriptor(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc);
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_global, a_e0_e1_k0_k1_e2_grid_desc.GetElementSpaceSize()); p_a_global, a_e0_e1_k0_k1_e2_grid_desc.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_global, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetElementSpaceSize()); p_b_global, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetElementSpaceSize());
auto d_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto d_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d_global, d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc.GetElementSpaceSize()); p_d_global, d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc.GetElementSpaceSize());
auto bias_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto bias_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_bias_global, bias_k0_k1_grid_desc.GetElementSpaceSize()); p_bias_global, bias_k0_k1_grid_desc.GetElementSpaceSize());
constexpr auto c_k1_n_h2_w2_thread_gemm_desc = MakeCK1NH2W2ThreadDescriptor(); constexpr auto c_k1_n_h2_w2_thread_gemm_desc = MakeCK1NH2W2ThreadDescriptor();
// register allocation for output // register allocation for output
StaticBuffer<AddressSpaceEnum_t::Vgpr, StaticBuffer<AddressSpaceEnum::Vgpr,
FloatAcc, FloatAcc,
c_k1_n_h2_w2_thread_gemm_desc.GetElementSpaceSize(), c_k1_n_h2_w2_thread_gemm_desc.GetElementSpaceSize(),
true> true>
......
...@@ -79,8 +79,8 @@ template <typename FloatAB, ...@@ -79,8 +79,8 @@ template <typename FloatAB,
typename CElementwiseOperation, typename CElementwiseOperation,
typename D0ReduceOperation, typename D0ReduceOperation,
typename D1ReduceOperation, typename D1ReduceOperation,
InMemoryDataOperationEnum_t CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
InMemoryDataOperationEnum_t DGlobalMemoryDataOperation, InMemoryDataOperationEnum DGlobalMemoryDataOperation,
typename AGridDesc_AK0_M_AK1, typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1, typename BGridDesc_BK0_N_BK1,
typename CGridDesc_M_N, typename CGridDesc_M_N,
...@@ -363,15 +363,15 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -363,15 +363,15 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
const DGridDescriptor_MBlock_MPerBlock& d_grid_desc_mblock_mperblock, const DGridDescriptor_MBlock_MPerBlock& d_grid_desc_mblock_mperblock,
const Block2CTileMap& block_2_ctile_map) const Block2CTileMap& block_2_ctile_map)
{ {
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
auto d0_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto d0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d0_grid, d_grid_desc_mblock_mperblock.GetElementSpaceSize()); p_d0_grid, d_grid_desc_mblock_mperblock.GetElementSpaceSize());
auto d1_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto d1_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d1_grid, d_grid_desc_mblock_mperblock.GetElementSpaceSize()); p_d1_grid, d_grid_desc_mblock_mperblock.GetElementSpaceSize());
// divide block work by [M, N] // divide block work by [M, N]
...@@ -399,7 +399,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -399,7 +399,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
BlockwiseTensorSliceTransfer_v4r1<BlockSize, BlockwiseTensorSliceTransfer_v4r1<BlockSize,
AElementwiseOperation, AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
Sequence<AK0, MPerBlock, AK1>, Sequence<AK0, MPerBlock, AK1>,
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
...@@ -430,7 +430,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -430,7 +430,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
BlockwiseTensorSliceTransfer_v4r1<BlockSize, BlockwiseTensorSliceTransfer_v4r1<BlockSize,
BElementwiseOperation, BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
Sequence<BK0, NPerBlock, BK1>, Sequence<BK0, NPerBlock, BK1>,
BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
...@@ -484,10 +484,10 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -484,10 +484,10 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
constexpr auto a_block_space_size_aligned = math::integer_least_multiple( constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>( auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); static_cast<FloatAB*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>( auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned, static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned,
b_block_desc_bk0_n_bk1.GetElementSpaceSize()); b_block_desc_bk0_n_bk1.GetElementSpaceSize());
...@@ -563,7 +563,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -563,7 +563,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>( auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatCShuffle*>(p_shared), static_cast<FloatCShuffle*>(p_shared),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
...@@ -632,7 +632,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -632,7 +632,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
Sequence<0, 1, 2, 3, 4, 5, 6, 7>, Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7, 7,
1, 1,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
1, 1,
true>{ true>{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
...@@ -723,13 +723,13 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -723,13 +723,13 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
make_naive_tensor_descriptor_packed(make_tuple(I1, Number<mreduce_per_thread>{})); make_naive_tensor_descriptor_packed(make_tuple(I1, Number<mreduce_per_thread>{}));
// TODO: this should be implemented as a blockwise reduction // TODO: this should be implemented as a blockwise reduction
auto c_reduce_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatCShuffle>( auto c_reduce_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatCShuffle>(
c_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize()); c_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize());
auto d0_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatCShuffle>( auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatCShuffle>(
d_reduce_thread_desc_mperblock.GetElementSpaceSize()); d_reduce_thread_desc_mperblock.GetElementSpaceSize());
auto d1_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatCShuffle>( auto d1_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatCShuffle>(
d_reduce_thread_desc_mperblock.GetElementSpaceSize()); d_reduce_thread_desc_mperblock.GetElementSpaceSize());
// reduce: threadwise copy from LDS to VGPR // reduce: threadwise copy from LDS to VGPR
......
...@@ -60,7 +60,7 @@ template <typename FloatAB, ...@@ -60,7 +60,7 @@ template <typename FloatAB,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
InMemoryDataOperationEnum_t CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AGridDesc_AK0_M_AK1, typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1, typename BGridDesc_BK0_N_BK1,
typename CGridDesc_M_N, typename CGridDesc_M_N,
...@@ -316,11 +316,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -316,11 +316,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2CTileMap& block_2_ctile_map) const Block2CTileMap& block_2_ctile_map)
{ {
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// divide block work by [M, N] // divide block work by [M, N]
...@@ -348,7 +348,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -348,7 +348,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
BlockwiseTensorSliceTransfer_v4r1<BlockSize, BlockwiseTensorSliceTransfer_v4r1<BlockSize,
AElementwiseOperation, AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
Sequence<AK0, MPerBlock, AK1>, Sequence<AK0, MPerBlock, AK1>,
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
...@@ -379,7 +379,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -379,7 +379,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
BlockwiseTensorSliceTransfer_v4r1<BlockSize, BlockwiseTensorSliceTransfer_v4r1<BlockSize,
BElementwiseOperation, BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
Sequence<BK0, NPerBlock, BK1>, Sequence<BK0, NPerBlock, BK1>,
BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
...@@ -433,10 +433,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -433,10 +433,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
constexpr auto a_block_space_size_aligned = math::integer_least_multiple( constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>( auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); static_cast<FloatAB*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>( auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned, static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned,
b_block_desc_bk0_n_bk1.GetElementSpaceSize()); b_block_desc_bk0_n_bk1.GetElementSpaceSize());
...@@ -512,7 +512,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -512,7 +512,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>( auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatCShuffle*>(p_shared), static_cast<FloatCShuffle*>(p_shared),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
...@@ -581,7 +581,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -581,7 +581,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
Sequence<0, 1, 2, 3, 4, 5, 6, 7>, Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7, 7,
1, 1,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
1, 1,
true>{ true>{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment