"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "cffaa589571db4abc4f9e6886176b87da31092e8"
Commit 489599ba authored by Jing Zhang's avatar Jing Zhang Committed by root
Browse files

add multiD support into gridwise and deviceOp

parent ad1597c4
......@@ -25,7 +25,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
using DeviceGemmV2Instance =
ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3<
ALayout, BLayout, CLayout,
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
ADataType, BDataType, ck::Tuple<>, CDataType, AccDataType, CShuffleDataType,
PassThrough, PassThrough, PassThrough, GemmDefault,
256,
224, 256,
......
......@@ -133,10 +133,12 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
#ifdef BUILD_INT4_EXAMPLE
static_cast<KernelADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<KernelBDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
{},
static_cast<KernelCDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
#else
static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
{},
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
#endif
M,
......@@ -144,6 +146,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
K,
StrideA,
StrideB,
{},
StrideC,
KBatch,
a_element_op,
......
......@@ -14,21 +14,26 @@ template <typename ALayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename DsDataType,
typename CDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct DeviceGemmV2 : public BaseOperator
{
static constexpr index_t NumDTensor = DsDataType::Size();
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_c,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
std::array<ck::index_t, NumDTensor> StrideDs,
ck::index_t StrideC,
ck::index_t KSplit,
AElementwiseOperation a_element_op,
......
......@@ -25,6 +25,7 @@ template <typename ALayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename DsDataType,
typename CDataType,
typename GemmAccDataType,
typename CShuffleDataType,
......@@ -69,11 +70,14 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
CLayout,
ADataType,
BDataType,
DsDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
{
static constexpr index_t NumDTensor = DsDataType::Size();
// GridwiseGemm
using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v3<
ALayout,
......@@ -83,6 +87,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
BDataType,
GemmAccDataType,
CShuffleDataType,
Tuple<>,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
......@@ -586,19 +591,35 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b,
std::array<const void*, NumDTensor> p_ds,
CDataType* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
std::array<index_t, NumDTensor> StrideDs,
index_t StrideC,
index_t KBatch,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation)
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC, KBatch};
return Argument{p_a,
p_b,
p_ds,
p_c,
M,
N,
K,
StrideA,
StrideB,
StrideDs,
StrideC,
KBatch,
a_element_op,
b_element_op,
c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
......@@ -606,28 +627,35 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
std::array<ck::index_t, NumDTensor> StrideDs,
index_t StrideC,
index_t KBatch,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation) override
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
p_ds,
static_cast<CDataType*>(p_c),
M,
N,
K,
StrideA,
StrideB,
StrideDs,
StrideC,
KBatch);
KBatch,
a_element_op,
b_element_op,
c_element_op);
}
// polymorphic
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment