Commit 489599ba authored by Jing Zhang's avatar Jing Zhang Committed by root
Browse files

add multiD support into gridwise and deviceOp

parent ad1597c4
...@@ -25,7 +25,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa ...@@ -25,7 +25,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
using DeviceGemmV2Instance = using DeviceGemmV2Instance =
ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3< ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3<
ALayout, BLayout, CLayout, ALayout, BLayout, CLayout,
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, ADataType, BDataType, ck::Tuple<>, CDataType, AccDataType, CShuffleDataType,
PassThrough, PassThrough, PassThrough, GemmDefault, PassThrough, PassThrough, PassThrough, GemmDefault,
256, 256,
224, 256, 224, 256,
......
...@@ -133,10 +133,12 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -133,10 +133,12 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
#ifdef BUILD_INT4_EXAMPLE #ifdef BUILD_INT4_EXAMPLE
static_cast<KernelADataType*>(a_m_k_device_buf.GetDeviceBuffer()), static_cast<KernelADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<KernelBDataType*>(b_k_n_device_buf.GetDeviceBuffer()), static_cast<KernelBDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
{},
static_cast<KernelCDataType*>(c_m_n_device_buf.GetDeviceBuffer()), static_cast<KernelCDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
#else #else
static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()), static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()), static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
{},
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()), static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
#endif #endif
M, M,
...@@ -144,6 +146,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -144,6 +146,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
K, K,
StrideA, StrideA,
StrideB, StrideB,
{},
StrideC, StrideC,
KBatch, KBatch,
a_element_op, a_element_op,
......
...@@ -14,21 +14,26 @@ template <typename ALayout, ...@@ -14,21 +14,26 @@ template <typename ALayout,
typename CLayout, typename CLayout,
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
typename DsDataType,
typename CDataType, typename CDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation> typename CElementwiseOperation>
struct DeviceGemmV2 : public BaseOperator struct DeviceGemmV2 : public BaseOperator
{ {
static constexpr index_t NumDTensor = DsDataType::Size();
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a, MakeArgumentPointer(const void* p_a,
const void* p_b, const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_c, void* p_c,
ck::index_t M, ck::index_t M,
ck::index_t N, ck::index_t N,
ck::index_t K, ck::index_t K,
ck::index_t StrideA, ck::index_t StrideA,
ck::index_t StrideB, ck::index_t StrideB,
std::array<ck::index_t, NumDTensor> StrideDs,
ck::index_t StrideC, ck::index_t StrideC,
ck::index_t KSplit, ck::index_t KSplit,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
......
...@@ -25,6 +25,7 @@ template <typename ALayout, ...@@ -25,6 +25,7 @@ template <typename ALayout,
typename CLayout, typename CLayout,
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
typename DsDataType,
typename CDataType, typename CDataType,
typename GemmAccDataType, typename GemmAccDataType,
typename CShuffleDataType, typename CShuffleDataType,
...@@ -69,11 +70,14 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -69,11 +70,14 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
CLayout, CLayout,
ADataType, ADataType,
BDataType, BDataType,
DsDataType,
CDataType, CDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation> CElementwiseOperation>
{ {
static constexpr index_t NumDTensor = DsDataType::Size();
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v3< using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v3<
ALayout, ALayout,
...@@ -83,6 +87,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -83,6 +87,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
BDataType, BDataType,
GemmAccDataType, GemmAccDataType,
CShuffleDataType, CShuffleDataType,
Tuple<>,
CDataType, CDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
...@@ -586,19 +591,35 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -586,19 +591,35 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
static auto MakeArgument(const ADataType* p_a, static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b, const BDataType* p_b,
std::array<const void*, NumDTensor> p_ds,
CDataType* p_c, CDataType* p_c,
index_t M, index_t M,
index_t N, index_t N,
index_t K, index_t K,
index_t StrideA, index_t StrideA,
index_t StrideB, index_t StrideB,
std::array<index_t, NumDTensor> StrideDs,
index_t StrideC, index_t StrideC,
index_t KBatch, index_t KBatch,
AElementwiseOperation, AElementwiseOperation a_element_op,
BElementwiseOperation, BElementwiseOperation b_element_op,
CElementwiseOperation) CElementwiseOperation c_element_op)
{ {
return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC, KBatch}; return Argument{p_a,
p_b,
p_ds,
p_c,
M,
N,
K,
StrideA,
StrideB,
StrideDs,
StrideC,
KBatch,
a_element_op,
b_element_op,
c_element_op};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
...@@ -606,28 +627,35 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -606,28 +627,35 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
// polymorphic // polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a, std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b, const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_c, void* p_c,
index_t M, index_t M,
index_t N, index_t N,
index_t K, index_t K,
index_t StrideA, index_t StrideA,
index_t StrideB, index_t StrideB,
std::array<ck::index_t, NumDTensor> StrideDs,
index_t StrideC, index_t StrideC,
index_t KBatch, index_t KBatch,
AElementwiseOperation, AElementwiseOperation a_element_op,
BElementwiseOperation, BElementwiseOperation b_element_op,
CElementwiseOperation) override CElementwiseOperation c_element_op) override
{ {
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_b),
p_ds,
static_cast<CDataType*>(p_c), static_cast<CDataType*>(p_c),
M, M,
N, N,
K, K,
StrideA, StrideA,
StrideB, StrideB,
StrideDs,
StrideC, StrideC,
KBatch); KBatch,
a_element_op,
b_element_op,
c_element_op);
} }
// polymorphic // polymorphic
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment