Commit 6ef4e211 authored by Chao Liu's avatar Chao Liu
Browse files

Merge remote-tracking branch 'origin/develop' into contraction

parents b0a2afb9 9e4429f9
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_bias_activation_add.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp"
#include "ck/device_utility/device_prop.hpp"
#include "ck/device_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
// C[M, N] = activate(A[M, K] * B[K, N] + C0[N]) + C1[M, N]
template <
typename ADataType,
typename BDataType,
typename CDataType,
typename AccDataType,
typename ALayout,
typename BLayout,
typename CLayout,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
ck::index_t K0PerBlock,
ck::index_t K1,
ck::index_t MPerXDL,
ck::index_t NPerXDL,
ck::index_t MXdlPerWave,
ck::index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_K0_M_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
ck::index_t ABlockTransferSrcVectorDim,
ck::index_t ABlockTransferSrcScalarPerVector,
ck::index_t ABlockTransferDstScalarPerVector_K1,
bool ABlockLdsAddExtraM,
typename BBlockTransferThreadClusterLengths_K0_N_K1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
ck::index_t BBlockTransferSrcVectorDim,
ck::index_t BBlockTransferSrcScalarPerVector,
ck::index_t BBlockTransferDstScalarPerVector_K1,
bool BBlockLdsAddExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
index_t CBlockTransferScalarPerVector_NWaveNPerXdl>
struct DeviceGemmXdl_C_Shuffle_Bias_Activation_Add
: public DeviceGemmBiasActivationAdd<AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
{
using DeviceOp = DeviceGemmXdl_C_Shuffle_Bias_Activation_Add;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto K1Number = Number<K1>{};
static auto MakeGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N_C0_M_N_C1_M_N(index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
index_t StrideC1)
{
assert(K % K1 == 0);
const index_t K0 = K / K1;
// A[K0, M, K1]
const auto a_grid_desc_m_k = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
}
}();
const auto a_grid_desc_k0_m_k1 =
transform_tensor_descriptor(a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// B[K0, N, K1]
const auto b_grid_desc_k_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
}
}();
const auto b_grid_desc_k0_n_k1 =
transform_tensor_descriptor(b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// C[M, N]
const auto c_grid_desc_m_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
}
}();
// C0[N]: assume a contiguous vector
const auto c0_grid_desc_m_n =
make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I0, I1));
// C1[M, N]: residual tensor: assume same layout as C
const auto c1_grid_desc_m_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC1, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC1));
}
}();
return make_tuple(a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1,
c_grid_desc_m_n,
c0_grid_desc_m_n,
c1_grid_desc_m_n);
}
using GridDescs =
decltype(MakeGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N_C0_M_N_C1_M_N(1, 1, 1, 1, 1, 1, 1));
using AGridDesc_K0_M_K1 = remove_cvref_t<decltype(GridDescs{}[I0])>;
using BGridDesc_K0_N_K1 = remove_cvref_t<decltype(GridDescs{}[I1])>;
using CGridDesc_M_N = remove_cvref_t<decltype(GridDescs{}[I2])>;
using C0GridDesc_M_N = remove_cvref_t<decltype(GridDescs{}[I3])>;
using C1GridDesc_M_N = remove_cvref_t<decltype(GridDescs{}[I4])>;
// GridwiseGemm
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3<
BlockSize,
ADataType, // TODO: distinguish A/B datatype
AccDataType,
CDataType,
InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
C0GridDesc_M_N,
C1GridDesc_M_N,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
MPerBlock,
NPerBlock,
K0PerBlock,
MPerXDL,
NPerXDL,
K1,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
false, // AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsAddExtraM,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
false, // BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
CBlockTransferScalarPerVector_NWaveNPerXdl>;
// Argument
struct Argument : public BaseArgument
{
Argument(const ADataType* p_a_grid,
const BDataType* p_b_grid,
CDataType* p_c_grid,
const CDataType* p_c0_grid,
const CDataType* p_c1_grid,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
index_t StrideC1,
index_t M01,
index_t N01,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid},
p_c0_grid_{p_c0_grid},
p_c1_grid_{p_c1_grid},
a_grid_desc_k0_m_k1_{},
b_grid_desc_k0_n_k1_{},
c_grid_desc_m_n_{},
c0_grid_desc_m_n_{},
c1_grid_desc_m_n_{},
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{},
c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{},
c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{},
block_2_ctile_map_{},
M01_{M01},
N01_{N01},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op}
{
const auto descs = DeviceOp::MakeGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N_C0_M_N_C1_M_N(
M, N, K, StrideA, StrideB, StrideC, StrideC1);
a_grid_desc_k0_m_k1_ = descs[I0];
b_grid_desc_k0_n_k1_ = descs[I1];
c_grid_desc_m_n_ = descs[I2];
c0_grid_desc_m_n_ = descs[I3];
c1_grid_desc_m_n_ = descs[I4];
block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_,
b_grid_desc_k0_n_k1_,
c_grid_desc_m_n_,
block_2_ctile_map_))
{
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ =
GridwiseGemm::
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
c_grid_desc_m_n_);
c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ =
GridwiseGemm::
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
c0_grid_desc_m_n_);
c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ =
GridwiseGemm::
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
c1_grid_desc_m_n_);
}
}
// private:
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
CDataType* p_c_grid_;
const CDataType* p_c0_grid_;
const CDataType* p_c1_grid_;
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_;
C0GridDesc_M_N c0_grid_desc_m_n_;
C1GridDesc_M_N c1_grid_desc_m_n_;
typename GridwiseGemm::
CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_;
typename GridwiseGemm::
C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_;
typename GridwiseGemm::
C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
index_t M01_;
index_t N01_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
};
// Invoker
struct Invoker : public BaseInvoker
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
{
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0)
<< ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
<< arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
std::cout << "arg.c0_grid_desc_m_n_{ " << arg.c0_grid_desc_m_n_.GetLength(I0)
<< ", " << arg.c0_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
std::cout << "arg.c1_grid_desc_m_n_{ " << arg.c1_grid_desc_m_n_.GetLength(I0)
<< ", " << arg.c1_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_))
{
throw std::runtime_error(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r5 has invalid setting");
}
const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
const auto K =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
float ave_time = 0;
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
const auto kernel = kernel_gemm_xdlops_v3r3<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
typename GridwiseGemm::
CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>,
remove_reference_t<
typename GridwiseGemm::
C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>,
remove_reference_t<
typename GridwiseGemm::
C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
true>;
ave_time = launch_and_time_kernel(
stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.p_c0_grid_,
arg.p_c1_grid_,
arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_,
arg.c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_,
arg.c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.block_2_ctile_map_);
}
else
{
const auto kernel = kernel_gemm_xdlops_v3r3<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
typename GridwiseGemm::
CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>,
remove_reference_t<
typename GridwiseGemm::
C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>,
remove_reference_t<
typename GridwiseGemm::
C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
false>;
ave_time = launch_and_time_kernel(
stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.p_c0_grid_,
arg.p_c1_grid_,
arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_,
arg.c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_,
arg.c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.block_2_ctile_map_);
}
return ave_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
static bool IsSupportedArgument(const Argument& arg)
{
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_);
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b,
CDataType* p_c,
const CDataType* p_c0,
const CDataType* p_c1,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
index_t StrideC1,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
return Argument{p_a,
p_b,
p_c,
p_c0,
p_c1,
M,
N,
K,
StrideA,
StrideB,
StrideC,
StrideC1,
1,
1,
a_element_op,
b_element_op,
c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
const void* p_c0,
const void* p_c1,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
index_t StrideC1,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
index_t /* KBatch */ = 1) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c),
static_cast<const CDataType*>(p_c0),
static_cast<const CDataType*>(p_c1),
M,
N,
K,
StrideA,
StrideB,
StrideC,
StrideC1,
1,
1,
a_element_op,
b_element_op,
c_element_op);
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceGemmXdl_C_Shuffle_Bias_Activation_Add"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< K0PerBlock
<< ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -65,8 +65,15 @@ template <typename ALayout, ...@@ -65,8 +65,15 @@ template <typename ALayout,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()> LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceGemm_Xdl_CShuffle struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
: public DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation> BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
{ {
using DeviceOp = DeviceGemm_Xdl_CShuffle; using DeviceOp = DeviceGemm_Xdl_CShuffle;
...@@ -622,8 +629,7 @@ struct DeviceGemm_Xdl_CShuffle ...@@ -622,8 +629,7 @@ struct DeviceGemm_Xdl_CShuffle
index_t StrideC, index_t StrideC,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op) override
index_t /* KBatch */ = 1) override
{ {
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_b),
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp"
#include "ck/device_utility/device_prop.hpp"
#include "ck/device_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
// The GEMM + Layernorm implementation is a specialized kernel which allows fusing both layers
// together given the condition GEMM extents N of MNK is spanned by a single workgroup. For example,
// a kernel configured with NPerBlock = 128 allows to operate on all GEMM sizes if N <= 128
//
// Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle
// version currently has compiler issues with register spill which further causes validation
// failures.
//
// D = Layernorm(acc_element_op(A * B + broadcast(bias)) + add) * broadcast(gamma) + broadcast(beta)
template <typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename CDataType,
typename C0DataType,
typename GemmAccDataType,
typename CShuffleDataType,
typename ReduceAccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename AccElementwiseOperation,
typename CElementwiseOperation,
GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t AK1,
index_t BK1,
index_t MPerXDL,
index_t NPerXDL,
index_t MXdlPerWave,
index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
typename CReduceThreadClusterLengths_MPerBlock_NPerBlock,
index_t CReduceThreadCopySrcDstScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
{
using DeviceOp = DeviceGemmLayerNorm_Xdl_CShuffle;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
{
const auto a_grid_desc_mraw_kraw = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(StrideA, I1));
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(I1, StrideA));
}
}();
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
const auto MPad = M - MRaw;
const auto KPad = K - KRaw;
if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad both M and K
assert(K % AK1 == 0);
const auto AK0 = K / AK1;
const auto a_grid_desc_m_k =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_right_pad_transform(MRaw, MPad),
make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MNPadding)
{
// pad M, but not K
assert(KRaw % AK1 == 0);
const auto AK0 = KRaw / AK1;
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_right_pad_transform(MRaw, MPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
// pad K, but not M
assert(K % AK1 == 0);
const auto AK0 = K / AK1;
const auto a_grid_desc_m_k = transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(MRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else
{
// not pad M or K
assert(KRaw % AK1 == 0);
const auto AK0 = KRaw / AK1;
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(MRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
}
static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
{
const auto b_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(I1, StrideB));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(StrideB, I1));
}
}();
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
const auto NPad = N - NRaw;
const auto KPad = K - KRaw;
if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad both N and K
assert(K % BK1 == 0);
const auto BK0 = K / BK1;
const auto b_grid_desc_n_k =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_right_pad_transform(NRaw, NPad),
make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::MNPadding)
{
// pad N, but not K
assert(KRaw % BK1 == 0);
const auto BK0 = KRaw / BK1;
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
// pad K, but not N
assert(K % BK1 == 0);
const auto BK0 = K / BK1;
const auto b_grid_desc_n_k = transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(NRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else
{
// not pad N or K
assert(KRaw % BK1 == 0);
const auto BK0 = KRaw / BK1;
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(NRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
}
static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC)
{
const auto c_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(StrideC, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(I1, StrideC));
}
}();
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
const auto MPad = M - MRaw;
const auto NPad = N - NRaw;
if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad M and N
return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(MRaw, MPad),
make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
// pad M, but not N
return transform_tensor_descriptor(
c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
// pad N, but not M
return transform_tensor_descriptor(
c_grid_desc_mraw_nraw,
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
// not pad M or N
return c_grid_desc_mraw_nraw;
}
}
static auto MakeGridDescriptor_N(index_t NRaw)
{
const auto grid_desc_nraw = make_naive_tensor_descriptor_packed(make_tuple(NRaw));
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
const auto NPad = N - NRaw;
if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::MNPadding ||
GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad N
return transform_tensor_descriptor(grid_desc_nraw,
make_tuple(make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
}
else
{
// not pad N
return grid_desc_nraw;
}
}
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
using C0GridDesc_N = decltype(MakeGridDescriptor_N(1));
// GridwiseGemm
using GridwiseGemm = GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1<
ADataType, // TODO: distinguish A/B datatype
GemmAccDataType,
CShuffleDataType,
CDataType,
C0DataType,
ReduceAccDataType,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
CElementwiseOperation,
InMemoryDataOperationEnum::Set,
AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
CGridDesc_M_N,
C0GridDesc_N,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
CReduceThreadClusterLengths_MPerBlock_NPerBlock,
CReduceThreadCopySrcDstScalarPerVector_NPerBlock,
LoopSched>;
using Block2CTileMap = typename GridwiseGemm::DefaultBlock2CTileMap;
// Argument
struct Argument : public BaseArgument
{
Argument(const ADataType* p_a_grid,
const BDataType* p_b_grid,
CDataType* p_c_grid,
const C0DataType* p_c0_grid_add,
const C0DataType* p_c0_grid_bias,
const C0DataType* p_c0_grid_gamma,
const C0DataType* p_c0_grid_beta,
index_t MRaw,
index_t NRaw,
index_t KRaw,
index_t StrideA,
index_t StrideB,
index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
CElementwiseOperation c_element_op)
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid},
p_c0_grid_bias_{p_c0_grid_bias},
p_c0_grid_add_{p_c0_grid_add},
p_c0_grid_gamma_{p_c0_grid_gamma},
p_c0_grid_beta_{p_c0_grid_beta},
a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)},
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)},
c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)},
c0_grid_desc_n_{MakeGridDescriptor_N(NRaw)},
c_grid_desc_mblock_mperblock_nblock_nperblock_{},
c0_grid_desc_nblock_nperblock_{},
block_2_ctile_map_{Block2CTileMap(c_grid_desc_m_n_)},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
acc_element_op_{acc_element_op},
c_element_op_{c_element_op}
{
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
b_grid_desc_bk0_n_bk1_,
c_grid_desc_m_n_,
block_2_ctile_map_))
{
c_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n_);
c0_grid_desc_nblock_nperblock_ =
GridwiseGemm::MakeC0GridDescriptor_NBlock_NPerBlock(c0_grid_desc_n_);
}
}
// private:
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
CDataType* p_c_grid_;
const C0DataType* p_c0_grid_bias_;
const C0DataType* p_c0_grid_add_;
const C0DataType* p_c0_grid_gamma_;
const C0DataType* p_c0_grid_beta_;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
CGridDesc_M_N c_grid_desc_m_n_;
C0GridDesc_N c0_grid_desc_n_;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::C0GridDescriptor_NBlock_NPerBlock c0_grid_desc_nblock_nperblock_;
Block2CTileMap block_2_ctile_map_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
AccElementwiseOperation acc_element_op_;
CElementwiseOperation c_element_op_;
};
// Invoker
struct Invoker : public BaseInvoker
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
#if 0
{
std::cout << "arg.a_grid_desc_ak0_m_ak1_{"
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", "
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", "
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_bk0_n_bk1_{"
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", "
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", "
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_))
{
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
}
const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
const auto K =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
float ave_time = 0;
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
const auto kernel = kernel_gemm_layernorm_xdl_cshuffle_v1<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
C0DataType,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::C0GridDescriptor_NBlock_NPerBlock,
Block2CTileMap,
true>;
ave_time =
launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.p_c0_grid_bias_,
arg.p_c0_grid_add_,
arg.p_c0_grid_gamma_,
arg.p_c0_grid_beta_,
arg.a_element_op_,
arg.b_element_op_,
arg.acc_element_op_,
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.c0_grid_desc_nblock_nperblock_,
arg.block_2_ctile_map_);
}
else
{
const auto kernel = kernel_gemm_layernorm_xdl_cshuffle_v1<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
C0DataType,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::C0GridDescriptor_NBlock_NPerBlock,
Block2CTileMap,
false>;
ave_time =
launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.p_c0_grid_bias_,
arg.p_c0_grid_add_,
arg.p_c0_grid_gamma_,
arg.p_c0_grid_beta_,
arg.a_element_op_,
arg.b_element_op_,
arg.acc_element_op_,
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.c0_grid_desc_nblock_nperblock_,
arg.block_2_ctile_map_);
}
return ave_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
static bool IsSupportedArgument(const Argument& arg)
{
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a"))
{
return false;
}
return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_);
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b,
CDataType* p_c,
const C0DataType* p_c0_bias,
const C0DataType* p_c0_add,
const C0DataType* p_c0_gamma,
const C0DataType* p_c0_beta,
index_t MRaw,
index_t NRaw,
index_t KRaw,
index_t StrideA,
index_t StrideB,
index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
CElementwiseOperation c_element_op)
{
return Argument{p_a,
p_b,
p_c,
p_c0_bias,
p_c0_add,
p_c0_gamma,
p_c0_beta,
MRaw,
NRaw,
KRaw,
StrideA,
StrideB,
StrideC,
a_element_op,
b_element_op,
acc_element_op,
c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
const void* p_c0_bias,
const void* p_c0_add,
const void* p_c0_gamma,
const void* p_c0_beta,
index_t MRaw,
index_t NRaw,
index_t KRaw,
index_t StrideA,
index_t StrideB,
index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
CElementwiseOperation c_element_op,
index_t /* KBatch */ = 1)
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c),
static_cast<const C0DataType*>(p_c0_bias),
static_cast<const C0DataType*>(p_c0_add),
static_cast<const C0DataType*>(p_c0_gamma),
static_cast<const C0DataType*>(p_c0_beta),
MRaw,
NRaw,
KRaw,
StrideA,
StrideB,
StrideC,
a_element_op,
b_element_op,
acc_element_op,
c_element_op);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceGemmLayerNorm_Xdl_CShuffle"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< KPerBlock << ", "
<< AK1 << ", "
<< BK1
<< ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp" #include "ck/tensor_operation/gpu/device/device_gemm_splitk.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp"
#include "ck/device_utility/device_prop.hpp" #include "ck/device_utility/device_prop.hpp"
...@@ -56,8 +56,15 @@ template <typename ADataType, ...@@ -56,8 +56,15 @@ template <typename ADataType,
bool BBlockLdsAddExtraN, bool BBlockLdsAddExtraN,
ck::index_t CThreadTransferSrcDstVectorDim, ck::index_t CThreadTransferSrcDstVectorDim,
ck::index_t CThreadTransferDstScalarPerVector> ck::index_t CThreadTransferDstScalarPerVector>
struct DeviceGemmXdlSplitK struct DeviceGemmXdlSplitK : public DeviceGemmSplitK<ALayout,
: public DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation> BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
......
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp" #include "ck/tensor_operation/gpu/device/device_gemm_splitk.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp"
#include "ck/device_utility/device_prop.hpp" #include "ck/device_utility/device_prop.hpp"
...@@ -58,8 +58,15 @@ template <typename ADataType, ...@@ -58,8 +58,15 @@ template <typename ADataType,
index_t CShuffleNRepeatPerShuffle, index_t CShuffleNRepeatPerShuffle,
typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CBlockTransferScalarPerVector_NWaveNPerXDL> index_t CBlockTransferScalarPerVector_NWaveNPerXDL>
struct DeviceGemmXdlSplitKCShuffle struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
: public DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation> BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -420,21 +427,22 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -420,21 +427,22 @@ struct DeviceGemmXdlSplitKCShuffle
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() * arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() *
sizeof(CDataType))); sizeof(CDataType)));
launch_and_time_kernel(stream_config, ave_time =
kernel, launch_and_time_kernel(stream_config,
dim3(grid_size), kernel,
dim3(BlockSize), dim3(grid_size),
0, dim3(BlockSize),
arg.p_a_grid_, 0,
arg.p_b_grid_, arg.p_a_grid_,
arg.p_c_grid_, arg.p_b_grid_,
arg.a_grid_desc_kbatch_k0_m_k1_, arg.p_c_grid_,
arg.b_grid_desc_kbatch_k0_n_k1_, arg.a_grid_desc_kbatch_k0_m_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.b_grid_desc_kbatch_k0_n_k1_,
arg.a_element_op_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.b_element_op_, arg.a_element_op_,
arg.c_element_op_, arg.b_element_op_,
arg.block_2_ctile_map_); arg.c_element_op_,
arg.block_2_ctile_map_);
}; };
if(has_main_k0_block_loop) if(has_main_k0_block_loop)
......
...@@ -46,13 +46,22 @@ __global__ void ...@@ -46,13 +46,22 @@ __global__ void
const auto gemm_desc_ptr = const auto gemm_desc_ptr =
reinterpret_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_const)); reinterpret_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_const));
index_t group_id = 0; index_t left = 0;
for(index_t i = 0; i < group_count; i++) index_t right = group_count;
index_t group_id = index_t((left + right) / 2);
while((!(block_id >= gemm_desc_ptr[group_id].BlockStart_ &&
block_id < gemm_desc_ptr[group_id].BlockEnd_)) &&
left <= right)
{ {
group_id = if(block_id < gemm_desc_ptr[group_id].BlockStart_)
(block_id >= gemm_desc_ptr[i].BlockStart_ && block_id < gemm_desc_ptr[i].BlockEnd_) {
? i right = group_id;
: group_id; }
else
{
left = group_id;
}
group_id = index_t((left + right) / 2);
} }
GridwiseGemm::template Run<HasMainKBlockLoop>( GridwiseGemm::template Run<HasMainKBlockLoop>(
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#pragma once #pragma once
#include <iostream> #include <iostream>
#include <vector>
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
...@@ -11,34 +12,31 @@ namespace ck { ...@@ -11,34 +12,31 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
template <typename AElementwiseOperation, struct DeviceNormalization : public BaseOperator
typename BElementwiseOperation,
typename CElementwiseOperation>
struct DeviceGemmBiasActivation : public BaseOperator
{ {
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a, // inLengths: input tensor extent(s) from high to low dimension
const void* p_b, // inStrides: input tensor stride(s) from high to low dimension
void* p_c, // reduceDims: the dimension(s) the normalization operation is applied
const void* p_c0, // alpha: typeless pointer in host memory storing the alpha scaling value of type AccDataType
ck::index_t M, // beta: typeless pointer in host memory storing the beta scaling value of type AccDataType
ck::index_t N, // in_dev: typeless const pointer in device memory storing the input tensor
ck::index_t K, // out_dev: typeless pointer in device memory storing the output tensor
ck::index_t StrideA, virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(const std::vector<index_t> inLengths,
ck::index_t StrideB, const std::vector<index_t> inStrides,
ck::index_t StrideC, const std::vector<int> reduceDims,
AElementwiseOperation a_element_op, const void* alpha,
BElementwiseOperation b_element_op, const void* beta,
CElementwiseOperation c_element_op, const void* in_dev,
ck::index_t KBatch = 1) = 0; void* out_dev) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
virtual index_t GetRank() const = 0;
virtual index_t GetNumReduceDim() const = 0;
}; };
template <typename AElementwiseOperation, using DeviceNormalizationPtr = std::unique_ptr<DeviceNormalization>;
typename BElementwiseOperation,
typename CElementwiseOperation>
using DeviceGemmBiasActivationPtr = std::unique_ptr<
DeviceGemmBiasActivation<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>>;
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include "ck/utility/reduction_operator.hpp" #include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp" #include "ck/tensor_operation/gpu/device/device_reduce.hpp"
#include "ck/tensor_operation/gpu/device/device_normalization.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce_multiblock.hpp" #include "ck/tensor_operation/gpu/device/device_reduce_multiblock.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce_common.hpp" #include "ck/tensor_operation/gpu/device/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_softmax.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_softmax.hpp"
...@@ -33,8 +34,15 @@ template <typename InDataType, ...@@ -33,8 +34,15 @@ template <typename InDataType,
index_t InSrcVectorDim, index_t InSrcVectorDim,
index_t InSrcVectorSize, index_t InSrcVectorSize,
index_t OutDstVectorSize> index_t OutDstVectorSize>
struct DeviceSoftmax : public BaseOperator struct DeviceSoftmax : public DeviceNormalization
{ {
static constexpr index_t kRank = Rank;
static constexpr index_t kNumReduceDim = NumReduceDim;
virtual index_t GetRank() const override { return kRank; }
virtual index_t GetNumReduceDim() const override { return kNumReduceDim; }
using PassThrough = tensor_operation::element_wise::PassThrough; using PassThrough = tensor_operation::element_wise::PassThrough;
// Used for freeloading of some handy functions from DeviceReduceMultiBlock // Used for freeloading of some handy functions from DeviceReduceMultiBlock
...@@ -61,18 +69,33 @@ struct DeviceSoftmax : public BaseOperator ...@@ -61,18 +69,33 @@ struct DeviceSoftmax : public BaseOperator
using GridDesc_M_K = decltype(Reduction::MakeSrc2dDescriptor({1}, {1}, 1, 1)); using GridDesc_M_K = decltype(Reduction::MakeSrc2dDescriptor({1}, {1}, 1, 1));
using GridwiseReduce = GridwiseSoftmax_mk_to_mk<InDataType, using GridwiseSoftmaxGeneric = GridwiseSoftmax_mk_to_mk<InDataType,
OutDataType, OutDataType,
AccDataType, AccDataType,
GridDesc_M_K, GridDesc_M_K,
BlockSize, BlockSize,
MThreadClusterSize, MThreadClusterSize,
KThreadClusterSize, KThreadClusterSize,
MThreadSliceSize, MThreadSliceSize,
KThreadSliceSize, KThreadSliceSize,
InSrcVectorDim, InSrcVectorDim,
InSrcVectorSize, InSrcVectorSize,
OutDstVectorSize>; OutDstVectorSize,
false>;
using GridwiseSoftmaxSweepOnce = GridwiseSoftmax_mk_to_mk<InDataType,
OutDataType,
AccDataType,
GridDesc_M_K,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
InSrcVectorDim,
InSrcVectorSize,
OutDstVectorSize,
true>;
struct Argument : public Reduction::Argument struct Argument : public Reduction::Argument
{ {
...@@ -121,8 +144,19 @@ struct DeviceSoftmax : public BaseOperator ...@@ -121,8 +144,19 @@ struct DeviceSoftmax : public BaseOperator
const auto out_grid_desc_m_k = Reduction::MakeSrc2dDescriptor( const auto out_grid_desc_m_k = Reduction::MakeSrc2dDescriptor(
arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration); arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration);
const auto kernel_main = bool sweep_once =
kernel_softmax<GridwiseReduce, InDataType, OutDataType, AccDataType, GridDesc_M_K>; in_grid_desc_m_k.GetLength(Number<1>{}) <= KThreadClusterSize * KThreadSliceSize;
const auto kernel_main = sweep_once ? kernel_softmax<GridwiseSoftmaxSweepOnce,
InDataType,
OutDataType,
AccDataType,
GridDesc_M_K>
: kernel_softmax<GridwiseSoftmaxGeneric,
InDataType,
OutDataType,
AccDataType,
GridDesc_M_K>;
float avg_time = 0; float avg_time = 0;
...@@ -167,24 +201,34 @@ struct DeviceSoftmax : public BaseOperator ...@@ -167,24 +201,34 @@ struct DeviceSoftmax : public BaseOperator
return true; return true;
}; };
// inLengths: input tensor extent(s) from high to low dimension
// inStrides: input tensor stride(s) from high to low dimension
// reduceDims: the dimension(s) the softmax normalization operate on
// alpha: typeless pointer in host memory storing the alpha scaling value as type AccDataType
// beta: typeless pointer in host memory storing the beta scaling value as type AccDataType
// in_dev: typeless const pointer in device memory storing the input tensor
// out_dev: typeless pointer in device memory storing the output tensor
std::unique_ptr<BaseArgument> MakeArgumentPointer(const std::vector<index_t> inLengths, std::unique_ptr<BaseArgument> MakeArgumentPointer(const std::vector<index_t> inLengths,
const std::vector<index_t> inStrides, const std::vector<index_t> inStrides,
const std::vector<int> reduceDims, const std::vector<int> reduceDims,
AccDataType alpha, const void* alpha,
AccDataType beta, const void* beta,
const void* in_dev, const void* in_dev,
void* out_dev) void* out_dev) override
{ {
return std::make_unique<Argument>(inLengths, return std::make_unique<Argument>(inLengths,
inStrides, inStrides,
reduceDims, reduceDims,
alpha, *static_cast<const AccDataType*>(alpha),
beta, *static_cast<const AccDataType*>(beta),
static_cast<const InDataType*>(in_dev), static_cast<const InDataType*>(in_dev),
static_cast<OutDataType*>(out_dev)); static_cast<OutDataType*>(out_dev));
}; };
std::unique_ptr<BaseInvoker> MakeInvokerPointer() { return std::make_unique<Invoker>(); }; std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>();
};
std::string GetTypeString() const override std::string GetTypeString() const override
{ {
......
...@@ -19,6 +19,22 @@ enum struct GemmSpecialization ...@@ -19,6 +19,22 @@ enum struct GemmSpecialization
MNKPadding, MNKPadding,
}; };
inline std::string getGemmSpecializationString(const GemmSpecialization& s)
{
switch(s)
{
case GemmSpecialization::Default: return "Default";
case GemmSpecialization::MPadding: return "MPadding";
case GemmSpecialization::NPadding: return "NPadding";
case GemmSpecialization::KPadding: return "KPadding";
case GemmSpecialization::MNPadding: return "MNPadding";
case GemmSpecialization::MKPadding: return "MKPadding";
case GemmSpecialization::NKPadding: return "NKPadding";
case GemmSpecialization::MNKPadding: return "MNKPadding";
default: return "Unrecognized specialization!";
}
}
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -11,8 +11,8 @@ namespace element_wise { ...@@ -11,8 +11,8 @@ namespace element_wise {
struct Add struct Add
{ {
template <typename T> template <typename Y, typename X0, typename X1>
__host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1) const; __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const;
template <> template <>
__host__ __device__ constexpr void __host__ __device__ constexpr void
...@@ -28,7 +28,13 @@ struct Add ...@@ -28,7 +28,13 @@ struct Add
y = x0 + x1; y = x0 + x1;
}; };
// Question: should half_t be supported ? template <>
__host__ __device__ constexpr void
operator()<half_t>(half_t& y, const float& x0, const half_t& x1) const
{
y = type_convert<half_t>(x0) + x1;
};
template <> template <>
__host__ __device__ constexpr void __host__ __device__ constexpr void
operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const
...@@ -36,7 +42,6 @@ struct Add ...@@ -36,7 +42,6 @@ struct Add
y = x0 + x1; y = x0 + x1;
}; };
// Question: should bhalf_t be supported ?
template <> template <>
__host__ __device__ constexpr void __host__ __device__ constexpr void
operator()<bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const operator()<bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const
...@@ -67,7 +72,6 @@ struct Subtract ...@@ -67,7 +72,6 @@ struct Subtract
y = x0 - x1; y = x0 - x1;
}; };
// Question: should half_t be supported ?
template <> template <>
__host__ __device__ constexpr void __host__ __device__ constexpr void
operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const
...@@ -75,7 +79,6 @@ struct Subtract ...@@ -75,7 +79,6 @@ struct Subtract
y = x0 - x1; y = x0 - x1;
}; };
// Question: should bhalf_t be supported ?
template <> template <>
__host__ __device__ constexpr void __host__ __device__ constexpr void
operator()<bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const operator()<bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const
...@@ -87,33 +90,25 @@ struct Subtract ...@@ -87,33 +90,25 @@ struct Subtract
} }
}; };
struct AlphaBetaAdd struct Bilinear
{ {
AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta){}; Bilinear(float alpha, float beta) : alpha_(alpha), beta_(beta){};
template <typename T> template <typename Y, typename X0, typename X1>
__host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1) const; __host__ __device__ constexpr void operator()(Y&, const X0&, const X1&) const;
template <> template <>
__host__ __device__ constexpr void __host__ __device__ constexpr void
operator()<float>(float& y, const float& x0, const float& x1) const operator()<float, float, float>(float& y, const float& x0, const float& x1) const
{ {
y = alpha_ * x0 + beta_ * x1; y = alpha_ * x0 + beta_ * x1;
}; };
template <> template <>
__host__ __device__ constexpr void __host__ __device__ constexpr void
operator()<double>(double& y, const double& x0, const double& x1) const operator()<half_t, float, half_t>(half_t& y, const float& x0, const half_t& x1) const
{
y = static_cast<double>(alpha_) * x0 + static_cast<double>(beta_) * x1;
};
// Question: should half_t be supported ?
template <>
__host__ __device__ constexpr void
operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const
{ {
y = static_cast<half_t>(alpha_ * static_cast<float>(x0) + beta_ * static_cast<float>(x1)); y = type_convert<half_t>(alpha_ * x0 + beta_ * ck::type_convert<float>(x1));
}; };
float alpha_; float alpha_;
...@@ -141,13 +136,12 @@ struct AddRelu ...@@ -141,13 +136,12 @@ struct AddRelu
y = a > 0.0 ? a : 0.0; y = a > 0.0 ? a : 0.0;
}; };
// Question: should half_t be supported ?
template <> template <>
__host__ __device__ constexpr void __host__ __device__ constexpr void
operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const
{ {
const half_t a = x0 + x1; const half_t a = x0 + x1;
y = a > static_cast<half_t>(0.0f) ? a : static_cast<half_t>(0.0f); y = a > type_convert<half_t>(0.0f) ? a : type_convert<half_t>(0.0f);
}; };
}; };
...@@ -176,7 +170,6 @@ struct AddHardswish ...@@ -176,7 +170,6 @@ struct AddHardswish
y = c; y = c;
}; };
// Question: should half_t be supported ?
template <> template <>
__host__ __device__ constexpr void __host__ __device__ constexpr void
operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const
......
...@@ -159,7 +159,7 @@ struct Normalize ...@@ -159,7 +159,7 @@ struct Normalize
using ck::math::sqrt; using ck::math::sqrt;
float variance = mean_square - (mean * mean); float variance = mean_square - (mean * mean);
y = ((x - mean) / sqrt(variance + static_cast<float>(epsilon_))) * gamma + beta; y = ((x - mean) / sqrt(variance + type_convert<float>(epsilon_))) * gamma + beta;
}; };
template <> template <>
......
...@@ -23,19 +23,19 @@ template <typename GridwiseGemm, ...@@ -23,19 +23,19 @@ template <typename GridwiseGemm,
typename FloatC, typename FloatC,
typename FloatC0, typename FloatC0,
typename FloatC1, typename FloatC1,
typename DPtrsGlobal, typename ReducePtrsGlobal,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
typename C1ElementwiseOperation, typename C1ElementwiseOperation,
typename DxsInElementwiseOperation, typename ReduceInElementwiseOperations,
typename DxsReduceAccElementwiseOperation, typename ReduceAccElementwiseOperations,
typename AGridDesc_AK0_M_AK1, typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1, typename BGridDesc_BK0_N_BK1,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename DGridDescriptor_MBlock_MPerBlock, typename ReduceGridDescriptor_MBlock_MPerBlock,
typename Block2CTileMap, typename Block2CTileMap,
bool HasMainKBlockLoop> bool HasMainKBlockLoop>
__global__ void __global__ void
...@@ -46,15 +46,15 @@ __global__ void ...@@ -46,15 +46,15 @@ __global__ void
const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const FloatC0* __restrict__ p_c0_grid, const FloatC0* __restrict__ p_bias_grid,
const FloatC1* __restrict__ p_c1_grid, const FloatC1* __restrict__ p_d0_grid,
DPtrsGlobal p_ds_grid, ReducePtrsGlobal p_reduces_grid,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op, const CElementwiseOperation c_element_op,
const C1ElementwiseOperation c1_element_op, const C1ElementwiseOperation c1_element_op,
const DxsInElementwiseOperation dxs_in_element_op, const ReduceInElementwiseOperations reduce_in_element_ops,
const DxsReduceAccElementwiseOperation dxs_out_element_op, const ReduceAccElementwiseOperations reduce_out_element_ops,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
...@@ -63,7 +63,7 @@ __global__ void ...@@ -63,7 +63,7 @@ __global__ void
c0_grid_desc_mblock_mperblock_nblock_nperblock, c0_grid_desc_mblock_mperblock_nblock_nperblock,
const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c1_grid_desc_mblock_mperblock_nblock_nperblock, c1_grid_desc_mblock_mperblock_nblock_nperblock,
const DGridDescriptor_MBlock_MPerBlock d_grid_desc_mblock_mperblock, const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock,
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
...@@ -72,42 +72,42 @@ __global__ void ...@@ -72,42 +72,42 @@ __global__ void
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
p_c0_grid, p_bias_grid,
p_c1_grid, p_d0_grid,
p_ds_grid, p_reduces_grid,
p_shared, p_shared,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
c1_element_op, c1_element_op,
dxs_in_element_op, reduce_in_element_ops,
dxs_out_element_op, reduce_out_element_ops,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
c0_grid_desc_mblock_mperblock_nblock_nperblock, c0_grid_desc_mblock_mperblock_nblock_nperblock,
c1_grid_desc_mblock_mperblock_nblock_nperblock, c1_grid_desc_mblock_mperblock_nblock_nperblock,
d_grid_desc_mblock_mperblock, reduce_grid_desc_mblock_mperblock,
block_2_ctile_map); block_2_ctile_map);
#else #else
ignore = p_a_grid; ignore = p_a_grid;
ignore = p_b_grid; ignore = p_b_grid;
ignore = p_c_grid; ignore = p_c_grid;
ignore = p_c0_grid; ignore = p_bias_grid;
ignore = p_c1_grid; ignore = p_d0_grid;
ignore = p_ds_grid; ignore = p_reduces_grid;
ignore = a_element_op; ignore = a_element_op;
ignore = b_element_op; ignore = b_element_op;
ignore = c_element_op; ignore = c_element_op;
ignore = c1_element_op; ignore = c1_element_op;
ignore = dxs_in_element_op; ignore = reduce_in_element_ops;
ignore = dxs_out_element_op; ignore = reduce_out_element_ops;
ignore = a_grid_desc_ak0_m_ak1; ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1; ignore = b_grid_desc_bk0_n_bk1;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = c0_grid_desc_mblock_mperblock_nblock_nperblock; ignore = c0_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = c1_grid_desc_mblock_mperblock_nblock_nperblock; ignore = c1_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = d_grid_desc_mblock_mperblock; ignore = reduce_grid_desc_mblock_mperblock;
ignore = block_2_ctile_map; ignore = block_2_ctile_map;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
} }
...@@ -119,22 +119,22 @@ template <typename FloatAB, ...@@ -119,22 +119,22 @@ template <typename FloatAB,
typename FloatC0, typename FloatC0,
typename FloatC1, typename FloatC1,
typename FloatReduceAcc, typename FloatReduceAcc,
typename DPtrsGlobal, typename ReducePtrsGlobal,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
typename C1ElementwiseOperation, typename C1ElementwiseOperation,
typename DxsReduceOperation, typename ReduceOperations,
typename DxsInElementwiseOperation, typename ReduceInElementwiseOperations,
typename DxsReduceAccElementwiseOperation, typename ReduceAccElementwiseOperations,
InMemoryDataOperationEnum CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename DGlobalMemoryDataOperation, typename ReduceGlobalMemoryDataOperation,
typename AGridDesc_AK0_M_AK1, typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1, typename BGridDesc_BK0_N_BK1,
typename CGridDesc_M_N, typename CGridDesc_M_N,
typename C0GridDesc_M_N, typename C0GridDesc_M_N,
typename C1GridDesc_M_N, typename C1GridDesc_M_N,
typename DGridDesc_M, typename ReduceGridDesc_M,
index_t NumGemmKPrefetchStage, index_t NumGemmKPrefetchStage,
index_t BlockSize, index_t BlockSize,
index_t MPerBlock, index_t MPerBlock,
...@@ -321,18 +321,18 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -321,18 +321,18 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
} }
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeDGridDescriptor_MBlock_MPerBlock(const DGridDesc_M& d_grid_desc_m) MakeReduceGridDescriptor_MBlock_MPerBlock(const ReduceGridDesc_M& d_grid_desc_m)
{ {
const auto M = d_grid_desc_m.GetLength(I0); const auto M = d_grid_desc_m.GetLength(I0);
const auto MBlock = M / MPerBlock; const auto MBlock = M / MPerBlock;
const auto d_grid_desc_mblock_mperblock = transform_tensor_descriptor( const auto reduce_grid_desc_mblock_mperblock = transform_tensor_descriptor(
d_grid_desc_m, d_grid_desc_m,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{}))), make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{}))),
make_tuple(Sequence<0>{}), make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1>{})); make_tuple(Sequence<0, 1>{}));
return d_grid_desc_mblock_mperblock; return reduce_grid_desc_mblock_mperblock;
} }
// return block_id to C matrix tile idx (m0, n0) mapping // return block_id to C matrix tile idx (m0, n0) mapping
...@@ -352,36 +352,37 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -352,36 +352,37 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( using C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(C1GridDesc_M_N{}))>; MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(C1GridDesc_M_N{}))>;
using DGridDescriptor_MBlock_MPerBlock = using ReduceGridDescriptor_MBlock_MPerBlock =
remove_cvref_t<decltype(MakeDGridDescriptor_MBlock_MPerBlock(DGridDesc_M{}))>; remove_cvref_t<decltype(MakeReduceGridDescriptor_MBlock_MPerBlock(ReduceGridDesc_M{}))>;
using DefaultBlock2CTileMap = using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>; remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>;
template <bool HasMainKBlockLoop, typename Block2CTileMap> template <bool HasMainKBlockLoop, typename Block2CTileMap>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid, __device__ static void
const FloatAB* __restrict__ p_b_grid, Run(const FloatAB* __restrict__ p_a_grid,
FloatC* __restrict__ p_c_grid, const FloatAB* __restrict__ p_b_grid,
const FloatC0* __restrict__ p_c0_grid, FloatC* __restrict__ p_c_grid,
const FloatC1* __restrict__ p_c1_grid, const FloatC0* __restrict__ p_bias_grid,
DPtrsGlobal p_ds_grid, const FloatC1* __restrict__ p_d0_grid,
void* __restrict__ p_shared, ReducePtrsGlobal p_reduces_grid,
const AElementwiseOperation& a_element_op, void* __restrict__ p_shared,
const BElementwiseOperation& b_element_op, const AElementwiseOperation& a_element_op,
const CElementwiseOperation& c_element_op, const BElementwiseOperation& b_element_op,
const C1ElementwiseOperation& c1_element_op, const CElementwiseOperation& c_element_op,
const DxsInElementwiseOperation& dxs_in_element_op, const C1ElementwiseOperation& c1_element_op,
const DxsReduceAccElementwiseOperation& dxs_out_element_op, const ReduceInElementwiseOperations& reduce_in_element_ops,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, const ReduceAccElementwiseOperations& reduce_out_element_ops,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
const C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& c_grid_desc_mblock_mperblock_nblock_nperblock,
c0_grid_desc_mblock_mperblock_nblock_nperblock, const C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& c0_grid_desc_mblock_mperblock_nblock_nperblock,
c1_grid_desc_mblock_mperblock_nblock_nperblock, const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
const DGridDescriptor_MBlock_MPerBlock& d_grid_desc_mblock_mperblock, c1_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2CTileMap& block_2_ctile_map) const ReduceGridDescriptor_MBlock_MPerBlock& reduce_grid_desc_mblock_mperblock,
const Block2CTileMap& block_2_ctile_map)
{ {
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
...@@ -390,9 +391,9 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -390,9 +391,9 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
auto c0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto c0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c0_grid, c0_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); p_bias_grid, c0_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
auto c1_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto c1_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c1_grid, c1_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); p_d0_grid, c1_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// divide block work by [M, N] // divide block work by [M, N]
const auto block_work_idx = const auto block_work_idx =
...@@ -725,12 +726,12 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -725,12 +726,12 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
make_naive_tensor_descriptor_packed( make_naive_tensor_descriptor_packed(
make_tuple(Number<mreduce_per_thread>{}, Number<nreduce_per_thread>{})); make_tuple(Number<mreduce_per_thread>{}, Number<nreduce_per_thread>{}));
// VGPR d_reduce_thread_desc_mperblock // VGPR reduce_thread_desc_mperblock
constexpr auto d_reduce_thread_desc_mperblock = constexpr auto reduce_thread_desc_mperblock =
make_naive_tensor_descriptor_packed(make_tuple(Number<mreduce_per_thread>{})); make_naive_tensor_descriptor_packed(make_tuple(Number<mreduce_per_thread>{}));
// VGPR d_reduce_thread_desc_mblock_mperblock // VGPR reduce_thread_desc_mblock_mperblock
constexpr auto d_reduce_thread_desc_mblock_mperblock = constexpr auto reduce_thread_desc_mblock_mperblock =
make_naive_tensor_descriptor_packed(make_tuple(I1, Number<mreduce_per_thread>{})); make_naive_tensor_descriptor_packed(make_tuple(I1, Number<mreduce_per_thread>{}));
auto c_reduce_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatReduceAcc>( auto c_reduce_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatReduceAcc>(
...@@ -759,29 +760,29 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -759,29 +760,29 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
1, 1,
true>{c_reduce_block_desc_mperblock_nperblock, c_reduce_thread_data_idx_begin}; true>{c_reduce_block_desc_mperblock_nperblock, c_reduce_thread_data_idx_begin};
auto dxs_reduce_thread_copy_vgpr_to_global = generate_tuple( auto reduce_tuple_thread_copy_vgpr_to_global = generate_tuple(
[&](auto I) { [&](auto I) {
auto p_d_grid = p_ds_grid[I]; auto p_reduce_grid = p_reduces_grid[I];
auto d_out_element_op = dxs_out_element_op[I]; auto reduce_acc_element_op = reduce_out_element_ops[I];
return ThreadwiseTensorSliceTransfer_v1r3< return ThreadwiseTensorSliceTransfer_v1r3<
FloatReduceAcc, FloatReduceAcc,
remove_pointer_t<decltype(p_d_grid)>, remove_pointer_t<decltype(p_reduce_grid)>,
decltype(d_reduce_thread_desc_mblock_mperblock), decltype(reduce_thread_desc_mblock_mperblock),
decltype(d_grid_desc_mblock_mperblock), decltype(reduce_grid_desc_mblock_mperblock),
decltype(d_out_element_op), decltype(reduce_acc_element_op),
Sequence<1, mreduce_per_thread>, Sequence<1, mreduce_per_thread>,
Sequence<0, 1>, Sequence<0, 1>,
1, 1,
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
DGlobalMemoryDataOperation::At(I), ReduceGlobalMemoryDataOperation::At(I),
1, 1,
false>{d_grid_desc_mblock_mperblock, false>{reduce_grid_desc_mblock_mperblock,
make_multi_index(block_work_idx[I0], // mblock make_multi_index(block_work_idx[I0], // mblock
c_reduce_thread_data_idx_begin[I0]), // mperblock c_reduce_thread_data_idx_begin[I0]), // mperblock
d_out_element_op}; reduce_acc_element_op};
}, },
Number<p_ds_grid.Size()>{}); Number<p_reduces_grid.Size()>{});
// c0 and c1 // c0 and c1
constexpr auto c0_reduce_thread_desc_mblock_mperblock_nblock_nperblock = constexpr auto c0_reduce_thread_desc_mblock_mperblock_nblock_nperblock =
...@@ -909,35 +910,35 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -909,35 +910,35 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf); c_grid_buf);
static_for<0, p_ds_grid.Size(), 1>{}([&](auto In) { static_for<0, p_reduces_grid.Size(), 1>{}([&](auto In) {
auto& p_d_grid = p_ds_grid[In]; auto& p_reduce_grid = p_reduces_grid[In];
auto d_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto reduce_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d_grid, d_grid_desc_mblock_mperblock.GetElementSpaceSize()); p_reduce_grid, reduce_grid_desc_mblock_mperblock.GetElementSpaceSize());
auto d_thread_buf = auto reduce_thread_buf =
make_static_buffer<AddressSpaceEnum::Vgpr, FloatReduceAcc>( make_static_buffer<AddressSpaceEnum::Vgpr, FloatReduceAcc>(
d_reduce_thread_desc_mperblock.GetElementSpaceSize()); reduce_thread_desc_mperblock.GetElementSpaceSize());
auto& d_in_element_op = dxs_in_element_op[In]; auto& reduce_in_element_op = reduce_in_element_ops[In];
auto& d_reduce_thread_copy_vgpr_to_global = auto& reduce_thread_copy_vgpr_to_global =
dxs_reduce_thread_copy_vgpr_to_global(In); reduce_tuple_thread_copy_vgpr_to_global(In);
using DReduceOperation = remove_cvref_t<decltype(DxsReduceOperation{}[In])>; using ReduceOperation = remove_cvref_t<decltype(ReduceOperations{}[In])>;
using ThreadwiseReduce = using ThreadwiseReduce =
ThreadwiseReduction<FloatReduceAcc, ThreadwiseReduction<FloatReduceAcc,
decltype(c_reduce_thread_desc_mperblock_nperblock), decltype(c_reduce_thread_desc_mperblock_nperblock),
decltype(d_reduce_thread_desc_mperblock), decltype(reduce_thread_desc_mperblock),
DReduceOperation, ReduceOperation,
false>; false>;
// Global write Gemm shuffle + reduction // Global write Gemm shuffle + reduction
const auto d_zeroVal = const auto reduce_identityVal =
DReduceOperation::template GetIdentityValue<FloatReduceAcc>(); ReduceOperation::template GetIdentityValue<FloatReduceAcc>();
static_for<0, mreduce_per_thread, 1>{}( static_for<0, mreduce_per_thread, 1>{}(
[&](auto I) { d_thread_buf(I) = d_zeroVal; }); [&](auto I) { reduce_thread_buf(I) = reduce_identityVal; });
// reduce in VGPR // reduce in VGPR
static_for<0, mreduce_per_thread, 1>{}([&](auto im) { static_for<0, mreduce_per_thread, 1>{}([&](auto im) {
...@@ -946,26 +947,25 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -946,26 +947,25 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
Number<c_reduce_thread_desc_mperblock_nperblock.CalculateOffset( Number<c_reduce_thread_desc_mperblock_nperblock.CalculateOffset(
make_tuple(im, in))>{}; make_tuple(im, in))>{};
d_in_element_op(c_reduce_thread_buf(offset), reduce_in_element_op(c_reduce_thread_buf(offset),
c_reduce_thread_buf(offset)); c_reduce_thread_buf(offset));
}); });
}); });
ThreadwiseReduce::Reduce(c_reduce_thread_buf, d_thread_buf); ThreadwiseReduce::Reduce(c_reduce_thread_buf, reduce_thread_buf);
// copy from VGPR to Global // copy from VGPR to Global
d_reduce_thread_copy_vgpr_to_global.Run( reduce_thread_copy_vgpr_to_global.Run(reduce_thread_desc_mblock_mperblock,
d_reduce_thread_desc_mblock_mperblock, make_tuple(I0, I0),
make_tuple(I0, I0), reduce_thread_buf,
d_thread_buf, reduce_grid_desc_mblock_mperblock,
d_grid_desc_mblock_mperblock, reduce_grid_buf);
d_grid_buf);
if constexpr(access_id < num_access - 1) if constexpr(access_id < num_access - 1)
{ {
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
d_reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow( reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow(
d_grid_desc_mblock_mperblock, reduce_grid_desc_mblock_mperblock,
make_tuple(c_global_step[I0], c_global_step[I1])); make_tuple(c_global_step[I0], c_global_step[I1]));
} }
}); });
......
...@@ -21,16 +21,16 @@ namespace ck { ...@@ -21,16 +21,16 @@ namespace ck {
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename FloatAB, typename FloatAB,
typename FloatC, typename FloatC,
typename DPtrsGlobal, typename ReducePtrsGlobal,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
typename DxsInElementwiseOperation, typename ReduceInElementwiseOperations,
typename DxsReduceAccElementwiseOperation, typename ReduceAccElementwiseOperations,
typename AGridDesc_AK0_M_AK1, typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1, typename BGridDesc_BK0_N_BK1,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename DGridDescriptor_MBlock_MPerBlock, typename ReduceGridDescriptor_MBlock_MPerBlock,
typename Block2CTileMap, typename Block2CTileMap,
bool HasMainKBlockLoop> bool HasMainKBlockLoop>
__global__ void __global__ void
...@@ -41,17 +41,17 @@ __global__ void ...@@ -41,17 +41,17 @@ __global__ void
const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
DPtrsGlobal p_ds_grid, ReducePtrsGlobal p_reduces_grid,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op, const CElementwiseOperation c_element_op,
const DxsInElementwiseOperation dxs_in_element_op, const ReduceInElementwiseOperations reduce_in_element_ops,
const DxsReduceAccElementwiseOperation dxs_out_element_op, const ReduceAccElementwiseOperations reduce_out_element_ops,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
const DGridDescriptor_MBlock_MPerBlock d_grid_desc_mblock_mperblock, const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock,
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
...@@ -60,32 +60,32 @@ __global__ void ...@@ -60,32 +60,32 @@ __global__ void
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
p_ds_grid, p_reduces_grid,
p_shared, p_shared,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
dxs_in_element_op, reduce_in_element_ops,
dxs_out_element_op, reduce_out_element_ops,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
d_grid_desc_mblock_mperblock, reduce_grid_desc_mblock_mperblock,
block_2_ctile_map); block_2_ctile_map);
#else #else
ignore = p_a_grid; ignore = p_a_grid;
ignore = p_b_grid; ignore = p_b_grid;
ignore = p_c_grid; ignore = p_c_grid;
ignore = p_ds_grid; ignore = p_reduces_grid;
ignore = a_element_op; ignore = a_element_op;
ignore = b_element_op; ignore = b_element_op;
ignore = c_element_op; ignore = c_element_op;
ignore = dxs_in_element_op; ignore = reduce_in_element_ops;
ignore = dxs_out_element_op; ignore = reduce_out_element_ops;
ignore = a_grid_desc_ak0_m_ak1; ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1; ignore = b_grid_desc_bk0_n_bk1;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = d_grid_desc_mblock_mperblock; ignore = reduce_grid_desc_mblock_mperblock;
ignore = block_2_ctile_map; ignore = block_2_ctile_map;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
} }
...@@ -95,19 +95,19 @@ template <typename FloatAB, ...@@ -95,19 +95,19 @@ template <typename FloatAB,
typename FloatCShuffle, typename FloatCShuffle,
typename FloatC, typename FloatC,
typename FloatReduceAcc, typename FloatReduceAcc,
typename DPtrsGlobal, typename ReducePtrsGlobal,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
typename DxsReduceOperation, typename ReduceOperations,
typename DxsInElementwiseOperation, typename ReduceInElementwiseOperations,
typename DxsReduceAccElementwiseOperation, typename ReduceAccElementwiseOperations,
InMemoryDataOperationEnum CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename DGlobalMemoryDataOperation, typename ReduceGlobalMemoryDataOperation,
typename AGridDesc_AK0_M_AK1, typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1, typename BGridDesc_BK0_N_BK1,
typename CGridDesc_M_N, typename CGridDesc_M_N,
typename DGridDesc_M, typename ReduceGridDesc_M,
index_t NumGemmKPrefetchStage, index_t NumGemmKPrefetchStage,
index_t BlockSize, index_t BlockSize,
index_t MPerBlock, index_t MPerBlock,
...@@ -293,18 +293,18 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -293,18 +293,18 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
} }
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeDGridDescriptor_MBlock_MPerBlock(const DGridDesc_M& d_grid_desc_m) MakeReduceGridDescriptor_MBlock_MPerBlock(const ReduceGridDesc_M& d_grid_desc_m)
{ {
const auto M = d_grid_desc_m.GetLength(I0); const auto M = d_grid_desc_m.GetLength(I0);
const auto MBlock = M / MPerBlock; const auto MBlock = M / MPerBlock;
const auto d_grid_desc_mblock_mperblock = transform_tensor_descriptor( const auto reduce_grid_desc_mblock_mperblock = transform_tensor_descriptor(
d_grid_desc_m, d_grid_desc_m,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{}))), make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{}))),
make_tuple(Sequence<0>{}), make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1>{})); make_tuple(Sequence<0, 1>{}));
return d_grid_desc_mblock_mperblock; return reduce_grid_desc_mblock_mperblock;
} }
// return block_id to C matrix tile idx (m0, n0) mapping // return block_id to C matrix tile idx (m0, n0) mapping
...@@ -318,29 +318,30 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -318,29 +318,30 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>; MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>;
using DGridDescriptor_MBlock_MPerBlock = using ReduceGridDescriptor_MBlock_MPerBlock =
remove_cvref_t<decltype(MakeDGridDescriptor_MBlock_MPerBlock(DGridDesc_M{}))>; remove_cvref_t<decltype(MakeReduceGridDescriptor_MBlock_MPerBlock(ReduceGridDesc_M{}))>;
using DefaultBlock2CTileMap = using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>; remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>;
template <bool HasMainKBlockLoop, typename Block2CTileMap> template <bool HasMainKBlockLoop, typename Block2CTileMap>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid, __device__ static void
const FloatAB* __restrict__ p_b_grid, Run(const FloatAB* __restrict__ p_a_grid,
FloatC* __restrict__ p_c_grid, const FloatAB* __restrict__ p_b_grid,
DPtrsGlobal p_ds_grid, FloatC* __restrict__ p_c_grid,
void* __restrict__ p_shared, ReducePtrsGlobal p_reduces_grid,
const AElementwiseOperation& a_element_op, void* __restrict__ p_shared,
const BElementwiseOperation& b_element_op, const AElementwiseOperation& a_element_op,
const CElementwiseOperation& c_element_op, const BElementwiseOperation& b_element_op,
const DxsInElementwiseOperation& dxs_in_element_op, const CElementwiseOperation& c_element_op,
const DxsReduceAccElementwiseOperation& dxs_out_element_op, const ReduceInElementwiseOperations& reduce_in_element_ops,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, const ReduceAccElementwiseOperations& reduce_out_element_ops,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
const DGridDescriptor_MBlock_MPerBlock& d_grid_desc_mblock_mperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2CTileMap& block_2_ctile_map) const ReduceGridDescriptor_MBlock_MPerBlock& reduce_grid_desc_mblock_mperblock,
const Block2CTileMap& block_2_ctile_map)
{ {
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
...@@ -706,12 +707,12 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -706,12 +707,12 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
make_naive_tensor_descriptor_packed( make_naive_tensor_descriptor_packed(
make_tuple(Number<mreduce_per_thread>{}, Number<nreduce_per_thread>{})); make_tuple(Number<mreduce_per_thread>{}, Number<nreduce_per_thread>{}));
// VGPR d_reduce_thread_desc_mperblock // VGPR reduce_thread_desc_mperblock
constexpr auto d_reduce_thread_desc_mperblock = constexpr auto reduce_thread_desc_mperblock =
make_naive_tensor_descriptor_packed(make_tuple(Number<mreduce_per_thread>{})); make_naive_tensor_descriptor_packed(make_tuple(Number<mreduce_per_thread>{}));
// VGPR d_reduce_thread_desc_mblock_mperblock // VGPR reduce_thread_desc_mblock_mperblock
constexpr auto d_reduce_thread_desc_mblock_mperblock = constexpr auto reduce_thread_desc_mblock_mperblock =
make_naive_tensor_descriptor_packed(make_tuple(I1, Number<mreduce_per_thread>{})); make_naive_tensor_descriptor_packed(make_tuple(I1, Number<mreduce_per_thread>{}));
auto c_reduce_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatReduceAcc>( auto c_reduce_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatReduceAcc>(
...@@ -740,29 +741,29 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -740,29 +741,29 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
1, 1,
true>{c_reduce_block_desc_mperblock_nperblock, c_reduce_thread_data_idx_begin}; true>{c_reduce_block_desc_mperblock_nperblock, c_reduce_thread_data_idx_begin};
auto dxs_reduce_thread_copy_vgpr_to_global = generate_tuple( auto reduce_tuple_thread_copy_vgpr_to_global = generate_tuple(
[&](auto I) { [&](auto I) {
auto p_d_grid = p_ds_grid[I]; auto p_reduce_grid = p_reduces_grid[I];
auto d_out_element_op = dxs_out_element_op[I]; auto reduce_acc_element_op = reduce_out_element_ops[I];
return ThreadwiseTensorSliceTransfer_v1r3< return ThreadwiseTensorSliceTransfer_v1r3<
FloatReduceAcc, FloatReduceAcc,
remove_pointer_t<decltype(p_d_grid)>, remove_pointer_t<decltype(p_reduce_grid)>,
decltype(d_reduce_thread_desc_mblock_mperblock), decltype(reduce_thread_desc_mblock_mperblock),
decltype(d_grid_desc_mblock_mperblock), decltype(reduce_grid_desc_mblock_mperblock),
decltype(d_out_element_op), decltype(reduce_acc_element_op),
Sequence<1, mreduce_per_thread>, Sequence<1, mreduce_per_thread>,
Sequence<0, 1>, Sequence<0, 1>,
1, 1,
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
DGlobalMemoryDataOperation::At(I), ReduceGlobalMemoryDataOperation::At(I),
1, 1,
false>{d_grid_desc_mblock_mperblock, false>{reduce_grid_desc_mblock_mperblock,
make_multi_index(block_work_idx[I0], // mblock make_multi_index(block_work_idx[I0], // mblock
c_reduce_thread_data_idx_begin[I0]), // mperblock c_reduce_thread_data_idx_begin[I0]), // mperblock
d_out_element_op}; reduce_acc_element_op};
}, },
Number<p_ds_grid.Size()>{}); Number<p_reduces_grid.Size()>{});
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
...@@ -797,35 +798,35 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -797,35 +798,35 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
make_tuple(I0, I0), make_tuple(I0, I0),
c_reduce_thread_buf); c_reduce_thread_buf);
static_for<0, p_ds_grid.Size(), 1>{}([&](auto In) { static_for<0, p_reduces_grid.Size(), 1>{}([&](auto In) {
auto& p_d_grid = p_ds_grid[In]; auto& p_reduce_grid = p_reduces_grid[In];
auto d_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto reduce_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d_grid, d_grid_desc_mblock_mperblock.GetElementSpaceSize()); p_reduce_grid, reduce_grid_desc_mblock_mperblock.GetElementSpaceSize());
auto d_thread_buf = auto reduce_thread_buf =
make_static_buffer<AddressSpaceEnum::Vgpr, FloatReduceAcc>( make_static_buffer<AddressSpaceEnum::Vgpr, FloatReduceAcc>(
d_reduce_thread_desc_mperblock.GetElementSpaceSize()); reduce_thread_desc_mperblock.GetElementSpaceSize());
auto& d_in_element_op = dxs_in_element_op[In]; auto& reduce_in_element_op = reduce_in_element_ops[In];
auto& d_reduce_thread_copy_vgpr_to_global = auto& reduce_thread_copy_vgpr_to_global =
dxs_reduce_thread_copy_vgpr_to_global(In); reduce_tuple_thread_copy_vgpr_to_global(In);
using DReduceOperation = remove_cvref_t<decltype(DxsReduceOperation{}[In])>; using ReduceOperation = remove_cvref_t<decltype(ReduceOperations{}[In])>;
using ThreadwiseReduce = using ThreadwiseReduce =
ThreadwiseReduction<FloatReduceAcc, ThreadwiseReduction<FloatReduceAcc,
decltype(c_reduce_thread_desc_mperblock_nperblock), decltype(c_reduce_thread_desc_mperblock_nperblock),
decltype(d_reduce_thread_desc_mperblock), decltype(reduce_thread_desc_mperblock),
DReduceOperation, ReduceOperation,
false>; false>;
// Global write Gemm shuffle + reduction // Global write Gemm shuffle + reduction
const auto d_identityVal = const auto reduce_identityVal =
DReduceOperation::template GetIdentityValue<FloatReduceAcc>(); ReduceOperation::template GetIdentityValue<FloatReduceAcc>();
static_for<0, mreduce_per_thread, 1>{}( static_for<0, mreduce_per_thread, 1>{}(
[&](auto I) { d_thread_buf(I) = d_identityVal; }); [&](auto I) { reduce_thread_buf(I) = reduce_identityVal; });
// reduce in VGPR // reduce in VGPR
static_for<0, mreduce_per_thread, 1>{}([&](auto im) { static_for<0, mreduce_per_thread, 1>{}([&](auto im) {
...@@ -834,26 +835,25 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -834,26 +835,25 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
Number<c_reduce_thread_desc_mperblock_nperblock.CalculateOffset( Number<c_reduce_thread_desc_mperblock_nperblock.CalculateOffset(
make_tuple(im, in))>{}; make_tuple(im, in))>{};
d_in_element_op(c_reduce_thread_buf(offset), reduce_in_element_op(c_reduce_thread_buf(offset),
c_reduce_thread_buf(offset)); c_reduce_thread_buf(offset));
}); });
}); });
ThreadwiseReduce::Reduce(c_reduce_thread_buf, d_thread_buf); ThreadwiseReduce::Reduce(c_reduce_thread_buf, reduce_thread_buf);
// copy from VGPR to Global // copy from VGPR to Global
d_reduce_thread_copy_vgpr_to_global.Run( reduce_thread_copy_vgpr_to_global.Run(reduce_thread_desc_mblock_mperblock,
d_reduce_thread_desc_mblock_mperblock, make_tuple(I0, I0),
make_tuple(I0, I0), reduce_thread_buf,
d_thread_buf, reduce_grid_desc_mblock_mperblock,
d_grid_desc_mblock_mperblock, reduce_grid_buf);
d_grid_buf);
if constexpr(access_id < num_access - 1) if constexpr(access_id < num_access - 1)
{ {
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
d_reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow( reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow(
d_grid_desc_mblock_mperblock, reduce_grid_desc_mblock_mperblock,
make_tuple(c_global_step[I0], c_global_step[I1])); make_tuple(c_global_step[I0], c_global_step[I1]));
} }
}); });
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp"
#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp"
namespace ck {
// D = Layernorm(acc_element_op(A * B + broadcast(bias)) + add) * broadcast(gamma) + broadcast(beta)
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename FloatC0,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename AccElementwiseOperation,
typename CElementwiseOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename C0GridDescriptor_NBlock_NPerBlock,
typename Block2CTileMap,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_layernorm_xdl_cshuffle_v1(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, // MxN
const FloatC0* __restrict__ p_c0_bias_grid, // 1xN
const FloatC0* __restrict__ p_c0_add_grid, // MxN
const FloatC0* __restrict__ p_c0_gamma_grid, // 1xN
const FloatC0* __restrict__ p_c0_beta_grid, // 1xN
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const AccElementwiseOperation acc_element_op,
const CElementwiseOperation c_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const C0GridDescriptor_NBlock_NPerBlock c0_grid_desc_nblock_nperblock,
const Block2CTileMap block_2_ctile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
// TODO ANT: separate into MMA + Epilogue
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid,
p_c_grid,
p_c0_bias_grid,
p_c0_add_grid,
p_c0_gamma_grid,
p_c0_beta_grid,
p_shared,
a_element_op,
b_element_op,
acc_element_op,
c_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c0_grid_desc_nblock_nperblock,
block_2_ctile_map);
// TODO ANT: Run layernorm epilogue here
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_c_grid;
ignore = p_c0_bias_grid;
ignore = p_c0_add_grid;
ignore = p_c0_gamma_grid;
ignore = p_c0_beta_grid;
ignore = a_element_op;
ignore = b_element_op;
ignore = acc_element_op;
ignore = c_element_op;
ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = c0_grid_desc_nblock_nperblock;
ignore = block_2_ctile_map;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
// The GEMM + Layernorm implementation is a specialized kernel which allows fusing both layers
// together given the condition GEMM extents N of MNK is spanned by a single workgroup. For example,
// a kernel configured with NPerBlock = 128 allows to operate on all GEMM sizes if N <= 128
template <typename FloatAB,
typename FloatGemmAcc,
typename FloatCShuffle,
typename FloatC,
typename FloatC0,
typename FloatReduceAcc, // Data type after shuffle
typename AElementwiseOperation,
typename BElementwiseOperation,
typename AccElementwiseOperation,
typename CElementwiseOperation,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename CGridDesc_M_N,
typename C0GridDesc_N,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t AK1Value,
index_t BK1Value,
index_t MPerXdl,
index_t NPerXdl,
index_t MXdlPerWave,
index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
bool AThreadTransferSrcResetCoordinateAfterRun,
index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BThreadTransferSrcResetCoordinateAfterRun,
index_t BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
typename CReduceThreadClusterLengths_MPerBlock_NPerBlock,
index_t CReduceThreadCopySrcDstScalarPerVector_NPerBlock,
LoopScheduler LoopSched>
struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
// K1 should be Number<...>
static constexpr auto AK0 = Number<KPerBlock / AK1Value>{};
static constexpr auto BK0 = Number<KPerBlock / BK1Value>{};
static constexpr auto AK1 = Number<AK1Value>{};
static constexpr auto BK1 = Number<BK1Value>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = GridwiseGemmPipeline_v1<NumGemmKPrefetchStage>;
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{
// A matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(AK0, Number<MPerBlock>{}, AK1),
make_tuple(Number<MPerBlock + ABlockLdsExtraM>{} * AK1, AK1, I1));
}
__host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
{
// B matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(BK0, Number<NPerBlock>{}, BK1),
make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1, BK1, I1));
}
__host__ __device__ static constexpr auto
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
{
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
make_naive_tensor_descriptor_packed(
make_tuple(I1,
Number<CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl>{},
I1,
Number<CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>{}));
return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
}
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// lds max alignment
constexpr auto max_lds_align = math::lcm(AK1, BK1);
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
// LDS allocation for C shuffle in LDS
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
// Align 16 bytes (maximum LDS read/write width)
constexpr auto c_block_size_aligned =
math::integer_least_multiple(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize() *
sizeof(FloatCShuffle),
16) /
sizeof(FloatCShuffle);
// LDS allocation for reduction workspace
constexpr index_t c_lds_workspace_size = BlockSize;
return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
sizeof(FloatAB),
c_block_size_aligned * sizeof(FloatCShuffle) +
c_lds_workspace_size * sizeof(FloatReduceAcc));
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Block2CTileMap>
__host__ __device__ static constexpr bool
CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const CGridDesc_M_N& c_grid_desc_m_n,
const Block2CTileMap& block_2_ctile_map)
{
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
"Invalid tuning param!");
const auto M = a_grid_desc_ak0_m_ak1.GetLength(I1);
const auto N = b_grid_desc_bk0_n_bk1.GetLength(I1);
const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2);
if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1)))
return false;
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
return false;
// in order to reduce N dim without elaborate sync across CUs in single kernel, one
// workgroup must span the entire N extent
if(math::integer_divide_ceil(N, NPerBlock) > 1)
{
return false;
}
// static check: all waves in the workgroups combined must cover whole N extent in order
// to have efficient N-dim reduction
static_assert(CShuffleNXdlPerWavePerShuffle == NXdlPerWave,
"condition not met for efficient layernorm");
// check gridwise gemm pipeline
const auto num_k_loop = K / KPerBlock;
if(!GridwiseGemmPipe::IsSupported(num_k_loop))
{
return false;
}
if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n))
{
return false;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return true;
}
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
{
const index_t num_loop = K / KPerBlock;
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
}
__host__ __device__ static constexpr auto
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
const auto MBlock = M / MPerBlock;
const auto NBlock = N / NPerBlock;
const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
return c_grid_desc_mblock_mperblock_nblock_nperblock;
}
// for bias, beta, gamma
__host__ __device__ static constexpr auto
MakeC0GridDescriptor_NBlock_NPerBlock(const C0GridDesc_N& c0_grid_desc_n)
{
const auto N = c0_grid_desc_n.GetLength(I0);
const auto NBlock = N / NPerBlock;
const auto c0_grid_desc_nblock_nperblock = transform_tensor_descriptor(
c0_grid_desc_n,
make_tuple(make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{}))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1>{}));
return c0_grid_desc_nblock_nperblock;
}
// return block_id to C matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n)
{
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>(
c_grid_desc_m_n);
}
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>;
using C0GridDescriptor_NBlock_NPerBlock =
remove_cvref_t<decltype(MakeC0GridDescriptor_NBlock_NPerBlock(C0GridDesc_N{}))>;
using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>;
template <bool HasMainKBlockLoop, typename Block2CTileMap>
__device__ static void
Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const FloatC0* __restrict__ p_c0_bias_grid, // 1xN
const FloatC0* __restrict__ p_c0_add_grid, // MxN
const FloatC0* __restrict__ p_c0_gamma_grid, // 1xN
const FloatC0* __restrict__ p_c0_beta_grid, // 1xN
void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const AccElementwiseOperation& acc_element_op,
const CElementwiseOperation& c_element_op,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock,
const C0GridDescriptor_NBlock_NPerBlock& c0_grid_desc_nblock_nperblock,
const Block2CTileMap& block_2_ctile_map)
{
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
auto c0_bias_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c0_bias_grid, c0_grid_desc_nblock_nperblock.GetElementSpaceSize());
// Note: c0_add is of same layout as c so we don't declare new c0_add_desc here
auto c0_add_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c0_add_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
auto c0_gamma_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c0_gamma_grid, c0_grid_desc_nblock_nperblock.GetElementSpaceSize());
auto c0_beta_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c0_beta_grid, c0_grid_desc_nblock_nperblock.GetElementSpaceSize());
// divide block work by [M, N]
const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!block_2_ctile_map.ValidCTileIndex(
block_work_idx,
make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
{
return;
}
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
// lds max alignment
constexpr auto max_lds_align = math::lcm(AK1, BK1);
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// A matrix blockwise copy
auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<AK0, MPerBlock, AK1>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true,
NumGemmKPrefetchStage>(
a_grid_desc_ak0_m_ak1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
// B matrix blockwise copy
auto b_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<BK0, NPerBlock, BK1>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true,
NumGemmKPrefetchStage>(
b_grid_desc_bk0_n_bk1,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[K0PerBlock, MPerBlock] is in LDS
// b_mtx[K0PerBlock, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
constexpr index_t KPack = math::max(
math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
FloatAB,
FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1),
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
KPack,
LoopSched>();
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned,
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0);
// gridwise GEMM pipeline
const auto gridwise_gemm_pipeline =
GridwiseGemmPipeline_v1_Selector<NumGemmKPrefetchStage, LoopSched>();
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
KPerBlock);
gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
a_blockwise_copy,
a_grid_buf,
a_block_buf,
a_block_slice_copy_step,
b_grid_desc_bk0_n_bk1,
b_block_desc_bk0_n_bk1,
b_blockwise_copy,
b_grid_buf,
b_block_buf,
b_block_slice_copy_step,
blockwise_gemm,
c_thread_buf,
num_k_block_main_loop);
// shuffle C and write out
{
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
"wrong!");
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
// TODO: hacky, fix it!
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
// TODO: hacky, fix it!
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatCShuffle*>(p_shared),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_tuple(
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
M1, // M1 = MWave
M2, // M2 * M3 * M4 = MPerXdl
M3,
M4)),
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
N1, // N1 = NWave
N2))), // N2 = NPerXdl
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(
Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_block_idx =
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_block));
const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_block_idx =
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block));
// shuffle: threadwise copy C from VGPR to LDS
auto c_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<FloatGemmAcc,
FloatCShuffle,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
tensor_operation::element_wise::PassThrough,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
I1,
I1,
M2,
I1,
M4,
I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
1,
InMemoryDataOperationEnum::Set,
1,
true>{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(0,
0,
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I3],
m_thread_data_on_block_idx[I4],
n_thread_data_on_block_idx[I2]),
tensor_operation::element_wise::PassThrough{}};
// shuffle: blockwise copy C from LDS to global
auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
ThisThreadBlock, // ThreadGroup
CElementwiseOperation, // ElementwiseOperation,
CGlobalMemoryDataOperation, // DstInMemOp,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
FloatCShuffle, // typename SrcData,
FloatC, // typename DstData,
decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
true, // bool ThreadTransferSrcResetCoordinateAfterRun,
false> // bool ThreadTransferDstResetCoordinateAfterRun>
{c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(0, 0, 0, 0),
c_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
c_element_op};
const auto NBlock = c0_grid_desc_nblock_nperblock.GetLength(I0);
// for broadcasting bias, beta, gamma
const auto c0_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
c0_grid_desc_nblock_nperblock,
make_tuple(make_insert_transform(I1),
make_insert_transform(I1),
make_pass_through_transform(NBlock),
make_pass_through_transform(NPerBlock)),
make_tuple(Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
// LDS c_reduce_block_desc_mperblock_nperblock
constexpr auto c_reduce_block_desc_mperblock_nperblock = transform_tensor_descriptor(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_tuple(
make_freeze_transform(I0),
make_pass_through_transform(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I1)),
make_freeze_transform(I0),
make_pass_through_transform(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I3))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<>{}, Sequence<1>{}));
static_assert(CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I0) *
CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I1) ==
BlockSize,
"wrong!");
static_assert((CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl) %
CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I0) ==
0 &&
(CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl) %
CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I1) ==
0,
"wrong!");
constexpr index_t mreduce_per_thread =
(CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl) /
CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I0);
constexpr index_t nreduce_per_thread =
(CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl) /
CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I1);
constexpr auto c_reduce_thread_lengths_mperblock_nperblock =
Sequence<mreduce_per_thread, nreduce_per_thread>{};
// pytorch default
// https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html
static constexpr FloatReduceAcc epsilon = 1e-5;
// VGPR c_reduce_thread_desc_mperblock_nperblock
constexpr auto c_reduce_thread_desc_mperblock_nperblock =
make_naive_tensor_descriptor_packed(
make_tuple(Number<mreduce_per_thread>{}, Number<nreduce_per_thread>{}));
constexpr auto c_reduce_thread_desc_mblock_mperblock_nblock_nperblock =
make_naive_tensor_descriptor_packed(
make_tuple(I1, Number<mreduce_per_thread>{}, I1, Number<nreduce_per_thread>{}));
// VGPR d_reduce_thread_desc_mperblock
constexpr auto d_reduce_thread_desc_mperblock =
make_naive_tensor_descriptor_packed(make_tuple(Number<mreduce_per_thread>{}));
// TODO: this should be implemented as a blockwise reduction
auto c_reduce_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatReduceAcc>(
c_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize());
auto c0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatC0>(
c_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize());
// Align 16 bytes (maximum LDS read/write width)
constexpr auto c_block_size_aligned =
math::integer_least_multiple(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize() *
sizeof(FloatCShuffle),
16) /
sizeof(FloatCShuffle);
auto d_reduce_work_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
reinterpret_cast<FloatReduceAcc*>(static_cast<FloatCShuffle*>(p_shared) +
c_block_size_aligned),
BlockSize);
// Sum thread workspace
auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatReduceAcc>(
d_reduce_thread_desc_mperblock.GetElementSpaceSize());
// Squared sum thread workspace
auto d1_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatReduceAcc>(
d_reduce_thread_desc_mperblock.GetElementSpaceSize());
// reduce: threadwise copy from LDS to VGPR
constexpr auto c_reduce_thread_cluster_desc = make_cluster_descriptor(
CReduceThreadClusterLengths_MPerBlock_NPerBlock{}, Sequence<1, 0>{});
const auto c_reduce_thread_cluster_idx =
c_reduce_thread_cluster_desc.CalculateBottomIndex(
make_multi_index(get_thread_local_1d_id()));
const auto c_reduce_thread_data_idx_begin =
c_reduce_thread_cluster_idx * c_reduce_thread_lengths_mperblock_nperblock;
auto c_reduce_thread_copy_lds_to_vgpr = ThreadwiseTensorSliceTransfer_v2<
FloatCShuffle,
FloatReduceAcc,
decltype(c_reduce_block_desc_mperblock_nperblock),
decltype(c_reduce_thread_desc_mperblock_nperblock),
decltype(c_reduce_thread_lengths_mperblock_nperblock),
Sequence<0, 1>,
1,
CReduceThreadCopySrcDstScalarPerVector_NPerBlock,
1,
true>{c_reduce_block_desc_mperblock_nperblock, c_reduce_thread_data_idx_begin};
auto c_reduce_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
FloatReduceAcc,
FloatCShuffle,
decltype(c_reduce_thread_desc_mperblock_nperblock),
decltype(c_reduce_block_desc_mperblock_nperblock),
tensor_operation::element_wise::PassThrough,
decltype(c_reduce_thread_lengths_mperblock_nperblock),
Sequence<0, 1>,
1,
CReduceThreadCopySrcDstScalarPerVector_NPerBlock,
InMemoryDataOperationEnum::Set,
1,
true>{c_reduce_block_desc_mperblock_nperblock,
c_reduce_thread_data_idx_begin,
tensor_operation::element_wise::PassThrough{}};
auto c0_thread_copy_global_to_vgpr = ThreadwiseTensorSliceTransfer_v2<
FloatC0,
FloatC0,
decltype(c0_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(c_reduce_thread_desc_mblock_mperblock_nblock_nperblock),
Sequence<I1, mreduce_per_thread, I1, nreduce_per_thread>,
Sequence<0, 1, 2, 3>,
3,
CReduceThreadCopySrcDstScalarPerVector_NPerBlock,
1,
true>(c0_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(block_work_idx[I0],
c_reduce_thread_data_idx_begin[I0],
block_work_idx[I1],
c_reduce_thread_data_idx_begin[I1]));
// Note: c0_add is of same layout as c so we don't declare new c0_add_desc here
auto c0_add_thread_copy_global_to_vgpr = ThreadwiseTensorSliceTransfer_v2<
FloatC0,
FloatC0,
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(c_reduce_thread_desc_mblock_mperblock_nblock_nperblock),
Sequence<I1, mreduce_per_thread, I1, nreduce_per_thread>,
Sequence<0, 1, 2, 3>,
3,
CReduceThreadCopySrcDstScalarPerVector_NPerBlock,
1,
true>(c_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(block_work_idx[I0],
c_reduce_thread_data_idx_begin[I0],
block_work_idx[I1],
c_reduce_thread_data_idx_begin[I1]));
// space filling curve for threadwise C in VGPR
constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
1,
1,
M2,
1,
M4,
1>>{};
// space filling curve for shuffled blockwise C in global mem
constexpr auto sfc_c_global =
SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
Sequence<0, 2, 1, 3>,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
block_sync_lds();
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
c_thread_buf,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_shuffle_block_buf);
block_sync_lds();
// load from LDS and global, add bias
c_reduce_thread_copy_lds_to_vgpr.Run(c_reduce_block_desc_mperblock_nperblock,
c_shuffle_block_buf,
c_reduce_thread_desc_mperblock_nperblock,
make_tuple(I0, I0),
c_reduce_thread_buf);
c0_thread_copy_global_to_vgpr.Run(
c0_grid_desc_mblock_mperblock_nblock_nperblock,
c0_bias_grid_buf,
c_reduce_thread_desc_mblock_mperblock_nblock_nperblock,
make_tuple(I0, I0, I0, I0),
c0_thread_buf);
static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}(
[&](auto i) {
FloatReduceAcc out;
acc_element_op(out,
c_reduce_thread_buf(i) +
static_cast<FloatReduceAcc>(c0_thread_buf(i)));
c_reduce_thread_buf(i) = out; // acc_element_op(acc + bias)
});
c0_add_thread_copy_global_to_vgpr.Run(
c_grid_desc_mblock_mperblock_nblock_nperblock,
c0_add_grid_buf,
c_reduce_thread_desc_mblock_mperblock_nblock_nperblock,
make_tuple(I0, I0, I0, I0),
c0_thread_buf);
static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}(
[&](auto i) {
c_reduce_thread_buf(i) +=
static_cast<FloatReduceAcc>(c0_thread_buf(i)); // add
});
// layernorm
{
using ThreadwiseReduceD0 =
ThreadwiseReduction<FloatReduceAcc,
decltype(c_reduce_thread_desc_mperblock_nperblock),
decltype(d_reduce_thread_desc_mperblock),
reduce::Add,
false>;
using ThreadwiseReduceD1 =
ThreadwiseReduction<FloatReduceAcc,
decltype(c_reduce_thread_desc_mperblock_nperblock),
decltype(d_reduce_thread_desc_mperblock),
reduce::SquaredAdd,
false>;
const auto d0_zeroVal =
ThreadwiseReduceD0::Op::template GetIdentityValue<FloatReduceAcc>();
const auto d1_zeroVal =
ThreadwiseReduceD1::Op::template GetIdentityValue<FloatReduceAcc>();
static_for<0, mreduce_per_thread, 1>{}(
[&](auto i) { d0_thread_buf(i) = d0_zeroVal; });
static_for<0, mreduce_per_thread, 1>{}(
[&](auto i) { d1_thread_buf(i) = d1_zeroVal; });
// reduce sum in VGPR
ThreadwiseReduceD0::Reduce(c_reduce_thread_buf, d0_thread_buf);
// reduce squared sum in VGPR
ThreadwiseReduceD1::Reduce(c_reduce_thread_buf, d1_thread_buf);
// reduce within workgroup
using BlockwiseReduce = PartitionedBlockwiseReduction<
FloatReduceAcc,
BlockSize,
CReduceThreadClusterLengths_MPerBlock_NPerBlock, // ThreadClusterLengths_M_K
Sequence<1, 0>, // ThreadClusterArrangeOrder
reduce::Add,
false>;
static_for<0, mreduce_per_thread, 1>{}([&](auto i) {
block_sync_lds();
BlockwiseReduce::Reduce(d_reduce_work_buf,
d0_thread_buf(i)); // blockwise reduced sum
block_sync_lds();
BlockwiseReduce::Reduce(d_reduce_work_buf,
d1_thread_buf(i)); // blockwise reduced squared sum
});
// normalize
const index_t NRaw =
c_grid_desc_mblock_mperblock_nblock_nperblock.GetTransforms()[I0]
.GetUpperLengths()[I1]; // TODO: proper handle
static_for<0, mreduce_per_thread, 1>{}([&](auto im) {
static_for<0, nreduce_per_thread, 1>{}([&](auto in) {
constexpr auto dst_offset =
Number<c_reduce_thread_desc_mperblock_nperblock.CalculateOffset(
make_tuple(im, in))>{};
constexpr auto src_offset =
Number<d_reduce_thread_desc_mperblock.CalculateOffset(
make_tuple(im))>{};
FloatReduceAcc avg_sum = d0_thread_buf(src_offset) / NRaw;
FloatReduceAcc avg_squared_sum = d1_thread_buf(src_offset) / NRaw;
FloatReduceAcc numerator = c_reduce_thread_buf(dst_offset) - avg_sum;
FloatReduceAcc divisor = epsilon + avg_squared_sum - avg_sum * avg_sum;
FloatReduceAcc divisor_sqrt;
tensor_operation::element_wise::UnarySqrt{}(divisor_sqrt, divisor);
c_reduce_thread_buf(dst_offset) = numerator / divisor_sqrt;
});
});
// scaling
c0_thread_copy_global_to_vgpr.Run(
c0_grid_desc_mblock_mperblock_nblock_nperblock,
c0_gamma_grid_buf,
c_reduce_thread_desc_mblock_mperblock_nblock_nperblock,
make_tuple(I0, I0, I0, I0),
c0_thread_buf);
static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}(
[&](auto i) {
c_reduce_thread_buf(i) *=
static_cast<FloatReduceAcc>(c0_thread_buf(i)); // * gamma
});
c0_thread_copy_global_to_vgpr.Run(
c0_grid_desc_mblock_mperblock_nblock_nperblock,
c0_beta_grid_buf,
c_reduce_thread_desc_mblock_mperblock_nblock_nperblock,
make_tuple(I0, I0, I0, I0),
c0_thread_buf);
static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}(
[&](auto i) {
c_reduce_thread_buf(i) +=
static_cast<FloatReduceAcc>(c0_thread_buf(i)); // + beta
});
block_sync_lds();
c_reduce_thread_copy_vgpr_to_lds.Run(c_reduce_thread_desc_mperblock_nperblock,
make_tuple(I0, I0),
c_reduce_thread_buf,
c_reduce_block_desc_mperblock_nperblock,
c_shuffle_block_buf);
} // end layernorm
block_sync_lds();
// each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global.Run(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
c_shuffle_block_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
if constexpr(access_id < num_access - 1)
{
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
// move on C
c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
// move on C0
c0_thread_copy_global_to_vgpr.MoveSrcSliceWindow(
c0_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
// move on C0_add
c0_add_thread_copy_global_to_vgpr.MoveSrcSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
}
});
}
}
};
} // namespace ck
...@@ -49,7 +49,8 @@ template <typename InDataType, ...@@ -49,7 +49,8 @@ template <typename InDataType,
index_t KThreadSliceSize, index_t KThreadSliceSize,
index_t InSrcVectorDim, index_t InSrcVectorDim,
index_t InSrcVectorSize, index_t InSrcVectorSize,
index_t OutDstVectorSize> index_t OutDstVectorSize,
bool SweepOnce>
struct GridwiseSoftmax_mk_to_mk struct GridwiseSoftmax_mk_to_mk
{ {
static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) || static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
...@@ -75,19 +76,6 @@ struct GridwiseSoftmax_mk_to_mk ...@@ -75,19 +76,6 @@ struct GridwiseSoftmax_mk_to_mk
using ThreadReduceDstDesc_M = using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}))); decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using BlockwiseMaxReduce = PartitionedBlockwiseReduction<AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
reduce::Max,
false>; // PropagateNan
using ThreadwiseMaxReduce = ThreadwiseReduction<AccDataType,
ThreadReduceSrcDesc_M_K,
ThreadReduceDstDesc_M,
reduce::Max,
false>; // PropagateNan
using PassThroughOp = tensor_operation::element_wise::PassThrough; using PassThroughOp = tensor_operation::element_wise::PassThrough;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -105,6 +93,11 @@ struct GridwiseSoftmax_mk_to_mk ...@@ -105,6 +93,11 @@ struct GridwiseSoftmax_mk_to_mk
AccDataType beta, AccDataType beta,
OutDataType* const __restrict__ p_out_value_global) OutDataType* const __restrict__ p_out_value_global)
{ {
if constexpr(SweepOnce)
{
num_k_block_tile_iteration = 1;
}
// LDS // LDS
__shared__ AccDataType p_reduce_work_buffer[BlockSize]; __shared__ AccDataType p_reduce_work_buffer[BlockSize];
...@@ -149,6 +142,20 @@ struct GridwiseSoftmax_mk_to_mk ...@@ -149,6 +142,20 @@ struct GridwiseSoftmax_mk_to_mk
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed( constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})); make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
// Normally, 0 as invalid element value is adequate since 0 makes no contribution to
// accumulated result. However, in stable softmax, all values 0s or not are subtracted by
// another value_max. As numbers become non-zero, effectively it allows invalid values to
// slip through and contribute to the accumulated result.
//
// The trick here is leveraging the fact that many math functions (add, sub, exp, ...)
// propagate NaNs when operands have NaNs involved. By initialiing invalid element value
// with NaN, an invalid value doing math manipulations is still NaN, which in turn can still
// be identified as an invalid value. We can then discard the invalid values which
// originally failed the bound check during accumulation. This allows to ignore values that
// failed bound check even after multiple math manipulations.
//
// NOTE: reset coordinate after every step because the same threadwise copy will sweep
// through global memory 3 times back and forth
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType, auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
AccDataType, AccDataType,
GridDesc_M_K, GridDesc_M_K,
...@@ -158,7 +165,8 @@ struct GridwiseSoftmax_mk_to_mk ...@@ -158,7 +165,8 @@ struct GridwiseSoftmax_mk_to_mk
InSrcVectorDim, InSrcVectorDim,
InSrcVectorSize, InSrcVectorSize,
1, 1,
false>( true /* ResetCoordAfterRun */,
true /* InvalidElementAsNaN */>(
in_grid_desc_m_k, in_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
block_local_id * reduceSizePerBlock + block_local_id * reduceSizePerBlock +
...@@ -198,21 +206,39 @@ struct GridwiseSoftmax_mk_to_mk ...@@ -198,21 +206,39 @@ struct GridwiseSoftmax_mk_to_mk
block_local_id * reduceSizePerBlock + thread_k_cluster_id * KThreadSliceSize), block_local_id * reduceSizePerBlock + thread_k_cluster_id * KThreadSliceSize),
PassThroughOp{}); PassThroughOp{});
constexpr auto in_thread_copy_fwd_step = make_multi_index(0, K_BlockTileSize); constexpr auto in_thread_copy_fwd_step =
constexpr auto in_thread_copy_bwd_step = make_multi_index(0, -K_BlockTileSize); make_multi_index(0, SweepOnce ? 0 : K_BlockTileSize);
constexpr auto in_thread_copy_bwd_step =
make_multi_index(0, SweepOnce ? 0 : -K_BlockTileSize);
/// ///
/// max(x) /// max(x)
/// ///
const auto in_global_val_buf_oob_non_zero = make_dynamic_buffer<AddressSpaceEnum::Global>( using BlockwiseMaxReduce = PartitionedBlockwiseReduction<
p_in_value_global, AccDataType,
in_grid_desc_m_k.GetElementSpaceSize(), BlockSize,
reduce::Max::template GetIdentityValue<InDataType>()); ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
reduce::Max,
false, // param ignored
detail::AccumulateWithNanIgnore<reduce::Max, AccDataType>>;
using ThreadwiseMaxReduce =
ThreadwiseReduction<AccDataType,
ThreadReduceSrcDesc_M_K,
ThreadReduceDstDesc_M,
reduce::Max,
false, // param ignored
detail::AccumulateWithNanIgnore<reduce::Max, AccDataType>>;
const auto in_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_value_global, in_grid_desc_m_k.GetElementSpaceSize());
index_t reducedTiles = 0; index_t reducedTiles = 0;
do do
{ {
threadwise_src_load.Run(in_grid_desc_m_k, threadwise_src_load.Run(in_grid_desc_m_k,
in_global_val_buf_oob_non_zero, in_global_val_buf,
thread_buffer_desc, thread_buffer_desc,
make_tuple(I0, I0), make_tuple(I0, I0),
in_thread_buf); in_thread_buf);
...@@ -232,26 +258,6 @@ struct GridwiseSoftmax_mk_to_mk ...@@ -232,26 +258,6 @@ struct GridwiseSoftmax_mk_to_mk
/// ///
/// sum(exp(x - max(x))) /// sum(exp(x - max(x)))
/// ///
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) = reduce::Add::template GetIdentityValue<AccDataType>();
});
// Normally, 0 as invalid element value is adequate since 0 makes no contribution to
// accumulated result. However, in stable softmax, all values 0s or not are subtracted by
// another value_max. As numbers become non-zero, effectively it allows invalid values to
// slip through and contribute to the accumulated result.
//
// The trick here is leveraging the fact that many math functions (add, sub, exp, ...)
// propagate NaNs when operands have NaNs involved. By initialiing invalid element value
// with NaN, an invalid value doing math manipulations is still NaN, which in turn can still
// be identified as an invalid value. We can then discard the invalid values which
// originally failed the bound check during accumulation. This allows to ignore values that
// failed bound check even after multiple math manipulations.
const auto in_global_val_buf_oob_nan =
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(),
NumericLimits<InDataType>::QuietNaN());
using BlockwiseSumReduce = PartitionedBlockwiseReduction< using BlockwiseSumReduce = PartitionedBlockwiseReduction<
AccDataType, AccDataType,
BlockSize, BlockSize,
...@@ -272,22 +278,25 @@ struct GridwiseSoftmax_mk_to_mk ...@@ -272,22 +278,25 @@ struct GridwiseSoftmax_mk_to_mk
reducedTiles = 0; reducedTiles = 0;
do do
{ {
threadwise_src_load.Run(in_grid_desc_m_k, if constexpr(!SweepOnce)
in_global_val_buf_oob_nan, {
thread_buffer_desc, threadwise_src_load.Run(in_grid_desc_m_k,
make_tuple(I0, I0), in_global_val_buf,
in_thread_buf); thread_buffer_desc,
make_tuple(I0, I0),
in_thread_buf);
}
// do element-wise pre-reduction operation // do element-wise pre-reduction operation
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
in_thread_buf(Number<offset>{}) = out_thread_buf(Number<offset>{}) =
math::exp(in_thread_buf(Number<offset>{}) - max_value_buf(iM)); math::exp(in_thread_buf(Number<offset>{}) - max_value_buf(iM));
}); });
}); });
ThreadwiseSumReduce::Reduce(in_thread_buf, accu_value_buf); ThreadwiseSumReduce::Reduce(out_thread_buf, accu_value_buf);
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_bwd_step); threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_bwd_step);
...@@ -309,11 +318,14 @@ struct GridwiseSoftmax_mk_to_mk ...@@ -309,11 +318,14 @@ struct GridwiseSoftmax_mk_to_mk
{ {
do do
{ {
threadwise_src_load.Run(in_grid_desc_m_k, if constexpr(!SweepOnce)
in_global_val_buf_oob_nan, {
thread_buffer_desc, threadwise_src_load.Run(in_grid_desc_m_k,
make_tuple(I0, I0), in_global_val_buf,
in_thread_buf); thread_buffer_desc,
make_tuple(I0, I0),
in_thread_buf);
}
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
// out = alpha * exp(x - max(x)) / sum(exp(x - max(x))) // out = alpha * exp(x - max(x)) / sum(exp(x - max(x)))
...@@ -340,18 +352,27 @@ struct GridwiseSoftmax_mk_to_mk ...@@ -340,18 +352,27 @@ struct GridwiseSoftmax_mk_to_mk
} }
else else
{ {
StaticBuffer<AddressSpaceEnum::Vgpr,
AccDataType,
MThreadSliceSize * KThreadSliceSize,
true>
in_prior_dst_buf;
do do
{ {
threadwise_src_load.Run(in_grid_desc_m_k, if constexpr(!SweepOnce)
in_global_val_buf_oob_nan, {
thread_buffer_desc, threadwise_src_load.Run(in_grid_desc_m_k,
make_tuple(I0, I0), in_global_val_buf,
in_thread_buf); thread_buffer_desc,
make_tuple(I0, I0),
in_thread_buf);
}
threadwise_dst_load.Run(out_grid_desc_m_k, threadwise_dst_load.Run(out_grid_desc_m_k,
out_global_val_buf, out_global_val_buf,
thread_buffer_desc, thread_buffer_desc,
make_tuple(I0, I0), make_tuple(I0, I0),
out_thread_buf); in_prior_dst_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
// out = alpha * exp(x - max(x)) / sum(exp(x - max(x))) + beta * prior_out // out = alpha * exp(x - max(x)) / sum(exp(x - max(x))) + beta * prior_out
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
...@@ -360,7 +381,7 @@ struct GridwiseSoftmax_mk_to_mk ...@@ -360,7 +381,7 @@ struct GridwiseSoftmax_mk_to_mk
out_thread_buf(Number<offset>{}) = out_thread_buf(Number<offset>{}) =
alpha * math::exp(in_thread_buf(Number<offset>{}) - max_value_buf(iM)) / alpha * math::exp(in_thread_buf(Number<offset>{}) - max_value_buf(iM)) /
accu_value_buf(iM) + accu_value_buf(iM) +
beta * out_thread_buf(Number<offset>{}); beta * in_prior_dst_buf(Number<offset>{});
}); });
}); });
......
...@@ -30,6 +30,8 @@ struct ThreadwiseReduction ...@@ -30,6 +30,8 @@ struct ThreadwiseReduction
static_assert(src_length_m == dst_length_m, "lengths of source and dst buffer must match!"); static_assert(src_length_m == dst_length_m, "lengths of source and dst buffer must match!");
using Op = OpReduce;
template <typename SrcBufferType, typename DstBufferType> template <typename SrcBufferType, typename DstBufferType>
__device__ static void Reduce(const SrcBufferType& src_buf, DstBufferType& dst_buf) __device__ static void Reduce(const SrcBufferType& src_buf, DstBufferType& dst_buf)
{ {
......
...@@ -236,9 +236,14 @@ template <typename SrcData, ...@@ -236,9 +236,14 @@ template <typename SrcData,
index_t SrcScalarPerVector, index_t SrcScalarPerVector,
index_t SrcScalarStrideInVector, index_t SrcScalarStrideInVector,
bool SrcResetCoordinateAfterRun, bool SrcResetCoordinateAfterRun,
bool InvalidElementAsNaN = false,
typename enable_if<DstDesc::IsKnownAtCompileTime(), bool>::type = false> typename enable_if<DstDesc::IsKnownAtCompileTime(), bool>::type = false>
struct ThreadwiseTensorSliceTransfer_v2 struct ThreadwiseTensorSliceTransfer_v2
{ {
static_assert((InvalidElementAsNaN && !std::is_integral<DstData>::value) ||
(!InvalidElementAsNaN),
"Filling invalid element as NaN is only for floating point types");
static constexpr index_t nDim = SliceLengths::Size(); static constexpr index_t nDim = SliceLengths::Size();
using Index = MultiIndex<nDim>; using Index = MultiIndex<nDim>;
...@@ -318,8 +323,18 @@ struct ThreadwiseTensorSliceTransfer_v2 ...@@ -318,8 +323,18 @@ struct ThreadwiseTensorSliceTransfer_v2
dst_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) + src_data_idx + dst_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) + src_data_idx +
i * src_scalar_step_in_vector); i * src_scalar_step_in_vector);
dst_buf(Number<dst_offset>{}) = if constexpr(InvalidElementAsNaN)
type_convert<DstData>(src_vector.template AsType<SrcData>()[i]); {
dst_buf(Number<dst_offset>{}) =
is_src_valid
? type_convert<DstData>(src_vector.template AsType<SrcData>()[i])
: NumericLimits<DstData>::QuietNaN();
}
else
{
dst_buf(Number<dst_offset>{}) =
type_convert<DstData>(src_vector.template AsType<SrcData>()[i]);
}
}); });
if constexpr(idx_1d.value != num_access - 1) if constexpr(idx_1d.value != num_access - 1)
......
...@@ -932,14 +932,14 @@ using int8x64_t = typename vector_type<int8_t, 64>::type; ...@@ -932,14 +932,14 @@ using int8x64_t = typename vector_type<int8_t, 64>::type;
// Convert X to Y // Convert X to Y
template <typename Y, typename X> template <typename Y, typename X>
__host__ __device__ Y type_convert(X x) __host__ __device__ constexpr Y type_convert(X x)
{ {
return static_cast<Y>(x); return static_cast<Y>(x);
} }
// convert bfp16 to fp32 // convert bfp16 to fp32
template <> template <>
inline __host__ __device__ float type_convert<float, bhalf_t>(bhalf_t x) inline __host__ __device__ constexpr float type_convert<float, bhalf_t>(bhalf_t x)
{ {
union union
{ {
...@@ -952,7 +952,7 @@ inline __host__ __device__ float type_convert<float, bhalf_t>(bhalf_t x) ...@@ -952,7 +952,7 @@ inline __host__ __device__ float type_convert<float, bhalf_t>(bhalf_t x)
// convert fp32 to bfp16 // convert fp32 to bfp16
template <> template <>
inline __host__ __device__ bhalf_t type_convert<bhalf_t, float>(float x) inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float x)
{ {
union union
{ {
......
...@@ -12,21 +12,27 @@ template <typename T, typename Enable = void> ...@@ -12,21 +12,27 @@ template <typename T, typename Enable = void>
struct PrintAsType; struct PrintAsType;
template <typename T> template <typename T>
struct PrintAsType<T, typename std::enable_if<std::is_floating_point<T>::value>::value> struct PrintAsType<T, typename std::enable_if<std::is_floating_point<T>::value>::type>
{ {
using type = float; using type = float;
__host__ __device__ static void Print(const T& p) { printf("%.3f ", static_cast<type>(p)); }
}; };
template <> template <>
struct PrintAsType<ck::half_t, void> struct PrintAsType<ck::half_t, void>
{ {
using type = float; using type = float;
__host__ __device__ static void Print(const ck::half_t& p)
{
printf("%.3f ", static_cast<type>(p));
}
}; };
template <typename T> template <typename T>
struct PrintAsType<T, typename std::enable_if<std::is_integral<T>::value>::value> struct PrintAsType<T, typename std::enable_if<std::is_integral<T>::value>::type>
{ {
using type = int; using type = int;
__host__ __device__ static void Print(const T& p) { printf("%d ", static_cast<type>(p)); }
}; };
} // namespace detail } // namespace detail
...@@ -41,7 +47,6 @@ struct PrintAsType<T, typename std::enable_if<std::is_integral<T>::value>::value ...@@ -41,7 +47,6 @@ struct PrintAsType<T, typename std::enable_if<std::is_integral<T>::value>::value
template <typename T, index_t element_stride = 1, index_t row_bytes = 128> template <typename T, index_t element_stride = 1, index_t row_bytes = 128>
__device__ void print_shared(T const* p_shared, index_t num_elements) __device__ void print_shared(T const* p_shared, index_t num_elements)
{ {
using PrintType = typename detail::PrintAsType<T>::type;
constexpr index_t row_elements = row_bytes / sizeof(T); constexpr index_t row_elements = row_bytes / sizeof(T);
static_assert((element_stride >= 1 && element_stride <= row_elements), static_assert((element_stride >= 1 && element_stride <= row_elements),
"element_stride should between [1, row_elements]"); "element_stride should between [1, row_elements]");
...@@ -63,7 +68,7 @@ __device__ void print_shared(T const* p_shared, index_t num_elements) ...@@ -63,7 +68,7 @@ __device__ void print_shared(T const* p_shared, index_t num_elements)
printf("elem %5d: ", i); printf("elem %5d: ", i);
for(index_t j = 0; j < row_elements; j += element_stride) for(index_t j = 0; j < row_elements; j += element_stride)
{ {
printf("%.0f ", static_cast<PrintType>(p_shared[i + j])); detail::PrintAsType<T>::Print(p_shared[i + j]);
} }
printf("\n"); printf("\n");
......
...@@ -148,6 +148,8 @@ __host__ __device__ constexpr auto min(X x, Ys... ys) ...@@ -148,6 +148,8 @@ __host__ __device__ constexpr auto min(X x, Ys... ys)
template <typename T> template <typename T>
__device__ T exp(T x); __device__ T exp(T x);
// TODO: add f16 support using v_exp_f16
template <> template <>
__device__ float exp<float>(float x) __device__ float exp<float>(float x)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment