Commit 5f94555b authored by Anthony Chang's avatar Anthony Chang
Browse files

implement proper interface

parent 98e4c0ce
......@@ -49,23 +49,26 @@ using B1Layout = Row;
using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using B0ElementOp = PassThrough;
using B1ElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmGemm_Xdl_CShuffle<
using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmGemm_Xdl_CShuffle<
ALayout,
B0Layout,
B1Layout,
CLayout,
ADataType,
B0DataType,
B1DataType,
CDataType,
AccDataType,
CShuffleDataType,
AElementOp,
BElementOp,
B0ElementOp,
B1ElementOp,
CElementOp,
GemmDefault,
1,
......@@ -114,10 +117,10 @@ using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
ADataType,
AccDataType,
AElementOp,
BElementOp,
B0ElementOp,
CElementOp>;
using ReferenceGemm1Instance = ck::tensor_operation::host::
ReferenceBatchedGemm<ADataType, B1DataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
ReferenceBatchedGemm<ADataType, B1DataType, CDataType, AccDataType, AElementOp, B1ElementOp, CElementOp>;
int main(int argc, char* argv[])
{
......@@ -130,7 +133,7 @@ int main(int argc, char* argv[])
// ck::index_t N = 1024;
// ck::index_t K = 64;
// ck::index_t O = 64;
// ck::index_t BatchCount = 4;
// ck::index_t StrideA = 1024;
// ck::index_t StrideB0 = 1024;
// ck::index_t StrideB1 = 1024;
......@@ -140,11 +143,15 @@ int main(int argc, char* argv[])
ck::index_t N = 128;
ck::index_t K = 32;
ck::index_t O = 128;
ck::index_t BatchCount = 4;
ck::index_t StrideA = 32;
ck::index_t StrideB0 = 32;
ck::index_t StrideB1 = 128;
ck::index_t StrideC = 128;
ck::index_t BatchCount = 64;
ck::index_t BatchStrideA = -1;
ck::index_t BatchStrideB0 = -1;
ck::index_t BatchStrideB1 = -1;
ck::index_t BatchStrideC = -1;
if(argc == 1)
{
......@@ -156,7 +163,7 @@ int main(int argc, char* argv[])
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 13)
else if(argc == 17)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
......@@ -167,12 +174,17 @@ int main(int argc, char* argv[])
K = std::stoi(argv[6]);
O = std::stoi(argv[7]);
StrideA = std::stoi(argv[8]);
StrideB0 = std::stoi(argv[9]);
StrideB1 = std::stoi(argv[10]);
StrideC = std::stoi(argv[11]);
BatchCount = std::stoi(argv[8]);
StrideA = std::stoi(argv[9]);
StrideB0 = std::stoi(argv[10]);
StrideB1 = std::stoi(argv[11]);
StrideC = std::stoi(argv[12]);
BatchCount = std::stoi(argv[12]);
BatchStrideA = std::stoi(argv[13]);
BatchStrideB0 = std::stoi(argv[14]);
BatchStrideB1 = std::stoi(argv[15]);
BatchStrideC = std::stoi(argv[16]);
}
else
{
......@@ -183,29 +195,55 @@ int main(int argc, char* argv[])
exit(0);
}
const int DefaultStrideA = ck::is_same_v<ALayout, Row> ? K : M;
const int DefaultStrideB0 = ck::is_same_v<B0Layout, Row> ? N : K;
const int DefaultStrideB1 = ck::is_same_v<B1Layout, Row> ? O : N;
const int DefaultStrideC = ck::is_same_v<CLayout, Row> ? O : M;
StrideA = (StrideA < 0) ? DefaultStrideA : StrideA;
StrideB0 = (StrideB0 < 0) ? DefaultStrideB0 : StrideB0;
StrideB1 = (StrideB1 < 0) ? DefaultStrideB1 : StrideB1;
StrideC = (StrideC < 0) ? DefaultStrideC : StrideC;
const int DefaultBatchStrideA = (ck::is_same_v<ALayout, Row> ? K : M) * StrideA;
const int DefaultBatchStrideB0 = (ck::is_same_v<B0Layout, Row> ? N : K) * StrideB0;
const int DefaultBatchStrideB1 = (ck::is_same_v<B1Layout, Row> ? O : N) * StrideB1;
const int DefaultBatchStrideC = (ck::is_same_v<CLayout, Row> ? O : M) * StrideC;
BatchStrideA = BatchStrideA < 0 ? DefaultBatchStrideA : BatchStrideA;
BatchStrideB0 = BatchStrideB0 < 0 ? DefaultBatchStrideB0 : BatchStrideB0;
BatchStrideB1 = BatchStrideB1 < 0 ? DefaultBatchStrideB1 : BatchStrideB1;
BatchStrideC = BatchStrideC < 0 ? DefaultBatchStrideC : BatchStrideC;
auto f_host_tensor_descriptor = [](std::size_t batch_count,
std::size_t row,
std::size_t col,
std::size_t stride,
std::size_t batch_stride,
auto layout) {
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
if(std::is_same<decltype(layout), Row>::value)
{
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
std::vector<std::size_t>({row * stride, stride, 1}));
std::vector<std::size_t>({batch_stride, stride, 1}));
}
else
{
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
std::vector<std::size_t>({col * stride, 1, stride}));
std::vector<std::size_t>({batch_stride, 1, stride}));
}
};
// C_m_o = A_m_k * B0_k_n * B1_n_o
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(BatchCount, M, K, StrideA, ALayout{}));
Tensor<B0DataType> b0_k_n(f_host_tensor_descriptor(BatchCount, K, N, StrideB0, B0Layout{}));
Tensor<B1DataType> b1_n_o(f_host_tensor_descriptor(BatchCount, N, O, StrideB1, B1Layout{}));
Tensor<CDataType> c_m_o_host_result(f_host_tensor_descriptor(BatchCount, M, O, StrideC, CLayout{}));
Tensor<CDataType> c_m_o_device_result(f_host_tensor_descriptor(BatchCount, M, O, StrideC, CLayout{}));
Tensor<ADataType> a_m_k(
f_host_tensor_descriptor(BatchCount, M, K, StrideA, BatchStrideA, ALayout{}));
Tensor<B0DataType> b0_k_n(
f_host_tensor_descriptor(BatchCount, K, N, StrideB0, BatchStrideB0, B0Layout{}));
Tensor<B1DataType> b1_n_o(
f_host_tensor_descriptor(BatchCount, N, O, StrideB1, BatchStrideB1, B1Layout{}));
Tensor<CDataType> c_m_o_host_result(
f_host_tensor_descriptor(BatchCount, M, O, StrideC, BatchStrideC, CLayout{}));
Tensor<CDataType> c_m_o_device_result(
f_host_tensor_descriptor(BatchCount, M, O, StrideC, BatchStrideC, CLayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl;
......@@ -241,7 +279,8 @@ int main(int argc, char* argv[])
b1_n_o_device_buf.ToDevice(b1_n_o.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto b0_element_op = B0ElementOp{};
auto b1_element_op = B1ElementOp{};
auto c_element_op = CElementOp{};
// do GEMM
......@@ -255,14 +294,19 @@ int main(int argc, char* argv[])
N,
K,
O,
BatchCount,
StrideA,
StrideB0,
StrideB1,
StrideC,
BatchStrideA,
BatchStrideB0,
BatchStrideB1,
BatchStrideC,
a_element_op,
b_element_op,
c_element_op,
BatchCount);
b0_element_op,
b1_element_op,
c_element_op);
if(!gemm.IsSupportedArgument(argument))
{
......@@ -290,19 +334,19 @@ int main(int argc, char* argv[])
if(do_verification)
{
// Output of Gemm0 is input A of Gemm1
Tensor<ADataType> a1_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, Row{}));
Tensor<ADataType> a1_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{}));
auto ref_gemm0 = ReferenceGemm0Instance{};
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
auto ref_gemm0_argument = ref_gemm0.MakeArgument(
a_m_k, b0_k_n, a1_m_n, a_element_op, b_element_op, c_element_op);
a_m_k, b0_k_n, a1_m_n, a_element_op, b0_element_op, PassThrough{});
ref_gemm0_invoker.Run(ref_gemm0_argument);
auto ref_gemm1 = ReferenceGemm1Instance{};
auto ref_gemm1_invoker = ref_gemm1.MakeInvoker();
auto ref_gemm1_argument = ref_gemm1.MakeArgument(
a1_m_n, b1_n_o, c_m_o_host_result, a_element_op, b_element_op, c_element_op);
a1_m_n, b1_n_o, c_m_o_host_result, PassThrough{}, b1_element_op, c_element_op);
ref_gemm1_invoker.Run(ref_gemm1_argument);
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename ALayout,
typename B0Layout,
typename B1Layout,
typename CLayout,
typename ADataType,
typename B0DataType,
typename B1DataType,
typename CDataType,
typename AElementwiseOperation,
typename B0ElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation>
struct DeviceBatchedGemmGemm : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b0,
const void* p_b1,
void* p_c,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t O,
ck::index_t Batch,
ck::index_t StrideA,
ck::index_t StrideB0,
ck::index_t StrideB1,
ck::index_t StrideC,
ck::index_t BatchStrideA,
ck::index_t BatchStrideB0,
ck::index_t BatchStrideB1,
ck::index_t BatchStrideC,
AElementwiseOperation a_element_op,
B0ElementwiseOperation b0_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename ALayout,
typename B0Layout,
typename B1Layout,
typename CLayout,
typename ADataType,
typename B0DataType,
typename B1DataType,
typename CDataType,
typename AElementwiseOperation,
typename B0ElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation>
using DeviceBatchedGemmGemmPtr = std::unique_ptr<DeviceBatchedGemmGemm<ALayout,
B0Layout,
B1Layout,
CLayout,
ADataType,
B0DataType,
B1DataType,
CDataType,
AElementwiseOperation,
B0ElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation>>;
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -10,7 +10,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp"
#include "ck/device_utility/device_prop.hpp"
......@@ -25,6 +25,7 @@ template <typename GridwiseGemm,
typename FloatC,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
......@@ -43,6 +44,7 @@ __global__ void
FloatC* __restrict__ p_c_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const B1ElementwiseOperation b1_element_op,
const CElementwiseOperation c_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
......@@ -68,14 +70,6 @@ __global__ void
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetCBasePtr(g_idx)));
// if(threadIdx.x == 0)
// printf("bid = %zd, offset a b c d = %zd, %zd, %zd, %zd\n",
// hipBlockIdx_x,
// a_batch_offset,
// b_batch_offset,
// b1_batch_offset,
// c_batch_offset);
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
p_b1_grid + b1_batch_offset,
......@@ -83,6 +77,7 @@ __global__ void
p_shared,
a_element_op,
b_element_op,
b1_element_op,
c_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
......@@ -92,20 +87,21 @@ __global__ void
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_b1_grid;
ignore = p_c_grid;
ignore = p_shared;
ignore = a_element_op;
ignore = b_element_op;
ignore = b1_element_op;
ignore = c_element_op;
ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1;
ignore = b1_grid_desc_bk0_n_bk1;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = block_2_ctile_map;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
// Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle
// version currently has compiler issues with register spill which further causes validation
// failures.
// Computes C = A * B0 * B1
// ^^^^^^ (Acc0)
// ^^^^^^^^^^^ (Acc1)
......@@ -114,12 +110,14 @@ template <typename ALayout,
typename B1Layout,
typename CLayout,
typename ADataType,
typename BDataType, // NOTE: don't distinguish B0/B1 type just yet
typename BDataType,
typename B1DataType,
typename CDataType,
typename GemmAccDataType,
typename CShuffleDataType,
typename AElementwiseOperation,
typename BElementwiseOperation, // NOTE: don't distinguish B0/B1 type just yet
typename BElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation,
GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage,
......@@ -163,9 +161,20 @@ template <typename ALayout,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit from DeviceGemmGemm subtype
struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout,
BLayout,
B1Layout,
CLayout,
ADataType,
BDataType,
B1DataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation>
{
using DeviceOp = DeviceGemmGemm_Xdl_CShuffle;
using DeviceOp = DeviceBatchedGemmGemm_Xdl_CShuffle;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
......@@ -526,6 +535,7 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr
CDataType,
AElementwiseOperation,
BElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
InMemoryDataOperationEnum::Set,
AGridDesc_AK0_M_AK1,
......@@ -582,20 +592,25 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr
{
Argument(const ADataType* p_a_grid,
const BDataType* p_b_grid,
const BDataType* p_b1_grid,
const B1DataType* p_b1_grid,
CDataType* p_c_grid,
index_t MRaw,
index_t NRaw,
index_t KRaw,
index_t Gemm1NRaw, // = ORaw
index_t Batch,
index_t StrideA,
index_t StrideB,
index_t StrideB1,
index_t StrideC,
index_t BatchStrideA,
index_t BatchStrideB,
index_t BatchStrideB1,
index_t BatchStrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
index_t Batch)
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op)
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
p_b1_grid_{p_b1_grid},
......@@ -609,13 +624,10 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
b1_element_op_{b1_element_op},
c_element_op_{c_element_op},
batch_count_(Batch),
compute_base_ptr_of_batch_{
type_convert<index_t>(a_grid_desc_ak0_m_ak1_.GetElementSpaceSize()),
type_convert<index_t>(b_grid_desc_bk0_n_bk1_.GetElementSpaceSize()),
type_convert<index_t>(b1_grid_desc_bk0_n_bk1_.GetElementSpaceSize()),
type_convert<index_t>(c_grid_desc_m_n_.GetElementSpaceSize())}
compute_base_ptr_of_batch_{BatchStrideA, BatchStrideB, BatchStrideB1, BatchStrideC}
{
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
b_grid_desc_bk0_n_bk1_,
......@@ -632,7 +644,7 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr
// private:
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
const BDataType* p_b1_grid_;
const B1DataType* p_b1_grid_;
CDataType* p_c_grid_;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
......@@ -643,6 +655,7 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
B1ElementwiseOperation b1_element_op_;
CElementwiseOperation c_element_op_;
index_t batch_count_;
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
......@@ -680,6 +693,7 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr
CDataType,
AElementwiseOperation,
BElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
......@@ -700,6 +714,7 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr
arg.p_c_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.b1_element_op_,
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
......@@ -760,20 +775,25 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr
static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b,
const BDataType* p_b1,
const B1DataType* p_b1,
CDataType* p_c,
index_t MRaw,
index_t NRaw,
index_t KRaw,
index_t Gemm1NRaw,
index_t Batch,
index_t StrideA,
index_t StrideB,
index_t StrideB1,
index_t StrideC,
index_t BatchStrideA,
index_t BatchStrideB,
index_t BatchStrideB1,
index_t BatchStrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
index_t Batch)
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op)
{
return Argument{p_a,
p_b,
......@@ -783,14 +803,19 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr
NRaw,
KRaw,
Gemm1NRaw,
Batch,
StrideA,
StrideB,
StrideB1,
StrideC,
BatchStrideA,
BatchStrideB,
BatchStrideB1,
BatchStrideC,
a_element_op,
b_element_op,
c_element_op,
Batch};
b1_element_op,
c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
......@@ -804,35 +829,45 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr
index_t NRaw,
index_t KRaw,
index_t Gemm1NRaw,
index_t Batch,
index_t StrideA,
index_t StrideB,
index_t StrideB1,
index_t StrideC,
index_t BatchStrideA,
index_t BatchStrideB,
index_t BatchStrideB1,
index_t BatchStrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
index_t Batch) /* override */
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<const BDataType*>(p_b1),
static_cast<const B1DataType*>(p_b1),
static_cast<CDataType*>(p_c),
MRaw,
NRaw,
KRaw,
Gemm1NRaw,
Batch,
StrideA,
StrideB,
StrideB1,
StrideC,
BatchStrideA,
BatchStrideB,
BatchStrideB1,
BatchStrideC,
a_element_op,
b_element_op,
c_element_op,
Batch);
b1_element_op,
c_element_op);
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() /* override */
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
......@@ -843,7 +878,7 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr
auto str = std::stringstream();
// clang-format off
str << "DeviceGemmGemm_Xdl_CShuffle"
str << "DeviceBatchedGemmGemm_Xdl_CShuffle"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
......@@ -851,7 +886,7 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr
<< KPerBlock << ", "
<< AK1 << ", "
<< BK1 << ", "
<< NPerBlock << ", "
<< MPerBlock << ", "
<< Gemm1NPerBlock << ", "
<< Gemm1KPerBlock << ", "
<< B1K1 << ">";
......
......@@ -23,6 +23,7 @@ template <typename FloatAB,
typename FloatC,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AGridDesc_AK0_M_AK1,
......@@ -316,6 +317,7 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const B1ElementwiseOperation& b1_element_op,
const CElementwiseOperation& c_element_op,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment