Commit 5f94555b authored by Anthony Chang's avatar Anthony Chang
Browse files

implement proper interface

parent 98e4c0ce
...@@ -49,23 +49,26 @@ using B1Layout = Row; ...@@ -49,23 +49,26 @@ using B1Layout = Row;
using CLayout = Row; using CLayout = Row;
using AElementOp = PassThrough; using AElementOp = PassThrough;
using BElementOp = PassThrough; using B0ElementOp = PassThrough;
using B1ElementOp = PassThrough;
using CElementOp = PassThrough; using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmGemm_Xdl_CShuffle< using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmGemm_Xdl_CShuffle<
ALayout, ALayout,
B0Layout, B0Layout,
B1Layout, B1Layout,
CLayout, CLayout,
ADataType, ADataType,
B0DataType, B0DataType,
B1DataType,
CDataType, CDataType,
AccDataType, AccDataType,
CShuffleDataType, CShuffleDataType,
AElementOp, AElementOp,
BElementOp, B0ElementOp,
B1ElementOp,
CElementOp, CElementOp,
GemmDefault, GemmDefault,
1, 1,
...@@ -114,10 +117,10 @@ using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm< ...@@ -114,10 +117,10 @@ using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
ADataType, ADataType,
AccDataType, AccDataType,
AElementOp, AElementOp,
BElementOp, B0ElementOp,
CElementOp>; CElementOp>;
using ReferenceGemm1Instance = ck::tensor_operation::host:: using ReferenceGemm1Instance = ck::tensor_operation::host::
ReferenceBatchedGemm<ADataType, B1DataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>; ReferenceBatchedGemm<ADataType, B1DataType, CDataType, AccDataType, AElementOp, B1ElementOp, CElementOp>;
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
...@@ -130,7 +133,7 @@ int main(int argc, char* argv[]) ...@@ -130,7 +133,7 @@ int main(int argc, char* argv[])
// ck::index_t N = 1024; // ck::index_t N = 1024;
// ck::index_t K = 64; // ck::index_t K = 64;
// ck::index_t O = 64; // ck::index_t O = 64;
// ck::index_t BatchCount = 4;
// ck::index_t StrideA = 1024; // ck::index_t StrideA = 1024;
// ck::index_t StrideB0 = 1024; // ck::index_t StrideB0 = 1024;
// ck::index_t StrideB1 = 1024; // ck::index_t StrideB1 = 1024;
...@@ -140,11 +143,15 @@ int main(int argc, char* argv[]) ...@@ -140,11 +143,15 @@ int main(int argc, char* argv[])
ck::index_t N = 128; ck::index_t N = 128;
ck::index_t K = 32; ck::index_t K = 32;
ck::index_t O = 128; ck::index_t O = 128;
ck::index_t BatchCount = 4;
ck::index_t StrideA = 32; ck::index_t StrideA = 32;
ck::index_t StrideB0 = 32; ck::index_t StrideB0 = 32;
ck::index_t StrideB1 = 128; ck::index_t StrideB1 = 128;
ck::index_t StrideC = 128; ck::index_t StrideC = 128;
ck::index_t BatchCount = 64; ck::index_t BatchStrideA = -1;
ck::index_t BatchStrideB0 = -1;
ck::index_t BatchStrideB1 = -1;
ck::index_t BatchStrideC = -1;
if(argc == 1) if(argc == 1)
{ {
...@@ -156,7 +163,7 @@ int main(int argc, char* argv[]) ...@@ -156,7 +163,7 @@ int main(int argc, char* argv[])
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
else if(argc == 13) else if(argc == 17)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
...@@ -167,12 +174,17 @@ int main(int argc, char* argv[]) ...@@ -167,12 +174,17 @@ int main(int argc, char* argv[])
K = std::stoi(argv[6]); K = std::stoi(argv[6]);
O = std::stoi(argv[7]); O = std::stoi(argv[7]);
StrideA = std::stoi(argv[8]); BatchCount = std::stoi(argv[8]);
StrideB0 = std::stoi(argv[9]);
StrideB1 = std::stoi(argv[10]); StrideA = std::stoi(argv[9]);
StrideC = std::stoi(argv[11]); StrideB0 = std::stoi(argv[10]);
StrideB1 = std::stoi(argv[11]);
StrideC = std::stoi(argv[12]);
BatchCount = std::stoi(argv[12]); BatchStrideA = std::stoi(argv[13]);
BatchStrideB0 = std::stoi(argv[14]);
BatchStrideB1 = std::stoi(argv[15]);
BatchStrideC = std::stoi(argv[16]);
} }
else else
{ {
...@@ -183,29 +195,55 @@ int main(int argc, char* argv[]) ...@@ -183,29 +195,55 @@ int main(int argc, char* argv[])
exit(0); exit(0);
} }
const int DefaultStrideA = ck::is_same_v<ALayout, Row> ? K : M;
const int DefaultStrideB0 = ck::is_same_v<B0Layout, Row> ? N : K;
const int DefaultStrideB1 = ck::is_same_v<B1Layout, Row> ? O : N;
const int DefaultStrideC = ck::is_same_v<CLayout, Row> ? O : M;
StrideA = (StrideA < 0) ? DefaultStrideA : StrideA;
StrideB0 = (StrideB0 < 0) ? DefaultStrideB0 : StrideB0;
StrideB1 = (StrideB1 < 0) ? DefaultStrideB1 : StrideB1;
StrideC = (StrideC < 0) ? DefaultStrideC : StrideC;
const int DefaultBatchStrideA = (ck::is_same_v<ALayout, Row> ? K : M) * StrideA;
const int DefaultBatchStrideB0 = (ck::is_same_v<B0Layout, Row> ? N : K) * StrideB0;
const int DefaultBatchStrideB1 = (ck::is_same_v<B1Layout, Row> ? O : N) * StrideB1;
const int DefaultBatchStrideC = (ck::is_same_v<CLayout, Row> ? O : M) * StrideC;
BatchStrideA = BatchStrideA < 0 ? DefaultBatchStrideA : BatchStrideA;
BatchStrideB0 = BatchStrideB0 < 0 ? DefaultBatchStrideB0 : BatchStrideB0;
BatchStrideB1 = BatchStrideB1 < 0 ? DefaultBatchStrideB1 : BatchStrideB1;
BatchStrideC = BatchStrideC < 0 ? DefaultBatchStrideC : BatchStrideC;
auto f_host_tensor_descriptor = [](std::size_t batch_count, auto f_host_tensor_descriptor = [](std::size_t batch_count,
std::size_t row, std::size_t row,
std::size_t col, std::size_t col,
std::size_t stride, std::size_t stride,
std::size_t batch_stride,
auto layout) { auto layout) {
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value) if(std::is_same<decltype(layout), Row>::value)
{ {
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}), return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
std::vector<std::size_t>({row * stride, stride, 1})); std::vector<std::size_t>({batch_stride, stride, 1}));
} }
else else
{ {
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}), return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
std::vector<std::size_t>({col * stride, 1, stride})); std::vector<std::size_t>({batch_stride, 1, stride}));
} }
}; };
// C_m_o = A_m_k * B0_k_n * B1_n_o // C_m_o = A_m_k * B0_k_n * B1_n_o
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(BatchCount, M, K, StrideA, ALayout{})); Tensor<ADataType> a_m_k(
Tensor<B0DataType> b0_k_n(f_host_tensor_descriptor(BatchCount, K, N, StrideB0, B0Layout{})); f_host_tensor_descriptor(BatchCount, M, K, StrideA, BatchStrideA, ALayout{}));
Tensor<B1DataType> b1_n_o(f_host_tensor_descriptor(BatchCount, N, O, StrideB1, B1Layout{})); Tensor<B0DataType> b0_k_n(
Tensor<CDataType> c_m_o_host_result(f_host_tensor_descriptor(BatchCount, M, O, StrideC, CLayout{})); f_host_tensor_descriptor(BatchCount, K, N, StrideB0, BatchStrideB0, B0Layout{}));
Tensor<CDataType> c_m_o_device_result(f_host_tensor_descriptor(BatchCount, M, O, StrideC, CLayout{})); Tensor<B1DataType> b1_n_o(
f_host_tensor_descriptor(BatchCount, N, O, StrideB1, BatchStrideB1, B1Layout{}));
Tensor<CDataType> c_m_o_host_result(
f_host_tensor_descriptor(BatchCount, M, O, StrideC, BatchStrideC, CLayout{}));
Tensor<CDataType> c_m_o_device_result(
f_host_tensor_descriptor(BatchCount, M, O, StrideC, BatchStrideC, CLayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl; std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl;
...@@ -241,7 +279,8 @@ int main(int argc, char* argv[]) ...@@ -241,7 +279,8 @@ int main(int argc, char* argv[])
b1_n_o_device_buf.ToDevice(b1_n_o.mData.data()); b1_n_o_device_buf.ToDevice(b1_n_o.mData.data());
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{}; auto b0_element_op = B0ElementOp{};
auto b1_element_op = B1ElementOp{};
auto c_element_op = CElementOp{}; auto c_element_op = CElementOp{};
// do GEMM // do GEMM
...@@ -255,14 +294,19 @@ int main(int argc, char* argv[]) ...@@ -255,14 +294,19 @@ int main(int argc, char* argv[])
N, N,
K, K,
O, O,
BatchCount,
StrideA, StrideA,
StrideB0, StrideB0,
StrideB1, StrideB1,
StrideC, StrideC,
BatchStrideA,
BatchStrideB0,
BatchStrideB1,
BatchStrideC,
a_element_op, a_element_op,
b_element_op, b0_element_op,
c_element_op, b1_element_op,
BatchCount); c_element_op);
if(!gemm.IsSupportedArgument(argument)) if(!gemm.IsSupportedArgument(argument))
{ {
...@@ -290,19 +334,19 @@ int main(int argc, char* argv[]) ...@@ -290,19 +334,19 @@ int main(int argc, char* argv[])
if(do_verification) if(do_verification)
{ {
// Output of Gemm0 is input A of Gemm1 // Output of Gemm0 is input A of Gemm1
Tensor<ADataType> a1_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, Row{})); Tensor<ADataType> a1_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{}));
auto ref_gemm0 = ReferenceGemm0Instance{}; auto ref_gemm0 = ReferenceGemm0Instance{};
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker(); auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
auto ref_gemm0_argument = ref_gemm0.MakeArgument( auto ref_gemm0_argument = ref_gemm0.MakeArgument(
a_m_k, b0_k_n, a1_m_n, a_element_op, b_element_op, c_element_op); a_m_k, b0_k_n, a1_m_n, a_element_op, b0_element_op, PassThrough{});
ref_gemm0_invoker.Run(ref_gemm0_argument); ref_gemm0_invoker.Run(ref_gemm0_argument);
auto ref_gemm1 = ReferenceGemm1Instance{}; auto ref_gemm1 = ReferenceGemm1Instance{};
auto ref_gemm1_invoker = ref_gemm1.MakeInvoker(); auto ref_gemm1_invoker = ref_gemm1.MakeInvoker();
auto ref_gemm1_argument = ref_gemm1.MakeArgument( auto ref_gemm1_argument = ref_gemm1.MakeArgument(
a1_m_n, b1_n_o, c_m_o_host_result, a_element_op, b_element_op, c_element_op); a1_m_n, b1_n_o, c_m_o_host_result, PassThrough{}, b1_element_op, c_element_op);
ref_gemm1_invoker.Run(ref_gemm1_argument); ref_gemm1_invoker.Run(ref_gemm1_argument);
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename ALayout,
typename B0Layout,
typename B1Layout,
typename CLayout,
typename ADataType,
typename B0DataType,
typename B1DataType,
typename CDataType,
typename AElementwiseOperation,
typename B0ElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation>
struct DeviceBatchedGemmGemm : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b0,
const void* p_b1,
void* p_c,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t O,
ck::index_t Batch,
ck::index_t StrideA,
ck::index_t StrideB0,
ck::index_t StrideB1,
ck::index_t StrideC,
ck::index_t BatchStrideA,
ck::index_t BatchStrideB0,
ck::index_t BatchStrideB1,
ck::index_t BatchStrideC,
AElementwiseOperation a_element_op,
B0ElementwiseOperation b0_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename ALayout,
typename B0Layout,
typename B1Layout,
typename CLayout,
typename ADataType,
typename B0DataType,
typename B1DataType,
typename CDataType,
typename AElementwiseOperation,
typename B0ElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation>
using DeviceBatchedGemmGemmPtr = std::unique_ptr<DeviceBatchedGemmGemm<ALayout,
B0Layout,
B1Layout,
CLayout,
ADataType,
B0DataType,
B1DataType,
CDataType,
AElementwiseOperation,
B0ElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation>>;
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp"
#include "ck/device_utility/device_prop.hpp" #include "ck/device_utility/device_prop.hpp"
...@@ -25,6 +25,7 @@ template <typename GridwiseGemm, ...@@ -25,6 +25,7 @@ template <typename GridwiseGemm,
typename FloatC, typename FloatC,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
typename AGridDesc_AK0_M_AK1, typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1, typename BGridDesc_BK0_N_BK1,
...@@ -43,6 +44,7 @@ __global__ void ...@@ -43,6 +44,7 @@ __global__ void
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const B1ElementwiseOperation b1_element_op,
const CElementwiseOperation c_element_op, const CElementwiseOperation c_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
...@@ -68,14 +70,6 @@ __global__ void ...@@ -68,14 +70,6 @@ __global__ void
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetCBasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch.GetCBasePtr(g_idx)));
// if(threadIdx.x == 0)
// printf("bid = %zd, offset a b c d = %zd, %zd, %zd, %zd\n",
// hipBlockIdx_x,
// a_batch_offset,
// b_batch_offset,
// b1_batch_offset,
// c_batch_offset);
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset, GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset, p_b_grid + b_batch_offset,
p_b1_grid + b1_batch_offset, p_b1_grid + b1_batch_offset,
...@@ -83,6 +77,7 @@ __global__ void ...@@ -83,6 +77,7 @@ __global__ void
p_shared, p_shared,
a_element_op, a_element_op,
b_element_op, b_element_op,
b1_element_op,
c_element_op, c_element_op,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
...@@ -92,20 +87,21 @@ __global__ void ...@@ -92,20 +87,21 @@ __global__ void
#else #else
ignore = p_a_grid; ignore = p_a_grid;
ignore = p_b_grid; ignore = p_b_grid;
ignore = p_b1_grid;
ignore = p_c_grid; ignore = p_c_grid;
ignore = p_shared;
ignore = a_element_op; ignore = a_element_op;
ignore = b_element_op; ignore = b_element_op;
ignore = b1_element_op;
ignore = c_element_op; ignore = c_element_op;
ignore = a_grid_desc_ak0_m_ak1; ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1; ignore = b_grid_desc_bk0_n_bk1;
ignore = b1_grid_desc_bk0_n_bk1;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = block_2_ctile_map; ignore = block_2_ctile_map;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
} }
// Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle
// version currently has compiler issues with register spill which further causes validation
// failures.
// Computes C = A * B0 * B1 // Computes C = A * B0 * B1
// ^^^^^^ (Acc0) // ^^^^^^ (Acc0)
// ^^^^^^^^^^^ (Acc1) // ^^^^^^^^^^^ (Acc1)
...@@ -114,12 +110,14 @@ template <typename ALayout, ...@@ -114,12 +110,14 @@ template <typename ALayout,
typename B1Layout, typename B1Layout,
typename CLayout, typename CLayout,
typename ADataType, typename ADataType,
typename BDataType, // NOTE: don't distinguish B0/B1 type just yet typename BDataType,
typename B1DataType,
typename CDataType, typename CDataType,
typename GemmAccDataType, typename GemmAccDataType,
typename CShuffleDataType, typename CShuffleDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, // NOTE: don't distinguish B0/B1 type just yet typename BElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
GemmSpecialization GemmSpec, GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage, index_t NumGemmKPrefetchStage,
...@@ -163,9 +161,20 @@ template <typename ALayout, ...@@ -163,9 +161,20 @@ template <typename ALayout,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = LoopScheduler::Default> LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit from DeviceGemmGemm subtype struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout,
BLayout,
B1Layout,
CLayout,
ADataType,
BDataType,
B1DataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation>
{ {
using DeviceOp = DeviceGemmGemm_Xdl_CShuffle; using DeviceOp = DeviceBatchedGemmGemm_Xdl_CShuffle;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -526,6 +535,7 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr ...@@ -526,6 +535,7 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr
CDataType, CDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
AGridDesc_AK0_M_AK1, AGridDesc_AK0_M_AK1,
...@@ -582,20 +592,25 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr ...@@ -582,20 +592,25 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr
{ {
Argument(const ADataType* p_a_grid, Argument(const ADataType* p_a_grid,
const BDataType* p_b_grid, const BDataType* p_b_grid,
const BDataType* p_b1_grid, const B1DataType* p_b1_grid,
CDataType* p_c_grid, CDataType* p_c_grid,
index_t MRaw, index_t MRaw,
index_t NRaw, index_t NRaw,
index_t KRaw, index_t KRaw,
index_t Gemm1NRaw, // = ORaw index_t Gemm1NRaw, // = ORaw
index_t Batch,
index_t StrideA, index_t StrideA,
index_t StrideB, index_t StrideB,
index_t StrideB1, index_t StrideB1,
index_t StrideC, index_t StrideC,
index_t BatchStrideA,
index_t BatchStrideB,
index_t BatchStrideB1,
index_t BatchStrideC,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, B1ElementwiseOperation b1_element_op,
index_t Batch) CElementwiseOperation c_element_op)
: p_a_grid_{p_a_grid}, : p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid}, p_b_grid_{p_b_grid},
p_b1_grid_{p_b1_grid}, p_b1_grid_{p_b1_grid},
...@@ -609,13 +624,10 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr ...@@ -609,13 +624,10 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)}, block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
b1_element_op_{b1_element_op},
c_element_op_{c_element_op}, c_element_op_{c_element_op},
batch_count_(Batch), batch_count_(Batch),
compute_base_ptr_of_batch_{ compute_base_ptr_of_batch_{BatchStrideA, BatchStrideB, BatchStrideB1, BatchStrideC}
type_convert<index_t>(a_grid_desc_ak0_m_ak1_.GetElementSpaceSize()),
type_convert<index_t>(b_grid_desc_bk0_n_bk1_.GetElementSpaceSize()),
type_convert<index_t>(b1_grid_desc_bk0_n_bk1_.GetElementSpaceSize()),
type_convert<index_t>(c_grid_desc_m_n_.GetElementSpaceSize())}
{ {
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
b_grid_desc_bk0_n_bk1_, b_grid_desc_bk0_n_bk1_,
...@@ -632,7 +644,7 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr ...@@ -632,7 +644,7 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr
// private: // private:
const ADataType* p_a_grid_; const ADataType* p_a_grid_;
const BDataType* p_b_grid_; const BDataType* p_b_grid_;
const BDataType* p_b1_grid_; const B1DataType* p_b1_grid_;
CDataType* p_c_grid_; CDataType* p_c_grid_;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
...@@ -643,6 +655,7 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr ...@@ -643,6 +655,7 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
B1ElementwiseOperation b1_element_op_;
CElementwiseOperation c_element_op_; CElementwiseOperation c_element_op_;
index_t batch_count_; index_t batch_count_;
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_; ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
...@@ -680,6 +693,7 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr ...@@ -680,6 +693,7 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr
CDataType, CDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::BGridDesc_BK0_N_BK1,
...@@ -700,6 +714,7 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr ...@@ -700,6 +714,7 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr
arg.p_c_grid_, arg.p_c_grid_,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.b1_element_op_,
arg.c_element_op_, arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_, arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
...@@ -760,20 +775,25 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr ...@@ -760,20 +775,25 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr
static auto MakeArgument(const ADataType* p_a, static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b, const BDataType* p_b,
const BDataType* p_b1, const B1DataType* p_b1,
CDataType* p_c, CDataType* p_c,
index_t MRaw, index_t MRaw,
index_t NRaw, index_t NRaw,
index_t KRaw, index_t KRaw,
index_t Gemm1NRaw, index_t Gemm1NRaw,
index_t Batch,
index_t StrideA, index_t StrideA,
index_t StrideB, index_t StrideB,
index_t StrideB1, index_t StrideB1,
index_t StrideC, index_t StrideC,
index_t BatchStrideA,
index_t BatchStrideB,
index_t BatchStrideB1,
index_t BatchStrideC,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, B1ElementwiseOperation b1_element_op,
index_t Batch) CElementwiseOperation c_element_op)
{ {
return Argument{p_a, return Argument{p_a,
p_b, p_b,
...@@ -783,14 +803,19 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr ...@@ -783,14 +803,19 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr
NRaw, NRaw,
KRaw, KRaw,
Gemm1NRaw, Gemm1NRaw,
Batch,
StrideA, StrideA,
StrideB, StrideB,
StrideB1, StrideB1,
StrideC, StrideC,
BatchStrideA,
BatchStrideB,
BatchStrideB1,
BatchStrideC,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, b1_element_op,
Batch}; c_element_op};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
...@@ -804,35 +829,45 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr ...@@ -804,35 +829,45 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr
index_t NRaw, index_t NRaw,
index_t KRaw, index_t KRaw,
index_t Gemm1NRaw, index_t Gemm1NRaw,
index_t Batch,
index_t StrideA, index_t StrideA,
index_t StrideB, index_t StrideB,
index_t StrideB1, index_t StrideB1,
index_t StrideC, index_t StrideC,
index_t BatchStrideA,
index_t BatchStrideB,
index_t BatchStrideB1,
index_t BatchStrideC,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, B1ElementwiseOperation b1_element_op,
index_t Batch) /* override */ CElementwiseOperation c_element_op) override
{ {
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_b),
static_cast<const BDataType*>(p_b1), static_cast<const B1DataType*>(p_b1),
static_cast<CDataType*>(p_c), static_cast<CDataType*>(p_c),
MRaw, MRaw,
NRaw, NRaw,
KRaw, KRaw,
Gemm1NRaw, Gemm1NRaw,
Batch,
StrideA, StrideA,
StrideB, StrideB,
StrideB1, StrideB1,
StrideC, StrideC,
BatchStrideA,
BatchStrideB,
BatchStrideB1,
BatchStrideC,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, b1_element_op,
Batch); c_element_op);
} }
// polymorphic // polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() /* override */ std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{ {
return std::make_unique<Invoker>(Invoker{}); return std::make_unique<Invoker>(Invoker{});
} }
...@@ -843,7 +878,7 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr ...@@ -843,7 +878,7 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DeviceGemmGemm_Xdl_CShuffle" str << "DeviceBatchedGemmGemm_Xdl_CShuffle"
<< "<" << "<"
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
...@@ -851,7 +886,7 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr ...@@ -851,7 +886,7 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr
<< KPerBlock << ", " << KPerBlock << ", "
<< AK1 << ", " << AK1 << ", "
<< BK1 << ", " << BK1 << ", "
<< NPerBlock << ", " << MPerBlock << ", "
<< Gemm1NPerBlock << ", " << Gemm1NPerBlock << ", "
<< Gemm1KPerBlock << ", " << Gemm1KPerBlock << ", "
<< B1K1 << ">"; << B1K1 << ">";
......
...@@ -23,6 +23,7 @@ template <typename FloatAB, ...@@ -23,6 +23,7 @@ template <typename FloatAB,
typename FloatC, typename FloatC,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
InMemoryDataOperationEnum CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AGridDesc_AK0_M_AK1, typename AGridDesc_AK0_M_AK1,
...@@ -316,6 +317,7 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle ...@@ -316,6 +317,7 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
void* __restrict__ p_shared, void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op, const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op, const BElementwiseOperation& b_element_op,
const B1ElementwiseOperation& b1_element_op,
const CElementwiseOperation& c_element_op, const CElementwiseOperation& c_element_op,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment