Commit 68f946f5 authored by Jing Zhang's avatar Jing Zhang
Browse files

replace gridwise_v2r3 with multiD

parent 12235112
......@@ -48,11 +48,11 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
// clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemmXdl
//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| Num|
//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| Prefetch|
//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| |
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1>;
//######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| A| B| C| GEMM| Num| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::
......@@ -81,11 +81,11 @@ int main(int argc, char* argv[])
int group_count = rand() % 16 + 1;
// GEMM shape
std::vector<ck::tensor_operation::device::GemmShape> gemm_shapes;
std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs;
std::vector<const void*> p_a, p_b;
std::vector<void*> p_c;
gemm_shapes.reserve(group_count);
gemm_descs.reserve(group_count);
for(int i = 0; i < group_count; i++)
{
......@@ -93,7 +93,7 @@ int main(int argc, char* argv[])
int N = 128 + 128 * i;
int K = 64 + 64 * i;
gemm_shapes.push_back({M, N, K, K, K, N});
gemm_descs.push_back({M, N, K, K, K, N});
}
auto f_host_tensor_descriptor =
......@@ -131,22 +131,22 @@ int main(int argc, char* argv[])
std::size_t flop = 0, num_btype = 0;
for(std::size_t i = 0; i < gemm_shapes.size(); i++)
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
a_tensors.push_back(Tensor<ADataType>(f_host_tensor_descriptor(
gemm_shapes[i].M, gemm_shapes[i].K, gemm_shapes[i].StrideA, ALayout{})));
gemm_descs[i].M_, gemm_descs[i].K_, gemm_descs[i].stride_A_, ALayout{})));
b_tensors.push_back(Tensor<BDataType>(f_host_tensor_descriptor(
gemm_shapes[i].K, gemm_shapes[i].N, gemm_shapes[i].StrideB, BLayout{})));
gemm_descs[i].K_, gemm_descs[i].N_, gemm_descs[i].stride_B_, BLayout{})));
c_host_tensors.push_back(Tensor<CDataType>(f_host_tensor_descriptor(
gemm_shapes[i].M, gemm_shapes[i].N, gemm_shapes[i].StrideC, CLayout{})));
gemm_descs[i].M_, gemm_descs[i].N_, gemm_descs[i].stride_C_, CLayout{})));
c_device_tensors.push_back(Tensor<CDataType>(f_host_tensor_descriptor(
gemm_shapes[i].M, gemm_shapes[i].N, gemm_shapes[i].StrideC, CLayout{})));
gemm_descs[i].M_, gemm_descs[i].N_, gemm_descs[i].stride_C_, CLayout{})));
std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc
<< " b_k_n: " << b_tensors[i].mDesc << " c_m_n: " << c_device_tensors[i].mDesc
<< std::endl;
flop += std::size_t(2) * gemm_shapes[i].M * gemm_shapes[i].K * gemm_shapes[i].N;
flop += std::size_t(2) * gemm_descs[i].M_ * gemm_descs[i].K_ * gemm_descs[i].N_;
num_btype += sizeof(ADataType) * a_tensors[i].mDesc.GetElementSize() +
sizeof(BDataType) * b_tensors[i].mDesc.GetElementSize() +
sizeof(CDataType) * c_device_tensors[i].mDesc.GetElementSize();
......@@ -168,7 +168,7 @@ int main(int argc, char* argv[])
}
}
for(std::size_t i = 0; i < gemm_shapes.size(); i++)
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
a_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(ADataType) * a_tensors[i].mDesc.GetElementSpace()));
......@@ -194,7 +194,7 @@ int main(int argc, char* argv[])
// do GEMM
auto argument =
gemm.MakeArgument(p_a, p_b, p_c, gemm_shapes, a_element_op, b_element_op, c_element_op);
gemm.MakeArgument(p_a, p_b, p_c, gemm_descs, a_element_op, b_element_op, c_element_op);
DeviceMem gemm_desc_workspace(gemm.GetWorkSpaceSize(&argument));
......@@ -219,7 +219,7 @@ int main(int argc, char* argv[])
bool pass = true;
if(do_verification)
{
for(std::size_t i = 0; i < gemm_shapes.size(); i++)
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data());
auto ref_gemm = ReferenceGemmInstance{};
......
......@@ -11,12 +11,6 @@ namespace ck {
namespace tensor_operation {
namespace device {
struct GemmShape
{
ck::index_t M, N, K;
ck::index_t StrideA, StrideB, StrideC;
};
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
......
#pragma once
#include <iostream>
#include <vector>
#include "device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
struct GemmDesc
{
ck::index_t M_, N_, K_;
ck::index_t stride_A_, stride_B_, stride_C_;
};
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct DeviceGroupedGemm : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(std::vector<const void*>& p_a,
std::vector<const void*>& p_b,
std::vector<void*>& p_c,
std::vector<GemmDesc>& gemm_desc,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
using DeviceGroupedGemmPtr = std::unique_ptr<
DeviceGroupedGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>>;
} // namespace device
} // namespace tensor_operation
} // namespace ck
#pragma once
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
......@@ -10,9 +11,9 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/device_utility/device_prop.hpp"
#include "ck/device_utility/kernel_launch.hpp"
......@@ -21,8 +22,6 @@ namespace tensor_operation {
namespace device {
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename GemmDesc,
typename AElementwiseOperation,
typename BElementwiseOperation,
......@@ -32,11 +31,11 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_gemm_xdlops_v2r3(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
const index_t group_count,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op)
kernel_grouped_gemm_xdl(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
const index_t group_count,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
......@@ -56,17 +55,21 @@ __global__ void
}
GridwiseGemm::template Run<HasMainKBlockLoop>(
gemm_desc_ptr[group_id].a_ptr,
gemm_desc_ptr[group_id].b_ptr,
gemm_desc_ptr[group_id].c_ptr,
gemm_desc_ptr[group_id].a_ptr_,
gemm_desc_ptr[group_id].b_ptr_,
ck::Tuple<>{},
gemm_desc_ptr[group_id].c_ptr_,
p_shared,
gemm_desc_ptr[group_id].a_grid_desc_k0_m_k1_,
gemm_desc_ptr[group_id].b_grid_desc_k0_n_k1_,
gemm_desc_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
a_element_op,
b_element_op,
c_element_op,
gemm_desc_ptr[group_id].grouped_gemm_block_2_ctile_map_);
gemm_desc_ptr[group_id].a_grid_desc_k0_m_k1_,
gemm_desc_ptr[group_id].b_grid_desc_k0_n_k1_,
ck::StaticallyIndexedArray<
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
0>{},
gemm_desc_ptr[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_,
gemm_desc_ptr[group_id].block_2_ctile_map_);
#else
ignore = gemm_descs_const;
ignore = group_count;
......@@ -76,22 +79,24 @@ __global__ void
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
template <typename ADataType,
template <typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename CDataType,
typename AccDataType,
typename ALayout,
typename BLayout,
typename CLayout,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
GemmSpecialization GemmSpec,
ck::index_t NumPrefetch,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
ck::index_t K0PerBlock,
ck::index_t K1,
ck::index_t KPerBlock,
ck::index_t AK1,
ck::index_t BK1,
ck::index_t MPerXDL,
ck::index_t NPerXDL,
ck::index_t MXdlPerWave,
......@@ -110,10 +115,11 @@ template <typename ADataType,
ck::index_t BBlockTransferSrcScalarPerVector,
ck::index_t BBlockTransferDstScalarPerVector_K1,
bool BBlockLdsAddExtraN,
ck::index_t CThreadTransferSrcDstVectorDim,
ck::index_t CThreadTransferDstScalarPerVector,
ck::index_t NumPrefetch = 1,
ck::index_t MaxGroupCount = 10>
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceGroupedGemmXdl
: public DeviceGroupedGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>
{
......@@ -121,144 +127,297 @@ struct DeviceGroupedGemmXdl
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto K1Number = Number<K1>{};
static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA)
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
{
assert(K % K1 == 0);
const index_t K0 = K / K1;
const auto a_grid_desc_m_k = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
const auto a_grid_desc_mraw_kraw = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(StrideA, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(I1, StrideA));
}
}();
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
const auto MPad = M - MRaw;
const auto KPad = K - KRaw;
if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
// pad both M and K
assert(K % AK1 == 0);
const auto AK0 = K / AK1;
const auto a_grid_desc_m_k =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_right_pad_transform(MRaw, MPad),
make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MNPadding)
{
// pad M, but not K
assert(KRaw % AK1 == 0);
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_right_pad_transform(M, PadM)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
const auto AK0 = KRaw / AK1;
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_right_pad_transform(MRaw, MPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
// pad K, but not M
assert(K % AK1 == 0);
const auto AK0 = K / AK1;
const auto a_grid_desc_m_k = transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(MRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else
{
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// not pad M or K
assert(KRaw % AK1 == 0);
const auto AK0 = KRaw / AK1;
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(MRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
}
static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB)
static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
{
assert(K % K1 == 0);
const index_t K0 = K / K1;
const auto b_grid_desc_k_n = [&]() {
const auto b_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1));
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(I1, StrideB));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(StrideB, I1));
}
}();
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
const auto NPad = N - NRaw;
const auto KPad = K - KRaw;
if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad both N and K
assert(K % BK1 == 0);
const auto BK0 = K / BK1;
const auto b_grid_desc_n_k =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_right_pad_transform(NRaw, NPad),
make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::MNPadding)
{
// pad N, but not K
assert(KRaw % BK1 == 0);
const auto BK0 = KRaw / BK1;
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
// pad K, but not N
assert(K % BK1 == 0);
return transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_right_pad_transform(N, PadN)),
const auto BK0 = K / BK1;
const auto b_grid_desc_n_k = transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(NRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else
{
return transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// not pad N or K
assert(KRaw % BK1 == 0);
const auto BK0 = KRaw / BK1;
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(NRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
}
static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC)
static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
{
const auto c_grid_desc_m_n = [&]() {
const auto c_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(StrideE, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(I1, StrideE));
}
}();
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
const auto MPad = M - MRaw;
const auto NPad = N - NRaw;
if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad M and N
return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(MRaw, MPad),
make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
// pad M, but not N
return transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)),
c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
// pad N, but not M
return transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)),
c_grid_desc_mraw_nraw,
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
// not pad M or N
return c_grid_desc_mraw_nraw;
}
}
using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1));
using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1));
using EGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
// GridwiseGemm
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<
BlockSize,
using GridwiseGemm = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype
AccDataType,
CDataType,
InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
ck::Tuple<>,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
InMemoryDataOperationEnum::Set,
AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
EGridDesc_M_N,
NumPrefetch, // NumGemmKPrefetchStage
BlockSize,
MPerBlock,
NPerBlock,
K0PerBlock,
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
K1,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_K0_M_K1,
......@@ -277,30 +436,28 @@ struct DeviceGroupedGemmXdl
BBlockTransferDstScalarPerVector_K1,
false, // BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN,
Sequence<0, 2, 4, 5, 6, 1, 3, 7>, // CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
NumPrefetch>;
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched>;
struct GroupedGemmBlock2CTileMap
struct GroupedGemmBlock2ETileMap
{
using UnderlyingBlock2CTileMap = typename GridwiseGemm::DefaultBlock2CTileMap;
using UnderlyingBlock2CTileMap = typename GridwiseGemm::DefaultBlock2ETileMap;
static_assert(
std::is_same<decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1)),
typename GridwiseGemm::DefaultBlock2CTileMap>::value,
std::is_same<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{})),
typename GridwiseGemm::DefaultBlock2ETileMap>::value,
"Wrong! Should be the same type name");
GroupedGemmBlock2CTileMap()
GroupedGemmBlock2ETileMap()
{
block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1);
block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{});
BlockStart_ = -1;
}
GroupedGemmBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n,
index_t M01,
index_t N01,
ck::index_t BlockStart)
GroupedGemmBlock2ETileMap(const EGridDesc_M_N& c_grid_desc_m_n, ck::index_t BlockStart)
{
block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n, M01, N01);
block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2ETileMap(c_grid_desc_m_n);
BlockStart_ = BlockStart;
}
......@@ -318,29 +475,29 @@ struct DeviceGroupedGemmXdl
return block_2_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim);
}
__host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
__host__ bool CheckValidity(const EGridDesc_M_N& c_grid_desc_m_n) const
{
return block_2_ctile_map_.CheckValidity(c_grid_desc_m_n);
}
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
typename GridwiseGemm::DefaultBlock2ETileMap block_2_ctile_map_;
ck::index_t BlockStart_;
};
struct GemmDescKernelArg
struct GemmBiasTransKernelArg
{
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_;
AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1_;
BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1_;
EGridDesc_M_N e_grid_desc_m_n_;
typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_;
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_;
GroupedGemmBlock2CTileMap grouped_gemm_block_2_ctile_map_;
GroupedGemmBlock2ETileMap block_2_ctile_map_;
const ADataType* a_ptr;
const BDataType* b_ptr;
CDataType* c_ptr;
const ADataType* a_ptr_;
const BDataType* b_ptr_;
CDataType* c_ptr_;
ck::index_t BlockStart_, BlockEnd_;
};
......@@ -348,97 +505,93 @@ struct DeviceGroupedGemmXdl
// Argument
struct Argument : public BaseArgument
{
Argument(std::vector<const void*>& p_a,
std::vector<const void*>& p_b,
std::vector<void*>& p_c,
std::vector<GemmShape>& gemm_shapes,
index_t M01,
index_t N01,
Argument(std::vector<const void*>& p_As,
std::vector<const void*>& p_Bs,
std::vector<void*>& p_Es,
std::vector<GemmDesc>& gemm_descs,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
: M01_{M01},
N01_{N01},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op}
: a_element_op_{a_element_op}, b_element_op_{b_element_op}, c_element_op_{c_element_op}
{
grid_size_ = 0;
p_workspace_ = nullptr;
gemm_descs_args_workspace_ = nullptr;
group_count_ = ck::type_convert<ck::index_t>(gemm_shapes.size());
group_count_ = ck::type_convert<ck::index_t>(gemm_descs.size());
if(!(group_count_ == ck::type_convert<ck::index_t>(p_a.size()) &&
group_count_ == ck::type_convert<ck::index_t>(p_b.size()) &&
group_count_ == ck::type_convert<ck::index_t>(p_c.size())))
if(!(group_count_ == ck::type_convert<ck::index_t>(p_As.size()) &&
group_count_ == ck::type_convert<ck::index_t>(p_Bs.size()) &&
group_count_ == ck::type_convert<ck::index_t>(p_Es.size())))
{
throw std::runtime_error("wrong! group_count_ != P_a/b/c.size");
throw std::runtime_error("wrong! group_count_ != p_As/b/c.size");
}
gemm_desc_kernel_arg_.reserve(group_count_);
for(std::size_t i = 0; i < gemm_shapes.size(); i++)
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
const index_t M = gemm_shapes[i].M;
const index_t N = gemm_shapes[i].N;
const index_t K = gemm_shapes[i].K;
const index_t M = gemm_descs[i].M_;
const index_t N = gemm_descs[i].N_;
const index_t K = gemm_descs[i].K_;
const index_t StrideA = gemm_shapes[i].StrideA;
const index_t StrideB = gemm_shapes[i].StrideB;
const index_t StrideC = gemm_shapes[i].StrideC;
const index_t StrideA = gemm_descs[i].stride_A_;
const index_t StrideB = gemm_descs[i].stride_B_;
const index_t StrideC = gemm_descs[i].stride_C_;
const auto a_grid_desc_k0_m_k1_ =
DeviceGroupedGemmXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
DeviceGroupedGemmXdl::MakeAGridDescriptor_AK0_M_AK1(M, K, StrideA);
const auto b_grid_desc_k0_n_k1_ =
DeviceGroupedGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
const auto c_grid_desc_m_n_ =
DeviceGroupedGemmXdl::MakeBGridDescriptor_BK0_N_BK1(K, N, StrideB);
const auto e_grid_desc_m_n_ =
DeviceGroupedGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC);
const index_t grid_size_grp =
GroupedGemmBlock2CTileMap(c_grid_desc_m_n_, M01, N01, 0)
.block_2_ctile_map_.CalculateGridSize(c_grid_desc_m_n_);
GroupedGemmBlock2ETileMap(e_grid_desc_m_n_, 0)
.block_2_ctile_map_.CalculateGridSize(e_grid_desc_m_n_);
const index_t BlockStart = grid_size_;
const index_t BlockEnd = grid_size_ + grid_size_grp;
grid_size_ += grid_size_grp;
const auto grouped_gemm_block_2_ctile_map_ =
GroupedGemmBlock2CTileMap(c_grid_desc_m_n_, M01, N01, BlockStart);
const auto block_2_ctile_map_ =
GroupedGemmBlock2ETileMap(e_grid_desc_m_n_, BlockStart);
if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_,
b_grid_desc_k0_n_k1_,
c_grid_desc_m_n_,
grouped_gemm_block_2_ctile_map_))
e_grid_desc_m_n_,
block_2_ctile_map_))
{
const auto c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_);
auto e_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n_);
gemm_desc_kernel_arg_.push_back(
GemmDescKernelArg{a_grid_desc_k0_m_k1_,
b_grid_desc_k0_n_k1_,
c_grid_desc_m_n_,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
grouped_gemm_block_2_ctile_map_,
static_cast<const ADataType*>(p_a[i]),
static_cast<const BDataType*>(p_b[i]),
static_cast<CDataType*>(p_c[i]),
BlockStart,
BlockEnd});
GemmBiasTransKernelArg{a_grid_desc_k0_m_k1_,
b_grid_desc_k0_n_k1_,
e_grid_desc_m_n_,
e_grid_desc_mblock_mperblock_nblock_nperblock_,
block_2_ctile_map_,
static_cast<const ADataType*>(p_As[i]),
static_cast<const BDataType*>(p_Bs[i]),
static_cast<CDataType*>(p_Es[i]),
BlockStart,
BlockEnd});
}
}
}
// private:
index_t M01_;
index_t N01_;
index_t group_count_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
std::vector<GemmDescKernelArg> gemm_desc_kernel_arg_;
std::vector<GemmBiasTransKernelArg> gemm_desc_kernel_arg_;
void* gemm_descs_args_workspace_;
index_t grid_size_;
};
......@@ -464,16 +617,15 @@ struct DeviceGroupedGemmXdl
<< arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
<< arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_.GetLength(I2) << "}";
std::cout << ", arg.c_grid_desc_m_n_{ "
<< arg.gemm_desc_kernel_arg_[i].c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.gemm_desc_kernel_arg_[i].c_grid_desc_m_n_.GetLength(I1) << "}"
std::cout << ", arg.e_grid_desc_m_n_{ "
<< arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_.GetLength(I1) << "}"
<< std::endl;
if(!GridwiseGemm::CheckValidity(
arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_,
arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_,
arg.gemm_desc_kernel_arg_[i].c_grid_desc_m_n_,
arg.gemm_desc_kernel_arg_[i].grouped_gemm_block_2_ctile_map_))
if(!GridwiseGemm::CheckValidity(arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_,
arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_,
arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_,
arg.gemm_desc_kernel_arg_[i].block_2_ctile_map_))
{
throw std::runtime_error(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting");
......@@ -489,60 +641,41 @@ struct DeviceGroupedGemmXdl
}
hipGetErrorString(
hipMemcpy(arg.p_workspace_,
hipMemcpy(arg.gemm_descs_args_workspace_,
arg.gemm_desc_kernel_arg_.data(),
arg.gemm_desc_kernel_arg_.size() * sizeof(GemmDescKernelArg),
arg.gemm_desc_kernel_arg_.size() * sizeof(GemmBiasTransKernelArg),
hipMemcpyHostToDevice));
float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_grouped_gemm_xdl<GridwiseGemm,
GemmBiasTransKernelArg,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
has_main_k_block_loop_>;
return launch_and_time_kernel(
stream_config,
kernel,
dim3(arg.grid_size_),
dim3(BlockSize),
0,
cast_pointer_to_constant_address_space(arg.gemm_descs_args_workspace_),
arg.gemm_desc_kernel_arg_.size(),
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_);
};
if(has_main_k_block_loop)
{
const auto kernel =
kernel_grouped_gemm_xdlops_v2r3<GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
GemmDescKernelArg,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
true>;
ave_time =
launch_and_time_kernel(stream_config,
kernel,
dim3(arg.grid_size_),
dim3(BlockSize),
0,
cast_pointer_to_constant_address_space(arg.p_workspace_),
arg.gemm_desc_kernel_arg_.size(),
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_);
ave_time = launch_kernel(integral_constant<bool, true>{});
}
else
{
const auto kernel =
kernel_grouped_gemm_xdlops_v2r3<GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
GemmDescKernelArg,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
false>;
ave_time =
launch_and_time_kernel(stream_config,
kernel,
dim3(arg.grid_size_),
dim3(BlockSize),
0,
cast_pointer_to_constant_address_space(arg.p_workspace_),
arg.gemm_desc_kernel_arg_.size(),
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_);
ave_time = launch_kernel(integral_constant<bool, false>{});
}
return ave_time;
......@@ -576,31 +709,30 @@ struct DeviceGroupedGemmXdl
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(std::vector<const void*>& p_a,
std::vector<const void*>& p_b,
std::vector<void*>& p_c,
std::vector<GemmShape> gemm_shapes,
static auto MakeArgument(std::vector<const void*>& p_As,
std::vector<const void*>& p_Bs,
std::vector<void*>& p_Es,
std::vector<GemmDesc> gemm_descs,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
return Argument{p_a, p_b, p_c, gemm_shapes, 1, 1, a_element_op, b_element_op, c_element_op};
return Argument{p_As, p_Bs, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(std::vector<const void*>& p_a,
std::vector<const void*>& p_b,
std::vector<void*>& p_c,
std::vector<GemmShape>& gemm_shapes,
std::unique_ptr<BaseArgument> MakeArgumentPointer(std::vector<const void*>& p_As,
std::vector<const void*>& p_Bs,
std::vector<void*>& p_Es,
std::vector<GemmDesc>& gemm_descs,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
index_t /* KBatch */ = 1) override
CElementwiseOperation c_element_op) override
{
return std::make_unique<Argument>(
p_a, p_b, p_c, gemm_shapes, 1, 1, a_element_op, b_element_op, c_element_op);
p_As, p_Bs, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op);
}
// polymorphic
......@@ -620,8 +752,9 @@ struct DeviceGroupedGemmXdl
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< K0PerBlock << ", "
<< K1 << ", "
<< KPerBlock << ", "
<< AK1 << ", "
<< BK1 << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
......@@ -634,7 +767,12 @@ struct DeviceGroupedGemmXdl
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
{
return dynamic_cast<const Argument*>(p_arg)->group_count_ * sizeof(GemmDescKernelArg);
return dynamic_cast<const Argument*>(p_arg)->group_count_ * sizeof(GemmBiasTransKernelArg);
}
void SetWorkSpacePointer(BaseArgument* p_arg, void* workspace_ptr) const override
{
dynamic_cast<Argument*>(p_arg)->gemm_descs_args_workspace_ = workspace_ptr;
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment