Commit b89a88b5 authored by Adam Osewski's avatar Adam Osewski
Browse files

Merge branch 'develop' into wavelet_model

parents 41d5fca7 43c898f6
......@@ -9,6 +9,7 @@ namespace device {
enum struct GemmSpecialization
{
// Gemm
Default,
MPadding,
NPadding,
......@@ -17,6 +18,15 @@ enum struct GemmSpecialization
MKPadding,
NKPadding,
MNKPadding,
// Gemm + Gemm
OPadding,
MOPadding,
NOPadding,
KOPadding,
MNOPadding,
MKOPadding,
NKOPadding,
MNKOPadding,
};
inline std::string getGemmSpecializationString(const GemmSpecialization& s)
......@@ -31,6 +41,14 @@ inline std::string getGemmSpecializationString(const GemmSpecialization& s)
case GemmSpecialization::MKPadding: return "MKPadding";
case GemmSpecialization::NKPadding: return "NKPadding";
case GemmSpecialization::MNKPadding: return "MNKPadding";
case GemmSpecialization::OPadding: return "OPadding";
case GemmSpecialization::MOPadding: return "MOPadding";
case GemmSpecialization::NOPadding: return "NOPadding";
case GemmSpecialization::KOPadding: return "KOPadding";
case GemmSpecialization::MNOPadding: return "MNOPadding";
case GemmSpecialization::MKOPadding: return "MKOPadding";
case GemmSpecialization::NKOPadding: return "NKOPadding";
case GemmSpecialization::MNKOPadding: return "MNKOPadding";
default: return "Unrecognized specialization!";
}
}
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp"
#include "ck/tensor_operation/gpu/device/device_softmax.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce_multiblock.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_softmax.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename InDataType,
typename AccDataType,
typename OutDataType,
typename InElementwiseOp,
typename AccElementwiseOp,
index_t Rank,
index_t NumReduceDim,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t InSrcVectorDim,
index_t InSrcVectorSize,
index_t OutDstVectorSize>
struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
AccDataType,
OutDataType,
InElementwiseOp,
AccElementwiseOp,
Rank>
{
static constexpr index_t kRank = Rank;
static constexpr index_t kNumReduceDim = NumReduceDim;
virtual index_t GetRank() const override { return kRank; }
virtual index_t GetNumReduceDim() const override { return kNumReduceDim; }
// Used for freeloading of some handy functions from DeviceReduceMultiBlock
using Reduction = DeviceReduceMultiBlock<InDataType,
AccDataType,
OutDataType,
Rank,
NumReduceDim,
reduce::Add,
InElementwiseOp,
AccElementwiseOp,
InMemoryDataOperationEnum::Set,
false, // PropagateNan
false, // OutputIndex
false, // HaveIndexInputIfOutputIndex
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
InSrcVectorDim,
InSrcVectorSize,
1>; // OutDstVectorSize
using GridDesc_M_K = decltype(Reduction::MakeSrc2dDescriptor({1}, {1}, 1, 1));
using GridwiseSoftmaxGeneric = GridwiseSoftmax_mk_to_mk<InDataType,
OutDataType,
AccDataType,
GridDesc_M_K,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
InSrcVectorDim,
InSrcVectorSize,
OutDstVectorSize,
false>;
using GridwiseSoftmaxSweepOnce = GridwiseSoftmax_mk_to_mk<InDataType,
OutDataType,
AccDataType,
GridDesc_M_K,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
InSrcVectorDim,
InSrcVectorSize,
OutDstVectorSize,
true>;
struct Argument : public Reduction::Argument
{
Argument(const std::vector<index_t> inLengths,
const std::vector<index_t> inStrides,
const std::vector<index_t> reduceDims,
AccDataType alpha,
AccDataType beta,
const InDataType* in_dev,
OutDataType* out_dev,
InElementwiseOp in_elementwise_op,
AccElementwiseOp acc_elementwise_op)
: Reduction::Argument(inLengths,
inStrides,
{},
{},
reduceDims,
0.0f, // alpha
0.0f, // beta
in_dev,
nullptr,
out_dev,
nullptr,
in_elementwise_op,
acc_elementwise_op),
// FIXME: The base class DeviceReduceMultiBlock::Argument only supports alpha/beta of
// float32 precision. Make it support any data type so the fields can be removed.
alpha_(alpha),
beta_(beta)
{
// std::cout << "blkGroupSize= " << this->blkGroupSize
// << ", numBlockTileIteration= " << this->numBlockTileIteration
// << ", gridSize=" << this->gridSize
// << ", invariant_total_length=" << this->invariant_total_length <<
// std::endl;
}
AccDataType alpha_;
AccDataType beta_;
};
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const auto in_grid_desc_m_k = Reduction::MakeSrc2dDescriptor(
arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration);
const auto out_grid_desc_m_k = Reduction::MakeSrc2dDescriptor(
arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration);
bool sweep_once =
in_grid_desc_m_k.GetLength(Number<1>{}) <= KThreadClusterSize * KThreadSliceSize;
const auto kernel_main = sweep_once ? kernel_softmax<GridwiseSoftmaxSweepOnce,
InDataType,
OutDataType,
AccDataType,
GridDesc_M_K>
: kernel_softmax<GridwiseSoftmaxGeneric,
InDataType,
OutDataType,
AccDataType,
GridDesc_M_K>;
float avg_time = 0;
avg_time += launch_and_time_kernel(stream_config,
kernel_main,
dim3(arg.gridSize),
dim3(BlockSize),
0,
in_grid_desc_m_k,
out_grid_desc_m_k,
arg.blkGroupSize,
arg.numBlockTileIteration,
arg.alpha_,
arg.in_dev_,
arg.beta_,
arg.out_dev_);
return (avg_time);
};
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
};
};
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
const Argument* p_arg_ = dynamic_cast<const Argument*>(p_arg);
if(!Reduction::IsSupportedArgument(p_arg_))
{
return false;
}
if(p_arg_->inLengths_[Rank - 1] % OutDstVectorSize != 0)
{
return false;
}
return true;
};
//
// @brief Makes a pointer to Argument class.
//
// @param[in] inLengths Input tensor extent(s) from high to low dimension
// @param[in] inStrides Input tensor stride(s) from high to low dimension
// @param[in] reduceDims The dimension(s) the normalization operation is applied
// @param[in] alpha Typeless pointer in host memory storing the alpha scaling
// value as type AccDataType
// @param[in] beta Typeless pointer in host memory storing the beta scaling
// value as type AccDataType
// @param[in] in_dev Typeless const pointer in device memory storing the input
// tensor
// @param out_dev Typeless pointer in device memory storing the output tensor
// @param[in] in_elementwise_op The input elementwise operation.
// @param[in] acc_elementwise_op The accumulation elementwise operation.
//
// @return Unique pointer to the Argument class.
//
std::unique_ptr<BaseArgument> MakeArgumentPointer(const std::vector<index_t> inLengths,
const std::vector<index_t> inStrides,
const std::vector<int> reduceDims,
const void* alpha,
const void* beta,
const void* in_dev,
void* out_dev,
InElementwiseOp in_elementwise_op,
AccElementwiseOp acc_elementwise_op) override
{
return std::make_unique<Argument>(inLengths,
inStrides,
reduceDims,
*static_cast<const AccDataType*>(alpha),
*static_cast<const AccDataType*>(beta),
static_cast<const InDataType*>(in_dev),
static_cast<OutDataType*>(out_dev),
in_elementwise_op,
acc_elementwise_op);
};
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>();
};
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceReduceSoftmax<" << BlockSize << ",";
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << "_OutDstVectorSize_" << OutDstVectorSize << ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -12,12 +12,220 @@ namespace ck {
namespace tensor_operation {
namespace device {
template <typename TensorDesc,
typename TileLengths, // Tuple<...>
typename DoPads> // Sequence<bool, bool, ...>
__host__ __device__ constexpr auto
PadTensorDescriptor(const TensorDesc& desc, const TileLengths& tile_lengths, DoPads)
{
constexpr index_t num_dim = DoPads::Size();
static_assert(num_dim == TileLengths::Size() && num_dim == TensorDesc::GetNumOfDimension(),
"wrong! inconsistent # of dimensions");
// transforms
const auto transforms = generate_tuple(
[&](auto idim) {
const auto MRaw = desc.GetLength(idim);
const auto MPerTile = tile_lengths[idim];
const auto M = math::integer_divide_ceil(MRaw, MPerTile) * MPerTile;
const auto MPad = M - MRaw;
const bool DoPadM = DoPads::At(idim);
const auto MTransform = conditional_expr<DoPadM>(make_right_pad_transform(MRaw, MPad),
make_pass_through_transform(MRaw));
return MTransform;
},
Number<num_dim>{});
// lower dimension Id
const auto lower_dimss =
generate_tuple([&](auto idim) { return Sequence<idim.value>{}; }, Number<num_dim>{});
// upper dimension Id
const auto upper_dimss = lower_dimss;
return transform_tensor_descriptor(desc, transforms, lower_dimss, upper_dimss);
}
// M/N/K/OPerTileType could be index_t or Number<>
template <GemmSpecialization GemmSpec,
typename MPerTileType,
typename NPerTileType,
typename KPerTileType,
typename OPerTileType>
struct GemmGemmPadder
{
// TODO: hard to scale; use mask instead
static constexpr bool PadM =
GemmSpec == GemmSpecialization::MPadding || GemmSpec == GemmSpecialization::MNPadding ||
GemmSpec == GemmSpecialization::MKPadding || GemmSpec == GemmSpecialization::MNKPadding ||
GemmSpec == GemmSpecialization::MOPadding || GemmSpec == GemmSpecialization::MNOPadding ||
GemmSpec == GemmSpecialization::MKOPadding || GemmSpec == GemmSpecialization::MNKOPadding;
static constexpr bool PadN =
GemmSpec == GemmSpecialization::NPadding || GemmSpec == GemmSpecialization::MNPadding ||
GemmSpec == GemmSpecialization::NKPadding || GemmSpec == GemmSpecialization::MNKPadding ||
GemmSpec == GemmSpecialization::NOPadding || GemmSpec == GemmSpecialization::MNOPadding ||
GemmSpec == GemmSpecialization::NKOPadding || GemmSpec == GemmSpecialization::MNKOPadding;
static constexpr bool PadK =
GemmSpec == GemmSpecialization::KPadding || GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::NKPadding || GemmSpec == GemmSpecialization::MNKPadding ||
GemmSpec == GemmSpecialization::KOPadding || GemmSpec == GemmSpecialization::MKOPadding ||
GemmSpec == GemmSpecialization::NKOPadding || GemmSpec == GemmSpecialization::MNKOPadding;
static constexpr bool PadO =
GemmSpec == GemmSpecialization::OPadding || GemmSpec == GemmSpecialization::MOPadding ||
GemmSpec == GemmSpecialization::NOPadding || GemmSpec == GemmSpecialization::KOPadding ||
GemmSpec == GemmSpecialization::MNOPadding || GemmSpec == GemmSpecialization::MKOPadding ||
GemmSpec == GemmSpecialization::NKOPadding || GemmSpec == GemmSpecialization::MNKOPadding;
// A[M, K]
template <typename ADesc_MRaw_KRaw>
__host__ __device__ constexpr auto
PadADescriptor_M_K(const ADesc_MRaw_KRaw& a_desc_mraw_kraw) const
{
return PadTensorDescriptor(
a_desc_mraw_kraw, make_tuple(MPerTile_, KPerTile_), Sequence<PadM, PadK>{});
}
// B[K, N]
template <typename BDesc_NRaw_KRaw>
__host__ __device__ constexpr auto
PadBDescriptor_N_K(const BDesc_NRaw_KRaw& b_desc_nraw_kraw) const
{
return PadTensorDescriptor(
b_desc_nraw_kraw, make_tuple(NPerTile_, KPerTile_), Sequence<PadN, PadK>{});
}
// B1[Gemm1N, Gemm1K] = B1[O, N]
template <typename B1Desc_NRaw_KRaw>
__host__ __device__ constexpr auto
PadB1Descriptor_N_K(const B1Desc_NRaw_KRaw& b1_desc_nraw_kraw) const
{
return PadTensorDescriptor(
b1_desc_nraw_kraw, make_tuple(OPerTile_, NPerTile_), Sequence<PadO, PadN>{});
}
// C[M, Gemm1N] = C[M, O]
template <typename CDesc_MRaw_NRaw>
__host__ __device__ constexpr auto
PadCDescriptor_M_N(const CDesc_MRaw_NRaw& c_desc_mraw_nraw) const
{
return PadTensorDescriptor(
c_desc_mraw_nraw, make_tuple(MPerTile_, OPerTile_), Sequence<PadM, PadO>{});
}
MPerTileType MPerTile_;
NPerTileType NPerTile_;
KPerTileType KPerTile_;
OPerTileType OPerTile_;
};
// M/N/KPerTileType could be index_t or Number<>
template <GemmSpecialization GemmSpec,
typename MPerTileType,
typename NPerTileType,
typename KPerTileType>
struct MatrixPadder
struct GemmPadder
{
static constexpr bool PadM =
(GemmSpec == GemmSpecialization::MPadding || GemmSpec == GemmSpecialization::MNPadding ||
GemmSpec == GemmSpecialization::MKPadding || GemmSpec == GemmSpecialization::MNKPadding);
static constexpr bool PadN =
(GemmSpec == GemmSpecialization::NPadding || GemmSpec == GemmSpecialization::MNPadding ||
GemmSpec == GemmSpecialization::NKPadding || GemmSpec == GemmSpecialization::MNKPadding);
static constexpr bool PadK =
(GemmSpec == GemmSpecialization::KPadding || GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::NKPadding || GemmSpec == GemmSpecialization::MNKPadding);
template <typename ADesc_MRaw_KRaw>
__host__ __device__ constexpr auto
PadADescriptor_M_K(const ADesc_MRaw_KRaw& a_desc_mraw_kraw) const
{
return PadTensorDescriptor(
a_desc_mraw_kraw, make_tuple(MPerTile_, KPerTile_), Sequence<PadM, PadK>{});
}
template <typename BDesc_NRaw_KRaw>
__host__ __device__ constexpr auto
PadBDescriptor_N_K(const BDesc_NRaw_KRaw& b_desc_nraw_kraw) const
{
return PadTensorDescriptor(
b_desc_nraw_kraw, make_tuple(NPerTile_, KPerTile_), Sequence<PadN, PadK>{});
}
template <typename CDesc_MRaw_NRaw>
__host__ __device__ constexpr auto
PadCDescriptor_M_N(const CDesc_MRaw_NRaw& c_desc_mraw_nraw) const
{
return PadTensorDescriptor(
c_desc_mraw_nraw, make_tuple(MPerTile_, NPerTile_), Sequence<PadM, PadN>{});
}
MPerTileType MPerTile_;
NPerTileType NPerTile_;
KPerTileType KPerTile_;
};
// Alias of GemmPadder; to deprecate
template <GemmSpecialization GemmSpec,
typename MPerTileType,
typename NPerTileType,
typename KPerTileType>
struct MatrixPadder : public GemmPadder<GemmSpec, MPerTileType, NPerTileType, KPerTileType>
{
};
// M/N/KPerTileType could be index_t or Number<>
template <bool PadM,
bool PadN,
bool PadK,
typename MPerTileType,
typename NPerTileType,
typename KPerTileType>
struct GemmPadder_v2
{
template <typename ADesc_MRaw_KRaw>
__host__ __device__ constexpr auto
PadADescriptor_M_K(const ADesc_MRaw_KRaw& a_desc_mraw_kraw) const
{
return PadTensorDescriptor(
a_desc_mraw_kraw, make_tuple(MPerTile_, KPerTile_), Sequence<PadM, PadK>{});
}
template <typename BDesc_NRaw_KRaw>
__host__ __device__ constexpr auto
PadBDescriptor_N_K(const BDesc_NRaw_KRaw& b_desc_nraw_kraw) const
{
return PadTensorDescriptor(
b_desc_nraw_kraw, make_tuple(NPerTile_, KPerTile_), Sequence<PadN, PadK>{});
}
template <typename CDesc_MRaw_NRaw>
__host__ __device__ constexpr auto
PadCDescriptor_M_N(const CDesc_MRaw_NRaw& c_desc_mraw_nraw) const
{
return PadTensorDescriptor(
c_desc_mraw_nraw, make_tuple(MPerTile_, NPerTile_), Sequence<PadM, PadN>{});
}
MPerTileType MPerTile_;
NPerTileType NPerTile_;
KPerTileType KPerTile_;
};
// M/N/KPerTileType could be index_t or Number<>
template <bool PadM,
bool PadN,
bool PadK,
typename MPerTileType,
typename NPerTileType,
typename KPerTileType>
struct MatrixPadder_v2
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
......@@ -37,8 +245,7 @@ struct MatrixPadder
const auto MPad = M - MRaw;
const auto KPad = K - KRaw;
if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
if constexpr(PadM && PadK)
{
// pad both M and K
return transform_tensor_descriptor(a_desc_mraw_kraw,
......@@ -47,8 +254,7 @@ struct MatrixPadder
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MNPadding)
else if constexpr(PadM && (!PadK))
{
// pad M, but not K
return transform_tensor_descriptor(
......@@ -57,8 +263,7 @@ struct MatrixPadder
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::NKPadding)
else if constexpr((!PadM) && PadK)
{
// pad K, but not M
return transform_tensor_descriptor(
......@@ -87,8 +292,7 @@ struct MatrixPadder
const auto NPad = N - NRaw;
const auto KPad = K - KRaw;
if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
if constexpr(PadN && PadK)
{
// pad both N and K
return transform_tensor_descriptor(b_desc_nraw_kraw,
......@@ -97,8 +301,7 @@ struct MatrixPadder
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::MNPadding)
else if constexpr(PadN && (!PadK))
{
// pad N, but not K
return transform_tensor_descriptor(
......@@ -107,8 +310,7 @@ struct MatrixPadder
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::MKPadding)
else if constexpr((!PadN) && PadK)
{
// pad K, but not N
return transform_tensor_descriptor(
......@@ -137,8 +339,7 @@ struct MatrixPadder
const auto MPad = M - MRaw;
const auto NPad = N - NRaw;
if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
if constexpr(PadM && PadN)
{
// pad M and N
return transform_tensor_descriptor(c_desc_mraw_nraw,
......@@ -147,8 +348,7 @@ struct MatrixPadder
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MKPadding)
else if constexpr(PadM && (!PadN))
{
// pad M, but not N
return transform_tensor_descriptor(
......@@ -157,8 +357,7 @@ struct MatrixPadder
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::NKPadding)
else if constexpr((!PadM) && PadN)
{
// pad N, but not M
return transform_tensor_descriptor(
......@@ -178,7 +377,6 @@ struct MatrixPadder
NPerTileType NPerTile_;
KPerTileType KPerTile_;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -93,7 +93,7 @@ struct GNDHWC : public BaseTensorLayout
};
// input tensor
// packed GNWC/GNHWC/GNDHWC
// packed NWGC/NHWGC/NDHWGC
struct NWGC : public BaseTensorLayout
{
static constexpr const char* name = "NWGC";
......@@ -330,6 +330,54 @@ struct G_NDHW_K : public BaseTensorLayout
static constexpr const char* name = "G_NDHW_K";
};
// K-reduced output tensor (packed)
struct GNW : public BaseTensorLayout
{
static constexpr const char* name = "GNW";
};
struct GNHW : public BaseTensorLayout
{
static constexpr const char* name = "GNHW";
};
struct GNDHW : public BaseTensorLayout
{
static constexpr const char* name = "GNDHW";
};
// K-reduced output tensor (packed)
struct NWG : public BaseTensorLayout
{
static constexpr const char* name = "NWG";
};
struct NHWG : public BaseTensorLayout
{
static constexpr const char* name = "NHWG";
};
struct NDHWG : public BaseTensorLayout
{
static constexpr const char* name = "NDHWG";
};
// K-reduced output tensor (strided)
struct G_NW : public BaseTensorLayout
{
static constexpr const char* name = "G_NW";
};
struct G_NHW : public BaseTensorLayout
{
static constexpr const char* name = "G_NHW";
};
struct G_NDHW : public BaseTensorLayout
{
static constexpr const char* name = "G_NDHW";
};
} // namespace convolution
template <
......
......@@ -28,6 +28,13 @@ struct Add
y = x0 + x1;
};
template <>
__host__ __device__ constexpr void
operator()<float>(float& y, const float& x0, const half_t& x1) const
{
y = x0 + type_convert<half_t>(x1);
};
template <>
__host__ __device__ constexpr void
operator()<half_t>(half_t& y, const float& x0, const half_t& x1) const
......@@ -172,6 +179,14 @@ struct AddRelu
const float a = x0 + x1;
y = a > type_convert<half_t>(0.0f) ? a : type_convert<half_t>(0.0f);
};
template <>
__host__ __device__ constexpr void
operator()<float, float, half_t>(float& y, const float& x0, const half_t& x1) const
{
const float a = x0 + type_convert<float>(x1);
y = a > 0.0f ? a : 0.0f;
};
};
struct AddHardswish
......@@ -210,6 +225,46 @@ struct AddHardswish
};
};
// C = A * B
// E = FastGelu(C + D)
struct AddFastGelu
{
// Fast GeLU
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
__host__ __device__ static constexpr float GetFastGeLU(float x)
{
const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float emu = exp(-u);
const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f);
return x * cdf;
}
template <typename T>
static inline constexpr bool is_valid_param_type_v =
std::is_same_v<T, float> || std::is_same_v<T, half_t> || std::is_same_v<T, bhalf_t> ||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>;
template <typename E, typename C, typename D>
__host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const
{
static_assert(is_valid_param_type_v<E> && is_valid_param_type_v<C> &&
is_valid_param_type_v<D>);
const float y = GetFastGeLU(type_convert<float>(c) + type_convert<float>(d));
e = type_convert<E>(y);
}
template <typename D>
__host__ __device__ constexpr void operator()(float& e, const float& c, const D& d) const
{
static_assert(is_valid_param_type_v<D>);
e = GetFastGeLU(c + type_convert<float>(d));
}
};
} // namespace element_wise
} // namespace tensor_operation
} // namespace ck
......@@ -98,6 +98,18 @@ struct AddReluAdd
int32_t c = b + x2;
y = c;
}
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template <>
__host__ __device__ constexpr void operator()<int4_t, int8_t, int4_t, int4_t>(
int4_t& y, const int8_t& x0, const int4_t& x1, const int4_t& x2) const
{
int32_t a = x0 + x1;
int32_t b = a > 0 ? a : 0;
int32_t c = b + x2;
y = c;
}
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
};
struct AddHardswishAdd
......@@ -177,7 +189,11 @@ struct AddAddFastGelu
template <typename T>
static inline constexpr bool is_valid_param_type_v =
std::is_same_v<T, float> || std::is_same_v<T, half_t> || std::is_same_v<T, bhalf_t> ||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>;
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|| std::is_same_v<T, ck::int4_t>
#endif
;
template <typename E, typename C, typename D0, typename D1>
__host__ __device__ constexpr void
......@@ -198,17 +214,44 @@ struct Normalize
// FIXME: is double absolutely necessary?
Normalize(double epsilon = 1e-4) : epsilon_(epsilon) {}
template <typename T>
__host__ __device__ constexpr void operator()(
T& y, const T& x, const T& mean, const T& mean_square, const T& gamma, const T& beta) const;
template <typename T1, typename T2, typename T3>
__host__ __device__ constexpr void operator()(T1& y,
const T1& x,
const T2& mean,
const T2& mean_square,
const T3& gamma,
const T3& beta) const;
template <>
__host__ __device__ constexpr void operator()<half_t, float, half_t>(half_t& y,
const half_t& x,
const float& mean,
const float& mean_square,
const half_t& gamma,
const half_t& beta) const
{
using ck::math::sqrt;
float variance = mean_square - (mean * mean);
float tmp_x = type_convert<float>(x);
float tmp_gamma = type_convert<float>(gamma);
float tmp_beta = type_convert<float>(beta);
float tmp_y =
((tmp_x - mean) / sqrt(variance + type_convert<float>(epsilon_))) * tmp_gamma +
tmp_beta;
y = type_convert<half_t>(tmp_y);
};
template <>
__host__ __device__ constexpr void operator()<float>(float& y,
const float& x,
const float& mean,
const float& mean_square,
const float& gamma,
const float& beta) const
__host__ __device__ constexpr void operator()<float, float, float>(float& y,
const float& x,
const float& mean,
const float& mean_square,
const float& gamma,
const float& beta) const
{
using ck::math::sqrt;
......@@ -217,12 +260,12 @@ struct Normalize
};
template <>
__host__ __device__ constexpr void operator()<double>(double& y,
const double& x,
const double& mean,
const double& mean_square,
const double& gamma,
const double& beta) const
__host__ __device__ constexpr void operator()<double, double, double>(double& y,
const double& x,
const double& mean,
const double& mean_square,
const double& gamma,
const double& beta) const
{
using ck::math::sqrt;
......
......@@ -62,6 +62,14 @@ struct PassThrough
{
y = type_convert<int8_t>(x);
}
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template <>
__host__ __device__ void operator()<int4_t, int4_t>(int4_t& y, const int4_t& x) const
{
y = x;
}
#endif
};
struct UnaryConvert
......@@ -89,6 +97,22 @@ struct Scale
float scale_;
};
struct ScaleAndResetNaNToMinusInfinity
{
__host__ __device__ ScaleAndResetNaNToMinusInfinity(float scale) : scale_(scale) {}
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const;
template <>
__host__ __device__ void operator()<float, float>(float& y, const float& x) const
{
y = ck::math::isnan(x) ? -ck::NumericLimits<float>::Infinity() : scale_ * x;
};
float scale_;
};
struct UnaryDivide
{
__host__ __device__ UnaryDivide(const int32_t divider = 1) : divider_(divider) {}
......@@ -111,9 +135,13 @@ struct UnarySquare
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value,
static_assert(is_same_v<T, float> || is_same_v<T, double> || is_same_v<T, int32_t> ||
is_same_v<T, int8_t>
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|| is_same_v<T, int4_t>
#endif
,
"Data type is not supported by this operation!");
y = x * x;
};
};
......@@ -183,6 +211,27 @@ struct FastGelu
}
};
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+erf(x/sqrt(2)))
struct Gelu
{
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const;
template <>
__host__ __device__ void operator()<float, float>(float& y, const float& x) const
{
y = 0.5f * x * (1.f + erf(float(0.70710678118f * x)));
}
template <>
__host__ __device__ void operator()<ck::half_t, ck::half_t>(ck::half_t& y,
const ck::half_t& x) const
{
y = ck::half_t(0.5) * x * (ck::half_t(1) + ck::half_t(erf(float(0.70710678118f * x))));
}
};
} // namespace element_wise
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/reduction_common.hpp"
#include "ck/utility/reduction_operator.hpp"
#include "ck/utility/reduction_functions_accumulate.hpp"
#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp"
#include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
template <typename GridwiseMultipleReduction,
index_t NumReduction,
typename InDataType,
typename OutDataTypePointerTuple,
typename AccDataType,
typename InGridDesc_M_K,
typename OutGridDesc_M_Tuple,
typename InElementwiseOperationTuple,
typename AccElementwiseOperationTuple>
__global__ void
kernel_multiple_reduce_multiblock(const InGridDesc_M_K in_grid_desc_m_k,
const OutGridDesc_M_Tuple out_grid_desc_m_tuple,
const InElementwiseOperationTuple in_elementwise_op_tuple,
const AccElementwiseOperationTuple acc_elementwise_op_tuple,
index_t block_group_size,
index_t num_k_block_tile_iteration,
Array<AccDataType, NumReduction> alpha_values,
const InDataType* const __restrict__ p_in_value_global,
Array<AccDataType, NumReduction> beta_values,
OutDataTypePointerTuple p_out_value_global_tuple)
{
GridwiseMultipleReduction::Run(in_grid_desc_m_k,
out_grid_desc_m_tuple,
in_elementwise_op_tuple,
acc_elementwise_op_tuple,
block_group_size,
num_k_block_tile_iteration,
alpha_values,
p_in_value_global,
beta_values,
p_out_value_global_tuple);
};
template <index_t NumReduction,
typename InDataType,
typename OutDataTypePointerTuple,
typename AccDataType,
typename InGridDesc_M_K,
typename OutGridDesc_M_Tuple,
typename ReduceOperation,
typename InElementwiseOperationTuple,
typename AccElementwiseOperationTuple,
InMemoryDataOperationEnum OutMemoryDataOperation,
bool PropagateNan,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t InSrcVectorDim,
index_t InSrcVectorSize,
typename OutDstVectorSizeSeq>
struct GridwiseMultipleReduction_mk_to_m_multiblock
{
static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static_assert(NumReduction == OutDataTypePointerTuple::Size() &&
NumReduction == OutGridDesc_M_Tuple::Size() &&
NumReduction == OutDstVectorSizeSeq::Size() &&
NumReduction == InElementwiseOperationTuple::Size() &&
NumReduction == AccElementwiseOperationTuple::Size(),
"All tuple should have the same size as the number of Reductions!");
static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0);
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
using ThreadBufferDimAccessOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
using ThreadClusterArrangeOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using BlockwiseReduce = PartitionedBlockwiseReduction<AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
ReduceOperation,
PropagateNan>;
using ThreadwiseReduce = ThreadwiseReduction<AccDataType,
ThreadReduceSrcDesc_M_K,
ThreadReduceDstDesc_M,
ReduceOperation,
PropagateNan>;
using PassThroughOp = tensor_operation::element_wise::PassThrough;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
using Accumulation = detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>;
__device__ static void Run(const InGridDesc_M_K& in_grid_desc_m_k,
const OutGridDesc_M_Tuple& out_grid_desc_m_tuple,
const InElementwiseOperationTuple& in_elementwise_op_tuple,
const AccElementwiseOperationTuple& acc_elementwise_op_tuple,
index_t block_group_size,
index_t num_k_block_tile_iteration,
Array<AccDataType, NumReduction> alpha_values,
const InDataType* const __restrict__ p_in_value_global,
Array<AccDataType, NumReduction> beta_values,
OutDataTypePointerTuple p_out_value_global_tuple)
{
const auto identityVal = ReduceOperation::template GetIdentityValue<AccDataType>();
// LDS, reused by all reductions
__shared__ AccDataType p_reduce_work_buffer[BlockSize];
const auto in_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(),
ReduceOperation::template GetIdentityValue<InDataType>());
auto out_global_val_buf_tuple = generate_tuple(
[&](auto iR) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_value_global_tuple[iR], out_grid_desc_m_tuple[iR].GetElementSpaceSize());
},
Number<NumReduction>{});
auto reduce_work_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
in_thread_buf;
auto in_thread_buf_tuple = generate_tuple(
[&](auto iR) {
(void)iR;
return StaticBuffer<AddressSpaceEnum::Vgpr,
AccDataType,
MThreadSliceSize * KThreadSliceSize,
true>{};
},
Number<NumReduction>{});
auto accu_value_buf_tuple = generate_tuple(
[&](auto iR) {
(void)iR;
return StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>{};
},
Number<NumReduction>{});
static_for<0, NumReduction, 1>{}([&](auto iR) {
static_for<0, MThreadSliceSize, 1>{}(
[&](auto J) { accu_value_buf_tuple(iR)(J) = identityVal; });
});
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id();
const index_t blkgroup_id = block_global_id / block_group_size;
const index_t block_local_id = block_global_id % block_group_size;
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1];
const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
AccDataType,
InGridDesc_M_K,
decltype(thread_buffer_desc),
ThreadBufferLengths,
ThreadBufferDimAccessOrder,
InSrcVectorDim,
InSrcVectorSize,
1,
false>(
in_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
block_local_id * reduceSizePerBlock +
thread_k_cluster_id * KThreadSliceSize));
constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize);
index_t reducedTiles = 0;
do
{
threadwise_src_load.Run(in_grid_desc_m_k,
in_global_val_buf,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_buf);
static_for<0, NumReduction, 1>{}([&](auto iR) {
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
// do element-wise pre-reduction operation
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset =
thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
in_elementwise_op_tuple[iR](in_thread_buf_tuple(iR)(Number<offset>{}),
in_thread_buf(Number<offset>{}));
});
});
ThreadwiseReduce::Reduce(in_thread_buf_tuple(iR), accu_value_buf_tuple(iR));
});
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
reducedTiles++;
} while(reducedTiles < num_k_block_tile_iteration);
constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
static_for<0, NumReduction, 1>{}([&](auto iR) {
using OutDataTypePointer = remove_cvref_t<decltype(OutDataTypePointerTuple{}[iR])>;
using OutDataType = remove_cvref_t<remove_pointer_t<OutDataTypePointer>>;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
BlockwiseReduce::Reduce(reduce_work_buf, accu_value_buf_tuple(iR)(I));
});
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if(thread_k_cluster_id == 0)
{
acc_elementwise_op_tuple[iR](accu_value_buf_tuple(iR)(I),
accu_value_buf_tuple(iR)(I));
accu_value_buf_tuple(iR)(I) *= alpha_values[iR];
}
});
if(thread_k_cluster_id == 0)
{
if(block_group_size == 0 && !float_equal_zero{}(beta_values[iR]))
{
StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true>
priorDstValueBuf;
auto threadwise_dst_load =
ThreadwiseTensorSliceTransfer_v2<OutDataType,
OutDataType,
decltype(out_grid_desc_m_tuple[iR]),
decltype(reduced_data_desc),
Sequence<MThreadSliceSize>,
Sequence<0>,
0,
OutDstVectorSizeSeq::At(iR),
1,
false>(
out_grid_desc_m_tuple[iR],
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize));
threadwise_dst_load.Run(out_grid_desc_m_tuple[iR],
out_global_val_buf_tuple(iR),
reduced_data_desc,
make_tuple(I0),
priorDstValueBuf);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf_tuple(iR)(I) +=
type_convert<AccDataType>(priorDstValueBuf[I]) * beta_values[iR];
});
};
auto threadwise_dst_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
OutDataType,
decltype(reduced_data_desc),
decltype(out_grid_desc_m_tuple[iR]),
PassThroughOp,
Sequence<MThreadSliceSize>,
Sequence<0>,
0,
OutDstVectorSizeSeq::At(iR),
OutMemoryDataOperation,
1,
true>(
out_grid_desc_m_tuple[iR],
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
threadwise_dst_store.Run(reduced_data_desc,
make_tuple(I0),
accu_value_buf_tuple[iR],
out_grid_desc_m_tuple[iR],
out_global_val_buf_tuple(iR));
};
});
};
}; // namespace ck
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/reduction_common.hpp"
#include "ck/utility/reduction_operator.hpp"
#include "ck/utility/reduction_functions_accumulate.hpp"
#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp"
#include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
template <typename GridwiseMultipleReduction,
index_t NumReduction,
typename InDataType,
typename OutDataTypePointerTuple,
typename AccDataType,
typename InGridDesc_M_K,
typename OutGridDesc_M_Tuple,
typename InElementwiseOperationTuple,
typename AccElementwiseOperationTuple>
__global__ void
kernel_multiple_reduce_threadwise(const InGridDesc_M_K in_grid_desc_m_k,
const OutGridDesc_M_Tuple out_grid_desc_m_tuple,
const InElementwiseOperationTuple in_elementwise_op_tuple,
const AccElementwiseOperationTuple acc_elementwise_op_tuple,
Array<AccDataType, NumReduction> alpha_values,
const InDataType* const __restrict__ p_in_value_global,
Array<AccDataType, NumReduction> beta_values,
OutDataTypePointerTuple p_out_value_global_tuple)
{
GridwiseMultipleReduction::Run(in_grid_desc_m_k,
out_grid_desc_m_tuple,
in_elementwise_op_tuple,
acc_elementwise_op_tuple,
alpha_values,
p_in_value_global,
beta_values,
p_out_value_global_tuple);
};
template <index_t NumReduction,
typename InDataType,
typename OutDataTypePointerTuple,
typename AccDataType,
typename InGridDesc_M_K,
typename OutGridDesc_M_Tuple,
typename ReduceOperation,
typename InElementwiseOperationTuple,
typename AccElementwiseOperationTuple,
InMemoryDataOperationEnum OutMemoryDataOperation,
bool PropagateNan,
index_t BlockSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t InSrcVectorDim,
index_t InSrcVectorSize,
typename OutDstVectorSizeSeq>
struct GridwiseMultipleReduction_mk_to_m_threadwise
{
static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static_assert(NumReduction == OutDataTypePointerTuple::Size() &&
NumReduction == OutGridDesc_M_Tuple::Size() &&
NumReduction == OutDstVectorSizeSeq::Size() &&
NumReduction == InElementwiseOperationTuple::Size() &&
NumReduction == AccElementwiseOperationTuple::Size(),
"All tuple should have the same size as the number of Reductions!");
static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0);
using ThreadBufferDimAccessOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using ThreadwiseReduce = ThreadwiseReduction<AccDataType,
ThreadReduceSrcDesc_M_K,
ThreadReduceDstDesc_M,
ReduceOperation,
PropagateNan>;
using PassThroughOp = tensor_operation::element_wise::PassThrough;
static constexpr auto I0 = Number<0>{};
using Accumulation = detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>;
__device__ static void Run(const InGridDesc_M_K& in_grid_desc_m_k,
const OutGridDesc_M_Tuple& out_grid_desc_m_tuple,
const InElementwiseOperationTuple& in_elementwise_op_tuple,
const AccElementwiseOperationTuple& acc_elementwise_op_tuple,
Array<AccDataType, NumReduction> alpha_values,
const InDataType* const __restrict__ p_in_value_global,
Array<AccDataType, NumReduction> beta_values,
OutDataTypePointerTuple p_out_value_global_tuple)
{
const auto identityVal = ReduceOperation::template GetIdentityValue<AccDataType>();
const auto in_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(),
ReduceOperation::template GetIdentityValue<InDataType>());
auto out_global_val_buf_tuple = generate_tuple(
[&](auto iR) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_value_global_tuple[iR], out_grid_desc_m_tuple[iR].GetElementSpaceSize());
},
Number<NumReduction>{});
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
in_thread_buf;
auto in_thread_buf_tuple = generate_tuple(
[&](auto iR) {
(void)iR;
return StaticBuffer<AddressSpaceEnum::Vgpr,
AccDataType,
MThreadSliceSize * KThreadSliceSize,
true>{};
},
Number<NumReduction>{});
auto accu_value_buf_tuple = generate_tuple(
[&](auto iR) {
(void)iR;
return StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>{};
},
Number<NumReduction>{});
static_for<0, NumReduction, 1>{}([&](auto iR) {
static_for<0, MThreadSliceSize, 1>{}(
[&](auto J) { accu_value_buf_tuple(iR)(J) = identityVal; });
});
const index_t thread_global_1d_id = get_thread_global_1d_id();
const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
AccDataType,
InGridDesc_M_K,
decltype(thread_buffer_desc),
ThreadBufferLengths,
ThreadBufferDimAccessOrder,
InSrcVectorDim,
InSrcVectorSize,
1,
false>(
in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0));
constexpr auto in_thread_copy_step = make_multi_index(0, KThreadSliceSize);
index_t reducedLength = 0;
do
{
threadwise_src_load.Run(in_grid_desc_m_k,
in_global_val_buf,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_buf);
static_for<0, NumReduction, 1>{}([&](auto iR) {
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
// do element-wise pre-reduction operation
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset =
thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
in_elementwise_op_tuple[iR](in_thread_buf_tuple(iR)(Number<offset>{}),
in_thread_buf(Number<offset>{}));
});
});
ThreadwiseReduce::Reduce(in_thread_buf_tuple(iR), accu_value_buf_tuple(iR));
});
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
reducedLength += KThreadSliceSize;
} while(reducedLength < toReduceLength);
constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
static_for<0, NumReduction, 1>{}([&](auto iR) {
using OutDataTypePointer = remove_cvref_t<decltype(OutDataTypePointerTuple{}[iR])>;
using OutDataType = remove_cvref_t<remove_pointer_t<OutDataTypePointer>>;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
acc_elementwise_op_tuple[iR](accu_value_buf_tuple(iR)(I),
accu_value_buf_tuple(iR)(I));
accu_value_buf_tuple(iR)(I) *= alpha_values[iR];
});
if(!float_equal_zero{}(beta_values[iR]))
{
StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true>
priorDstValueBuf;
auto threadwise_dst_load =
ThreadwiseTensorSliceTransfer_v2<OutDataType,
OutDataType,
decltype(out_grid_desc_m_tuple[iR]),
decltype(reduced_data_desc),
Sequence<MThreadSliceSize>,
Sequence<0>,
0,
OutDstVectorSizeSeq::At(iR),
1,
false>(
out_grid_desc_m_tuple[iR],
make_multi_index(thread_global_1d_id * MThreadSliceSize));
threadwise_dst_load.Run(out_grid_desc_m_tuple[iR],
out_global_val_buf_tuple(iR),
reduced_data_desc,
make_tuple(I0),
priorDstValueBuf);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf_tuple(iR)(I) +=
type_convert<AccDataType>(priorDstValueBuf[I]) * beta_values[iR];
});
};
auto threadwise_dst_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
OutDataType,
decltype(reduced_data_desc),
decltype(out_grid_desc_m_tuple[iR]),
PassThroughOp,
Sequence<MThreadSliceSize>,
Sequence<0>,
0,
OutDstVectorSizeSeq::At(iR),
OutMemoryDataOperation,
1,
true>(
out_grid_desc_m_tuple[iR],
make_multi_index(thread_global_1d_id * MThreadSliceSize),
PassThroughOp{});
threadwise_dst_store.Run(reduced_data_desc,
make_tuple(I0),
accu_value_buf_tuple[iR],
out_grid_desc_m_tuple[iR],
out_global_val_buf_tuple(iR));
});
};
};
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
namespace ck {
template <typename Gridwise5AryEltwise,
typename ADataType,
typename BDataType,
typename CDataType,
typename DDataType,
typename EDataType,
typename FDataType,
typename AGridDesc_M,
typename BGridDesc_M,
typename CGridDesc_M,
typename DGridDesc_M,
typename EGridDesc_M,
typename FGridDesc_M,
typename ElementwiseFunctor>
__global__ void kernel_5ary_elementwise_1d(const ADataType* __restrict__ p_a_global,
const BDataType* __restrict__ p_b_global,
const CDataType* __restrict__ p_c_global,
const DDataType* __restrict__ p_d_global,
const EDataType* __restrict__ p_e_global,
FDataType* __restrict__ p_f_global,
const AGridDesc_M a_grid_desc_m,
const BGridDesc_M b_grid_desc_m,
const CGridDesc_M c_grid_desc_m,
const DGridDesc_M d_grid_desc_m,
const EGridDesc_M e_grid_desc_m,
const FGridDesc_M f_grid_desc_m,
const ElementwiseFunctor functor)
{
Gridwise5AryEltwise::Run(p_a_global,
p_b_global,
p_c_global,
p_d_global,
p_e_global,
p_f_global,
a_grid_desc_m,
b_grid_desc_m,
c_grid_desc_m,
d_grid_desc_m,
e_grid_desc_m,
f_grid_desc_m,
functor);
}
// TODO - implement n-ary Elemenetwise_1D, tuple of inputs and tuple of outputs
template <typename ADataType,
typename BDataType,
typename CDataType,
typename DDataType,
typename EDataType,
typename FDataType,
typename ComputeDataType,
typename AGridDesc_M,
typename BGridDesc_M,
typename CGridDesc_M,
typename DGridDesc_M,
typename EGridDesc_M,
typename FGridDesc_M,
typename ElementwiseFunctor,
index_t MPerThread,
index_t AScalarPerVector,
index_t BScalarPerVector,
index_t CScalarPerVector,
index_t DScalarPerVector,
index_t EScalarPerVector,
index_t FScalarPerVector>
struct Gridwise5AryElementwise_1D
{
static constexpr auto I0 = Number<0>{};
static constexpr auto thread_desc_m =
make_naive_tensor_descriptor_packed(make_tuple(Number<MPerThread>{}));
using PassThrough = tensor_operation::element_wise::PassThrough;
static __device__ auto CalculateElementwiseIndex()
{
const index_t global_thread_id = get_thread_global_1d_id();
return make_multi_index(global_thread_id * MPerThread);
}
__device__ static void Run(const ADataType* __restrict__ p_a_global,
const BDataType* __restrict__ p_b_global,
const CDataType* __restrict__ p_c_global,
const DDataType* __restrict__ p_d_global,
const EDataType* __restrict__ p_e_global,
FDataType* __restrict__ p_f_global,
const AGridDesc_M a_grid_desc_m,
const BGridDesc_M b_grid_desc_m,
const CGridDesc_M c_grid_desc_m,
const DGridDesc_M d_grid_desc_m,
const EGridDesc_M e_grid_desc_m,
const FGridDesc_M f_grid_desc_m,
const ElementwiseFunctor functor)
{
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_global, a_grid_desc_m.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_global, b_grid_desc_m.GetElementSpaceSize());
const auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_global, c_grid_desc_m.GetElementSpaceSize());
const auto d_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d_global, d_grid_desc_m.GetElementSpaceSize());
const auto e_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_e_global, e_grid_desc_m.GetElementSpaceSize());
auto f_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_f_global, f_grid_desc_m.GetElementSpaceSize());
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MPerThread, true> a_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MPerThread, true> b_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MPerThread, true> c_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MPerThread, true> d_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MPerThread, true> e_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MPerThread, true> f_thread_buf;
const auto thread_store_global_offset = CalculateElementwiseIndex();
auto a_global_load =
ThreadwiseTensorSliceTransfer_v2<ADataType,
ComputeDataType,
AGridDesc_M,
decltype(thread_desc_m),
Sequence<MPerThread>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
AScalarPerVector, // ScalarPerVector
1, // SrcScalarStrideInVector
false>{a_grid_desc_m, thread_store_global_offset};
auto b_global_load =
ThreadwiseTensorSliceTransfer_v2<BDataType,
ComputeDataType,
BGridDesc_M,
decltype(thread_desc_m),
Sequence<MPerThread>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
BScalarPerVector, // ScalarPerVector
1, // SrcScalarStrideInVector
false>{b_grid_desc_m, thread_store_global_offset};
auto c_global_load =
ThreadwiseTensorSliceTransfer_v2<CDataType,
ComputeDataType,
CGridDesc_M,
decltype(thread_desc_m),
Sequence<MPerThread>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
CScalarPerVector, // ScalarPerVector
1, // SrcScalarStrideInVector
false>{c_grid_desc_m, thread_store_global_offset};
auto d_global_load =
ThreadwiseTensorSliceTransfer_v2<DDataType,
ComputeDataType,
DGridDesc_M,
decltype(thread_desc_m),
Sequence<MPerThread>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
DScalarPerVector, // ScalarPerVector
1, // SrcScalarStrideInVector
false>{d_grid_desc_m, thread_store_global_offset};
auto e_global_load =
ThreadwiseTensorSliceTransfer_v2<EDataType,
ComputeDataType,
EGridDesc_M,
decltype(thread_desc_m),
Sequence<MPerThread>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
EScalarPerVector, // ScalarPerVector
1, // SrcScalarStrideInVector
false>{e_grid_desc_m, thread_store_global_offset};
auto f_global_write =
ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
FDataType,
decltype(thread_desc_m),
FGridDesc_M,
PassThrough,
Sequence<MPerThread>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // DstVectorDim
FScalarPerVector, // ScalarPerVector
InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector
false>{
f_grid_desc_m, thread_store_global_offset, PassThrough{}};
const index_t blockSize = get_block_size();
const index_t blockPerGrid = get_grid_size();
const auto M = c_grid_desc_m.GetLength(I0);
const index_t loop_step = blockPerGrid * blockSize * MPerThread;
const auto loop_step_index = make_multi_index(loop_step);
index_t num_iter = M / (loop_step);
do
{
// read and process MPerThread elements
a_global_load.Run(
a_grid_desc_m, a_global_buf, thread_desc_m, make_tuple(I0), a_thread_buf);
b_global_load.Run(
b_grid_desc_m, b_global_buf, thread_desc_m, make_tuple(I0), b_thread_buf);
c_global_load.Run(
c_grid_desc_m, c_global_buf, thread_desc_m, make_tuple(I0), c_thread_buf);
d_global_load.Run(
d_grid_desc_m, d_global_buf, thread_desc_m, make_tuple(I0), d_thread_buf);
e_global_load.Run(
e_grid_desc_m, e_global_buf, thread_desc_m, make_tuple(I0), e_thread_buf);
static_for<0, MPerThread, 1>{}([&](auto m) {
constexpr auto offset = thread_desc_m.CalculateOffset(make_tuple(m));
functor(f_thread_buf(Number<offset>{}),
a_thread_buf(Number<offset>{}),
b_thread_buf(Number<offset>{}),
c_thread_buf(Number<offset>{}),
d_thread_buf(Number<offset>{}),
e_thread_buf(Number<offset>{}));
});
f_global_write.Run(thread_desc_m,
make_tuple(I0), // SrcSliceOriginIdx
f_thread_buf,
f_grid_desc_m,
f_global_buf);
a_global_load.MoveSrcSliceWindow(a_grid_desc_m, loop_step_index);
b_global_load.MoveSrcSliceWindow(b_grid_desc_m, loop_step_index);
c_global_load.MoveSrcSliceWindow(c_grid_desc_m, loop_step_index);
d_global_load.MoveSrcSliceWindow(d_grid_desc_m, loop_step_index);
e_global_load.MoveSrcSliceWindow(e_grid_desc_m, loop_step_index);
f_global_write.MoveDstSliceWindow(f_grid_desc_m, loop_step_index);
} while(--num_iter);
}
};
} // namespace ck
......@@ -181,36 +181,16 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
constexpr auto b1_block_desc_bk0_n_bk1 = GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// lds max alignment
constexpr auto max_lds_align = math::lcm(math::lcm(AK1, BK1), B1K1);
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
constexpr auto b0_block_space_size_aligned = math::integer_least_multiple(
b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
constexpr auto b1_block_space_size_aligned = math::integer_least_multiple(
b1_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size_aligned =
math::max(b0_block_space_size_aligned.value, b1_block_space_size_aligned.value);
// LDS allocation for C shuffle in LDS
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
constexpr auto c_block_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
sizeof(FloatAB),
c_block_size * sizeof(FloatCShuffle));
const index_t gemm0_bytes_end = (SharedMemTrait::a_block_space_size_aligned +
SharedMemTrait::b_block_space_size_aligned) *
sizeof(FloatAB);
const index_t gemm1_bytes_end =
(SharedMemTrait::b1_block_space_offset + SharedMemTrait::b1_block_space_size_aligned) *
sizeof(FloatAB);
const index_t c_block_bytes_end =
SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle);
return math::max(gemm0_bytes_end, gemm1_bytes_end, c_block_bytes_end);
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
......@@ -261,8 +241,6 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
return false;
}
assert(num_gemm1_k_outer_loop * num_gemm1_k_inner_loop == N / Gemm1KPerBlock);
if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n))
{
return false;
......@@ -312,6 +290,36 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>;
struct SharedMemTrait
{
// LDS allocation for A and B: be careful of alignment
static constexpr auto a_block_desc_ak0_m_ak1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
static constexpr auto b_block_desc_bk0_n_bk1 =
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
static constexpr auto b1_block_desc_bk0_n_bk1 =
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1();
static constexpr auto max_lds_align = math::lcm(math::lcm(AK1, BK1), B1K1);
static constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
static constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
static constexpr auto b1_block_space_size_aligned = math::integer_least_multiple(
b1_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
static constexpr auto a_block_space_offset = 0;
static constexpr auto b_block_space_offset = a_block_space_size_aligned.value;
static constexpr auto b1_block_space_offset = 0;
// LDS allocation for C shuffle in LDS
static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
static constexpr auto c_block_space_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
};
template <bool HasMainKBlockLoop, typename Block2CTileMap>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
......@@ -358,9 +366,6 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * Gemm1NPerBlock);
// lds max alignment
constexpr auto max_lds_align = math::lcm(math::lcm(AK1, BK1), B1K1);
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
......@@ -464,14 +469,12 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
auto acc_thread_buf = blockwise_gemm.GetCThreadBuffer();
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
static_cast<FloatAB*>(p_shared) + SharedMemTrait::a_block_space_offset,
a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned,
static_cast<FloatAB*>(p_shared) + SharedMemTrait::b_block_space_offset,
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
......@@ -588,12 +591,20 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
// reuse LDS space for gemm0's b_block_buf
auto b1_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned,
static_cast<FloatAB*>(p_shared) + SharedMemTrait::b1_block_space_offset,
b1_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr index_t Gemm1KPack = math::max(
math::lcm(MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.group_size, B1K1),
MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
// selected_mfma.group_size or B1K1 <= Gemm1KPack <= selected_mfma.group_size
// selected_mfma.k_per_blk <= Gemm1KPack
//
// Following similar rationale behind Gemm0KPack, let Gemm1KPack be the lowest common
// multiples of A1K1 (predetermined by selected_mfma.group_size) and B1K1. But in this case
// Gemm1KPack can't be higher than A1K1 itself because A1 matrix is distributed in VGPRs
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
constexpr index_t Gemm1KPack =
MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.group_size;
auto gemm1_blockwise_gemm = BlockwiseGemmXdlops_v2<
BlockSize,
......@@ -611,10 +622,11 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
MXdlPerWave,
Gemm1NXdlPerWave,
Gemm1KPack,
false,
false, // TransposeC
Gemm1KPack, // AMmaKStride
Gemm1KPack * XdlopsGemm<FloatAB, MPerXdl, NPerXdl, Gemm1KPack, false>{}.K0PerXdlops>{
make_tuple(0, 0, 0, 0)}; // TransposeC
// BMmaKStride
make_tuple(0, 0, 0, 0)}; // A_origin
auto c_thread_buf = gemm1_blockwise_gemm.GetCThreadBuffer();
......@@ -699,6 +711,7 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
a1_thread_desc_k0_m_k1,
make_tuple(I0, I0, I0),
a1_thread_buf);
block_sync_lds();
gemm1_blockwise_gemm.Run(a1_thread_buf, b1_block_buf, c_thread_buf);
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
template <typename A0B0B1DataType, // FIXME: don't assume A0/B0/B1 have same datatype
typename Acc0DataType,
typename D0sDataType,
typename Acc1DataType,
typename C1ShuffleDataType,
typename D1sDataType,
typename E1DataType,
typename A0ElementwiseOperation,
typename B0ElementwiseOperation,
typename CDE0ElementwiseOperation,
typename B1ElementwiseOperation,
typename CDE1ElementwiseOperation,
InMemoryDataOperationEnum E1GlobalMemoryDataOperation,
typename A0GridDesc_M_K,
typename B0GridDesc_N_K,
typename D0sGridDesc_M_N,
typename B1GridDesc_N_K,
typename D1sGridDesc_M_N,
typename E1GridDesc_M_N,
index_t NumGemm0KPrefetchStage,
index_t BlockSize,
index_t Gemm0MPerBlock,
index_t Gemm0NPerBlock,
index_t Gemm0KPerBlock,
index_t Gemm1NPerBlock,
index_t Gemm1KPerBlock,
index_t A0K1Value,
index_t B0K1Value,
index_t B1K1Value,
index_t Gemm0MPerXdl,
index_t Gemm0NPerXdl,
index_t Gemm0MXdlPerWave,
index_t Gemm0NXdlPerWave,
index_t Gemm1NXdlPerWave,
typename A0BlockTransferThreadClusterLengths_AK0_M_AK1,
typename A0BlockTransferThreadClusterArrangeOrder,
typename A0BlockTransferSrcAccessOrder,
index_t A0BlockTransferSrcVectorDim,
index_t A0BlockTransferSrcScalarPerVector,
index_t A0BlockTransferDstScalarPerVector_AK1,
bool A0ThreadTransferSrcResetCoordinateAfterRun, // ignored
index_t A0BlockLdsExtraM,
typename B0BlockTransferThreadClusterLengths_BK0_N_BK1,
typename B0BlockTransferThreadClusterArrangeOrder,
typename B0BlockTransferSrcAccessOrder,
index_t B0BlockTransferSrcVectorDim,
index_t B0BlockTransferSrcScalarPerVector,
index_t B0BlockTransferDstScalarPerVector_BK1,
bool B0ThreadTransferSrcResetCoordinateAfterRun, // ignored
index_t B0BlockLdsExtraN,
typename B1BlockTransferThreadClusterLengths_BK0_N_BK1,
typename B1BlockTransferThreadClusterArrangeOrder,
typename B1BlockTransferSrcAccessOrder,
index_t B1BlockTransferSrcVectorDim,
index_t B1BlockTransferSrcScalarPerVector,
index_t B1BlockTransferDstScalarPerVector_BK1,
bool B1ThreadTransferSrcResetCoordinateAfterRun,
index_t B1BlockLdsExtraN,
index_t C1ShuffleGemm0MXdlPerWavePerShuffle,
index_t C1ShuffleGemm0NXdlPerWavePerShuffle,
typename CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDE1ShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched>
struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
{
static_assert(LoopSched == LoopScheduler::Default,
"Non-default loop scheduler is currently not supported");
static constexpr index_t NumD0Tensor = D0sDataType::Size();
static constexpr index_t NumD1Tensor = D1sDataType::Size();
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
static constexpr auto WaveSize = 64;
// K1 should be Number<...>
// Gemm0
static constexpr auto A0K1 = Number<A0K1Value>{};
static constexpr auto B0K1 = Number<B0K1Value>{};
static constexpr auto A0K0PerBlock = Number<Gemm0KPerBlock / A0K1Value>{};
static constexpr auto B0K0PerBlock = Number<Gemm0KPerBlock / B0K1Value>{};
static constexpr auto Gemm0MWaves = Gemm0MPerBlock / (Gemm0MPerXdl * Gemm0MXdlPerWave);
static constexpr auto Gemm0NWaves = Gemm0NPerBlock / (Gemm0NPerXdl * Gemm0NXdlPerWave);
// Gemm1
static constexpr auto B1K1 = Number<B1K1Value>{};
static constexpr auto B1K0PerBlock = Number<Gemm1KPerBlock / B1K1Value>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = GridwiseGemmPipeline_v1<NumGemm0KPrefetchStage>;
// ck::Tuple<const D0DataType1*, const D0DataType2*, ...>
static constexpr auto MakeD0sGridPointer()
{
return generate_tuple(
[&](auto i) {
using D0DataType = remove_cvref_t<tuple_element_t<i.value, D0sDataType>>;
return static_cast<const D0DataType*>(nullptr);
},
Number<NumD0Tensor>{});
}
// ck::Tuple<const D1DataType1*, const D1DataType2*, ...>
static constexpr auto MakeD1sGridPointer()
{
return generate_tuple(
[&](auto i) {
using D1DataType = remove_cvref_t<tuple_element_t<i.value, D1sDataType>>;
return static_cast<const D1DataType*>(nullptr);
},
Number<NumD1Tensor>{});
}
__device__ static auto GetGemm0WaveIdx()
{
const index_t thread_id = get_thread_local_1d_id();
constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(Gemm0MWaves, Gemm0NWaves, WaveSize))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
}
__device__ static auto GetGemm0WaveMNIdx(const index_t thread_id)
{
constexpr auto wave_threadid_to_mn_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(WaveSize / Gemm0NPerXdl, Gemm0NPerXdl))),
make_tuple(Sequence<0, 1>{}),
make_tuple(Sequence<0>{}));
return wave_threadid_to_mn_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
}
template <typename A0BlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto
MakeGemm0AMmaTileDescriptor_M0_M1_M2_K(const A0BlockDesc_AK0_M_AK1&)
{
constexpr index_t MWaves = Gemm0MPerBlock / (Gemm0MXdlPerWave * Gemm0MPerXdl);
return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<Gemm0MXdlPerWave, MWaves, Gemm0MPerXdl>(
A0BlockDesc_AK0_M_AK1{});
}
template <typename BBlockDesc_BK0_N_BK1>
__host__ __device__ static constexpr auto
MakeGemm0BMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
{
constexpr index_t NWaves = Gemm0NPerBlock / (Gemm0NXdlPerWave * Gemm0NPerXdl);
return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<Gemm0NXdlPerWave, NWaves, Gemm0NPerXdl>(
BBlockDesc_BK0_N_BK1{});
}
template <typename A0BlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto
MakeGemm1AMmaTileDescriptor_M0_M1_M2_K(const A0BlockDesc_AK0_M_AK1&)
{
return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<Gemm0MXdlPerWave, 1, 1>(
A0BlockDesc_AK0_M_AK1{});
}
template <typename BBlockDesc_BK0_N_BK1>
__host__ __device__ static constexpr auto
MakeGemm1BMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
{
constexpr index_t Gemm1NWaves = Gemm1NPerBlock / (Gemm1NXdlPerWave * Gemm0NPerXdl);
return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<Gemm1NXdlPerWave, Gemm1NWaves, Gemm0NPerXdl>(
BBlockDesc_BK0_N_BK1{});
}
__host__ __device__ static constexpr auto GetA0BlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{
// A0 matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(A0K0PerBlock, Number<Gemm0MPerBlock>{}, A0K1),
make_tuple(Number<Gemm0MPerBlock + A0BlockLdsExtraM>{} * A0K1, A0K1, I1));
}
__host__ __device__ static constexpr auto GetB0BlockDescriptor_BK0PerBlock_NPerBlock_BK1()
{
// B0 matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(B0K0PerBlock, Number<Gemm0NPerBlock>{}, B0K1),
make_tuple(Number<Gemm0NPerBlock + B0BlockLdsExtraN>{} * B0K1, B0K1, I1));
}
__host__ __device__ static constexpr auto GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1()
{
// B1 matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(B1K0PerBlock, Number<Gemm1NPerBlock>{}, B1K1),
make_tuple(Number<Gemm1NPerBlock + B1BlockLdsExtraN>{} * B1K1, B1K1, I1));
}
__host__ __device__ static constexpr auto
GetC1ShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
{
constexpr index_t MWave = Gemm0MPerBlock / (Gemm0MXdlPerWave * Gemm0MPerXdl);
constexpr index_t NWave = Gemm1NPerBlock / (Gemm1NXdlPerWave * Gemm0NPerXdl);
constexpr auto c1_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
make_naive_tensor_descriptor_packed(
make_tuple(I1,
Number<C1ShuffleGemm0MXdlPerWavePerShuffle * MWave * Gemm0MPerXdl>{},
I1,
Number<C1ShuffleGemm0NXdlPerWavePerShuffle * NWave * Gemm0NPerXdl>{}));
return c1_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
}
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
const index_t gemm0_bytes_end = (SharedMemTrait::a0_block_space_size_aligned +
SharedMemTrait::b0_block_space_size_aligned) *
sizeof(A0B0B1DataType);
const index_t gemm1_bytes_end =
(SharedMemTrait::b1_block_space_offset + SharedMemTrait::b1_block_space_size_aligned) *
sizeof(A0B0B1DataType);
const index_t c1_block_bytes_end =
SharedMemTrait::c1_block_space_size * sizeof(C1ShuffleDataType);
return math::max(gemm0_bytes_end, gemm1_bytes_end, c1_block_bytes_end);
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Block2E1TileMap>
__host__ __device__ static constexpr bool
CheckValidity(const A0GridDesc_M_K& a0_grid_desc_m_k,
const B0GridDesc_N_K& b0_grid_desc_n_k,
const B1GridDesc_N_K& b1_grid_desc_n_k,
const E1GridDesc_M_N& e1_grid_desc_m_n,
const Block2E1TileMap& block_2_e1tile_map)
{
static_assert((Gemm0MPerBlock % (Gemm0MPerXdl * Gemm0MXdlPerWave) == 0) &&
(Gemm0NPerBlock % (Gemm0NXdlPerWave * Gemm0NPerXdl)) == 0,
"Invalid tuning param!");
const auto M = a0_grid_desc_m_k.GetLength(I0);
const auto N = b0_grid_desc_n_k.GetLength(I0);
const auto K = a0_grid_desc_m_k.GetLength(I1);
const auto Gemm1N = b1_grid_desc_n_k.GetLength(I0);
if(!(M == e1_grid_desc_m_n.GetLength(I0) && Gemm1N == e1_grid_desc_m_n.GetLength(I1)))
{
return false;
}
if(!(M % Gemm0MPerBlock == 0 && N % Gemm0NPerBlock == 0 && K % Gemm0KPerBlock == 0 &&
Gemm1N % Gemm1NPerBlock == 0))
{
return false;
}
// check gemm0 gridwise gemm pipeline
const auto num_gemm0_k_loop = K / Gemm0KPerBlock;
if(!GridwiseGemmPipe::IsSupported(num_gemm0_k_loop))
{
return false;
}
// check gemm1 gridwise gemm pipeline
if(!(Gemm0NPerBlock % Gemm1KPerBlock == 0))
{
return false;
}
const auto num_gemm1_k_inner_loop = Gemm0NPerBlock / Gemm1KPerBlock;
if(!GridwiseGemmPipe::IsSupported(num_gemm1_k_inner_loop))
{
return false;
}
if(!block_2_e1tile_map.CheckValidity(e1_grid_desc_m_n))
{
return false;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return true;
}
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
{
const index_t num_loop = K / Gemm0KPerBlock;
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
}
// A0 desc for source in blockwise copy
__host__ __device__ static constexpr auto
MakeDefaultA0GridDescriptor_AK0_M_AK1(const A0GridDesc_M_K& a0_grid_desc_m_k)
{
const auto M = a0_grid_desc_m_k.GetLength(I0);
const auto K = a0_grid_desc_m_k.GetLength(I1);
const auto A0K0 = K / A0K1;
return transform_tensor_descriptor(
a0_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(A0K0, A0K1)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
// B0 desc for source in blockwise copy
__host__ __device__ static constexpr auto
MakeDefaultB0GridDescriptor_BK0_N_BK1(const B0GridDesc_N_K& b0_grid_desc_n_k)
{
const auto N = b0_grid_desc_n_k.GetLength(I0);
const auto K = b0_grid_desc_n_k.GetLength(I1);
const auto B0K0 = K / B0K1;
return transform_tensor_descriptor(
b0_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(B0K0, B0K1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
// D0 desc for source in blockwise copy
template <typename D0GridDesc_M_N>
__host__ __device__ static constexpr auto
MakeGemm0D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(const D0GridDesc_M_N& d0_grid_desc_m_n)
{
const auto M = d0_grid_desc_m_n.GetLength(I0);
const auto N = d0_grid_desc_m_n.GetLength(I1);
constexpr auto mfma =
MfmaSelector<A0B0B1DataType, Gemm0MPerXdl, Gemm0NPerXdl>::selected_mfma;
constexpr auto N3 = mfma.num_groups_per_blk;
constexpr auto N5 = mfma.group_size;
return transform_tensor_descriptor(
d0_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(
M / Gemm0MPerBlock, Gemm0MXdlPerWave, Gemm0MWaves, Gemm0MPerXdl)),
make_unmerge_transform(make_tuple(N / Gemm0NPerBlock,
Gemm0NXdlPerWave,
Gemm0NWaves,
N3,
WaveSize / Gemm0NPerXdl,
N5))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7, 8, 9>{}));
}
// B1 desc for source in blockwise copy
__host__ __device__ static constexpr auto
MakeDefaultB1GridDescriptor_BK0_N_BK1(const B1GridDesc_N_K& b1_grid_desc_n_k)
{
const auto N = b1_grid_desc_n_k.GetLength(I0);
const auto K = b1_grid_desc_n_k.GetLength(I1);
const auto B1K0 = K / B1K1;
return transform_tensor_descriptor(
b1_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
// C1 desc for destination in blockwise copy
__host__ __device__ static constexpr auto
MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const E1GridDesc_M_N& e1_grid_desc_m_n)
{
const auto M = e1_grid_desc_m_n.GetLength(I0);
const auto N = e1_grid_desc_m_n.GetLength(I1);
const auto MBlock = M / Gemm0MPerBlock;
const auto NBlock = N / Gemm1NPerBlock;
const auto e1_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
e1_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<Gemm0MPerBlock>{})),
make_unmerge_transform(make_tuple(NBlock, Number<Gemm1NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
return e1_grid_desc_mblock_mperblock_nblock_nperblock;
}
// D0s desc for source in blockwise copy
__host__ __device__ static constexpr auto
MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(const D0sGridDesc_M_N& ds_grid_desc_m_n)
{
return generate_tuple(
[&](auto i) {
return MakeGemm0D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(ds_grid_desc_m_n[i]);
},
Number<NumD0Tensor>{});
}
// Ds desc for source in blockwise copy
template <typename DsGridDescriptor_M_N>
__host__ __device__ static constexpr auto
MakeD1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
const DsGridDescriptor_M_N& ds_grid_desc_m_n)
{
return generate_tuple(
[&](auto i) {
return MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[i]);
},
Number<NumD1Tensor>{});
}
// return block_id to C1 matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto
MakeDefaultBlock2E1TileMap(const E1GridDesc_M_N& e1_grid_desc_m_n)
{
return BlockToCTileMap_M00_N0_M01Adapt<Gemm0MPerBlock, Gemm1NPerBlock, E1GridDesc_M_N>(
e1_grid_desc_m_n);
}
using E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(E1GridDesc_M_N{}))>;
using D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 = remove_cvref_t<decltype(
MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(D0sGridDesc_M_N{}))>;
using D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeD1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(D1sGridDesc_M_N{}))>;
using DefaultBlock2E1TileMap =
remove_cvref_t<decltype(MakeDefaultBlock2E1TileMap(E1GridDesc_M_N{}))>;
struct SharedMemTrait
{
// LDS allocation for A0 and B0: be careful of alignment
static constexpr auto a0_block_desc_ak0_m_ak1 =
GetA0BlockDescriptor_AK0PerBlock_MPerBlock_AK1();
static constexpr auto b0_block_desc_bk0_n_bk1 =
GetB0BlockDescriptor_BK0PerBlock_NPerBlock_BK1();
static constexpr auto b1_block_desc_bk0_n_bk1 =
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1();
static constexpr auto max_lds_align = math::lcm(math::lcm(A0K1, B0K1), B1K1);
static constexpr auto a0_block_space_size_aligned = math::integer_least_multiple(
a0_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
static constexpr auto b0_block_space_size_aligned = math::integer_least_multiple(
b0_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
static constexpr auto b1_block_space_size_aligned = math::integer_least_multiple(
b1_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
static constexpr auto a0_block_space_offset = 0;
static constexpr auto b0_block_space_offset = a0_block_space_size_aligned.value;
static constexpr auto b1_block_space_offset = 0;
// LDS allocation for C1 shuffle in LDS
static constexpr auto c1_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetC1ShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
static constexpr auto c1_block_space_size =
c1_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
};
using D0sGridPointer = decltype(MakeD0sGridPointer());
using D1sGridPointer = decltype(MakeD1sGridPointer());
template <bool HasMainKBlockLoop,
typename A0GridDesc_AK0_M_AK1,
typename B0GridDesc_BK0_N_BK1,
typename B1GridDesc_BK0_N_BK1,
typename Block2E1TileMap>
__device__ static void Run(const A0B0B1DataType* __restrict__ p_a0_grid,
const A0B0B1DataType* __restrict__ p_b0_grid,
D0sGridPointer p_d0s_grid,
const A0B0B1DataType* __restrict__ p_b1_grid,
D1sGridPointer p_d1s_grid,
E1DataType* __restrict__ p_e1_grid,
void* __restrict__ p_shared,
const A0ElementwiseOperation& a0_element_op,
const B0ElementwiseOperation& b0_element_op,
const CDE0ElementwiseOperation& cde0_element_op,
const B1ElementwiseOperation& b1_element_op,
const CDE1ElementwiseOperation& cde1_element_op,
const A0GridDesc_AK0_M_AK1& a0_grid_desc_ak0_m_ak1,
const B0GridDesc_BK0_N_BK1& b0_grid_desc_bk0_n_bk1,
const D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5&
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
const B1GridDesc_BK0_N_BK1& b1_grid_desc_bk0_n_bk1,
const D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
d1s_grid_desc_mblock_mperblock_nblock_nperblock,
const E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
e1_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2E1TileMap& block_2_e1tile_map)
{
const auto a0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a0_grid, a0_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b0_grid, b0_grid_desc_bk0_n_bk1.GetElementSpaceSize());
const auto b1_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b1_grid, b1_grid_desc_bk0_n_bk1.GetElementSpaceSize());
auto e1_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_e1_grid, e1_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
const auto d0s_grid_buf = generate_tuple(
[&](auto i) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d0s_grid[i],
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i].GetElementSpaceSize());
},
Number<NumD0Tensor>{});
const auto d1s_grid_buf = generate_tuple(
[&](auto i) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d1s_grid[i],
d1s_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
},
Number<NumD1Tensor>{});
// divide block work by [M, N]
const auto block_work_idx =
block_2_e1tile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!block_2_e1tile_map.ValidCTileIndex(
block_work_idx,
make_tuple(e1_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
e1_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
{
return;
}
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * Gemm0MPerBlock);
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * Gemm1NPerBlock);
// A0 matrix in LDS memory, dst of blockwise copy
constexpr auto a0_block_desc_ak0_m_ak1 = GetA0BlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// B0 matrix in LDS memory, dst of blockwise copy
constexpr auto b0_block_desc_bk0_n_bk1 = GetB0BlockDescriptor_BK0PerBlock_NPerBlock_BK1();
//
// set up Gemm0
//
// A0 matrix blockwise copy
auto a0_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
A0ElementwiseOperation,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<A0K0PerBlock, Gemm0MPerBlock, A0K1>,
A0BlockTransferThreadClusterLengths_AK0_M_AK1,
A0BlockTransferThreadClusterArrangeOrder,
A0B0B1DataType,
A0B0B1DataType,
decltype(a0_grid_desc_ak0_m_ak1),
decltype(a0_block_desc_ak0_m_ak1),
A0BlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
A0BlockTransferSrcVectorDim,
2,
A0BlockTransferSrcScalarPerVector,
A0BlockTransferDstScalarPerVector_AK1,
1,
1,
true, // SrcResetCoord
true, // DstResetCoord
NumGemm0KPrefetchStage>(
a0_grid_desc_ak0_m_ak1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
a0_element_op,
a0_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
// B0 matrix blockwise copy
auto b0_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
B0ElementwiseOperation,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<B0K0PerBlock, Gemm0NPerBlock, B0K1>,
B0BlockTransferThreadClusterLengths_BK0_N_BK1,
B0BlockTransferThreadClusterArrangeOrder,
A0B0B1DataType,
A0B0B1DataType,
decltype(b0_grid_desc_bk0_n_bk1),
decltype(b0_block_desc_bk0_n_bk1),
B0BlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
B0BlockTransferSrcVectorDim,
2,
B0BlockTransferSrcScalarPerVector,
B0BlockTransferDstScalarPerVector_BK1,
1,
1,
true, // SrcResetCoord
true, // DstResetCoord
NumGemm0KPrefetchStage>(
b0_grid_desc_bk0_n_bk1,
make_multi_index(0, 0, 0), // will loop over GemmN dimension
b0_element_op,
b0_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
// Fused Gemm+Gemm pipeline
// for n in N0:
// for k in K0:
// acc[m][n] += A[m][k] * B0[k][n]
// acc1[m][o] += acc[m][n] * B1[n][o]
// sanity check
constexpr index_t KPack = math::max(
math::lcm(A0K1, B0K1),
MfmaSelector<A0B0B1DataType, Gemm0MPerXdl, Gemm0NPerXdl>::selected_mfma.k_per_blk);
auto blockwise_gemm0 = BlockwiseGemmXdlops_v2<
BlockSize,
A0B0B1DataType,
Acc0DataType,
decltype(a0_block_desc_ak0_m_ak1),
decltype(b0_block_desc_bk0_n_bk1),
decltype(MakeGemm0AMmaTileDescriptor_M0_M1_M2_K(a0_block_desc_ak0_m_ak1)),
decltype(MakeGemm0BMmaTileDescriptor_N0_N1_N2_K(b0_block_desc_bk0_n_bk1)),
Gemm0MPerBlock,
Gemm0NPerBlock,
Gemm0KPerBlock,
Gemm0MPerXdl,
Gemm0NPerXdl,
Gemm0MXdlPerWave,
Gemm0NXdlPerWave,
KPack,
true>{}; // TransposeC
auto acc0_thread_buf = blockwise_gemm0.GetCThreadBuffer();
// LDS allocation for A0 and B0: be careful of alignment
auto a0_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<A0B0B1DataType*>(p_shared) + SharedMemTrait::a0_block_space_offset,
a0_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b0_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<A0B0B1DataType*>(p_shared) + SharedMemTrait::b0_block_space_offset,
b0_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a0_block_slice_copy_step = make_multi_index(Gemm0KPerBlock / A0K1, 0, 0);
constexpr auto b0_block_slice_copy_step = make_multi_index(Gemm0KPerBlock / B0K1, 0, 0);
const auto a0_block_reset_copy_step =
make_multi_index(-a0_grid_desc_ak0_m_ak1.GetLength(I0), 0, 0);
const auto b0_block_reset_copy_step =
make_multi_index(-b0_grid_desc_bk0_n_bk1.GetLength(I0), Gemm0NPerBlock, 0);
// gridwise GEMM pipeline
// Only supports LoopScheduler::Default
const auto gridwise_gemm0_pipeline =
GridwiseGemmPipeline_v1_Selector<NumGemm0KPrefetchStage, LoopScheduler::Default>();
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a0_grid_desc_ak0_m_ak1.GetLength(I0) * a0_grid_desc_ak0_m_ak1.GetLength(I2)) /
Gemm0KPerBlock);
//
// set up Gemm1
//
// Acc0 matrix threadwise copy: AccVGPR to VGPR and downcast to XDL input data type
constexpr auto acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
blockwise_gemm0.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
constexpr auto m0 = acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0);
constexpr auto n0 = acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I1);
constexpr auto m1 = acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I2);
constexpr auto n1 = acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I3);
constexpr auto m2 = acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I4);
constexpr auto n2 = acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I5);
constexpr auto n3 = acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I6);
constexpr auto n4 = acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I7);
constexpr auto b1_block_slice_copy_step = make_multi_index(Gemm1KPerBlock / B1K1, 0, 0);
// d0 matrix threadwise copy
constexpr auto d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId
I1, // NBlockID
I1, // MRepeat
I1, // NRepeat
I1, // MWaveId
I1, // NWaveId
I1, // MPerXdl
I1, // NGroupNum
I1, // NInputNum
n4)); // registerNum
auto d0s_thread_buf = generate_tuple(
[&](auto) {
return StaticBuffer<
AddressSpaceEnum::Vgpr,
A0B0B1DataType,
d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize(),
true>{};
},
Number<NumD0Tensor>{});
const auto wave_id = GetGemm0WaveIdx();
const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63
constexpr auto acc0_thread_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<Gemm0MXdlPerWave>{}, Number<Gemm0NXdlPerWave>{}, n2, n4));
auto d0s_threadwise_copy = generate_tuple(
[&](auto i) {
return ThreadwiseTensorSliceTransfer_v2<
A0B0B1DataType,
A0B0B1DataType,
decltype(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i]),
decltype(d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
Sequence<I1, I1, I1, I1, I1, I1, I1, I1, I1, n4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
9,
n4,
1,
false>(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
make_multi_index(block_work_idx[I0], // MBlockId
0, // NBlockId
0, // mrepeat
0, // nrepeat
wave_id[I0], // MWaveId
wave_id[I1], // NWaveId
wave_m_n_id[I1], // MPerXdl
0, // group
wave_m_n_id[I0], // NInputIndex
0)); // register number
},
Number<NumD0Tensor>{});
// acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 to acc0_thread_desc_k0_m_k1
// n0_n1_n2_n3 -> k0
// m0_m1_m2 -> m
// n4 -> k1
// NOTE: had to use merge_v3 or will spit out compilation errors
constexpr auto acc0_thread_desc_k0_m_k1 = transform_tensor_descriptor(
acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_tuple(make_merge_transform_v3_division_mod(make_tuple(n0, n1, n2, n3)),
make_merge_transform_v3_division_mod(make_tuple(m0, m1, m2)),
make_pass_through_transform(n4)),
make_tuple(Sequence<1, 3, 5, 6>{}, Sequence<0, 2, 4>{}, Sequence<7>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// A1 matrix in AccVGPR
// N2 num_groups_per_blk, N3 num_input_blks, N4 group_size
constexpr auto Acc0N3 =
blockwise_gemm0.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLength(I6);
constexpr auto A1ThreadSlice_K0_M_K1 = make_tuple(
Number<Gemm1KPerBlock / n4 / Acc0N3>{}, Number<m0 * m1 * m2>{}, Number<n4>{});
constexpr auto A1ThreadSliceK0 = A1ThreadSlice_K0_M_K1[I0];
constexpr auto A1ThreadSliceM = A1ThreadSlice_K0_M_K1[I1];
constexpr auto A1ThreadSliceK1 = A1ThreadSlice_K0_M_K1[I2];
constexpr auto a1_thread_desc_k0_m_k1 = make_naive_tensor_descriptor(
A1ThreadSlice_K0_M_K1,
make_tuple(A1ThreadSliceM * A1ThreadSliceK1, A1ThreadSliceK1, I1));
// B1 matrix in LDS memory, dst of blockwise copy
constexpr auto b1_block_desc_bk0_n_bk1 = GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// A1 matrix blockwise copy
auto a1_blockwise_copy = ThreadwiseTensorSliceTransfer_StaticToStatic<
Acc0DataType,
A0B0B1DataType,
decltype(acc0_thread_desc_k0_m_k1),
decltype(a1_thread_desc_k0_m_k1),
tensor_operation::element_wise::PassThrough,
Sequence<A1ThreadSliceK0, A1ThreadSliceM, A1ThreadSliceK1>,
Sequence<1, 0, 2>,
2,
n4>{tensor_operation::element_wise::PassThrough{}};
// B1 matrix blockwise copy
auto b1_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
B0ElementwiseOperation,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<B1K0PerBlock, Gemm1NPerBlock, B1K1>,
B1BlockTransferThreadClusterLengths_BK0_N_BK1,
B1BlockTransferThreadClusterArrangeOrder,
A0B0B1DataType,
A0B0B1DataType,
decltype(b1_grid_desc_bk0_n_bk1),
decltype(b1_block_desc_bk0_n_bk1),
B1BlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
B1BlockTransferSrcVectorDim,
2,
B1BlockTransferSrcScalarPerVector,
B1BlockTransferDstScalarPerVector_BK1,
1,
1,
B1ThreadTransferSrcResetCoordinateAfterRun,
true, // DstResetCoord
1>(b1_grid_desc_bk0_n_bk1,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b1_element_op,
b1_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
auto a1_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, A0B0B1DataType>(
a1_thread_desc_k0_m_k1.GetElementSpaceSize());
// reuse LDS space for gemm0's b0_block_buf
auto b1_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<A0B0B1DataType*>(p_shared) + SharedMemTrait::b1_block_space_offset,
b1_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr index_t Gemm1KPack = math::max(
math::lcm(
MfmaSelector<A0B0B1DataType, Gemm0MPerXdl, Gemm0NPerXdl>::selected_mfma.group_size,
B1K1),
MfmaSelector<A0B0B1DataType, Gemm0MPerXdl, Gemm0NPerXdl>::selected_mfma.k_per_blk);
auto blockwise_gemm1 = BlockwiseGemmXdlops_v2<
BlockSize,
A0B0B1DataType,
Acc1DataType,
decltype(a1_thread_desc_k0_m_k1),
decltype(b1_block_desc_bk0_n_bk1),
decltype(MakeGemm1AMmaTileDescriptor_M0_M1_M2_K(a1_thread_desc_k0_m_k1)),
decltype(MakeGemm1BMmaTileDescriptor_N0_N1_N2_K(b1_block_desc_bk0_n_bk1)),
Gemm0MPerBlock,
Gemm1NPerBlock,
Gemm1KPerBlock,
Gemm0MPerXdl,
Gemm0NPerXdl,
Gemm0MXdlPerWave,
Gemm1NXdlPerWave,
Gemm1KPack,
false, // TransposeC
Gemm1KPack, // AMmaKStride
Gemm1KPack * XdlopsGemm<A0B0B1DataType, Gemm0MPerXdl, Gemm0NPerXdl, Gemm1KPack, false>{}
.K0PerXdlops>{ // BMmaKStride
make_tuple(0, 0, 0, 0)}; // A_origin
auto c1_thread_buf = blockwise_gemm1.GetCThreadBuffer();
const index_t num_gemm1_k_block_outer_loop =
b0_grid_desc_bk0_n_bk1.GetLength(I1) / Gemm0NPerBlock;
constexpr index_t num_gemm1_k_block_inner_loop = Gemm0NPerBlock / Gemm1KPerBlock;
// Initialize C1
c1_thread_buf.Clear();
// gemm1 K loop
index_t gemm1_k_block_outer_index = 0;
do
{
// gemm0
gridwise_gemm0_pipeline.template Run<HasMainKBlockLoop>(a0_grid_desc_ak0_m_ak1,
a0_block_desc_ak0_m_ak1,
a0_blockwise_copy,
a0_grid_buf,
a0_block_buf,
a0_block_slice_copy_step,
b0_grid_desc_bk0_n_bk1,
b0_block_desc_bk0_n_bk1,
b0_blockwise_copy,
b0_grid_buf,
b0_block_buf,
b0_block_slice_copy_step,
blockwise_gemm0,
acc0_thread_buf,
num_k_block_main_loop);
// bias+gelu
{
static_for<0, Gemm0MXdlPerWave, 1>{}([&](auto mr) {
static_for<0, Gemm0NXdlPerWave, 1>{}([&](auto nr) {
static_for<0, n2, 1>{}([&](auto groupid) {
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
d0s_threadwise_copy(i).Run(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
d0s_grid_buf[i],
d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
d0s_thread_buf(i));
});
static_for<0, n4, 1>{}([&](auto i) {
constexpr index_t c_offset = acc0_thread_desc.CalculateOffset(
make_tuple(mr, nr, groupid, i));
// get reference to src data
const auto src_data_refs = generate_tie(
// return type should be lvalue
[&](auto iSrc) -> const auto& {
return d0s_thread_buf[iSrc][i];
},
Number<NumD0Tensor>{});
// get reference to dst data
auto dst_data_refs = generate_tie(
// return type should be lvalue
[&](auto) -> auto& {
return acc0_thread_buf(Number<c_offset>{});
},
Number<2>{});
unpack2(cde0_element_op, dst_data_refs, src_data_refs);
});
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
d0s_threadwise_copy(i).MoveSrcSliceWindow(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
make_multi_index(0, 0, 0, 0, 0, 0, 0, 1, 0, 0));
});
});
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
d0s_threadwise_copy(i).MoveSrcSliceWindow(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
make_multi_index(0, 0, 0, 1, 0, 0, 0, -n2.value, 0, 0));
});
});
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
d0s_threadwise_copy(i).MoveSrcSliceWindow(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
make_multi_index(0, 0, 1, -Gemm0NXdlPerWave, 0, 0, 0, 0, 0, 0));
});
});
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
d0s_threadwise_copy(i).MoveSrcSliceWindow(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
make_multi_index(0, 1, -Gemm0MXdlPerWave, 0, 0, 0, 0, 0, 0, 0));
});
}
// gemm1
{
// TODO: explore using dynamic buffer for a1 thread buffer
// For a1_blockwise_copy, the goal is to satisfy pipeline requirements RunRead(),
// RunWrite(), and MoveSliceWindow(). But it is impossible to implement given that
// the A1 source buffer is static buffer holding the output of first GEMM and
// requires constexpr offset by design. Therefore, we pass tensor coordinate offset
// explicitly in Run() below.
// preload data into LDS
b1_blockwise_copy.RunRead(b1_grid_desc_bk0_n_bk1, b1_grid_buf);
b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc_bk0_n_bk1,
b1_block_slice_copy_step);
block_sync_lds(); // wait for gemm0 LDS read
b1_blockwise_copy.RunWrite(b1_block_desc_bk0_n_bk1, b1_block_buf);
// main body
if constexpr(num_gemm1_k_block_inner_loop > 1)
{
static_for<0, num_gemm1_k_block_inner_loop - 1, 1>{}([&](auto i) {
a1_blockwise_copy.Run(acc0_thread_desc_k0_m_k1,
make_tuple(Number<i * A1ThreadSliceK0>{}, I0, I0),
acc0_thread_buf,
a1_thread_desc_k0_m_k1,
make_tuple(I0, I0, I0),
a1_thread_buf);
b1_blockwise_copy.RunRead(b1_grid_desc_bk0_n_bk1, b1_grid_buf);
block_sync_lds();
blockwise_gemm1.Run(a1_thread_buf, b1_block_buf, c1_thread_buf);
block_sync_lds();
b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc_bk0_n_bk1,
b1_block_slice_copy_step);
b1_blockwise_copy.RunWrite(b1_block_desc_bk0_n_bk1, b1_block_buf);
});
}
// tail
{
a1_blockwise_copy.Run(
acc0_thread_desc_k0_m_k1,
make_tuple(
Number<(num_gemm1_k_block_inner_loop - 1) * A1ThreadSliceK0>{}, I0, I0),
acc0_thread_buf,
a1_thread_desc_k0_m_k1,
make_tuple(I0, I0, I0),
a1_thread_buf);
block_sync_lds();
blockwise_gemm1.Run(a1_thread_buf, b1_block_buf, c1_thread_buf);
}
} // end gemm1
a0_blockwise_copy.MoveSrcSliceWindow(a0_grid_desc_ak0_m_ak1,
a0_block_reset_copy_step); // rewind K
b0_blockwise_copy.MoveSrcSliceWindow(b0_grid_desc_bk0_n_bk1,
b0_block_reset_copy_step); // rewind K and step N
block_sync_lds(); // wait for gemm1 LDS read
} while(++gemm1_k_block_outer_index < num_gemm1_k_block_outer_loop); // end j loop
// shuffle C1 and write out
{
static_assert(Gemm0MXdlPerWave % C1ShuffleGemm0MXdlPerWavePerShuffle == 0 &&
Gemm1NXdlPerWave % C1ShuffleGemm0NXdlPerWavePerShuffle == 0,
"wrong!");
constexpr index_t MWave = Gemm0MPerBlock / (Gemm0MXdlPerWave * Gemm0MPerXdl);
constexpr index_t NWave = Gemm1NPerBlock / (Gemm1NXdlPerWave * Gemm0NPerXdl);
// TODO: hacky, fix it!
constexpr auto c1_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
blockwise_gemm1.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
// TODO: hacky, fix it!
// c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
constexpr auto c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
blockwise_gemm1.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto M0 = c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
constexpr auto N0 = c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
constexpr auto M1 = c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
constexpr auto N1 = c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
constexpr auto M2 = c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
constexpr auto M3 = c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
constexpr auto M4 = c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
constexpr auto N2 = c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
constexpr auto c1_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetC1ShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
auto c1_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<C1ShuffleDataType*>(p_shared),
c1_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
constexpr auto c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
c1_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_tuple(
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<C1ShuffleGemm0MXdlPerWavePerShuffle>{}, // M0 (Gemm0MXdlPerWave) per
// shuffle
M1, // M1 = MWave
M2, // M2 * M3 * M4 = Gemm0MPerXdl
M3,
M4)),
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<C1ShuffleGemm0NXdlPerWavePerShuffle>{}, // N0 (Gemm0NXdlPerWave) per
// shuffle
N1, // N1 = NWave
N2))), // N2 = Gemm0NPerXdl
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(
Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM C1 matrix starting index
const auto c1_thread_mtx_on_block =
blockwise_gemm1.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_block = c1_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c1_thread_mtx_on_block[I1];
const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_block_idx =
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_block));
const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_block_idx =
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block));
// shuffle: threadwise copy C from VGPR to LDS
auto c1_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<Acc1DataType,
C1ShuffleDataType,
decltype(c1_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
tensor_operation::element_wise::PassThrough,
Sequence<C1ShuffleGemm0MXdlPerWavePerShuffle,
C1ShuffleGemm0NXdlPerWavePerShuffle,
I1,
I1,
M2,
I1,
M4,
I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
1,
InMemoryDataOperationEnum::Set,
1,
true>{
c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(0,
0,
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I3],
m_thread_data_on_block_idx[I4],
n_thread_data_on_block_idx[I2]),
tensor_operation::element_wise::PassThrough{}};
// tuple of reference to C/Ds tensor descriptors
const auto c1_d1s_desc_refs = concat_tuple_of_reference(
tie(c1_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
generate_tie(
[&](auto i) -> const auto& // return type should be reference
{ return d1s_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
Number<NumD1Tensor>{}));
// tuple of reference to C/Ds tensor descriptors
const auto c1_d1s_buf_refs = concat_tuple_of_reference(
tie(c1_shuffle_block_buf),
generate_tie(
[&](auto i) -> const auto& // return type should be reference
{ return d1s_grid_buf[i]; },
Number<NumD1Tensor>{}));
// tuple of starting index of C/Ds blockwise copy
const auto idx_c1_d1s_block_begin = container_concat(
make_tuple(make_multi_index(0, 0, 0, 0)),
generate_tuple(
[&](auto) {
return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0);
},
Number<NumD1Tensor>{}));
// shuffle: blockwise copy C from LDS to global
auto cde1_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v7<
ThisThreadBlock,
decltype(container_concat(make_tuple(C1ShuffleDataType{}), D1sDataType{})),
Tuple<E1DataType>,
decltype(c1_d1s_desc_refs),
decltype(tie(e1_grid_desc_mblock_mperblock_nblock_nperblock)),
CDE1ElementwiseOperation,
Sequence<static_cast<index_t>(E1GlobalMemoryDataOperation)>, // FIXME: make Sequence
// support arbitray
// type
Sequence<1,
C1ShuffleGemm0MXdlPerWavePerShuffle * MWave * Gemm0MPerXdl,
1,
C1ShuffleGemm0NXdlPerWavePerShuffle * NWave *
Gemm0NPerXdl>, // BlockSliceLengths,
CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CDE1ShuffleBlockTransferScalarPerVector_NPerBlock,
sequence_merge_t<
Sequence<true>,
uniform_sequence_gen_t<NumD1Tensor,
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
Sequence<false>> // ThreadTransferDstResetCoordinateAfterRunFlags
{c1_d1s_desc_refs,
idx_c1_d1s_block_begin,
tie(e1_grid_desc_mblock_mperblock_nblock_nperblock),
make_tuple(make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)),
cde1_element_op};
// space filling curve for threadwise C in VGPR
constexpr auto sfc_c1_vgpr =
SpaceFillingCurve<Sequence<Gemm0MXdlPerWave, Gemm1NXdlPerWave, 1, 1, M2, 1, M4, 1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
Sequence<C1ShuffleGemm0MXdlPerWavePerShuffle,
C1ShuffleGemm0NXdlPerWavePerShuffle,
1,
1,
M2,
1,
M4,
1>>{};
// space filling curve for shuffled blockwise C in global mem
constexpr auto sfc_e1_global = SpaceFillingCurve<
Sequence<1, Gemm0MPerBlock, 1, Gemm1NPerBlock>,
Sequence<0, 2, 1, 3>,
Sequence<1,
C1ShuffleGemm0MXdlPerWavePerShuffle * MWave * Gemm0MPerXdl,
1,
C1ShuffleGemm0NXdlPerWavePerShuffle * NWave * Gemm0NPerXdl>>{};
constexpr index_t num_access = sfc_c1_vgpr.GetNumOfAccess();
static_assert(num_access == sfc_e1_global.GetNumOfAccess(), "wrong!");
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
block_sync_lds();
// each thread write its data from VGPR to LDS
c1_thread_copy_vgpr_to_lds.Run(c1_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
sfc_c1_vgpr.GetIndexTupleOfNumber(access_id),
c1_thread_buf,
c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c1_shuffle_block_buf);
// make sure it's safe to read from LDS
block_sync_lds();
// each block copy its data from LDS to global
cde1_shuffle_block_copy_lds_to_global.Run(
c1_d1s_desc_refs,
c1_d1s_buf_refs,
tie(e1_grid_desc_mblock_mperblock_nblock_nperblock),
tie(e1_grid_buf));
if constexpr(access_id < num_access - 1)
{
constexpr auto e1_global_step = sfc_e1_global.GetForwardStep(access_id);
// move on D1s
static_for<0, NumD1Tensor, 1>{}([&](auto i) {
cde1_shuffle_block_copy_lds_to_global.MoveSrcSliceWindow(
c1_d1s_desc_refs, i + I1, e1_global_step);
});
// move on C
cde1_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
tie(e1_grid_desc_mblock_mperblock_nblock_nperblock), I0, e1_global_step);
}
});
}
}
};
} // namespace ck
......@@ -75,7 +75,8 @@ template <typename FloatAB,
index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched>
LoopScheduler LoopSched,
bool PadN>
struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
{
static_assert(LoopSched == LoopScheduler::Default,
......@@ -182,11 +183,19 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
return math::max((SharedMemTrait::a_block_space_size_aligned +
SharedMemTrait::b_block_space_size_aligned) *
sizeof(FloatAB) +
SharedMemTrait::reduction_workspace * sizeof(FloatGemmAcc),
SharedMemTrait::c_block_size * sizeof(FloatCShuffle));
const index_t gemm0_bytes_end = (SharedMemTrait::a_block_space_size_aligned +
SharedMemTrait::b_block_space_size_aligned) *
sizeof(FloatAB);
const index_t gemm1_bytes_end =
(SharedMemTrait::b1_block_space_offset + SharedMemTrait::b1_block_space_size_aligned) *
sizeof(FloatAB);
const index_t softmax_bytes_end = (SharedMemTrait::reduction_space_offset +
SharedMemTrait::reduction_space_size_aligned) *
sizeof(FloatGemmAcc);
const index_t c_block_bytes_end =
SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle);
return math::max(gemm0_bytes_end, gemm1_bytes_end, softmax_bytes_end, c_block_bytes_end);
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
......@@ -237,8 +246,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
return false;
}
assert(num_gemm1_k_outer_loop * num_gemm1_k_inner_loop == N / Gemm1KPerBlock);
if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n))
{
return false;
......@@ -302,25 +309,58 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
static constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
static constexpr auto b0_block_space_size_aligned = math::integer_least_multiple(
static constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
static constexpr auto b1_block_space_size_aligned = math::integer_least_multiple(
b1_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
// B1 can reuse B's LDS
static constexpr auto b_block_space_size_aligned =
math::max(b0_block_space_size_aligned.value, b1_block_space_size_aligned.value);
static constexpr auto a_block_space_offset = 0;
static constexpr auto b_block_space_offset = a_block_space_size_aligned.value;
static constexpr auto b1_block_space_offset = 0;
// LDS allocation for reduction
static constexpr index_t reduction_workspace = BlockSize;
static constexpr index_t reduction_space_size_aligned =
math::integer_least_multiple(BlockSize, max_lds_align);
static constexpr auto reduction_space_offset = 0;
// LDS allocation for C shuffle in LDS
static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
static constexpr auto c_block_size =
static constexpr auto c_block_space_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
};
template <bool Pred>
struct ElementOpPredicatedResetNaNToMinusInf;
template <>
struct ElementOpPredicatedResetNaNToMinusInf<true>
{
template <typename ElementOp, typename OutT, typename InT>
__host__ __device__ void Run(OutT& y, const ElementOp& op, const InT& x)
{
if(ck::math::isnan(x))
{
y = -ck::NumericLimits<float>::Infinity();
}
else
{
op(y, x);
}
}
};
template <>
struct ElementOpPredicatedResetNaNToMinusInf<false>
{
template <typename ElementOp, typename OutT, typename InT>
__host__ __device__ void Run(OutT& y, const ElementOp& op, const InT& x)
{
op(y, x);
}
};
template <bool HasMainKBlockLoop, typename Block2CTileMap>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
......@@ -339,10 +379,20 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
c_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2CTileMap& block_2_ctile_map)
{
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
const auto a_grid_buf =
conditional_expr<PadN>(make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid,
a_grid_desc_ak0_m_ak1.GetElementSpaceSize(),
NumericLimits<FloatAB>::QuietNaN()),
make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()));
const auto b_grid_buf =
conditional_expr<PadN>(make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid,
b_grid_desc_bk0_n_bk1.GetElementSpaceSize(),
NumericLimits<FloatAB>::QuietNaN()),
make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()));
const auto b1_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b1_grid, b1_grid_desc_bk0_n_bk1.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
......@@ -471,10 +521,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// LDS allocation for A and B: be careful of alignment
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
static_cast<FloatAB*>(p_shared) + SharedMemTrait::a_block_space_offset,
a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared) + SharedMemTrait::a_block_space_size_aligned,
static_cast<FloatAB*>(p_shared) + SharedMemTrait::b_block_space_offset,
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
......@@ -549,11 +600,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
FloatAB,
decltype(acc_thread_desc_k0_m_k1),
decltype(a1_thread_desc_k0_m_k1),
decltype(acc_element_op),
tensor_operation::element_wise::PassThrough,
Sequence<A1ThreadSliceK0, A1ThreadSliceM, A1ThreadSliceK1>,
Sequence<1, 0, 2>,
2,
n4>{acc_element_op};
n4>{tensor_operation::element_wise::PassThrough{}};
// B1 matrix blockwise copy
auto b1_blockwise_copy =
......@@ -591,12 +642,20 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// reuse LDS space for gemm0's b_block_buf
auto b1_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared) + SharedMemTrait::a_block_space_size_aligned,
static_cast<FloatAB*>(p_shared) + SharedMemTrait::b1_block_space_offset,
b1_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr index_t Gemm1KPack = math::max(
math::lcm(MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.group_size, B1K1),
MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
// selected_mfma.group_size or B1K1 <= Gemm1KPack <= selected_mfma.group_size
// selected_mfma.k_per_blk <= Gemm1KPack
//
// Following similar rationale behind Gemm0KPack, let Gemm1KPack be the lowest common
// multiples of A1K1 (predetermined by selected_mfma.group_size) and B1K1. But in this case
// Gemm1KPack can't be higher than A1K1 itself because A1 matrix is distributed in VGPRs
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
constexpr index_t Gemm1KPack =
MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.group_size;
auto gemm1_blockwise_gemm = BlockwiseGemmXdlops_v2<
BlockSize,
......@@ -617,7 +676,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
true, // TransposeC
Gemm1KPack, // AMmaKStride
Gemm1KPack * XdlopsGemm<FloatAB, MPerXdl, NPerXdl, Gemm1KPack, false>{}.K0PerXdlops>{
make_tuple(0, 0, 0, 0)}; // TransposeC
// BMmaKStride
make_tuple(0, 0, 0, 0)}; // A_origin
auto acc1_thread_buf = gemm1_blockwise_gemm.GetCThreadBuffer();
......@@ -625,10 +685,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// Blockwise softmax
//
auto workspace_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatGemmAcc*>(p_shared) +
SharedMemTrait::a_block_space_size_aligned * sizeof(FloatAB) / 4 +
SharedMemTrait::b_block_space_size_aligned * sizeof(FloatAB) / 4,
SharedMemTrait::reduction_workspace);
static_cast<FloatGemmAcc*>(p_shared) + SharedMemTrait::reduction_space_offset,
SharedMemTrait::reduction_space_size_aligned);
// get acc0 8D thread cluster
constexpr auto thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4 =
......@@ -667,7 +725,12 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
FloatGemmAcc,
decltype(threadid_to_m_n_thread_cluster_adaptor),
decltype(thread_cluster_desc_m_n),
decltype(thread_slice_desc_m_n)>{};
decltype(thread_slice_desc_m_n)
#if CK_WORKAROUND_SWDEV_XXXXXX_ATTN_KERNEL_CLANG_CANNOT_SCAVENGE_REGISTER
,
true
#endif
>{};
const index_t num_gemm1_k_block_outer_loop =
b_grid_desc_bk0_n_bk1.GetLength(I1) / NPerBlock;
......@@ -706,6 +769,20 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
blockwise_gemm,
acc_thread_buf,
num_k_block_main_loop);
// Acc0 elementwise Op
#if CK_WORKAROUND_SWDEV_XXXXXX_ATTN_KERNEL_CLANG_CANNOT_SCAVENGE_REGISTER
static_for<0, acc_thread_buf.Size(), 1>{}(
[&](auto i) { acc_element_op(acc_thread_buf(i), acc_thread_buf[i]); });
#else
static_for<0, acc_thread_buf.Size(), 1>{}([&](auto i) {
ElementOpPredicatedResetNaNToMinusInf<PadN>{}.Run(
acc_thread_buf(i), acc_element_op, acc_thread_buf[i]);
});
#endif
block_sync_lds(); // wait for lds read in gemm0 blockwise gemm
// softmax
SoftmaxBuf& max = blockwise_softmax.max_value_buf;
SoftmaxBuf& sum = blockwise_softmax.sum_value_buf;
......@@ -717,7 +794,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
running_sum_new = mathext::exp(running_max - running_max_new) * running_sum +
mathext::exp(max - running_max_new) * sum;
block_sync_lds();
// gemm1
{
// TODO: explore using dynamic buffer for a1 thread buffer
......@@ -736,12 +812,13 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc_bk0_n_bk1,
b1_block_slice_copy_step);
block_sync_lds(); // wait for reduction LDS read
b1_blockwise_copy.RunWrite(b1_block_desc_bk0_n_bk1, b1_block_buf);
// main body
if constexpr(num_gemm1_k_block_inner_loop > 1)
{
static_for<0, num_gemm1_k_block_inner_loop - 1, 1>{}([&](auto i) {
a1_blockwise_copy.Run(acc_thread_desc_k0_m_k1,
make_tuple(Number<i * A1ThreadSliceK0>{}, I0, I0),
......@@ -749,6 +826,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
a1_thread_desc_k0_m_k1,
make_tuple(I0, I0, I0),
a1_thread_buf);
b1_blockwise_copy.RunRead(b1_grid_desc_bk0_n_bk1, b1_grid_buf);
block_sync_lds();
......@@ -773,6 +851,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
a1_thread_desc_k0_m_k1,
make_tuple(I0, I0, I0),
a1_thread_buf);
block_sync_lds();
gemm1_blockwise_gemm.Run(a1_thread_buf, b1_block_buf, acc1_thread_buf);
......@@ -817,6 +896,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
running_max = running_max_new;
running_sum = running_sum_new;
block_sync_lds(); // wait for gemm1 LDS read
} while(++gemm1_k_block_outer_index < num_gemm1_k_block_outer_loop); // end j loop
// shuffle C and write out
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
namespace ck {
template <typename GridwiseBinEltwise,
typename ADataType,
typename BDataType,
typename CDataType,
typename AGridDesc_M,
typename BGridDesc_M,
typename CGridDesc_M,
typename ElementwiseFunctor>
__global__ void kernel_binary_elementwise_1d(const ADataType* __restrict__ p_a_global,
const BDataType* __restrict__ p_b_global,
CDataType* __restrict__ p_c_global,
const AGridDesc_M a_grid_desc_m,
const BGridDesc_M b_grid_desc_m,
const CGridDesc_M c_grid_desc_m,
const ElementwiseFunctor functor)
{
GridwiseBinEltwise::Run(
p_a_global, p_b_global, p_c_global, a_grid_desc_m, b_grid_desc_m, c_grid_desc_m, functor);
}
template <typename ADataType,
typename BDataType,
typename CDataType,
typename ComputeDataType,
typename AGridDesc_M,
typename BGridDesc_M,
typename CGridDesc_M,
typename ElementwiseFunctor,
index_t MPerThread,
index_t AScalarPerVector,
index_t BScalarPerVector,
index_t CScalarPerVector>
struct GridwiseBinaryElementwise_1D
{
static constexpr auto I0 = Number<0>{};
static constexpr auto thread_desc_m =
make_naive_tensor_descriptor_packed(make_tuple(Number<MPerThread>{}));
using PassThrough = tensor_operation::element_wise::PassThrough;
static __device__ auto CalculateElementwiseIndex()
{
const index_t global_thread_id = get_thread_global_1d_id();
return make_multi_index(global_thread_id * MPerThread);
}
__device__ static void Run(const ADataType* __restrict__ p_a_global,
const BDataType* __restrict__ p_b_global,
CDataType* __restrict__ p_c_global,
const AGridDesc_M a_grid_desc_m,
const BGridDesc_M b_grid_desc_m,
const CGridDesc_M c_grid_desc_m,
const ElementwiseFunctor functor)
{
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_global, a_grid_desc_m.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_global, b_grid_desc_m.GetElementSpaceSize());
auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_global, c_grid_desc_m.GetElementSpaceSize());
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MPerThread, true> a_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MPerThread, true> b_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MPerThread, true> c_thread_buf;
const auto thread_store_global_offset = CalculateElementwiseIndex();
auto a_global_load =
ThreadwiseTensorSliceTransfer_v2<ADataType,
ComputeDataType,
AGridDesc_M,
decltype(thread_desc_m),
Sequence<MPerThread>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
AScalarPerVector, // ScalarPerVector
1, // SrcScalarStrideInVector
false>{a_grid_desc_m, thread_store_global_offset};
auto b_global_load =
ThreadwiseTensorSliceTransfer_v2<BDataType,
ComputeDataType,
BGridDesc_M,
decltype(thread_desc_m),
Sequence<MPerThread>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
BScalarPerVector, // ScalarPerVector
1, // SrcScalarStrideInVector
false>{b_grid_desc_m, thread_store_global_offset};
auto c_global_write =
ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
CDataType,
decltype(thread_desc_m),
CGridDesc_M,
PassThrough,
Sequence<MPerThread>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // DstVectorDim
CScalarPerVector, // ScalarPerVector
InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector
false>{
c_grid_desc_m, thread_store_global_offset, PassThrough{}};
const index_t blockSize = get_block_size();
const index_t blockPerGrid = get_grid_size();
const auto M = c_grid_desc_m.GetLength(I0);
const index_t loop_step = blockPerGrid * blockSize * MPerThread;
const auto loop_step_index = make_multi_index(loop_step);
index_t num_iter = M / (loop_step);
do
{
// read and process MPerThread elements
a_global_load.Run(
a_grid_desc_m, a_global_buf, thread_desc_m, make_tuple(I0), a_thread_buf);
b_global_load.Run(
b_grid_desc_m, b_global_buf, thread_desc_m, make_tuple(I0), b_thread_buf);
static_for<0, MPerThread, 1>{}([&](auto m) {
constexpr auto offset = thread_desc_m.CalculateOffset(make_tuple(m));
functor(c_thread_buf(Number<offset>{}),
a_thread_buf(Number<offset>{}),
b_thread_buf(Number<offset>{}));
});
c_global_write.Run(thread_desc_m,
make_tuple(I0), // SrcSliceOriginIdx
c_thread_buf,
c_grid_desc_m,
c_global_buf);
a_global_load.MoveSrcSliceWindow(a_grid_desc_m, loop_step_index);
b_global_load.MoveSrcSliceWindow(b_grid_desc_m, loop_step_index);
c_global_write.MoveDstSliceWindow(c_grid_desc_m, loop_step_index);
} while(--num_iter);
}
};
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
template <typename GridwiseElementwise1dFunctor,
typename InGrid1dDescTuple,
typename OutGrid1dDescTuple,
typename InDataTypePointerTuple,
typename OutDataTypePointerTuple,
typename ElementwiseOperation>
__global__ void kernel_elementwise_1d(const InGrid1dDescTuple in_grid_1d_desc_tuple,
const OutGrid1dDescTuple out_grid_1d_desc_tuple,
const InDataTypePointerTuple p_in_global_tuple,
const OutDataTypePointerTuple p_out_global_tuple,
const ElementwiseOperation elementwise_op)
{
GridwiseElementwise1dFunctor::Run(in_grid_1d_desc_tuple,
out_grid_1d_desc_tuple,
p_in_global_tuple,
p_out_global_tuple,
elementwise_op);
}
template <typename InGrid1dDescTuple,
typename OutGrid1dDescTuple,
typename InDataTypePointerTuple,
typename OutDataTypePointerTuple,
typename ElementwiseOperation,
index_t MPerThread,
typename InScalarPerVectorSeq,
typename OutScalarPerVectorSeq>
struct GridwiseElementwise_1D
{
static constexpr index_t NumInput = InDataTypePointerTuple::Size();
static constexpr index_t NumOutput = OutDataTypePointerTuple::Size();
static_assert(NumInput == InScalarPerVectorSeq::Size() &&
NumOutput == OutScalarPerVectorSeq::Size() &&
NumInput == InGrid1dDescTuple::Size() &&
NumOutput == OutGrid1dDescTuple::Size(),
"Tuple size is inconsistent with the number of in/out!");
static constexpr auto I0 = Number<0>{};
static constexpr auto thread_buffer_desc_m =
make_naive_tensor_descriptor_packed(make_tuple(Number<MPerThread>{}));
using PassThroughOp = tensor_operation::element_wise::PassThrough;
__device__ static void Run(const InGrid1dDescTuple in_grid_1d_desc_tuple,
const OutGrid1dDescTuple out_grid_1d_desc_tuple,
const InDataTypePointerTuple p_in_global_tuple,
const OutDataTypePointerTuple p_out_global_tuple,
const ElementwiseOperation elementwise_op)
{
const index_t thread_global_id = get_thread_global_1d_id();
auto in_thread_buf_tuple = generate_tuple(
[&](auto I) {
using DataTypePointer = remove_cvref_t<decltype(InDataTypePointerTuple{}[I])>;
using DataType = remove_cv_t<remove_pointer_t<DataTypePointer>>;
return StaticBuffer<AddressSpaceEnum::Vgpr, DataType, MPerThread, true>{};
},
Number<NumInput>{});
auto out_thread_buf_tuple = generate_tuple(
[&](auto I) {
using DataTypePointer = remove_cvref_t<decltype(OutDataTypePointerTuple{}[I])>;
using DataType = remove_pointer_t<DataTypePointer>;
return StaticBuffer<AddressSpaceEnum::Vgpr, DataType, MPerThread, true>{};
},
Number<NumOutput>{});
auto in_global_buf_tuple = generate_tuple(
[&](auto I) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_global_tuple[I], in_grid_1d_desc_tuple[I].GetElementSpaceSize());
},
Number<NumInput>{});
auto out_global_buf_tuple = generate_tuple(
[&](auto I) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global_tuple[I], out_grid_1d_desc_tuple[I].GetElementSpaceSize());
},
Number<NumOutput>{});
const auto thread_global_offset = make_multi_index(thread_global_id * MPerThread);
const index_t blockSize = get_block_size();
const index_t blockPerGrid = get_grid_size();
const auto M = in_grid_1d_desc_tuple[I0].GetLength(I0);
const index_t loop_step = blockPerGrid * blockSize * MPerThread;
const auto loop_step_index = make_multi_index(loop_step);
auto in_global_load_tuple = generate_tuple(
[&](auto I) {
using DataTypePointer = remove_cvref_t<decltype(InDataTypePointerTuple{}[I])>;
using DataType = remove_cv_t<remove_pointer_t<DataTypePointer>>;
return ThreadwiseTensorSliceTransfer_v2<DataType,
DataType,
decltype(in_grid_1d_desc_tuple[I]),
decltype(thread_buffer_desc_m),
Sequence<MPerThread>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
InScalarPerVectorSeq::At(
I), // ScalarPerVector
1, // SrcScalarStrideInVector
false>{in_grid_1d_desc_tuple[I],
thread_global_offset};
},
Number<NumInput>{});
auto out_global_store_tuple = generate_tuple(
[&](auto I) {
using DataTypePointer = remove_cvref_t<decltype(OutDataTypePointerTuple{}[I])>;
using DataType = remove_pointer_t<DataTypePointer>;
return ThreadwiseTensorSliceTransfer_v1r3<DataType,
DataType,
decltype(thread_buffer_desc_m),
decltype(out_grid_1d_desc_tuple[I]),
PassThroughOp,
Sequence<MPerThread>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
OutScalarPerVectorSeq::At(I),
InMemoryDataOperationEnum::Set,
1,
false>(
out_grid_1d_desc_tuple[I], thread_global_offset, PassThroughOp{});
},
Number<NumOutput>{});
index_t num_iter = M / (loop_step);
do
{
static_for<0, NumInput, 1>{}([&](auto I) {
in_global_load_tuple(I).Run(in_grid_1d_desc_tuple[I],
in_global_buf_tuple[I],
thread_buffer_desc_m,
make_tuple(I0),
in_thread_buf_tuple(I));
in_global_load_tuple(I).MoveSrcSliceWindow(in_grid_1d_desc_tuple[I],
loop_step_index);
});
static_for<0, MPerThread, 1>{}([&](auto iM) {
// get reference to in data
const auto in_data_refs = generate_tie(
// return type should be lvalue
[&](auto I) -> const auto& { return in_thread_buf_tuple(I)(iM); },
Number<NumInput>{});
// get reference to dst data
auto out_data_refs = generate_tie(
// return type should be lvalue
[&](auto I) -> auto& { return out_thread_buf_tuple(I)(iM); },
Number<NumOutput>{});
unpack2(elementwise_op, out_data_refs, in_data_refs);
});
static_for<0, NumOutput, 1>{}([&](auto I) {
out_global_store_tuple(I).Run(thread_buffer_desc_m,
make_tuple(I0),
out_thread_buf_tuple[I],
out_grid_1d_desc_tuple[I],
out_global_buf_tuple(I));
out_global_store_tuple(I).MoveDstSliceWindow(out_grid_1d_desc_tuple[I],
loop_step_index);
});
} while(--num_iter);
}
};
} // namespace ck
......@@ -32,8 +32,8 @@ template <typename FloatAB,
typename ThreadReduceOperations,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
typename RsGlobalMemoryDataOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename AGridDesc_M_K,
typename BGridDesc_N_K,
typename EGridDesc_M_N,
typename RGridDesc_M,
index_t NumGemmKPrefetchStage,
......@@ -84,10 +84,10 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
static constexpr auto I7 = Number<7>{};
// K1 should be Number<...>
static constexpr auto AK0 = Number<KPerBlock / AK1Value>{};
static constexpr auto BK0 = Number<KPerBlock / BK1Value>{};
static constexpr auto AK1 = Number<AK1Value>{};
static constexpr auto BK1 = Number<BK1Value>{};
static constexpr auto AK1 = Number<AK1Value>{};
static constexpr auto BK1 = Number<BK1Value>{};
static constexpr auto AK0PerBlock = Number<KPerBlock / AK1Value>{};
static constexpr auto BK0PerBlock = Number<KPerBlock / BK1Value>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
......@@ -97,7 +97,7 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{
// A matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(AK0, Number<MPerBlock>{}, AK1),
make_tuple(AK0PerBlock, Number<MPerBlock>{}, AK1),
make_tuple(Number<MPerBlock + ABlockLdsExtraM>{} * AK1, AK1, I1));
}
......@@ -105,7 +105,7 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{
// B matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(BK0, Number<NPerBlock>{}, BK1),
make_tuple(BK0PerBlock, Number<NPerBlock>{}, BK1),
make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1, BK1, I1));
}
......@@ -167,22 +167,57 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
c_block_size * sizeof(FloatCShuffle));
}
// A desc for source in blockwise copy
__host__ __device__ static constexpr auto
MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k)
{
const auto M = a_grid_desc_m_k.GetLength(I0);
const auto K = a_grid_desc_m_k.GetLength(I1);
const auto AK0 = K / AK1;
return transform_tensor_descriptor(a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
// B desc for source in blockwise copy
__host__ __device__ static constexpr auto
MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k)
{
const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = b_grid_desc_n_k.GetLength(I1);
const auto BK0 = K / BK1;
return transform_tensor_descriptor(b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Block2ETileMap>
__host__ __device__ static constexpr bool
CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const EGridDesc_M_N& e_grid_desc_m_n,
const RGridDesc_M& r_grid_desc_m,
const Block2ETileMap& block_2_etile_map)
__host__ __device__ static constexpr bool CheckValidity(const AGridDesc_M_K& a_grid_desc_m_k,
const BGridDesc_N_K& b_grid_desc_n_k,
const EGridDesc_M_N& e_grid_desc_m_n,
const RGridDesc_M& r_grid_desc_m,
const Block2ETileMap& block_2_etile_map)
{
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
"Invalid tuning param!");
const auto M = a_grid_desc_ak0_m_ak1.GetLength(I1);
const auto N = b_grid_desc_bk0_n_bk1.GetLength(I1);
const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2);
static_assert(AGridDesc_M_K::GetNumOfDimension() == 2);
static_assert(BGridDesc_N_K::GetNumOfDimension() == 2);
static_assert(EGridDesc_M_N::GetNumOfDimension() == 2);
const auto M = a_grid_desc_m_k.GetLength(I0);
const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = a_grid_desc_m_k.GetLength(I1);
if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1)))
return false;
......@@ -259,6 +294,10 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
e_grid_desc_m_n);
}
using DefaultAGridDesc_AK0_M_AK1 =
remove_cvref_t<decltype(MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
using DefaultBGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
......@@ -272,7 +311,10 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using DsGridPointer = decltype(MakeTsGridPointer<DsDataType, true>());
using RsGridPointer = decltype(MakeTsGridPointer<RsDataType, false>());
template <bool HasMainKBlockLoop, typename Block2ETileMap>
template <bool HasMainKBlockLoop,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename Block2ETileMap>
__device__ static void
Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
......@@ -356,7 +398,7 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<AK0, MPerBlock, AK1>,
Sequence<AK0PerBlock, MPerBlock, AK1>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
......@@ -387,7 +429,7 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<BK0, NPerBlock, BK1>,
Sequence<BK0PerBlock, NPerBlock, BK1>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
......@@ -776,8 +818,7 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to read from LDS
if constexpr(access_id > 0)
block_sync_lds();
block_sync_lds();
// each thread shuffle data from VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
......
......@@ -53,7 +53,7 @@ __global__ void
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid,
p_c_grid,
p_shared_block,
static_cast<void*>(p_shared_block),
a_b_k0_m_k1_grid_desc,
b_b_k0_n_k1_grid_desc,
c_grid_desc_mblock_mperblock_nblock_nperblock,
......@@ -270,7 +270,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
FloatAB* __restrict__ p_shared_block,
void* __restrict__ p_shared_block,
const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc,
const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
......@@ -463,8 +463,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
constexpr auto a_block_space_size =
math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
FloatAB* p_a_block = p_shared_block;
FloatAB* p_b_block = p_shared_block + a_block_space_size;
FloatAB* p_a_block = static_cast<FloatAB*>(p_shared_block);
FloatAB* p_b_block = static_cast<FloatAB*>(p_shared_block) + a_block_space_size;
constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
......@@ -547,11 +547,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
static_cast<FloatC*>(p_shared_block),
c_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
static_assert(M1 == MWave, "");
static_assert(N1 == NWave, "");
static_assert(M2 * M3 * M4 == MPerXDL, "");
static_assert(N2 == NPerXDL, "");
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
c_block_desc_mblock_mperblock_nblock_nperblock,
make_tuple(
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
namespace ck {
template <typename Grid1dBufferDescTuple,
index_t NumBuffer,
index_t BlockSize,
typename DataTypePointerTuple,
typename DataTypeTuple>
__global__ void
kernel_multiple_buffer_set_value(const Grid1dBufferDescTuple grid_1d_buffer_desc_tuple,
DataTypePointerTuple p_global_tuple,
DataTypeTuple value_tuple)
{
static_assert(NumBuffer == DataTypePointerTuple::Size() && NumBuffer == DataTypeTuple::Size(),
"The tuple size should be same as NumBuffer!");
static_for<0, NumBuffer, 1>{}([&](auto iB) {
using DataTypePointer = remove_cvref_t<decltype(DataTypePointerTuple{}[iB])>;
using DataTypeFromPointer = remove_pointer_t<DataTypePointer>;
using DataType = remove_cvref_t<decltype(DataTypeTuple{}[iB])>;
static_assert(is_same<DataType, DataTypeFromPointer>::value,
"Types in tuples does not match!");
});
constexpr auto I0 = Number<0>{};
const index_t thread_global_id = get_thread_global_1d_id();
auto value_buf_tuple = generate_tuple(
[&](auto iB) {
using DataType = remove_cvref_t<decltype(DataTypeTuple{}[iB])>;
return StaticBuffer<AddressSpaceEnum::Vgpr, DataType, 1, true>{};
},
Number<NumBuffer>{});
static_for<0, NumBuffer, 1>{}([&](auto iB) {
static_for<0, 1, 1>{}([&](auto J) { value_buf_tuple(iB)(J) = value_tuple[iB]; });
});
auto global_buf_tuple = generate_tuple(
[&](auto iB) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_global_tuple(iB), grid_1d_buffer_desc_tuple[iB].GetElementSpaceSize());
},
Number<NumBuffer>{});
constexpr auto val_buff_desc = make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}));
static_for<0, NumBuffer, 1>{}([&](auto iB) {
using DataType = remove_cvref_t<decltype(DataTypeTuple{}[iB])>;
using PassThroughOp = tensor_operation::element_wise::PassThrough;
auto threadwise_store =
ThreadwiseTensorSliceTransfer_v1r3<DataType,
DataType,
decltype(val_buff_desc),
decltype(Grid1dBufferDescTuple{}[iB]),
PassThroughOp,
Sequence<1>,
Sequence<0>,
0,
1,
InMemoryDataOperationEnum::Set,
1,
true>(
grid_1d_buffer_desc_tuple[iB], make_multi_index(thread_global_id), PassThroughOp{});
threadwise_store.Run(val_buff_desc,
make_tuple(I0),
value_buf_tuple(iB),
grid_1d_buffer_desc_tuple[iB],
global_buf_tuple(iB));
});
};
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp"
namespace ck {
template <typename GridwiseSparseEmbedding,
typename EmbType,
typename IndexType,
typename GammaDataType,
typename BetaDataType,
typename AccDataType,
typename OutType,
typename OutGridDesc>
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
__global__ void kernel_sparse_embedding3_forward_layernorm(OutType* p_out,
const EmbType* p_emb_a,
const EmbType* p_emb_b,
const EmbType* p_emb_c,
const IndexType* p_index_a,
const IndexType* p_index_b,
const IndexType* p_index_c,
const GammaDataType* p_gamma,
const BetaDataType* p_beta,
const OutGridDesc out_grid_desc,
const AccDataType epsilon)
{
GridwiseSparseEmbedding::Run(p_out,
p_emb_a,
p_emb_b,
p_emb_c,
p_index_a,
p_index_b,
p_index_c,
p_gamma,
p_beta,
out_grid_desc,
epsilon);
}
template <typename EmbType,
typename IndexType,
typename GammaDataType,
typename BetaDataType,
typename AccDataType,
typename OutType,
typename OutGridDesc,
ck::index_t BlockSize,
ck::index_t DimClusterSize,
ck::index_t RowClusterSize,
ck::index_t DimPerBlock, // Row x Dim, along Dim
ck::index_t RowPerBlock, // Row x Dim, along Row
ck::index_t DimThreadSize, // this is actually not vector, but number of registers
ck::index_t RowVectorSize>
struct GridwiseSparseEmbedding3ForwardLayernorm
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr index_t WaveSize = 64;
static_assert(BlockSize == RowClusterSize * DimClusterSize,
"Invalid cluster distribution within block");
static_assert(RowClusterSize % WaveSize == 0, "need to be wavewise");
static_assert(DimPerBlock % (DimClusterSize * DimThreadSize) == 0, "");
static_assert(RowPerBlock % (RowClusterSize * RowVectorSize) == 0, "");
static constexpr auto DimSubBlocks = DimPerBlock / (DimClusterSize * DimThreadSize);
static constexpr auto RowSubBlocks = RowPerBlock / (RowClusterSize * RowVectorSize);
static_assert((DimPerBlock % DimSubBlocks == 0) && (RowPerBlock % RowSubBlocks == 0), "");
static constexpr auto DimPerSubBlock = DimPerBlock / DimSubBlocks;
static constexpr auto RowPerSubBlock = RowPerBlock / RowSubBlocks;
using ThreadwiseWolfordDesc2D = decltype(make_naive_tensor_descriptor_packed(make_tuple(
Number<DimSubBlocks * DimThreadSize>{}, Number<RowSubBlocks * RowVectorSize>{})));
using ThreadwiseWolfordDescReduce = decltype(
make_naive_tensor_descriptor_packed(make_tuple(Number<DimSubBlocks * DimThreadSize>{})));
using ThreadwiseWelford =
ThreadwiseWelford<AccDataType, ThreadwiseWolfordDesc2D, ThreadwiseWolfordDescReduce>;
using ThreadClusterLength = Sequence<DimClusterSize, RowClusterSize>;
using BlockwiseWelford =
BlockwiseWelford<AccDataType, BlockSize, ThreadClusterLength, Sequence<0, 1>>;
__device__ static void Run(OutType* p_out,
const EmbType* p_emb_a,
const EmbType* p_emb_b,
const EmbType* p_emb_c,
const IndexType* p_index_a,
const IndexType* p_index_b,
const IndexType* p_index_c,
const GammaDataType* p_gamma,
const BetaDataType* p_beta,
const OutGridDesc,
const AccDataType epsilon)
{
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id();
// const auto index_length = out_grid_desc.GetLength(I0);
// const auto emb_dim = out_grid_desc.GetLength(I1);
constexpr auto thread_cluster_desc =
make_cluster_descriptor(Sequence<DimClusterSize, RowClusterSize>{}, Sequence<0, 1>{});
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_dim_cluster_id = thread_cluster_idx[I0];
const auto thread_row_cluster_id = thread_cluster_idx[I1];
const auto wave_dim_id = __builtin_amdgcn_readfirstlane(thread_dim_cluster_id / WaveSize);
const auto index_start = block_global_id * DimPerBlock + wave_dim_id * DimThreadSize;
auto threadwise_welford = ThreadwiseWelford();
threadwise_welford.max_count_ = RowSubBlocks * RowVectorSize;
constexpr auto thread_buf_size =
DimSubBlocks * DimThreadSize * RowSubBlocks * RowVectorSize;
constexpr auto thread_buf_desc = make_naive_tensor_descriptor_packed(
make_tuple(DimSubBlocks, DimThreadSize, RowSubBlocks, RowVectorSize));
constexpr auto mean_var_buf_size = DimSubBlocks * DimThreadSize;
constexpr auto mean_var_buf_desc =
make_naive_tensor_descriptor_packed(make_tuple(DimSubBlocks, DimThreadSize));
constexpr auto gamma_beta_buf_size = RowSubBlocks * RowVectorSize;
constexpr auto gamma_beta_buf_desc =
make_naive_tensor_descriptor_packed(make_tuple(RowSubBlocks, RowVectorSize));
StaticBuffer<AddressSpaceEnum::Vgpr, EmbType, thread_buf_size, true> in_thread_buf_a;
StaticBuffer<AddressSpaceEnum::Vgpr, EmbType, thread_buf_size, true> in_thread_buf_b;
StaticBuffer<AddressSpaceEnum::Vgpr, EmbType, thread_buf_size, true> in_thread_buf_c;
StaticBuffer<AddressSpaceEnum::Sgpr, IndexType, DimPerBlock, true> index_buf_a;
StaticBuffer<AddressSpaceEnum::Sgpr, IndexType, DimPerBlock, true> index_buf_b;
StaticBuffer<AddressSpaceEnum::Sgpr, IndexType, DimPerBlock, true> index_buf_c;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, thread_buf_size, true> acc_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, gamma_beta_buf_size, true>
gamma_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, gamma_beta_buf_size, true>
beta_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, mean_var_buf_size, true> mean_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, mean_var_buf_size, true> var_thread_buf;
auto load_current_sub_row = [&](auto i_dim_sub_, auto i_row_sub_) {
vector_type_maker_t<EmbType, RowVectorSize> emb_vector_a;
vector_type_maker_t<EmbType, RowVectorSize> emb_vector_b;
vector_type_maker_t<EmbType, RowVectorSize> emb_vector_c;
using src_vector_t = typename decltype(emb_vector_a)::type;
static_for<0, DimThreadSize, 1>{}([&](auto i_dim_vec_) {
constexpr auto current_dim = i_dim_sub_ * DimPerSubBlock + i_dim_vec_;
IndexType index_a = index_buf_a[Number<current_dim>{}];
IndexType index_b = index_buf_b[Number<current_dim>{}];
IndexType index_c = index_buf_c[Number<current_dim>{}];
auto thread_offset = (thread_row_cluster_id + i_row_sub_ * RowClusterSize) *
sizeof(EmbType) * RowVectorSize;
int32x4_t emb_res_a =
make_wave_buffer_resource_with_default_range(p_emb_a + index_a * RowPerBlock);
int32x4_t emb_res_b =
make_wave_buffer_resource_with_default_range(p_emb_b + index_b * RowPerBlock);
int32x4_t emb_res_c =
make_wave_buffer_resource_with_default_range(p_emb_c + index_c * RowPerBlock);
emb_vector_a.template AsType<src_vector_t>()(I0) =
amd_buffer_load_impl<EmbType, RowVectorSize>(emb_res_a, thread_offset, 0);
emb_vector_b.template AsType<src_vector_t>()(I0) =
amd_buffer_load_impl<EmbType, RowVectorSize>(emb_res_b, thread_offset, 0);
emb_vector_c.template AsType<src_vector_t>()(I0) =
amd_buffer_load_impl<EmbType, RowVectorSize>(emb_res_c, thread_offset, 0);
static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) {
constexpr auto register_offset = thread_buf_desc.CalculateOffset(
make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_));
in_thread_buf_a(Number<register_offset>{}) =
emb_vector_a.template AsType<EmbType>()[i_row_vec_];
in_thread_buf_b(Number<register_offset>{}) =
emb_vector_b.template AsType<EmbType>()[i_row_vec_];
in_thread_buf_c(Number<register_offset>{}) =
emb_vector_c.template AsType<EmbType>()[i_row_vec_];
});
});
};
auto accumulate_current_sub_row = [&](auto i_dim_sub_, auto i_row_sub_) {
static_for<0, DimThreadSize, 1>{}([&](auto i_dim_vec_) {
static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) {
constexpr auto register_offset = thread_buf_desc.CalculateOffset(
make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_));
AccDataType va =
ck::type_convert<AccDataType>(in_thread_buf_a(Number<register_offset>{}));
AccDataType vb =
ck::type_convert<AccDataType>(in_thread_buf_b(Number<register_offset>{}));
AccDataType vc =
ck::type_convert<AccDataType>(in_thread_buf_c(Number<register_offset>{}));
acc_thread_buf(Number<register_offset>{}) += va + vb + vc;
});
});
};
auto threadwise_welford_sub_row = [&](auto i_dim_sub_, auto i_row_sub_) {
static_for<0, DimThreadSize, 1>{}([&](auto i_dim_vec_) {
static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) {
constexpr auto register_offset = thread_buf_desc.CalculateOffset(
make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_));
constexpr auto mean_var_offset =
mean_var_buf_desc.CalculateOffset(make_tuple(i_dim_sub_, i_dim_vec_));
threadwise_welford.cur_count_++;
threadwise_welford.Update(mean_thread_buf(Number<mean_var_offset>{}),
var_thread_buf(Number<mean_var_offset>{}),
acc_thread_buf(Number<register_offset>{}));
});
});
};
auto threadwise_normalize_store_out = [&](auto i_dim_sub_, auto i_row_sub_) {
int32x4_t out_res =
make_wave_buffer_resource_with_default_range(p_out + index_start * RowPerBlock);
static_for<0, DimThreadSize, 1>{}([&](auto i_dim_vec_) {
vector_type_maker_t<OutType, RowVectorSize> out_vector;
using dst_vector_t = typename decltype(out_vector)::type;
constexpr auto mean_var_offset =
mean_var_buf_desc.CalculateOffset(make_tuple(i_dim_sub_, i_dim_vec_));
static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) {
constexpr auto register_offset = thread_buf_desc.CalculateOffset(
make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_));
constexpr auto gamma_beta_offset =
gamma_beta_buf_desc.CalculateOffset(make_tuple(i_row_sub_, i_row_vec_));
auto acc_val = acc_thread_buf[Number<register_offset>{}];
acc_val = (acc_val - mean_thread_buf(Number<mean_var_offset>{})) /
sqrt(var_thread_buf(Number<mean_var_offset>{}) + epsilon);
acc_val = acc_val * gamma_thread_buf[Number<gamma_beta_offset>{}] +
beta_thread_buf[Number<gamma_beta_offset>{}];
out_vector.template AsType<OutType>()(Number<i_row_vec_>{}) =
type_convert<OutType>(acc_val);
});
index_t thread_offset = (thread_row_cluster_id + i_row_sub_ * RowClusterSize) *
sizeof(OutType) * RowVectorSize;
amd_buffer_store_impl<OutType, RowVectorSize>(
out_vector.template AsType<dst_vector_t>()[Number<0>{}],
out_res,
thread_offset,
0);
});
};
// first load index
ck::static_for<0, DimPerBlock, 1>{}([&](auto i_idx_) {
// prefer use s_load
index_buf_a(i_idx_) = p_index_a[index_start + i_idx_.value];
index_buf_b(i_idx_) = p_index_b[index_start + i_idx_.value];
index_buf_c(i_idx_) = p_index_c[index_start + i_idx_.value];
});
// load gamma/beta
static_for<0, RowSubBlocks, 1>{}([&](auto i_row_sub_) {
vector_type_maker_t<GammaDataType, RowVectorSize> gamma_vector;
vector_type_maker_t<BetaDataType, RowVectorSize> beta_vector;
index_t thread_offset_gamma = (thread_row_cluster_id + i_row_sub_ * RowClusterSize) *
sizeof(GammaDataType) * RowVectorSize;
index_t thread_offset_beta = (thread_row_cluster_id + i_row_sub_ * RowClusterSize) *
sizeof(BetaDataType) * RowVectorSize;
int32x4_t gamma_res = make_wave_buffer_resource_with_default_range(p_gamma);
int32x4_t beta_res = make_wave_buffer_resource_with_default_range(p_beta);
gamma_vector.template AsType<typename decltype(gamma_vector)::type>()(I0) =
amd_buffer_load_impl<GammaDataType, RowVectorSize>(
gamma_res, thread_offset_gamma, 0);
beta_vector.template AsType<typename decltype(beta_vector)::type>()(I0) =
amd_buffer_load_impl<BetaDataType, RowVectorSize>(beta_res, thread_offset_beta, 0);
static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) {
constexpr auto offset =
gamma_beta_buf_desc.CalculateOffset(make_tuple(i_row_sub_, i_row_vec_));
gamma_thread_buf(Number<offset>{}) = type_convert<AccDataType>(
gamma_vector.template AsType<GammaDataType>()[Number<i_row_vec_>{}]);
beta_thread_buf(Number<offset>{}) = type_convert<AccDataType>(
beta_vector.template AsType<BetaDataType>()[Number<i_row_vec_>{}]);
});
});
static_for<0, thread_buf_size, 1>{}(
[&](auto I) { acc_thread_buf(I) = type_convert<AccDataType>(0.0f); });
static_for<0, mean_var_buf_size, 1>{}([&](auto I) {
mean_thread_buf(I) = type_convert<AccDataType>(0.0f);
var_thread_buf(I) = type_convert<AccDataType>(0.0f);
});
static_for<0, DimSubBlocks, 1>{}([&](auto i_dim_sub) {
load_current_sub_row(i_dim_sub, Number<0>{});
static_for<0, RowSubBlocks - 1, 1>{}([&](auto i_row) {
load_current_sub_row(i_dim_sub, Number<1>{} + i_row);
accumulate_current_sub_row(i_dim_sub, i_row);
threadwise_welford_sub_row(i_dim_sub, i_row);
});
accumulate_current_sub_row(i_dim_sub, Number<RowSubBlocks - 1>{});
threadwise_welford_sub_row(i_dim_sub, Number<RowSubBlocks - 1>{});
// blockwise welford
static_for<0, mean_var_buf_size, 1>{}([&](auto I) {
if constexpr(I > 0)
block_sync_lds();
BlockwiseWelford::Run(
mean_thread_buf(I), var_thread_buf(I), threadwise_welford.cur_count_);
});
// store
static_for<0, RowSubBlocks, 1>{}(
[&](auto i_row) { threadwise_normalize_store_out(i_dim_sub, i_row); });
});
}
};
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
template <typename GridwiseUEltwise,
typename ADataType,
typename BDataType,
typename GridDesc_M0,
typename ElementwiseFunctor>
__global__ void kernel_unary_elementwise_1d(const ADataType* __restrict__ p_a_global,
BDataType* __restrict__ p_b_global,
const GridDesc_M0 a_grid_desc_m0,
const GridDesc_M0 b_grid_desc_m0,
const ElementwiseFunctor functor)
{
GridwiseUEltwise::Run(p_a_global, p_b_global, a_grid_desc_m0, b_grid_desc_m0, functor);
}
template <typename ADataType,
typename BDataType,
typename GridDesc_M0,
typename ElementwiseFunctor,
index_t ScalarPerVector>
struct GridwiseUnaryElementwise_1D
{
static constexpr auto I0 = Number<0>{};
static constexpr auto thread_desc_m0 =
make_naive_tensor_descriptor_packed(make_tuple(Number<ScalarPerVector>{}));
using PassThrough = tensor_operation::element_wise::PassThrough;
static __device__ auto CalculateElementwiseIndex()
{
const index_t global_thread_id = get_thread_global_1d_id();
return make_multi_index(global_thread_id * ScalarPerVector);
}
__host__ __device__ static constexpr bool CheckValidity(const GridDesc_M0 a_grid_desc_m0,
const GridDesc_M0 b_grid_desc_m0)
{
return a_grid_desc_m0.GetLength(I0) == b_grid_desc_m0.GetLength(I0);
}
__host__ __device__ static constexpr index_t CalculateGridSize(const index_t tensor_size)
{
const index_t grid_size = math::integer_divide_ceil(tensor_size, 256 * ScalarPerVector);
return grid_size;
}
__device__ static void Run(const ADataType* __restrict__ p_a_global,
BDataType* __restrict__ p_b_global,
const GridDesc_M0 a_grid_desc_m0,
const GridDesc_M0 b_grid_desc_m0,
const ElementwiseFunctor functor)
{
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_global, a_grid_desc_m0.GetElementSpaceSize());
auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_global, b_grid_desc_m0.GetElementSpaceSize());
StaticBuffer<AddressSpaceEnum::Vgpr, ADataType, ScalarPerVector, true> a_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, BDataType, ScalarPerVector, true> b_thread_buf;
const auto thread_store_global_offset = CalculateElementwiseIndex();
auto a_global_load =
ThreadwiseTensorSliceTransfer_v2<ADataType,
ADataType,
GridDesc_M0,
decltype(thread_desc_m0),
Sequence<ScalarPerVector>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
ScalarPerVector,
1, // SrcScalarStrideInVector
false>{a_grid_desc_m0, thread_store_global_offset};
auto b_global_write =
ThreadwiseTensorSliceTransfer_v1r3<BDataType,
BDataType,
decltype(thread_desc_m0),
GridDesc_M0,
PassThrough,
Sequence<ScalarPerVector>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // DstVectorDim
ScalarPerVector,
InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector
false>{
b_grid_desc_m0, thread_store_global_offset, PassThrough{}};
const index_t blockSize = get_block_size();
const index_t blockPerGrid = get_grid_size();
const auto m0 = b_grid_desc_m0.GetLength(I0);
const index_t loop_step = blockPerGrid * blockSize * ScalarPerVector;
const auto loop_step_index = make_multi_index(loop_step);
index_t num_iter = m0 / (loop_step);
do
{
// read and process ScalarPerVector elements
a_global_load.Run(
a_grid_desc_m0, a_global_buf, thread_desc_m0, make_tuple(I0), a_thread_buf);
static_for<0, ScalarPerVector, 1>{}([&](auto m) {
constexpr auto offset = thread_desc_m0.CalculateOffset(make_tuple(m));
functor(b_thread_buf(Number<offset>{}), a_thread_buf(Number<offset>{}));
});
b_global_write.Run(thread_desc_m0,
make_tuple(I0), // SrcSliceOriginIdx
b_thread_buf,
b_grid_desc_m0,
b_global_buf);
a_global_load.MoveSrcSliceWindow(a_grid_desc_m0, loop_step_index);
b_global_write.MoveDstSliceWindow(b_grid_desc_m0, loop_step_index);
} while(--num_iter);
}
};
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment