Unverified Commit 3f299c33 authored by Adam Osewski's avatar Adam Osewski Committed by GitHub
Browse files

Merge branch 'develop' into aosewski/ggemm_dl_instances

parents 507d793a 091570f5
...@@ -76,4 +76,8 @@ using DeviceGroupedConvNDFwdInstance = ...@@ -76,4 +76,8 @@ using DeviceGroupedConvNDFwdInstance =
#include "run_conv2d_fwd_perchannel_quantization_example.inc" #include "run_conv2d_fwd_perchannel_quantization_example.inc"
int main() { run_conv2d_fwd_perchannel_quantization_example(); } int main()
{
const auto out_element_op = OutElementOp{ActivationOp{}};
run_conv2d_fwd_perchannel_quantization_example(out_element_op);
}
...@@ -71,4 +71,9 @@ using DeviceGroupedConvNDFwdInstance = ...@@ -71,4 +71,9 @@ using DeviceGroupedConvNDFwdInstance =
#include "run_conv2d_fwd_perlayer_quantization_example.inc" #include "run_conv2d_fwd_perlayer_quantization_example.inc"
int main() { run_conv2d_fwd_perlayer_quantization_example(); } int main()
{
float requant_scale = 0.5f;
const auto out_element_op = OutElementOp{requant_scale, ActivationOp{}};
run_conv2d_fwd_perlayer_quantization_example(out_element_op);
}
...@@ -80,6 +80,10 @@ using DeviceGroupedConvNDFwdInstance = ...@@ -80,6 +80,10 @@ using DeviceGroupedConvNDFwdInstance =
S<1, 64, 1, 4>, S<1, 64, 1, 4>,
8>; 8>;
#include "run_conv2d_fwd_bias_relu_perchannel_quantization_example.inc" #include "run_conv2d_fwd_bias_perchannel_quantization_example.inc"
int main() { run_conv2d_fwd_bias_relu_perchannel_quantization_example(); }; int main()
{
const auto out_element_op = OutElementOp{ActivationOp{}};
run_conv2d_fwd_bias_perchannel_quantization_example(out_element_op);
};
...@@ -78,6 +78,11 @@ using DeviceGroupedConvNDFwdInstance = ...@@ -78,6 +78,11 @@ using DeviceGroupedConvNDFwdInstance =
S<1, 64, 1, 4>, S<1, 64, 1, 4>,
8>; 8>;
#include "run_conv2d_fwd_bias_relu_perlayer_quantization_example.inc" #include "run_conv2d_fwd_bias_perlayer_quantization_example.inc"
int main() { run_conv2d_fwd_bias_relu_perlayer_quantization_example(); } int main()
{
float requant_scale = 0.5f;
const auto out_element_op = OutElementOp{requant_scale, ActivationOp{}};
run_conv2d_fwd_bias_perlayer_quantization_example(out_element_op);
}
...@@ -80,4 +80,8 @@ using DeviceGroupedConvNDFwdInstance = ...@@ -80,4 +80,8 @@ using DeviceGroupedConvNDFwdInstance =
#include "run_conv2d_fwd_perchannel_quantization_example.inc" #include "run_conv2d_fwd_perchannel_quantization_example.inc"
int main() { run_conv2d_fwd_perchannel_quantization_example(); } int main()
{
const auto out_element_op = OutElementOp{ActivationOp{}};
run_conv2d_fwd_perchannel_quantization_example(out_element_op);
}
...@@ -75,4 +75,9 @@ using DeviceGroupedConvNDFwdInstance = ...@@ -75,4 +75,9 @@ using DeviceGroupedConvNDFwdInstance =
#include "run_conv2d_fwd_perlayer_quantization_example.inc" #include "run_conv2d_fwd_perlayer_quantization_example.inc"
int main() { run_conv2d_fwd_perlayer_quantization_example(); } int main()
{
float requant_scale = 0.5f;
const auto out_element_op = OutElementOp{requant_scale, ActivationOp{}};
run_conv2d_fwd_perlayer_quantization_example(out_element_op);
}
...@@ -167,7 +167,7 @@ bool run_grouped_conv_fwd(bool do_verification, ...@@ -167,7 +167,7 @@ bool run_grouped_conv_fwd(bool do_verification,
return (pass ? 0 : 1); return (pass ? 0 : 1);
} }
int run_conv2d_fwd_bias_relu_perchannel_quantization_example() int run_conv2d_fwd_bias_perchannel_quantization_example(const OutElementOp& out_element_op)
{ {
bool do_verification = true; bool do_verification = true;
bool time_kernel = true; bool time_kernel = true;
...@@ -189,7 +189,6 @@ int run_conv2d_fwd_bias_relu_perchannel_quantization_example() ...@@ -189,7 +189,6 @@ int run_conv2d_fwd_bias_relu_perchannel_quantization_example()
const auto in_element_op = InElementOp{}; const auto in_element_op = InElementOp{};
const auto wei_element_op = WeiElementOp{}; const auto wei_element_op = WeiElementOp{};
const auto out_element_op = OutElementOp{ActivationOp{}};
using InLayout = ck::tensor_layout::convolution::GNHWC; using InLayout = ck::tensor_layout::convolution::GNHWC;
using WeiLayout = ck::tensor_layout::convolution::GKYXC; using WeiLayout = ck::tensor_layout::convolution::GKYXC;
......
...@@ -155,7 +155,7 @@ bool run_grouped_conv_fwd(bool do_verification, ...@@ -155,7 +155,7 @@ bool run_grouped_conv_fwd(bool do_verification,
return (pass ? 0 : 1); return (pass ? 0 : 1);
} }
int run_conv2d_fwd_bias_relu_perlayer_quantization_example() int run_conv2d_fwd_bias_perlayer_quantization_example(const OutElementOp& out_element_op)
{ {
bool do_verification = true; bool do_verification = true;
bool time_kernel = true; bool time_kernel = true;
...@@ -177,7 +177,6 @@ int run_conv2d_fwd_bias_relu_perlayer_quantization_example() ...@@ -177,7 +177,6 @@ int run_conv2d_fwd_bias_relu_perlayer_quantization_example()
const auto in_element_op = InElementOp{}; const auto in_element_op = InElementOp{};
const auto wei_element_op = WeiElementOp{}; const auto wei_element_op = WeiElementOp{};
const auto out_element_op = OutElementOp{0.5f, ActivationOp{}};
using InLayout = ck::tensor_layout::convolution::GNHWC; using InLayout = ck::tensor_layout::convolution::GNHWC;
using WeiLayout = ck::tensor_layout::convolution::GKYXC; using WeiLayout = ck::tensor_layout::convolution::GKYXC;
......
...@@ -157,7 +157,7 @@ bool run_grouped_conv_fwd(bool do_verification, ...@@ -157,7 +157,7 @@ bool run_grouped_conv_fwd(bool do_verification,
return (pass ? 0 : 1); return (pass ? 0 : 1);
} }
int run_conv2d_fwd_perchannel_quantization_example() int run_conv2d_fwd_perchannel_quantization_example(const OutElementOp& out_element_op)
{ {
bool do_verification = true; bool do_verification = true;
bool time_kernel = true; bool time_kernel = true;
...@@ -179,7 +179,6 @@ int run_conv2d_fwd_perchannel_quantization_example() ...@@ -179,7 +179,6 @@ int run_conv2d_fwd_perchannel_quantization_example()
const auto in_element_op = InElementOp{}; const auto in_element_op = InElementOp{};
const auto wei_element_op = WeiElementOp{}; const auto wei_element_op = WeiElementOp{};
const auto out_element_op = OutElementOp{ActivationOp{}};
using InLayout = ck::tensor_layout::convolution::GNHWC; using InLayout = ck::tensor_layout::convolution::GNHWC;
using WeiLayout = ck::tensor_layout::convolution::GKYXC; using WeiLayout = ck::tensor_layout::convolution::GKYXC;
......
...@@ -139,7 +139,7 @@ bool run_grouped_conv_fwd(bool do_verification, ...@@ -139,7 +139,7 @@ bool run_grouped_conv_fwd(bool do_verification,
return (pass ? 0 : 1); return (pass ? 0 : 1);
} }
int run_conv2d_fwd_perlayer_quantization_example() int run_conv2d_fwd_perlayer_quantization_example(const OutElementOp& out_element_op)
{ {
bool do_verification = true; bool do_verification = true;
bool time_kernel = false; bool time_kernel = false;
...@@ -161,7 +161,6 @@ int run_conv2d_fwd_perlayer_quantization_example() ...@@ -161,7 +161,6 @@ int run_conv2d_fwd_perlayer_quantization_example()
const auto in_element_op = InElementOp{}; const auto in_element_op = InElementOp{};
const auto wei_element_op = WeiElementOp{}; const auto wei_element_op = WeiElementOp{};
const auto out_element_op = OutElementOp{0.5f, ActivationOp{}};
using InLayout = ck::tensor_layout::convolution::GNHWC; using InLayout = ck::tensor_layout::convolution::GNHWC;
using WeiLayout = ck::tensor_layout::convolution::GKYXC; using WeiLayout = ck::tensor_layout::convolution::GKYXC;
......
...@@ -36,7 +36,7 @@ ...@@ -36,7 +36,7 @@
#elif defined(__gfx1030__) // for GPU code #elif defined(__gfx1030__) // for GPU code
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000 #define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000
#elif defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) // for GPU code #elif defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) // for GPU code
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x10020000 #define CK_BUFFER_RESOURCE_3RD_DWORD 0x31004000
#endif #endif
// FMA instruction // FMA instruction
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
...@@ -73,157 +73,18 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout, ...@@ -73,157 +73,18 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
static constexpr auto K1Number = Number<K1>{};
static auto
MakeAGridDescriptor_KBatch_K0_M_K1(index_t M, index_t K, index_t StrideA, int KBatch, int KPad)
{
assert(KPad % (K1 * KBatch) == 0);
const index_t K0 = KPad / (K1 * KBatch);
const auto a_grid_desc_m_k = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
}
}();
const auto a_grid_desc_m_kpad = transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
return transform_tensor_descriptor(
a_grid_desc_m_kpad,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1Number)),
make_right_pad_transform(M, PadM)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
}
else
{
return transform_tensor_descriptor(
a_grid_desc_m_kpad,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1Number)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
}
}
static auto
MakeBGridDescriptor_KBatch_K0_N_K1(index_t K, index_t N, index_t StrideB, int KBatch, int KPad)
{
assert(KPad % (K1 * KBatch) == 0);
const index_t K0 = KPad / (K1 * KBatch);
const auto b_grid_desc_k_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
}
}();
const auto b_grid_desc_kpad_n = transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor(
b_grid_desc_kpad_n,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1Number)),
make_right_pad_transform(N, PadN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
}
else
{
return transform_tensor_descriptor(
b_grid_desc_kpad_n,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1Number)),
make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
}
}
static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC)
{
const auto c_grid_desc_m_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
}
}();
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}
static auto GetKPad(index_t K, index_t KBatch)
{
const index_t K0 = math::integer_divide_ceil(K, K1 * K0PerBlock * KBatch) * K0PerBlock;
const index_t KPad = KBatch * K0 * K1;
return KPad;
}
using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_KBatch_K0_M_K1(1, 1, 1, 1, 1));
using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_KBatch_K0_N_K1(1, 1, 1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
// GridwiseGemm
using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2<
BlockSize, BlockSize,
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
AccDataType, AccDataType,
CDataType, CDataType,
InMemoryDataOperationEnum::Set, ALayout,
AGridDesc_K0_M_K1, BLayout,
BGridDesc_K0_N_K1, CLayout,
CGridDesc_M_N,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
GemmSpec,
MPerBlock, MPerBlock,
NPerBlock, NPerBlock,
K0PerBlock, K0PerBlock,
...@@ -253,236 +114,64 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout, ...@@ -253,236 +114,64 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
CBlockTransferScalarPerVector_NWaveNPerXDL, CBlockTransferScalarPerVector_NWaveNPerXDL,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>; CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>;
// GridwiseGemm using Argument = typename GridwiseGemm::Argument;
using GridwiseGemmAtomicAdd = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2<
BlockSize,
ADataType, // TODO: distinguish A/B datatype
AccDataType,
CDataType,
InMemoryDataOperationEnum::AtomicAdd,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
MPerBlock,
NPerBlock,
K0PerBlock,
MPerXDL,
NPerXDL,
K1,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
false, // AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsAddExtraM,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
false, // BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
CBlockTransferScalarPerVector_NWaveNPerXDL,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>;
using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
decltype(GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}));
using Block2CTileMap = typename GridwiseGemm::CBlockClusterAdaptor;
// Argument
struct Argument : public BaseArgument
{
Argument(const ADataType* p_a_grid,
const BDataType* p_b_grid,
CDataType* p_c_grid,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
index_t M01,
index_t N01,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
index_t k_batch)
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid},
a_grid_desc_kbatch_k0_m_k1_{},
b_grid_desc_kbatch_k0_n_k1_{},
c_grid_desc_m_n_{},
c_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_ctile_map_{},
M01_{M01},
N01_{N01},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op},
k_batch_{k_batch}
{
int KPad = DeviceGemmXdlSplitKCShuffle::GetKPad(K, k_batch_);
a_grid_desc_kbatch_k0_m_k1_ =
DeviceGemmXdlSplitKCShuffle::MakeAGridDescriptor_KBatch_K0_M_K1(
M, K, StrideA, k_batch_, KPad);
b_grid_desc_kbatch_k0_n_k1_ =
DeviceGemmXdlSplitKCShuffle::MakeBGridDescriptor_KBatch_K0_N_K1(
K, N, StrideB, k_batch_, KPad);
c_grid_desc_m_n_ = DeviceGemmXdlSplitKCShuffle::MakeCGridDescriptor_M_N(M, N, StrideC);
block_2_ctile_map_ =
GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_);
if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_,
b_grid_desc_kbatch_k0_n_k1_,
c_grid_desc_m_n_,
block_2_ctile_map_))
{
c_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n_);
}
}
// private:
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
CDataType* p_c_grid_;
AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_;
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_;
Block2CTileMap block_2_ctile_map_;
index_t M01_;
index_t N01_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
index_t k_batch_;
};
// Invoker // Invoker
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
{ {
using Argument = DeviceGemmXdlSplitKCShuffle::Argument;
void Print(const Argument& arg) void Print(const Argument& karg) { karg.Print(); }
{
std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{"
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", "
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", "
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << ", "
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I3) << "}" << std::endl;
std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{"
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", "
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", "
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << ", "
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}" << std::endl;
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{})
{ {
if(stream_config.log_level_ > 0) if(stream_config.log_level_ > 0)
{ {
Print(arg); Print(karg);
} }
const auto kbatch = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0); const auto kbatch = karg.k_batch;
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, if(!GridwiseGemm::CheckValidity(karg))
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_))
{ {
throw std::runtime_error( throw std::runtime_error(
"wrong! GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 has invalid setting"); "wrong! GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 has invalid "
"setting");
} }
const index_t grid_size = index_t gdx, gdy, gdz;
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(karg);
const auto K0 = karg.K0;
const auto K0 = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1);
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
float ave_time = 0; float ave_time = 0;
const auto Run = [&](const auto& kernel) { const auto Run = [&](const auto& kernel) {
hipGetErrorString(hipMemset( if(kbatch > 1)
arg.p_c_grid_, hipGetErrorString(
0, hipMemset(karg.p_c_grid, 0, karg.M * karg.N * sizeof(CDataType)));
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() *
sizeof(CDataType))); ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
ave_time =
launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.block_2_ctile_map_);
}; };
if(has_main_k0_block_loop) if(has_main_k0_block_loop)
{ {
if(kbatch == 1) if(kbatch == 1)
{ {
const auto kernel = kernel_gemm_xdlops_v2r4r2< const auto kernel =
GridwiseGemm, kernel_gemm_xdlops_v2r4r2_simplified<GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype true,
CDataType, InMemoryDataOperationEnum::Set>;
remove_reference_t<DeviceGemmXdlSplitKCShuffle::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdlSplitKCShuffle::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmXdlSplitKCShuffle::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
remove_reference_t<DeviceGemmXdlSplitKCShuffle::Block2CTileMap>,
true>;
Run(kernel); Run(kernel);
} }
else else
{ {
const auto kernel = kernel_gemm_xdlops_v2r4r2< const auto kernel =
GridwiseGemmAtomicAdd, kernel_gemm_xdlops_v2r4r2_simplified<GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype true,
CDataType, InMemoryDataOperationEnum::AtomicAdd>;
remove_reference_t<DeviceGemmXdlSplitKCShuffle::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdlSplitKCShuffle::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmXdlSplitKCShuffle::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
remove_reference_t<DeviceGemmXdlSplitKCShuffle::Block2CTileMap>,
true>;
Run(kernel); Run(kernel);
} }
...@@ -491,37 +180,19 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout, ...@@ -491,37 +180,19 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
{ {
if(kbatch == 1) if(kbatch == 1)
{ {
const auto kernel = kernel_gemm_xdlops_v2r4r2< const auto kernel =
GridwiseGemm, kernel_gemm_xdlops_v2r4r2_simplified<GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype false,
CDataType, InMemoryDataOperationEnum::Set>;
remove_reference_t<DeviceGemmXdlSplitKCShuffle::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdlSplitKCShuffle::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmXdlSplitKCShuffle::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
remove_reference_t<DeviceGemmXdlSplitKCShuffle::Block2CTileMap>,
false>;
Run(kernel); Run(kernel);
} }
else else
{ {
const auto kernel = kernel_gemm_xdlops_v2r4r2< const auto kernel =
GridwiseGemmAtomicAdd, kernel_gemm_xdlops_v2r4r2_simplified<GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype false,
CDataType, InMemoryDataOperationEnum::AtomicAdd>;
remove_reference_t<DeviceGemmXdlSplitKCShuffle::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdlSplitKCShuffle::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmXdlSplitKCShuffle::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
remove_reference_t<DeviceGemmXdlSplitKCShuffle::Block2CTileMap>,
false>;
Run(kernel); Run(kernel);
} }
...@@ -544,12 +215,9 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout, ...@@ -544,12 +215,9 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
return true; return true;
} }
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& karg)
{ {
return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, return GridwiseGemm::CheckValidity(karg);
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_);
} }
// polymorphic // polymorphic
...@@ -567,9 +235,9 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout, ...@@ -567,9 +235,9 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
index_t StrideA, index_t StrideA,
index_t StrideB, index_t StrideB,
index_t StrideC, index_t StrideC,
AElementwiseOperation a_element_op, AElementwiseOperation,
BElementwiseOperation b_element_op, BElementwiseOperation,
CElementwiseOperation c_element_op, CElementwiseOperation,
index_t KBatch) index_t KBatch)
{ {
return Argument{p_a, return Argument{p_a,
...@@ -581,11 +249,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout, ...@@ -581,11 +249,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
StrideA, StrideA,
StrideB, StrideB,
StrideC, StrideC,
1, GridwiseGemm::CalculateMPadded(M),
1, GridwiseGemm::CalculateNPadded(N),
a_element_op, GridwiseGemm::CalculateKPadded(K),
b_element_op, GridwiseGemm::CalculateK0(K, KBatch),
c_element_op,
KBatch}; KBatch};
} }
...@@ -601,9 +268,9 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout, ...@@ -601,9 +268,9 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
index_t StrideA, index_t StrideA,
index_t StrideB, index_t StrideB,
index_t StrideC, index_t StrideC,
AElementwiseOperation a_element_op, AElementwiseOperation,
BElementwiseOperation b_element_op, BElementwiseOperation,
CElementwiseOperation c_element_op, CElementwiseOperation,
ck::index_t KBatch = 1) override ck::index_t KBatch = 1) override
{ {
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
...@@ -615,11 +282,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout, ...@@ -615,11 +282,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
StrideA, StrideA,
StrideB, StrideB,
StrideC, StrideC,
1, GridwiseGemm::CalculateMPadded(M),
1, GridwiseGemm::CalculateNPadded(N),
a_element_op, GridwiseGemm::CalculateKPadded(K),
b_element_op, GridwiseGemm::CalculateK0(K, KBatch),
c_element_op,
KBatch); KBatch);
} }
...@@ -630,31 +296,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout, ...@@ -630,31 +296,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
} }
// polymorphic // polymorphic
std::string GetTypeString() const override std::string GetTypeString() const override { return GridwiseGemm::GetTypeString(); }
{
auto str = std::stringstream();
// clang-format off
str << "DeviceGemmXdlSplitKCShuffle"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< K0PerBlock << ", "
<< K1 << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< ABlockTransferDstScalarPerVector_K1 << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< BBlockTransferDstScalarPerVector_K1
<< ">";
// clang-format on
return str.str();
}
}; };
} // namespace device } // namespace device
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
...@@ -7,10 +7,30 @@ namespace ck { ...@@ -7,10 +7,30 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace element_wise { namespace element_wise {
// Y = Sy * Qy
// W = Sw * Qw
// X = Sx * Qx
// B = Sb * Qb = Sw * Sx * Qb
// Where X, W, Y are float32, Qx, Qw, Qy are int8
// Sx, Sw, Sy are scale of x, w, y (float32), which is calculated from quantization range
// Qb is int32, scale of B is Sw * Sx for convenient
// Y = W @ X, where @ is convolution or matrix multiplication
// Sy * Qy = Sw * Qw @ Sx * Qx
// Qy = [(Sw*Sx)/Sy] * Qw @ Qx
// For Activation function which is piecewise linear function, such as relu, leaky relu ...etc // For Activation function which is piecewise linear function, such as relu, leaky relu ...etc
// Activation(Sy * Qy) = Sy * Activation(Qy)
template <typename Activation> template <typename Activation>
struct Activation_Mul_Clamp struct Activation_Mul_Clamp
{ {
// Convolution + Activation (piecewise linear function)
// If an activation is piecewise linear function, then Activation(Sy * Qy) = Sy * Activation(Qy)
// Z = Activation(Y) = Activation(W @ X)
// Sz * Qz = Activation(Sy * Qy)
// Qz = Sy / Sz * Activation(Qy) = (Sw * Sx / Sz) * Activation(Qw @ Qx)
// requantScale_ = Sw * Sx / Sz
Activation_Mul_Clamp(float requantScale, Activation activationOp) Activation_Mul_Clamp(float requantScale, Activation activationOp)
: requantScale_(requantScale), activationOp_(activationOp) : requantScale_(requantScale), activationOp_(activationOp)
{ {
...@@ -45,8 +65,39 @@ struct Activation_Mul_Clamp ...@@ -45,8 +65,39 @@ struct Activation_Mul_Clamp
Activation activationOp_; Activation activationOp_;
}; };
// For Activation function which is non piecewise linear function, such as TanH, Sigmoid ...etc
// If an activation is not piecewise linear function
// then Activation(Sy * Qy) != Sy * Activation(Qy)
template <typename Activation>
struct Mul_Activation_Mul_Clamp
{
// Convolution + Activation (non piecewise linear function)
// Z = Activation(Y) = Activation(W @ X)
// Sz * Qz = Activation(Sy * Qy)
// Qz = S1 * Activation[Sacc * (Qw @ Qx)]
// Where S1 = 1 / Sz, Sacc = Sw * Sx
Mul_Activation_Mul_Clamp(float scale_z_inv, float scaleAcc, Activation activationOp)
: scale_z_inv_(scale_z_inv), scaleAcc_(scaleAcc), activationOp_(activationOp)
{
}
__host__ __device__ constexpr void operator()(int8_t& y, const int32_t& x) const
{
float y_fp32 = ck::type_convert<float>(x);
y_fp32 = scaleAcc_ * y_fp32;
activationOp_(y_fp32, y_fp32);
y_fp32 = math::clamp(scale_z_inv_ * y_fp32, -128.f, 127.f);
y = ck::type_convert<int8_t>(y_fp32);
}
float scale_z_inv_;
float scaleAcc_;
Activation activationOp_;
};
// Conv Perchannel quantization + Activation function which is piecewise linear function, such as // Conv Perchannel quantization + Activation function which is piecewise linear function, such as
// relu, leaky relu ...etc // relu, leaky relu ...etc
// Activation(Sy * Qy) = Sy * Activation(Qy)
template <typename Activation> template <typename Activation>
struct Activation_Mul2_Clamp struct Activation_Mul2_Clamp
{ {
...@@ -76,9 +127,20 @@ struct Activation_Mul2_Clamp ...@@ -76,9 +127,20 @@ struct Activation_Mul2_Clamp
}; };
// For Activation function which is piecewise linear function, such as relu, leaky relu ...etc // For Activation function which is piecewise linear function, such as relu, leaky relu ...etc
// Activation(Sy * Qy) = Sy * Activation(Qy)
template <typename Activation> template <typename Activation>
struct Add_Activation_Mul_Clamp struct Add_Activation_Mul_Clamp
{ {
// Convolution + bias
// Let Bias = B = Sw * Sx * Qb
// Where Qb is int32
// Y = W @ X + B
// Sy * Qy = Sw * Qw @ Sx * Qx + Sw * Sx * Qb
// Qy = [(Sw*Sx)/Sy] * (Qw @ Qx + Qb)
// For activation, Z = Activaiton(Y)
// Sz * Qz = Activation(Sy * Qy)
// Qz = Sy / Sz * Activation(Qy) = [(Sw*Sx)/Sz] * Activation(Qw @ Qx + Qb)
Add_Activation_Mul_Clamp(float requantScale, Activation activationOp) Add_Activation_Mul_Clamp(float requantScale, Activation activationOp)
: requantScale_(requantScale), activationOp_(activationOp) : requantScale_(requantScale), activationOp_(activationOp)
{ {
...@@ -139,11 +201,18 @@ struct Add_Activation_Mul2_Clamp ...@@ -139,11 +201,18 @@ struct Add_Activation_Mul2_Clamp
}; };
// For Activation function which is non piecewise linear function, such as TanH, Sigmoid ...etc // For Activation function which is non piecewise linear function, such as TanH, Sigmoid ...etc
// If an activation is not piecewise linear function
// then Activation(Sy * Qy) != Sy * Activation(Qy)
template <typename Activation> template <typename Activation>
struct Add_Mul_Activation_Mul_Clamp struct Add_Mul_Activation_Mul_Clamp
{ {
Add_Mul_Activation_Mul_Clamp(float requantScale1, float requantScale2, Activation activationOp) // Convolution + Activation (non piecewise linear function)
: requantScale1_(requantScale1), requantScale2_(requantScale2), activationOp_(activationOp) // Z = Activation(Y) = Activation(W @ X + B)
// Sz * Qz = Activation(Sy * Qy)
// Qz = S1 * Activation[Sacc * (Qw @ Qx + Qb)]
// Where S1 = 1 / Sz, Sacc = Sw * Sx
Add_Mul_Activation_Mul_Clamp(float scale_z_inv, float scaleAcc, Activation activationOp)
: scale_z_inv_(scale_z_inv), scaleAcc_(scaleAcc), activationOp_(activationOp)
{ {
} }
...@@ -151,14 +220,64 @@ struct Add_Mul_Activation_Mul_Clamp ...@@ -151,14 +220,64 @@ struct Add_Mul_Activation_Mul_Clamp
operator()(int8_t& y, const int32_t& x, const int32_t& bias) const operator()(int8_t& y, const int32_t& x, const int32_t& bias) const
{ {
float y_fp32 = ck::type_convert<float>(x + bias); float y_fp32 = ck::type_convert<float>(x + bias);
y_fp32 = requantScale1_ * y_fp32; y_fp32 = scaleAcc_ * y_fp32;
activationOp_(y_fp32, y_fp32);
y_fp32 = math::clamp(scale_z_inv_ * y_fp32, -128.f, 127.f);
y = ck::type_convert<int8_t>(y_fp32);
}
__host__ __device__ constexpr void
operator()(int32_t& y, const int32_t& x, const int32_t& bias) const
{
// CAUSION - We might type_convert to int8 in threadwise copy
// eg. GridwiseGemmDlMultipleD_km_kn_mn
float y_fp32 = ck::type_convert<float>(x + bias);
y_fp32 = scaleAcc_ * y_fp32;
activationOp_(y_fp32, y_fp32); activationOp_(y_fp32, y_fp32);
y_fp32 = math::clamp(requantScale2_ * y_fp32, -128.f, 127.f); y_fp32 = math::clamp(scale_z_inv_ * y_fp32, -128.f, 127.f);
y = ck::type_convert<int32_t>(y_fp32);
}
float scale_z_inv_;
float scaleAcc_;
Activation activationOp_;
};
// Conv Perchannel quantization + Activation function which is non piecewise linear function,
// such as TanH, Sigmoid ...etc
// If an activation is not piecewise linear function
// then Activation(Sy *Qy) != Sy * Activation(Qy)
template <typename Activation>
struct Add_Mul2_Activation_Mul_Clamp
{
Add_Mul2_Activation_Mul_Clamp(float scale_z_inv, Activation activationOp)
: scale_z_inv_(scale_z_inv), activationOp_(activationOp)
{
}
__host__ __device__ constexpr void
operator()(int8_t& y, const int32_t& x, const int32_t& bias, const float& scaleAcc) const
{
float y_fp32 = ck::type_convert<float>(x + bias);
y_fp32 = scaleAcc * y_fp32;
activationOp_(y_fp32, y_fp32);
y_fp32 = math::clamp(scale_z_inv_ * y_fp32, -128.f, 127.f);
y = ck::type_convert<int8_t>(y_fp32); y = ck::type_convert<int8_t>(y_fp32);
} }
float requantScale1_; __host__ __device__ constexpr void
float requantScale2_; operator()(int32_t& y, const int32_t& x, const int32_t& bias, const float& scaleAcc) const
{
// CAUSION - We might type_convert to int8 in threadwise copy
// eg. GridwiseGemmDlMultipleD_km_kn_mn
float y_fp32 = ck::type_convert<float>(x + bias);
y_fp32 = scaleAcc * y_fp32;
activationOp_(y_fp32, y_fp32);
y_fp32 = math::clamp(scale_z_inv_ * y_fp32, -128.f, 127.f);
y = ck::type_convert<int32_t>(y_fp32);
}
float scale_z_inv_;
Activation activationOp_; Activation activationOp_;
}; };
......
...@@ -320,6 +320,19 @@ struct Sigmoid ...@@ -320,6 +320,19 @@ struct Sigmoid
int32_t divider_ = 1; int32_t divider_ = 1;
}; };
struct TanH
{
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value,
"Data type is not supported by this operation!");
y = ck::math::tanh(x);
};
};
} // namespace element_wise } // namespace element_wise
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -431,6 +431,9 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle ...@@ -431,6 +431,9 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
constexpr auto b_block_desc_k0perblock_nperblock_k1 = constexpr auto b_block_desc_k0perblock_nperblock_k1 =
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
constexpr auto cshuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat();
constexpr auto max_lds_align = K1; constexpr auto max_lds_align = K1;
constexpr auto a_block_space_size_aligned = math::integer_least_multiple( constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
...@@ -439,8 +442,13 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle ...@@ -439,8 +442,13 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
constexpr auto b_block_space_size_aligned = math::integer_least_multiple( constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
b_block_desc_k0perblock_nperblock_k1.GetElementSpaceSize(), max_lds_align); b_block_desc_k0perblock_nperblock_k1.GetElementSpaceSize(), max_lds_align);
return (a_block_space_size_aligned * sizeof(ADataType) + constexpr auto c_block_space_size_aligned = math::integer_least_multiple(
b_block_space_size_aligned * sizeof(BDataType)); cshuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetElementSpaceSize(),
max_lds_align);
return math::max((a_block_space_size_aligned * sizeof(ADataType) +
b_block_space_size_aligned * sizeof(BDataType)),
c_block_space_size_aligned * sizeof(CShuffleDataType));
} }
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -92,6 +92,17 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -92,6 +92,17 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
using GridwiseGemmPipe = remove_cvref_t<decltype( using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>; GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
// denorm test fix, required to work around fp16 mfma issue
// we convert fp16->fp32->bf16 and execute bf16 mfma instruction
// when mfma if fixed, remove this section and update
// ABDataTypeAdjusted -> ABDataType throughout this file
#if defined(__gfx90a__)
using ABDataTypeAdjusted =
conditional_t<is_same_v<ABDataType, ck::half_t>, ck::bhalf_t, ABDataType>;
#else
using ABDataTypeAdjusted = ABDataType;
#endif
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{ {
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
...@@ -397,7 +408,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -397,7 +408,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
ABDataType, ABDataType,
ABDataType, ABDataTypeAdjusted,
decltype(a_grid_desc_ak0_m_ak1), decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
...@@ -428,7 +439,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -428,7 +439,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
ABDataType, ABDataType,
ABDataType, ABDataTypeAdjusted,
decltype(b_grid_desc_bk0_n_bk1), decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
...@@ -458,11 +469,11 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -458,11 +469,11 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
// sanity check // sanity check
constexpr index_t KPack = constexpr index_t KPack =
math::max(math::lcm(AK1, BK1), math::max(math::lcm(AK1, BK1),
MfmaSelector<ABDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); MfmaSelector<ABDataTypeAdjusted, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize, BlockSize,
ABDataType, ABDataTypeAdjusted,
AccDataType, AccDataType,
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
...@@ -480,10 +491,11 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -480,10 +491,11 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ABDataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); static_cast<ABDataTypeAdjusted*>(p_shared),
a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ABDataType*>(p_shared) + a_block_space_size_aligned, static_cast<ABDataTypeAdjusted*>(p_shared) + a_block_space_size_aligned,
b_block_desc_bk0_n_bk1.GetElementSpaceSize()); b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -166,15 +166,12 @@ __global__ void ...@@ -166,15 +166,12 @@ __global__ void
const CBlockClusterAdaptor c_block_cluster_adaptor) const CBlockClusterAdaptor c_block_cluster_adaptor)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
constexpr index_t shared_block_size = __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
p_shared_block, p_shared,
a_b_k0_m_k1_grid_desc, a_b_k0_m_k1_grid_desc,
b_b_k0_n_k1_grid_desc, b_b_k0_n_k1_grid_desc,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
...@@ -183,16 +180,16 @@ __global__ void ...@@ -183,16 +180,16 @@ __global__ void
c_element_op, c_element_op,
c_block_cluster_adaptor); c_block_cluster_adaptor);
#else #else
ignore = p_a_grid; ignore = p_a_grid;
ignore = p_b_grid; ignore = p_b_grid;
ignore = p_c_grid; ignore = p_c_grid;
ignore = a_b_k0_m_k1_grid_desc; ignore = a_b_k0_m_k1_grid_desc;
ignore = b_b_k0_n_k1_grid_desc; ignore = b_b_k0_n_k1_grid_desc;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = a_element_op; ignore = a_element_op;
ignore = b_element_op; ignore = b_element_op;
ignore = c_element_op; ignore = c_element_op;
ignore = c_block_cluster_adaptor; ignore = c_block_cluster_adaptor;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
} }
...@@ -264,6 +261,16 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight ...@@ -264,6 +261,16 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
using GridwiseGemmPipe = remove_cvref_t<decltype( using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>; GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
// denorm test fix, required to work around fp16 mfma issue
// we convert fp16->fp32->bf16 and execute bf16 mfma instruction
// when mfma if fixed, remove this section and update
// FloatABAdjusted -> FloatAB throughout this file
#if defined(__gfx90a__)
using FloatABAdjusted = conditional_t<is_same_v<FloatAB, ck::half_t>, ck::bhalf_t, FloatAB>;
#else
using FloatABAdjusted = FloatAB;
#endif
// M0/M1/M1Padding // M0/M1/M1Padding
static constexpr auto M1PerBlock = Number<ABlockLdsM1PerBlock>{}; static constexpr auto M1PerBlock = Number<ABlockLdsM1PerBlock>{};
static constexpr auto M0PerBlock = Number<ABlockLdsM0PerBlock>{}; static constexpr auto M0PerBlock = Number<ABlockLdsM0PerBlock>{};
...@@ -605,7 +612,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight ...@@ -605,7 +612,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
__device__ static void Run(const FloatAB* __restrict__ p_a_grid, __device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
FloatAB* __restrict__ p_shared_block, void* __restrict__ p_shared,
const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc, const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc,
const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc, const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
...@@ -666,7 +673,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight ...@@ -666,7 +673,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
FloatAB, FloatABAdjusted,
decltype(a_b_k0_m_k1_grid_desc), decltype(a_b_k0_m_k1_grid_desc),
decltype(a_b_k0_m_k1_block_desc), decltype(a_b_k0_m_k1_block_desc),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
...@@ -696,7 +703,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight ...@@ -696,7 +703,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
FloatAB, FloatABAdjusted,
decltype(b_b_k0_n_k1_grid_desc), decltype(b_b_k0_n_k1_grid_desc),
decltype(b_b_k0_n_k1_block_desc), decltype(b_b_k0_n_k1_block_desc),
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
...@@ -725,11 +732,11 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight ...@@ -725,11 +732,11 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
// sanity check // sanity check
constexpr index_t KPack = constexpr index_t KPack =
math::max(K1, MfmaSelector<FloatAB, MPerXDL, NPerXDL>::selected_mfma.k_per_blk); math::max(K1, MfmaSelector<FloatABAdjusted, MPerXDL, NPerXDL>::selected_mfma.k_per_blk);
auto blockwise_gemm = auto blockwise_gemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize, BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB, FloatABAdjusted,
FloatAcc, FloatAcc,
decltype(a_k0_m_k1_block_desc), decltype(a_k0_m_k1_block_desc),
decltype(b_k0_n_k1_block_desc), decltype(b_k0_n_k1_block_desc),
...@@ -745,16 +752,15 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight ...@@ -745,16 +752,15 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
constexpr auto a_block_space_size = constexpr auto a_block_space_size =
math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
FloatAB* p_a_block = p_shared_block;
FloatAB* p_b_block = p_shared_block + a_block_space_size;
constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize()); static_cast<FloatABAdjusted*>(p_shared), a_k0_m_k1_block_desc.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize()); static_cast<FloatABAdjusted*>(p_shared) + a_block_space_size,
b_k0_n_k1_block_desc.GetElementSpaceSize());
// gridwise GEMM pipeline // gridwise GEMM pipeline
const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock); const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock);
...@@ -798,8 +804,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight ...@@ -798,8 +804,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
constexpr auto c_block_desc_mblock_mperblock_nblock_nperblock = constexpr auto c_block_desc_mblock_mperblock_nblock_nperblock =
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
void* p_shared = static_cast<void*>(p_shared_block);
auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatC*>(p_shared), static_cast<FloatC*>(p_shared),
c_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); c_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment