Commit e9047ab9 authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'develop' into amd-develop

parents bc641634 a2969aa8
......@@ -9,8 +9,77 @@ namespace ck {
namespace tensor_operation {
namespace device {
template <index_t NumDTensor>
template <index_t NumATensor = 1, index_t NumBTensor = 1, index_t NumDTensor = 0, typename = void>
struct ComputePtrOffsetOfStridedBatch
{
};
template <index_t NumATensor, index_t NumBTensor, index_t NumDTensor>
struct ComputePtrOffsetOfStridedBatch<NumATensor,
NumBTensor,
NumDTensor,
ck::enable_if_t<(NumATensor > 1 || NumBTensor > 1)>>
{
ComputePtrOffsetOfStridedBatch() = default;
ComputePtrOffsetOfStridedBatch(Array<ck::index_t, NumATensor>& BatchStrideAs,
Array<ck::index_t, NumBTensor>& BatchStrideBs,
Array<ck::index_t, NumDTensor>& BatchStrideDs,
index_t BatchStrideE)
: BatchStrideA_(BatchStrideAs),
BatchStrideB_(BatchStrideBs),
BatchStrideDs_(BatchStrideDs),
BatchStrideE_(BatchStrideE)
{
}
__host__ __device__ constexpr auto GetAsPtrOffset(index_t g_idx) const
{
Array<long_index_t, NumATensor> as_offset;
static_for<0, NumATensor, 1>{}(
[&](auto i) { as_offset(i) = g_idx * static_cast<long_index_t>(BatchStrideA_[i]); });
return as_offset;
}
__host__ __device__ constexpr auto GetBsPtrOffset(index_t g_idx) const
{
Array<long_index_t, NumBTensor> bs_offset;
static_for<0, NumBTensor, 1>{}(
[&](auto i) { bs_offset(i) = g_idx * static_cast<long_index_t>(BatchStrideB_[i]); });
return bs_offset;
}
__host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const
{
Array<long_index_t, NumDTensor> ds_offset;
static_for<0, NumDTensor, 1>{}(
[&](auto i) { ds_offset(i) = g_idx * static_cast<long_index_t>(BatchStrideDs_[i]); });
return ds_offset;
}
[[maybe_unused]] __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideE_);
}
// alias for kernels without multiple D
[[maybe_unused]] __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideE_);
}
Array<ck::index_t, NumATensor> BatchStrideA_;
Array<ck::index_t, NumBTensor> BatchStrideB_;
Array<ck::index_t, NumDTensor> BatchStrideDs_;
index_t BatchStrideE_;
index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D
};
template <index_t NumATensor, index_t NumBTensor, index_t NumDTensor>
struct ComputePtrOffsetOfStridedBatch<NumATensor,
NumBTensor,
NumDTensor,
ck::enable_if_t<(NumATensor == 1 && NumBTensor == 1)>>
{
ComputePtrOffsetOfStridedBatch() = default;
......@@ -54,13 +123,67 @@ struct ComputePtrOffsetOfStridedBatch
return g_idx * static_cast<long_index_t>(BatchStrideE_);
}
index_t BatchStrideA_;
index_t BatchStrideB_;
ck::index_t BatchStrideA_;
ck::index_t BatchStrideB_;
Array<ck::index_t, NumDTensor> BatchStrideDs_;
index_t BatchStrideE_;
index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D
};
template <bool isTuple, typename Tensors>
constexpr static auto GetNumABTensors()
{
if constexpr(isTuple)
{
return Number<Tensors::Size()>{};
}
else
{
return Number<1>{};
}
}
template <bool isTuple, typename GridwiseGemm, typename DataType>
constexpr static auto GetAGridPointer()
{
if constexpr(isTuple)
{
return typename GridwiseGemm::AsGridPointer{};
}
else
{
return Tuple<const DataType*>{};
}
}
template <bool isTuple, typename GridwiseGemm, typename DataType>
constexpr static auto GetBGridPointer()
{
if constexpr(isTuple)
{
return typename GridwiseGemm::BsGridPointer{};
}
else
{
return Tuple<const DataType*>{};
}
}
template <bool isTuple, typename Id, typename Type>
constexpr static auto UnpackDataType()
{
if constexpr(isTuple)
{
// unpack if tuple
return tuple_element_t<Id{}, Type>{};
}
else
{
// if no, return Type
return Type{};
}
}
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -142,19 +142,18 @@ struct DeviceImageToColumnImpl
decltype(BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, KPerBlock, OutputGridDesc>(
OutputGridDesc{}))>;
using GridwiseTensorRearrangeKernel =
GridwiseTensorRearrange<InputGridDesc,
InputDataType,
OutputGridDesc,
OutputDataType,
BlockSize,
MPerBlock,
KPerBlock,
ThreadClusterLengths,
ScalarPerVector,
InMemoryDataOperationEnum::Set,
Block2ETileMap,
ComputePtrOffsetOfStridedBatch<I0>>;
using GridwiseTensorRearrangeKernel = GridwiseTensorRearrange<InputGridDesc,
InputDataType,
OutputGridDesc,
OutputDataType,
BlockSize,
MPerBlock,
KPerBlock,
ThreadClusterLengths,
ScalarPerVector,
InMemoryDataOperationEnum::Set,
Block2ETileMap,
ComputePtrOffsetOfStridedBatch<>>;
struct Argument : public BaseArgument
{
......@@ -224,7 +223,7 @@ struct DeviceImageToColumnImpl
InputGridDesc in_grid_desc_m_k_;
OutputGridDesc out_grid_desc_m_k_;
ComputePtrOffsetOfStridedBatch<I0> compute_ptr_offset_of_batch_;
ComputePtrOffsetOfStridedBatch<> compute_ptr_offset_of_batch_;
};
struct Invoker : public BaseInvoker
......@@ -246,7 +245,7 @@ struct DeviceImageToColumnImpl
OutputGridDesc,
OutputDataType,
Block2ETileMap,
ComputePtrOffsetOfStridedBatch<I0>,
ComputePtrOffsetOfStridedBatch<>,
GridwiseTensorRearrangeKernel>;
float elapsed_time = launch_and_time_kernel(stream_config,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "ck/tensor_operation/gpu/device/device_normalization_bwd_gamma_beta.hpp"
#include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_gamma_beta.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
// M is invarient dimension, K is reduced dimension
namespace ck {
namespace tensor_operation {
namespace device {
template <typename GridwiseReduction,
typename DYDataType,
typename XDataType,
typename MeanInvStdDataType,
typename DGammaDataType,
typename DBetaDataType,
typename GridDesc_M_K,
typename GridDesc_M>
__global__ void
kernel_normalization_bwd_gamma_beta(const GridDesc_M_K dy_grid_desc_m_k,
const GridDesc_M_K x_grid_desc_m_k,
const GridDesc_M_K mean_grid_desc_m_k,
const GridDesc_M_K inv_std_grid_desc_m_k,
const GridDesc_M dgamma_grid_desc_m,
const GridDesc_M dbeta_grid_desc_m,
index_t num_k_block_tile_iteration,
const DYDataType* const __restrict__ p_dy_global,
const XDataType* const __restrict__ p_x_global,
const MeanInvStdDataType* const __restrict__ p_mean_global,
const MeanInvStdDataType* const __restrict__ p_inv_std_global,
DGammaDataType* const __restrict__ p_dgamma_global,
DBetaDataType* const __restrict__ p_dbeta_global)
{
GridwiseReduction::Run(dy_grid_desc_m_k,
x_grid_desc_m_k,
mean_grid_desc_m_k,
inv_std_grid_desc_m_k,
dgamma_grid_desc_m,
dbeta_grid_desc_m,
num_k_block_tile_iteration,
p_dy_global,
p_x_global,
p_mean_global,
p_inv_std_global,
p_dgamma_global,
p_dbeta_global);
};
template <typename DYDataType,
typename XDataType,
typename MeanInvStdDataType,
typename ComputeDataType,
typename DGammaDataType,
typename DBetaDataType,
index_t Rank,
index_t NumReduceDim,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
bool IsDYFastestDimReduced,
index_t DYSrcVectorSize,
bool IsXFastestDimReduced,
index_t XSrcVectorSize,
bool IsMeanInvStdFastestDimReduced,
index_t MeanInvStdSrcVectorSize,
index_t DGammaDstVectorSize,
index_t DBetaDstVectorSize>
struct DeviceNormalizationBwdGammaBetaImpl
: public DeviceNormalizationBwdGammaBeta<DYDataType,
XDataType,
MeanInvStdDataType,
DGammaDataType,
DBetaDataType,
Rank,
NumReduceDim>
{
static constexpr index_t DYSrcVectorDim = IsDYFastestDimReduced ? 1 : 0;
static constexpr index_t XSrcVectorDim = IsXFastestDimReduced ? 1 : 0;
static constexpr index_t MeanInvStdSrcVectorDim = IsMeanInvStdFastestDimReduced ? 1 : 0;
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize);
static_assert(((DYSrcVectorDim == 0 && MThreadSliceSize % DYSrcVectorSize == 0) ||
(DYSrcVectorDim == 1 && KThreadSliceSize % DYSrcVectorSize == 0)),
"Invalid thread slice sizes and/or dy vector sizes configuration, please check!");
static_assert(((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) ||
(XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0)),
"Invalid thread slice sizes and/or x vector sizes configuration, please check!");
static_assert(
((MThreadSliceSize % DGammaDstVectorSize == 0) ||
(MThreadSliceSize % DBetaDstVectorSize == 0)),
"Invalid thread slice sizes and/or Gamma and beta vector sizes configuration, please "
"check!");
static_assert(
(MeanInvStdSrcVectorDim == 0 && MThreadSliceSize % MeanInvStdSrcVectorSize == 0) ||
(MeanInvStdSrcVectorDim == 1 && KThreadSliceSize % MeanInvStdSrcVectorSize == 0),
"Invalid thread slice sizes and/or mean and inverse std vector sizes configuration, please "
"check!");
static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
static constexpr bool reduceAllDim = (NumInvariantDim == 0);
static_assert(!reduceAllDim);
static auto MakeSrc2dDescriptor(const std::vector<index_t>& inLengths,
const std::vector<index_t>& inStrides,
int numBlockTileIteration)
{
const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<Rank>{});
const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<Rank>{});
const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
const auto in_grid_desc_m_k = [&]() {
using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
const auto reduceDimLengths =
make_tuple_from_array_and_index_seq(inLengths, ReduceDims{});
const auto invariantDimLengths =
make_tuple_from_array_and_index_seq(inLengths, InvariantDims{});
return transform_tensor_descriptor(inDesc,
make_tuple(make_merge_transform(invariantDimLengths),
make_merge_transform(reduceDimLengths)),
make_tuple(InvariantDims{}, ReduceDims{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}();
const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{});
const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
const auto inPad_M =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
const auto inPad_K = K_BlockTileSize * numBlockTileIteration - reduceLength;
auto in_grid_desc_m_k_padded = transform_tensor_descriptor(
in_grid_desc_m_k,
make_tuple(make_right_pad_transform(invariantLength, inPad_M),
make_right_pad_transform(reduceLength, inPad_K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_grid_desc_m_k_padded;
}
static auto MakeDst1dDescriptor(const std::vector<index_t>& outLengths,
const std::vector<index_t>& outStrides)
{
const auto tupleDstLengths =
generate_tuple([&](auto I) { return outLengths[I]; }, Number<NumInvariantDim>{});
const auto tupleDstStrides =
generate_tuple([&](auto I) { return outStrides[I]; }, Number<NumInvariantDim>{});
auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
auto out_grid_desc_m = transform_tensor_descriptor(
outDesc,
make_tuple(make_merge_transform(tupleDstLengths)),
make_tuple(typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type{}),
make_tuple(Sequence<0>{}));
const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{});
const auto outPad =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
auto out_grid_desc_m_padded = transform_tensor_descriptor(
out_grid_desc_m,
make_tuple(make_right_pad_transform(invariantLength, outPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
return (out_grid_desc_m_padded);
};
using GridDesc_M_K = decltype(MakeSrc2dDescriptor({1}, {1}, 1));
using GridDesc_M = decltype(MakeDst1dDescriptor({1}, {1}));
using GridwiseNormalizationBwdGammaBeta =
GridwiseNormalizationBwdGammaBeta_mk_to_k<DYDataType,
XDataType,
MeanInvStdDataType,
ComputeDataType,
DGammaDataType,
DBetaDataType,
GridDesc_M_K,
GridDesc_M,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
DYSrcVectorDim,
DYSrcVectorSize,
XSrcVectorDim,
XSrcVectorSize,
MeanInvStdSrcVectorDim,
MeanInvStdSrcVectorSize,
DGammaDstVectorSize,
DBetaDstVectorSize>;
struct Argument : public BaseArgument
{
Argument(const std::vector<index_t> inLengths,
const std::vector<index_t> dyStrides,
const std::vector<index_t> xStrides,
const std::vector<index_t> meanStrides,
const std::vector<index_t> invStdStrides,
const std::vector<index_t> outLengths,
const std::vector<index_t> dgammaStrides,
const std::vector<index_t> dbetaStrides,
const std::vector<index_t> reduceDims,
const DYDataType* p_dy,
const XDataType* p_x,
const MeanInvStdDataType* p_mean,
const MeanInvStdDataType* p_invStd,
DGammaDataType* p_dgamma,
DBetaDataType* p_dbeta)
: p_dy_(p_dy),
p_x_(p_x),
p_mean_(p_mean),
p_invStd_(p_invStd),
p_dgamma_(p_dgamma),
p_dbeta_(p_dbeta),
outLengths_{outLengths},
dgammaStrides_{dgammaStrides},
dbetaStrides_{dbetaStrides}
{
inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims);
dyStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(dyStrides, reduceDims);
xStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(xStrides, reduceDims);
meanStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(meanStrides, reduceDims);
invStdStrides_ =
shuffle_tensor_dimensions<Rank, NumReduceDim>(invStdStrides, reduceDims);
std::tie(MRaw_, KRaw_) = get_2d_lengths<Rank, NumReduceDim>(inLengths_);
numBlockTileIteration_ = math::integer_divide_ceil(KRaw_, K_BlockTileSize);
gridSize_ = math::integer_divide_ceil(MRaw_, M_BlockTileSize);
dy_grid_desc_m_k_ = MakeSrc2dDescriptor(inLengths_, dyStrides_, numBlockTileIteration_);
x_grid_desc_m_k_ = MakeSrc2dDescriptor(inLengths_, xStrides_, numBlockTileIteration_);
mean_grid_desc_m_k_ =
MakeSrc2dDescriptor(inLengths_, meanStrides_, numBlockTileIteration_);
inv_std_grid_desc_m_k_ =
MakeSrc2dDescriptor(inLengths_, invStdStrides_, numBlockTileIteration_);
dgamma_grid_desc_m_ = MakeDst1dDescriptor(outLengths_, dgammaStrides_);
dbeta_grid_desc_m_ = MakeDst1dDescriptor(outLengths_, dbetaStrides_);
}
const DYDataType* p_dy_;
const XDataType* p_x_;
const MeanInvStdDataType* p_mean_;
const MeanInvStdDataType* p_invStd_;
DGammaDataType* p_dgamma_;
DBetaDataType* p_dbeta_;
std::vector<index_t> inLengths_;
std::vector<index_t> dyStrides_;
std::vector<index_t> xStrides_;
std::vector<index_t> meanStrides_;
std::vector<index_t> invStdStrides_;
std::vector<index_t> outLengths_;
std::vector<index_t> dgammaStrides_;
std::vector<index_t> dbetaStrides_;
int numBlockTileIteration_;
size_t gridSize_;
// Source descriptor
GridDesc_M_K dy_grid_desc_m_k_;
GridDesc_M_K x_grid_desc_m_k_;
GridDesc_M_K mean_grid_desc_m_k_;
GridDesc_M_K inv_std_grid_desc_m_k_;
// Destination descriptor
GridDesc_M dgamma_grid_desc_m_;
GridDesc_M dbeta_grid_desc_m_;
index_t MRaw_; // invarient length
index_t KRaw_; // reduce length
};
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const auto kernel_main =
kernel_normalization_bwd_gamma_beta<GridwiseNormalizationBwdGammaBeta,
DYDataType,
XDataType,
MeanInvStdDataType,
DGammaDataType,
DBetaDataType,
GridDesc_M_K,
GridDesc_M>;
return launch_and_time_kernel(stream_config,
kernel_main,
dim3(arg.gridSize_),
dim3(BlockSize),
0,
arg.dy_grid_desc_m_k_,
arg.x_grid_desc_m_k_,
arg.mean_grid_desc_m_k_,
arg.inv_std_grid_desc_m_k_,
arg.dgamma_grid_desc_m_,
arg.dbeta_grid_desc_m_,
arg.numBlockTileIteration_,
arg.p_dy_,
arg.p_x_,
arg.p_mean_,
arg.p_invStd_,
arg.p_dgamma_,
arg.p_dbeta_);
}
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
template <index_t SrcVectorDim, index_t SrcVectorSize>
bool IsSrcVectorDimSizeValid(const std::vector<index_t>& lengths,
const std::vector<index_t>& strides)
{
if constexpr(SrcVectorSize == 1)
return true;
// Fastest dimension is not reduced
if constexpr(SrcVectorDim == 0)
{
if constexpr(NumInvariantDim == 0)
return false;
if(strides[NumInvariantDim - 1] != 1)
return false;
if(lengths[NumInvariantDim - 1] % SrcVectorSize != 0)
return false;
}
else // Fastest dimension is reduced
{
if(strides[Rank - 1] != 1)
return false;
if(lengths[Rank - 1] % SrcVectorSize != 0)
return false;
};
return true;
}
template <index_t DstVectorSize>
bool IsDstVectorSizeValid(const std::vector<index_t>& lengths,
const std::vector<index_t>& strides)
{
if constexpr(DstVectorSize == 1)
return true;
if(strides[NumInvariantDim - 1] != 1)
return false;
if(lengths[NumInvariantDim - 1] % DstVectorSize != 0)
return false;
return true;
}
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
const Argument* p_arg_ = dynamic_cast<const Argument*>(p_arg);
bool pass = true;
pass &= IsSrcVectorDimSizeValid<DYSrcVectorDim, DYSrcVectorSize>(p_arg_->inLengths_,
p_arg_->dyStrides_);
pass &= IsSrcVectorDimSizeValid<XSrcVectorDim, XSrcVectorSize>(p_arg_->inLengths_,
p_arg_->xStrides_);
pass &= IsSrcVectorDimSizeValid<MeanInvStdSrcVectorDim, MeanInvStdSrcVectorSize>(
p_arg_->inLengths_, p_arg_->meanStrides_);
pass &= IsSrcVectorDimSizeValid<MeanInvStdSrcVectorDim, MeanInvStdSrcVectorSize>(
p_arg_->inLengths_, p_arg_->invStdStrides_);
pass &=
IsDstVectorSizeValid<DGammaDstVectorSize>(p_arg_->outLengths_, p_arg_->dgammaStrides_);
pass &=
IsDstVectorSizeValid<DBetaDstVectorSize>(p_arg_->outLengths_, p_arg_->dbetaStrides_);
return pass;
}
std::unique_ptr<BaseArgument> MakeArgumentPointer(const std::vector<index_t> inLengths,
const std::vector<index_t> dyStrides,
const std::vector<index_t> xStrides,
const std::vector<index_t> meanStrides,
const std::vector<index_t> invStdStrides,
const std::vector<index_t> outLengths,
const std::vector<index_t> dgammaStrides,
const std::vector<index_t> dbetaStrides,
const std::vector<index_t> reduceDims,
const void* p_dy,
const void* p_x,
const void* p_mean,
const void* p_invStd,
void* p_dgamma,
void* p_dbeta) override
{
if(inLengths.size() != Rank || dyStrides.size() != Rank || xStrides.size() != Rank ||
meanStrides.size() != Rank || invStdStrides.size() != Rank)
throw std::runtime_error("dimension is incorrect");
if(outLengths.size() != NumInvariantDim || dgammaStrides.size() != NumInvariantDim ||
dbetaStrides.size() != NumInvariantDim)
throw std::runtime_error("dimension is incorrect");
return std::make_unique<Argument>(inLengths,
dyStrides,
xStrides,
meanStrides,
invStdStrides,
outLengths,
dgammaStrides,
dbetaStrides,
reduceDims,
static_cast<const DYDataType*>(p_dy),
static_cast<const XDataType*>(p_x),
static_cast<const MeanInvStdDataType*>(p_mean),
static_cast<const MeanInvStdDataType*>(p_invStd),
static_cast<DGammaDataType*>(p_dgamma),
static_cast<DBetaDataType*>(p_dbeta));
}
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -85,10 +85,13 @@ struct Add
struct ScaleAdd
{
__host__ __device__ ScaleAdd(float scale) : scale_(scale) {}
__host__ __device__ ScaleAdd(float scale = 1.f) : scale_(scale) {}
template <typename Y, typename X0, typename X1>
__host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const;
__host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const
{
y = ck::type_convert<Y>(scale_ * ck::type_convert<float>(x0) + ck::type_convert<float>(x1));
}
template <>
__host__ __device__ void
......
......@@ -281,6 +281,24 @@ struct ConvertF8SR
}
};
struct ConvertF8RNE
{
// convert to fp8 using rounding to nearest even
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const
{
// check Y datatype
static_assert(is_same<Y, f8_t>::value || is_same<Y, bf8_t>::value,
"Data type is not supported by this operation!");
// check X datatype
static_assert(is_same<X, float>::value || is_same<X, half_t>::value,
"Data type is not supported by this operation!");
y = f8_convert_rne<Y>(x);
}
};
struct Scale
{
__host__ __device__ Scale(float scale) : scale_(scale) {}
......@@ -355,8 +373,8 @@ struct UnarySquare
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same_v<T, float> || is_same_v<T, double> || is_same_v<T, int32_t> ||
is_same_v<T, int8_t>
static_assert(is_same_v<T, float> || is_same_v<T, half_t> || is_same_v<T, double> ||
is_same_v<T, int32_t> || is_same_v<T, int8_t>
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|| is_same_v<T, int4_t>
#endif
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
template <typename GridwiseElementwise1dFunctor,
typename InGrid1dDescTuple,
typename OutGrid1dDescTuple,
typename InDataTypePointerTuple,
typename OutDataTypePointerTuple,
typename ElementwiseOperation,
typename UnaryOperation,
typename Scale>
__global__ void kernel_elementwise_1d(const InGrid1dDescTuple in_grid_1d_desc_tuple,
const OutGrid1dDescTuple out_grid_1d_desc_tuple,
const InDataTypePointerTuple p_in_global_tuple,
const OutDataTypePointerTuple p_out_global_tuple,
const ElementwiseOperation elementwise_op,
const UnaryOperation unary_op,
const Scale scale_op)
{
GridwiseElementwise1dFunctor::Run(in_grid_1d_desc_tuple,
out_grid_1d_desc_tuple,
p_in_global_tuple,
p_out_global_tuple,
elementwise_op,
unary_op,
scale_op);
}
template <typename InGrid1dDescTuple,
typename OutGrid1dDescTuple,
typename InDataTypePointerTuple,
typename OutDataTypePointerTuple,
typename ElementwiseOperation,
typename UnaryOperation,
typename Scale,
index_t MPerThread,
typename InScalarPerVectorSeq,
typename OutScalarPerVectorSeq>
struct GridwiseElementwise_1D
{
static constexpr index_t NumInput = InDataTypePointerTuple::Size();
static constexpr index_t NumOutput = OutDataTypePointerTuple::Size();
static_assert(NumInput == InScalarPerVectorSeq::Size() &&
NumOutput == OutScalarPerVectorSeq::Size() &&
NumInput == InGrid1dDescTuple::Size() &&
NumOutput == OutGrid1dDescTuple::Size(),
"Tuple size is inconsistent with the number of in/out!");
static constexpr auto I0 = Number<0>{};
static constexpr auto thread_buffer_desc_m =
make_naive_tensor_descriptor_packed(make_tuple(Number<MPerThread>{}));
using PassThroughOp = tensor_operation::element_wise::PassThrough;
__device__ static void Run(const InGrid1dDescTuple in_grid_1d_desc_tuple,
const OutGrid1dDescTuple out_grid_1d_desc_tuple,
const InDataTypePointerTuple p_in_global_tuple,
const OutDataTypePointerTuple p_out_global_tuple,
const ElementwiseOperation elementwise_op,
const UnaryOperation unary_op,
const Scale scale_op)
{
const index_t thread_global_id = get_thread_global_1d_id();
auto in_thread_buf_tuple = generate_tuple(
[&](auto I) {
using DataTypePointer = remove_cvref_t<decltype(InDataTypePointerTuple{}[I])>;
using DataType = remove_cv_t<remove_pointer_t<DataTypePointer>>;
return StaticBuffer<AddressSpaceEnum::Vgpr, DataType, MPerThread, true>{};
},
Number<NumInput>{});
auto out_thread_buf_tuple = generate_tuple(
[&](auto I) {
using DataTypePointer = remove_cvref_t<decltype(OutDataTypePointerTuple{}[I])>;
using DataType = remove_pointer_t<DataTypePointer>;
return StaticBuffer<AddressSpaceEnum::Vgpr, DataType, MPerThread, true>{};
},
Number<NumOutput>{});
auto in_global_buf_tuple = generate_tuple(
[&](auto I) {
static_assert(in_grid_1d_desc_tuple[I].GetNumOfDimension() == 1);
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_global_tuple[I], in_grid_1d_desc_tuple[I].GetElementSpaceSize());
},
Number<NumInput>{});
auto out_global_buf_tuple = generate_tuple(
[&](auto I) {
static_assert(out_grid_1d_desc_tuple[I].GetNumOfDimension() == 1);
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global_tuple[I], out_grid_1d_desc_tuple[I].GetElementSpaceSize());
},
Number<NumOutput>{});
const auto thread_global_offset = make_multi_index(thread_global_id * MPerThread);
const index_t blockSize = get_block_size();
const index_t blockPerGrid = get_grid_size();
const auto M = in_grid_1d_desc_tuple[I0].GetLength(I0);
const index_t loop_step = blockPerGrid * blockSize * MPerThread;
const auto loop_step_index = make_multi_index(loop_step);
auto in_global_load_tuple = generate_tuple(
[&](auto I) {
using DataTypePointer = remove_cvref_t<decltype(InDataTypePointerTuple{}[I])>;
using DataType = remove_cv_t<remove_pointer_t<DataTypePointer>>;
return ThreadwiseTensorSliceTransfer_v2<DataType,
DataType,
decltype(in_grid_1d_desc_tuple[I]),
decltype(thread_buffer_desc_m),
Sequence<MPerThread>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
InScalarPerVectorSeq::At(
I), // ScalarPerVector
1, // SrcScalarStrideInVector
false>{in_grid_1d_desc_tuple[I],
thread_global_offset};
},
Number<NumInput>{});
auto out_global_store_tuple = generate_tuple(
[&](auto I) {
using DataTypePointer = remove_cvref_t<decltype(OutDataTypePointerTuple{}[I])>;
using DataType = remove_pointer_t<DataTypePointer>;
return ThreadwiseTensorSliceTransfer_v1r3<DataType,
DataType,
decltype(thread_buffer_desc_m),
decltype(out_grid_1d_desc_tuple[I]),
PassThroughOp,
Sequence<MPerThread>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
OutScalarPerVectorSeq::At(I),
InMemoryDataOperationEnum::Set,
1,
false>(
out_grid_1d_desc_tuple[I], thread_global_offset, PassThroughOp{});
},
Number<NumOutput>{});
index_t num_iter = M / (loop_step);
do
{
static_for<0, NumInput, 1>{}([&](auto I) {
in_global_load_tuple(I).Run(in_grid_1d_desc_tuple[I],
in_global_buf_tuple[I],
thread_buffer_desc_m,
make_tuple(I0),
in_thread_buf_tuple(I));
in_global_load_tuple(I).MoveSrcSliceWindow(in_grid_1d_desc_tuple[I],
loop_step_index);
});
static_for<0, MPerThread, 1>{}([&](auto iM) {
// get reference to in data
auto uop_data_refs = generate_tie(
// return type should be lvalue
[&](auto I) -> auto& { return in_thread_buf_tuple(I)(iM); },
Number<NumInput>{});
// get reference to dst data
auto out_data_refs = generate_tie(
// return type should be lvalue
[&](auto I) -> auto& { return out_thread_buf_tuple(I)(iM); },
Number<NumOutput>{});
unpack2(unary_op, uop_data_refs, uop_data_refs);
auto sop_in_data_refs = generate_tie(
// return type should be lvalue
[&](auto I) -> auto& { return in_thread_buf_tuple(I)(iM); },
Number<NumInput>{});
auto sop_out_data_refs = generate_tie(
// return type should be lvalue
[&](auto I) -> auto& { return in_thread_buf_tuple(I)(iM); },
Number<NumInput>{});
unpack2(scale_op, sop_out_data_refs, sop_in_data_refs);
const auto in_data_refs = generate_tie(
// return type should be lvalue
[&](auto I) -> const auto& { return in_thread_buf_tuple(I)(iM); },
Number<NumInput>{});
unpack2(elementwise_op, out_data_refs, in_data_refs);
});
static_for<0, NumOutput, 1>{}([&](auto I) {
out_global_store_tuple(I).Run(thread_buffer_desc_m,
make_tuple(I0),
out_thread_buf_tuple[I],
out_grid_1d_desc_tuple[I],
out_global_buf_tuple(I));
out_global_store_tuple(I).MoveDstSliceWindow(out_grid_1d_desc_tuple[I],
loop_step_index);
});
} while(--num_iter);
}
};
} // namespace ck
// SPDX-License-Identifier: MIT
// // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
//
#pragma once
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
template <typename GridwiseElementwise3dFunctor,
typename InGrid3dDescTuple,
typename OutGrid3dDescTuple,
typename InDataTypePointerTuple,
typename OutDataTypePointerTuple,
typename ElementwiseOperation>
__global__ void kernel_elementwise_3d(const InGrid3dDescTuple in_grid_3d_desc_tuple,
const OutGrid3dDescTuple out_grid_3d_desc_tuple,
const InDataTypePointerTuple p_in_global_tuple,
const OutDataTypePointerTuple p_out_global_tuple,
const ElementwiseOperation elementwise_op,
const index_t num_threads_m,
const index_t num_threads_n,
const index_t num_threads_k)
{
GridwiseElementwise3dFunctor::Run(in_grid_3d_desc_tuple,
out_grid_3d_desc_tuple,
p_in_global_tuple,
p_out_global_tuple,
elementwise_op,
num_threads_m,
num_threads_n,
num_threads_k);
}
template <typename InGrid3dDescTuple,
typename OutGrid3dDescTuple,
typename InDataTypePointerTuple,
typename OutDataTypePointerTuple,
typename ElementwiseOperation,
index_t MPerThread,
index_t NPerThread,
index_t KPerThread,
typename InScalarPerVectorSeq,
typename OutScalarPerVectorSeq>
struct GridwiseElementwise_3D
{
static constexpr index_t NumInput = InDataTypePointerTuple::Size();
static constexpr index_t NumOutput = OutDataTypePointerTuple::Size();
static_assert(NumInput == InScalarPerVectorSeq::Size() &&
NumOutput == OutScalarPerVectorSeq::Size() &&
NumInput == InGrid3dDescTuple::Size() &&
NumOutput == OutGrid3dDescTuple::Size(),
"Tuple size is inconsistent with the number of in/out!");
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto thread_buffer_desc_mnk = make_naive_tensor_descriptor_packed(
make_tuple(Number<MPerThread>{}, Number<NPerThread>{}, Number<KPerThread>{}));
using PassThroughOp = tensor_operation::element_wise::PassThrough;
__device__ static void Run(const InGrid3dDescTuple in_grid_3d_desc_tuple,
const OutGrid3dDescTuple out_grid_3d_desc_tuple,
const InDataTypePointerTuple p_in_global_tuple,
const OutDataTypePointerTuple p_out_global_tuple,
const ElementwiseOperation elementwise_op,
const index_t num_threads_m,
const index_t num_threads_n,
const index_t num_threads_k)
{
auto in_thread_buf_tuple = generate_tuple(
[&](auto I) {
using DataTypePointer = remove_cvref_t<decltype(InDataTypePointerTuple{}[I])>;
using DataType = remove_cv_t<remove_pointer_t<DataTypePointer>>;
return StaticBuffer<AddressSpaceEnum::Vgpr,
DataType,
MPerThread * NPerThread * KPerThread,
true>{};
},
Number<NumInput>{});
auto out_thread_buf_tuple = generate_tuple(
[&](auto I) {
using DataTypePointer = remove_cvref_t<decltype(OutDataTypePointerTuple{}[I])>;
using DataType = remove_pointer_t<DataTypePointer>;
return StaticBuffer<AddressSpaceEnum::Vgpr,
DataType,
MPerThread * NPerThread * KPerThread,
true>{};
},
Number<NumOutput>{});
auto in_global_buf_tuple = generate_tuple(
[&](auto I) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_global_tuple[I], in_grid_3d_desc_tuple[I].GetElementSpaceSize());
},
Number<NumInput>{});
auto out_global_buf_tuple = generate_tuple(
[&](auto I) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global_tuple[I], out_grid_3d_desc_tuple[I].GetElementSpaceSize());
},
Number<NumOutput>{});
const auto M = in_grid_3d_desc_tuple[I0].GetLength(I0);
const auto N = in_grid_3d_desc_tuple[I0].GetLength(I1);
const auto K = in_grid_3d_desc_tuple[I0].GetLength(I2);
const index_t loop_step_m = num_threads_m * MPerThread;
const index_t loop_step_n = num_threads_n * NPerThread;
const index_t loop_step_k = num_threads_k * KPerThread;
const index_t thread_1d_id = get_thread_global_1d_id();
const index_t tid_m = thread_1d_id / (num_threads_n * num_threads_k);
const index_t tid_nk = thread_1d_id % (num_threads_n * num_threads_k);
const index_t tid_n = tid_nk / num_threads_k;
const index_t tid_k = tid_nk % num_threads_k;
const auto thread_global_offset =
make_multi_index(tid_m * MPerThread, tid_n * NPerThread, tid_k * KPerThread);
auto in_global_load_tuple = generate_tuple(
[&](auto I) {
using DataTypePointer = remove_cvref_t<decltype(InDataTypePointerTuple{}[I])>;
using DataType = remove_cv_t<remove_pointer_t<DataTypePointer>>;
return ThreadwiseTensorSliceTransfer_v2<
DataType,
DataType,
decltype(in_grid_3d_desc_tuple[I]),
decltype(thread_buffer_desc_mnk),
Sequence<MPerThread, NPerThread, KPerThread>, // SliceLengths
Sequence<0, 1, 2>, // DimAccessOrder
01, // SrcVectorDim
InScalarPerVectorSeq::At(I), // InScalarPerVectorSeq::At(I), //
// ScalarPerVector
1, // SrcScalarStrideInVector
true>{in_grid_3d_desc_tuple[I], thread_global_offset};
},
Number<NumInput>{});
auto out_global_store_tuple = generate_tuple(
[&](auto I) {
using DataTypePointer = remove_cvref_t<decltype(OutDataTypePointerTuple{}[I])>;
using DataType = remove_pointer_t<DataTypePointer>;
return ThreadwiseTensorSliceTransfer_v1r3<
DataType,
DataType,
decltype(thread_buffer_desc_mnk),
decltype(out_grid_3d_desc_tuple[I]),
PassThroughOp,
Sequence<MPerThread, NPerThread, KPerThread>, // SliceLengths
Sequence<0, 1, 2>, // DimAccessOrder
2, // SrcVectorDim
OutScalarPerVectorSeq::At(I), // OutScalarPerVectorSeq::At(I),
InMemoryDataOperationEnum::Set,
1,
true>(out_grid_3d_desc_tuple[I], thread_global_offset, PassThroughOp{});
},
Number<NumOutput>{});
index_t num_iter_m = M / (loop_step_m);
do
{
index_t num_iter_n = N / (loop_step_n);
do
{
index_t num_iter_k = K / (loop_step_k);
do
{
static_for<0, NumInput, 1>{}([&](auto I) {
in_global_load_tuple(I).Run(in_grid_3d_desc_tuple[I],
in_global_buf_tuple[I],
thread_buffer_desc_mnk,
make_tuple(I0, I0, I0),
in_thread_buf_tuple(I));
in_global_load_tuple(I).MoveSrcSliceWindow(
in_grid_3d_desc_tuple[I], make_multi_index(0, 0, loop_step_k));
});
static_for<0, MPerThread, 1>{}([&](auto iM) {
static_for<0, NPerThread, 1>{}([&](auto iN) {
static_for<0, KPerThread, 1>{}([&](auto iK) {
constexpr auto offset =
thread_buffer_desc_mnk.CalculateOffset(make_tuple(iM, iN, iK));
// get reference to in data
const auto in_data_refs = generate_tie(
// return type should be lvalue
[&](auto I) -> const auto& {
return in_thread_buf_tuple(I)(Number<offset>{});
},
Number<NumInput>{});
// get referenec to dst data
auto out_data_refs = generate_tie(
// return type should be lvalue
[&](auto I) -> auto& {
return out_thread_buf_tuple(I)(Number<offset>{});
},
Number<NumOutput>{});
unpack2(elementwise_op, out_data_refs, in_data_refs);
});
});
});
static_for<0, NumOutput, 1>{}([&](auto I) {
out_global_store_tuple(I).Run(thread_buffer_desc_mnk,
make_tuple(I0, I0, I0),
out_thread_buf_tuple[I],
out_grid_3d_desc_tuple[I],
out_global_buf_tuple(I));
out_global_store_tuple(I).MoveDstSliceWindow(
out_grid_3d_desc_tuple[I], make_multi_index(0, 0, loop_step_k));
});
} while(--num_iter_k);
static_for<0, NumInput, 1>{}([&](auto I) {
in_global_load_tuple(I).MoveSrcSliceWindow(
in_grid_3d_desc_tuple[I],
make_multi_index(0, loop_step_n, -(K / loop_step_k) * loop_step_k));
});
static_for<0, NumOutput, 1>{}([&](auto I) {
out_global_store_tuple(I).MoveDstSliceWindow(
out_grid_3d_desc_tuple[I],
make_multi_index(0, loop_step_n, -(K / loop_step_k) * loop_step_k));
});
} while(--num_iter_n);
static_for<0, NumInput, 1>{}([&](auto I) {
in_global_load_tuple(I).MoveSrcSliceWindow(
in_grid_3d_desc_tuple[I],
make_multi_index(loop_step_m,
-(N / loop_step_n) * loop_step_n,
-(K / loop_step_k) * loop_step_k));
});
static_for<0, NumOutput, 1>{}([&](auto I) {
out_global_store_tuple(I).MoveDstSliceWindow(
out_grid_3d_desc_tuple[I],
make_multi_index(loop_step_m,
-(N / loop_step_n) * loop_step_n,
-(K / loop_step_k) * loop_step_k));
});
} while(--num_iter_m);
}
};
} // namespace ck
......@@ -203,7 +203,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
// A desc for source in blockwise copy
template <typename AGridDesc_M_K>
__host__ __device__ static constexpr auto
MakeAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k)
MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k)
{
const auto M = a_grid_desc_m_k.GetLength(I0);
const auto K = a_grid_desc_m_k.GetLength(I1);
......@@ -219,17 +219,17 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
template <typename AsGridDesc_M_K>
__host__ __device__ static constexpr auto
MakeAsGridDescriptor_AK0_M_AK1(const AsGridDesc_M_K& as_grid_desc_m_k)
MakeDefaultAsGridDescriptor_AK0_M_AK1(const AsGridDesc_M_K& as_grid_desc_m_k)
{
return generate_tuple(
[&](auto i) { return MakeAGridDescriptor_AK0_M_AK1(as_grid_desc_m_k[i]); },
[&](auto i) { return MakeDefaultAGridDescriptor_AK0_M_AK1(as_grid_desc_m_k[i]); },
Number<NumATensor>{});
}
// B desc for source in blockwise copy
template <typename BGridDesc_N_K>
__host__ __device__ static constexpr auto
MakeBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k)
MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k)
{
const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = b_grid_desc_n_k.GetLength(I1);
......@@ -245,10 +245,10 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
template <typename BsGridDesc_N_K>
__host__ __device__ static constexpr auto
MakeBsGridDescriptor_BK0_N_BK1(const BsGridDesc_N_K& bs_grid_desc_n_k)
MakeDefaultBsGridDescriptor_BK0_N_BK1(const BsGridDesc_N_K& bs_grid_desc_n_k)
{
return generate_tuple(
[&](auto i) { return MakeBGridDescriptor_BK0_N_BK1(bs_grid_desc_n_k[i]); },
[&](auto i) { return MakeDefaultBGridDescriptor_BK0_N_BK1(bs_grid_desc_n_k[i]); },
Number<NumBTensor>{});
}
......@@ -288,7 +288,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
// return block_id to E matrix tile idx (m0, n0) mapping
template <typename EGridDesc_M_N>
__host__ __device__ static constexpr auto
MakeBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n)
MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n)
{
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, EGridDesc_M_N>(
e_grid_desc_m_n);
......@@ -591,6 +591,9 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
generate_tuple([&](auto) { return make_multi_index(0, m_block_data_idx_on_grid, 0); },
Number<NumATensor>{});
static_assert(ABlockTransferSrcScalarPerVector == ABlockTransferDstScalarPerVector_AK1,
"Src and Dst ScalarPerVector must be the same");
auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2<
ThisThreadBlock,
AsDataType,
......@@ -619,6 +622,9 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
generate_tuple([&](auto) { return make_multi_index(0, n_block_data_idx_on_grid, 0); },
Number<NumBTensor>{});
static_assert(BBlockTransferSrcScalarPerVector == BBlockTransferDstScalarPerVector_BK1,
"Src and Dst ScalarPerVector must be the same");
auto b_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2<
ThisThreadBlock,
BsDataType,
......@@ -1005,9 +1011,9 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N<ELayout, GemmSpec>(M, N, StrideE);
// tensor descriptors for block/thread-wise copy
const auto as_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1(as_grid_desc_m_k);
const auto as_grid_desc_ak0_m_ak1 = MakeDefaultAsGridDescriptor_AK0_M_AK1(as_grid_desc_m_k);
const auto bs_grid_desc_bk0_n_bk1 = MakeBsGridDescriptor_BK0_N_BK1(bs_grid_desc_n_k);
const auto bs_grid_desc_bk0_n_bk1 = MakeDefaultBsGridDescriptor_BK0_N_BK1(bs_grid_desc_n_k);
const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n);
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
namespace ck {
template <typename GridwiseGemm,
typename ADataType,
typename BDataType,
typename DsPointer,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename Block2ETileMap,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_multiple_d_xdl_cshuffle_lds_direct_load(
const ADataType* __restrict__ p_a_grid,
const BDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2ETileMap block_2_etile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx90a__) || defined(__gfx940__) || \
defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid,
p_ds_grid,
p_e_grid,
p_shared,
a_element_op,
b_element_op,
cde_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_etile_map);
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_ds_grid;
ignore = p_e_grid;
ignore = a_element_op;
ignore = b_element_op;
ignore = cde_element_op;
ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1;
ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = e_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = block_2_etile_map;
#endif
}
// GEMM:
// input : A[M, K]
// input : B[N, K]
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// D0, D1, ... and E have the same layout
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename AComputeDataType_,
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
tensor_operation::device::GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t AK1Value,
index_t BK1Value,
index_t MPerXdl,
index_t NPerXdl,
index_t MXdlPerWave,
index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferScalarPerVector,
index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferScalarPerVector,
index_t BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched,
PipelineVersion PipelineVer = PipelineVersion::v4,
typename BComputeDataType = AComputeDataType_>
struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
{
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
static constexpr auto AK1 = Number<AK1Value>{};
static constexpr auto BK1 = Number<BK1Value>{};
static constexpr auto AK0PerBlock = Number<KPerBlock / AK1Value>{};
static constexpr auto BK0PerBlock = Number<KPerBlock / BK1Value>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = remove_cvref_t<
decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
#if CK_WORKAROUND_DENORM_FIX
using AComputeDataType =
conditional_t<is_same_v<AComputeDataType_, ck::half_t>, ck::bhalf_t, AComputeDataType_>;
#else
using AComputeDataType = AComputeDataType_;
#endif
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{
// A matrix in LDS memory, destination of blockwise copy.
return make_naive_tensor_descriptor(
make_tuple(AK0PerBlock, Number<MPerBlock>{}, AK1),
make_tuple(Number<MPerBlock + ABlockLdsExtraM>{} * AK1, AK1, I1));
}
__host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
{
// B matrix in LDS memory, destination of blockwise copy.
return make_naive_tensor_descriptor(
make_tuple(BK0PerBlock, Number<NPerBlock>{}, BK1),
make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1, BK1, I1));
}
__host__ __device__ static constexpr auto
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
{
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
make_naive_tensor_descriptor_packed(
make_tuple(I1,
Number<CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl>{},
I1,
Number<CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>{}));
return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
}
// ck::Tuple<const D0DataType*, const D1DataType*, ...>
static constexpr auto MakeDsGridPointer()
{
return generate_tuple(
[&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
return static_cast<const DDataType*>(nullptr);
},
Number<NumDTensor>{});
}
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
// LDS allocation for A and B: be careful of alignment.
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
constexpr auto max_lds_align = math::lcm(AK1, BK1);
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
// LDS allocation for C shuffle.
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
constexpr auto c_block_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
return math::max(a_block_space_size_aligned * sizeof(AComputeDataType) +
b_block_space_size_aligned * sizeof(BComputeDataType),
c_block_size * sizeof(CShuffleDataType));
}
__host__ __device__ static auto
MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA)
{
constexpr auto matrix_padder =
ck::tensor_operation::device::MatrixPadder<GemmSpec, index_t, index_t, index_t>{
MPerBlock, NPerBlock, KPerBlock};
const auto a_grid_desc_mraw_kraw = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(StrideA, I1));
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(I1, StrideA));
}
}();
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
}
__host__ __device__ static auto
MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
{
constexpr auto matrix_padder =
ck::tensor_operation::device::MatrixPadder<GemmSpec, index_t, index_t, index_t>{
MPerBlock, NPerBlock, KPerBlock};
const auto b_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(I1, StrideB));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(StrideB, I1));
}
}();
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
}
__host__ __device__ static auto
MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
{
constexpr auto matrix_padder =
ck::tensor_operation::device::MatrixPadder<GemmSpec, index_t, index_t, index_t>{
MPerBlock, NPerBlock, KPerBlock};
const auto e_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ELayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(StrideE, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ELayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(I1, StrideE));
}
}();
return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
}
__host__ __device__ static auto
MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& MRaws,
const std::array<index_t, NumDTensor>& NRaws,
const std::array<index_t, NumDTensor>& DsStride)
{
return generate_tuple(
[&](auto i) { return MakeEGridDescriptor_M_N(MRaws[i], NRaws[i], DsStride[i]); },
Number<NumDTensor>{});
}
using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1));
using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1));
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(1, 1, 1));
// A desc for source in blockwise copy.
__host__ __device__ static constexpr auto
MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k)
{
const auto M = a_grid_desc_m_k.GetLength(I0);
const auto K = a_grid_desc_m_k.GetLength(I1);
const auto AK0 = K / AK1;
return transform_tensor_descriptor(a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
// B desc for source in blockwise copy.
__host__ __device__ static constexpr auto
MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k)
{
const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = b_grid_desc_n_k.GetLength(I1);
const auto BK0 = K / BK1;
return transform_tensor_descriptor(b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
// E desc for destination in blockwise copy.
__host__ __device__ static constexpr auto
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N& e_grid_desc_m_n)
{
const auto M = e_grid_desc_m_n.GetLength(I0);
const auto N = e_grid_desc_m_n.GetLength(I1);
const auto MBlock = M / MPerBlock;
const auto NBlock = N / NPerBlock;
const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
e_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
return e_grid_desc_mblock_mperblock_nblock_nperblock;
}
// Ds desc for source in blockwise copy.
__host__ __device__ static constexpr auto
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc_M_N& ds_grid_desc_m_n)
{
return generate_tuple(
[&](auto i) {
return MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[i]);
},
Number<NumDTensor>{});
}
__host__ __device__ static constexpr auto
MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n)
{
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, EGridDesc_M_N>(
e_grid_desc_m_n);
}
using AGridDesc_AK0_M_AK1 =
remove_cvref_t<decltype(MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
using BGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
DsGridDesc_M_N{}))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
EGridDesc_M_N{}))>;
using Block2ETileMap = remove_cvref_t<decltype(MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
__host__ __device__ static constexpr bool CheckValidity(const AGridDesc_M_K& a_grid_desc_m_k,
const BGridDesc_N_K& b_grid_desc_n_k,
const DsGridDesc_M_N& ds_grid_desc_m_n,
const EGridDesc_M_N& e_grid_desc_m_n,
const Block2ETileMap& block_2_etile_map)
{
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
"Invalid tuning param!");
static_assert(KPerBlock % AK1Value == 0 && KPerBlock % BK1Value == 0,
"KPerBlock must be divisible by AK1Value and BK1Value!");
static_assert(
std::is_same_v<AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough> &&
std::is_same_v<BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough>,
"Direct load transfers do not support elementwise operations other than passthrough.");
const auto M = a_grid_desc_m_k.GetLength(I0);
const auto N = b_grid_desc_n_k.GetLength(I0);
const auto AK = a_grid_desc_m_k.GetLength(I1);
const auto BK = b_grid_desc_n_k.GetLength(I1);
// Check the consistency of descriptors.
if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) && AK == BK))
{
return false;
}
bool valid = true;
static_for<0, NumDTensor, 1>{}([&](auto i) {
valid = valid && (M == ds_grid_desc_m_n[i].GetLength(I0) &&
N == ds_grid_desc_m_n[i].GetLength(I1));
});
if(!valid)
{
return false;
}
// Check the tile size.
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && AK % KPerBlock == 0))
{
return false;
}
// Check gridwise gemm pipeline.
const auto num_k_loop = AK / KPerBlock;
if(!GridwiseGemmPipe::IsSupported(num_k_loop))
{
return false;
}
// Check block-to-E-tile.
if(!block_2_etile_map.CheckValidity(e_grid_desc_m_n))
{
return false;
}
// Check tensor size: cannot exceed 2GB.
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
if(!(a_grid_desc_m_k.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB &&
b_grid_desc_n_k.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB &&
e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB))
{
return false;
}
return true;
}
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
{
const index_t num_loop = K / KPerBlock;
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
}
using DsGridPointer = decltype(MakeDsGridPointer());
__device__ __host__ static constexpr auto GetMPerBlock() { return MPerBlock; }
template <bool HasMainKBlockLoop,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>
__device__ static void Run(const ADataType* __restrict__ p_a_grid,
const BDataType* __restrict__ p_b_grid,
DsGridPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
e_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2ETileMap& block_2_etile_map)
{
// Elementwise operations are not supported for A and B, arguments left only for the API
// consistency.
(void)a_element_op;
(void)b_element_op;
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
const auto ds_grid_buf = generate_tuple(
[&](auto i) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ds_grid[i],
ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
},
Number<NumDTensor>{});
auto e_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// Divide block work by [M, N].
const auto block_work_idx =
block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!block_2_etile_map.ValidCTileIndex(
block_work_idx,
make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
{
return;
}
// This forces m/n_block_data_idx_on_grid into SGPR.
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
constexpr auto max_lds_align = math::lcm(AK1, BK1);
// A matrix in LDS memory, destination of blockwise copy.
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// B matrix in LDS memory, destination of blockwise copy.
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_DirectLoad<ThisThreadBlock,
Sequence<AK0PerBlock, MPerBlock, AK1>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ADataType,
AComputeDataType,
decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcVectorDim,
2,
ABlockTransferScalarPerVector>(
a_grid_desc_ak0_m_ak1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0));
auto b_blockwise_copy =
ThreadGroupTensorSliceTransfer_DirectLoad<ThisThreadBlock,
Sequence<BK0PerBlock, NPerBlock, BK1>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BDataType,
BComputeDataType,
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcVectorDim,
2,
BBlockTransferScalarPerVector>(
b_grid_desc_bk0_n_bk1,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[K0PerBlock, MPerBlock] is in LDS
// b_mtx[K0PerBlock, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
constexpr index_t KPack = math::max(
math::lcm(AK1, BK1),
MfmaSelector<AComputeDataType, MPerXdl, NPerXdl, BComputeDataType>::selected_mfma
.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
AComputeDataType,
BComputeDataType,
AccDataType,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1),
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
KPack,
LoopSched>();
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
// LDS allocation for A and B: be careful of alignment.
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<AComputeDataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<BComputeDataType*>(p_shared) + a_block_space_size_aligned,
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0);
const auto gridwise_gemm_pipeline =
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>();
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
KPerBlock);
gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
a_blockwise_copy,
a_grid_buf,
a_block_buf,
a_block_slice_copy_step,
b_grid_desc_bk0_n_bk1,
b_block_desc_bk0_n_bk1,
b_blockwise_copy,
b_grid_buf,
b_block_buf,
b_block_slice_copy_step,
blockwise_gemm,
c_thread_buf,
num_k_block_main_loop);
// Shuffle C and write out.
{
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
"wrong!");
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<CShuffleDataType*>(p_shared),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_tuple(
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
M1, // M1 = MWave
M2, // M2 * M3 * M4 = MPerXdl
M3,
M4)),
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
N1, // N1 = NWave
N2))), // N2 = NPerXdl
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(
Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
// Calculate the origin of thread output tensor on global memory.
const auto c_thread_mtx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_block_idx =
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_block));
const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_block_idx =
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block));
// Shuffle: threadwise copy C from VGPR to LDS.
auto c_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
CShuffleDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
ck::tensor_operation::element_wise::PassThrough,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
I1,
I1,
M2,
I1,
M4,
I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
1,
InMemoryDataOperationEnum::Set,
1,
true>{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(0,
0,
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I3],
m_thread_data_on_block_idx[I4],
n_thread_data_on_block_idx[I2]),
ck::tensor_operation::element_wise::PassThrough{}};
// A tuple of reference to C/Ds tensor descriptors.
const auto c_ds_desc_refs = concat_tuple_of_reference(
tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
generate_tie(
[&](auto i) -> const auto& // return type should be reference
{ return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
Number<NumDTensor>{}));
// A tuple of reference to C/Ds grid buffers.
const auto c_ds_buf_refs = concat_tuple_of_reference(
tie(c_shuffle_block_buf),
generate_tie(
[&](auto i) -> const auto& // return type should be reference
{ return ds_grid_buf[i]; },
Number<NumDTensor>{}));
// A tuple of starting index of C/Ds blockwise copy.
const auto idx_c_ds_block_begin = container_concat(
make_tuple(make_multi_index(0, 0, 0, 0)),
generate_tuple(
[&](auto) {
return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0);
},
Number<NumDTensor>{}));
// Blockwise copy C/D/E between LDS and global.
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7<
ThisThreadBlock,
decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
Tuple<EDataType>,
decltype(c_ds_desc_refs),
decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
CDEElementwiseOperation,
Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CDEShuffleBlockTransferScalarPerVector_NPerBlock,
sequence_merge_t<
Sequence<true>,
uniform_sequence_gen_t<NumDTensor,
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
Sequence<false>> // ThreadTransferDstResetCoordinateAfterRunFlags
{c_ds_desc_refs,
idx_c_ds_block_begin,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
make_tuple(make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)),
cde_element_op};
// Space filling curve for threadwise C in VGPR before shuffle.
constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
1,
1,
M2,
1,
M4,
1>>{};
// Space filling curve for shuffled blockwise C/D/E.
constexpr auto sfc_cde_block =
SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
Sequence<0, 2, 1, 3>,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
static_for<0, num_access, 1>{}([&](auto access_id) {
// Make sure it's safe to write to LDS.
block_sync_lds();
// Each thread write its data from VGPR to LDS.
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
c_thread_buf,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_shuffle_block_buf);
// Make sure it's safe to read from LDS.
block_sync_lds();
// Each block copy its data from LDS to global.
cde_block_copy_lds_and_global.Run(
c_ds_desc_refs,
c_ds_buf_refs,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
tie(e_grid_buf));
if constexpr(access_id < num_access - 1)
{
constexpr auto cde_lds_and_global_step =
sfc_cde_block.GetForwardStep(access_id);
// Move on Ds.
static_for<0, NumDTensor, 1>{}([&](auto i) {
cde_block_copy_lds_and_global.MoveSrcSliceWindow(
c_ds_desc_refs, i + I1, cde_lds_and_global_step);
});
// Move on E.
cde_block_copy_lds_and_global.MoveDstSliceWindow(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
I0,
cde_lds_and_global_step);
}
});
}
}
struct Argument : public tensor_operation::device::BaseArgument
{
Argument(const void* p_a_grid,
const void* p_b_grid,
std::array<const void*, NumDTensor> p_ds_grid,
void* p_e_grid,
index_t MRaw,
index_t NRaw,
index_t KRaw,
index_t StrideA,
index_t StrideB,
std::array<index_t, NumDTensor> StrideDs,
index_t StrideE,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
: p_a_grid_{static_cast<const ADataType*>(p_a_grid)},
p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e_grid)},
a_grid_desc_m_k_{MakeAGridDescriptor_M_K(MRaw, KRaw, StrideA)},
b_grid_desc_n_k_{MakeBGridDescriptor_N_K(KRaw, NRaw, StrideB)},
ds_grid_desc_m_n_{},
e_grid_desc_m_n_{MakeEGridDescriptor_M_N(MRaw, NRaw, StrideE)},
a_grid_desc_ak0_m_ak1_{MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)},
b_grid_desc_bk0_n_bk1_{MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)},
ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
e_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_etile_map_{MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op},
MRaw_{MRaw},
NRaw_{NRaw},
KRaw_{KRaw}
{
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]);
ds_grid_desc_m_n_(i) = MakeEGridDescriptor_M_N(MRaw, NRaw, StrideDs[i]);
});
if(CheckValidity(a_grid_desc_m_k_,
b_grid_desc_n_k_,
ds_grid_desc_m_n_,
e_grid_desc_m_n_,
block_2_etile_map_))
{
ds_grid_desc_mblock_mperblock_nblock_nperblock_ =
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n_);
e_grid_desc_mblock_mperblock_nblock_nperblock_ =
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n_);
}
}
void Print() const
{
std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl;
std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl;
static_for<0, NumDTensor, 1>{}(
[&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; });
std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl;
}
// Pointers
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
DsGridPointer p_ds_grid_;
EDataType* p_e_grid_;
// Tensor descriptors for problem definiton
AGridDesc_M_K a_grid_desc_m_k_;
BGridDesc_N_K b_grid_desc_n_k_;
DsGridDesc_M_N ds_grid_desc_m_n_;
EGridDesc_M_N e_grid_desc_m_n_;
// Tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock_;
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_;
// block-to-e-tile map
Block2ETileMap block_2_etile_map_;
// element-wise ops
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
// For checking vector load/store
index_t MRaw_;
index_t NRaw_;
index_t KRaw_;
};
};
} // namespace ck
......@@ -7,6 +7,7 @@
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v4_direct_load.hpp"
namespace ck {
......@@ -14,6 +15,8 @@ enum struct PipelineVersion
{
v1,
v2,
// v3 is only used in the Stream-K implementation.
v4,
};
template <PipelineVersion PipelineVer,
......@@ -36,6 +39,10 @@ constexpr auto GridwiseGemmPipeline_Selector()
{
return GridwiseGemmPipeline_v2{};
}
else if constexpr(PipelineVer == PipelineVersion::v4)
{
return GridwiseGemmPipeline_v4<NumPrefetch>{};
}
else
{
std::cerr << "GridwiseGemmPipeline configuration is not available" << std::endl;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/utility/loop_scheduler.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
namespace ck {
template <index_t NumPrefetch>
struct GridwiseGemmPipeline_v4;
// 1-stage prefetch
template <>
struct GridwiseGemmPipeline_v4<1>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
__host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; }
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
{
return num_loop > 1;
}
template <bool HasMainLoop,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename BlockwiseGemm,
typename CThreadBuffer>
__device__ static void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
const BlockwiseGemm& blockwise_gemm,
CThreadBuffer& c_thread_buf,
index_t num_loop)
{
a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf);
b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Initialize C
c_thread_buf.Clear();
// main body
if constexpr(HasMainLoop)
{
index_t i = 0;
do
{
block_sync_lds_direct_load();
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds_direct_load();
a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf);
b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
++i;
} while(i < (num_loop - 1));
}
// tail
{
block_sync_lds_direct_load();
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
}
}
};
} // namespace ck
......@@ -996,6 +996,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext
}
}
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
{
if(!(problem.K0 % K0PerBlock == 0))
{
return false;
}
}
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
if(problem.K % ABlockTransferSrcScalarPerVector != 0)
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp"
namespace ck {
// dgamma = reduce_sum(dy * (x - mean) * inv_std)
// dbeta = reduce_sum(dy)
template <typename DYDataType,
typename XDataType,
typename MeanInvStdDataType,
typename ComputeDataType,
typename DGammaDataType,
typename DBetaDataType,
typename GridDesc_M_K,
typename GridDesc_M,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t DYSrcVectorDim,
index_t DYSrcVectorSize,
index_t XSrcVectorDim,
index_t XSrcVectorSize,
index_t MeanInvStdSrcVectorDim,
index_t MeanInvStdSrcVectorSize,
index_t DGammaDstVectorSize,
index_t DBetaDstVectorSize>
struct GridwiseNormalizationBwdGammaBeta_mk_to_k
{
// if we just check ThreadSliceSize & VectorSize == 0, the performance may be poor
static_assert(((DYSrcVectorDim == 0 && MThreadSliceSize == DYSrcVectorSize) ||
(DYSrcVectorDim == 1 && KThreadSliceSize == DYSrcVectorSize)),
"Invalid thread slice sizes and/or dy vector sizes configuration, please check!");
static_assert(((XSrcVectorDim == 0 && MThreadSliceSize == XSrcVectorSize) ||
(XSrcVectorDim == 1 && KThreadSliceSize == XSrcVectorSize)),
"Invalid thread slice sizes and/or x vector sizes configuration, please check!");
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
using DYThreadBufferDimAccessOrder =
typename conditional<DYSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type;
using XThreadBufferDimAccessOrder =
typename conditional<XSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type;
using MeanInvStdThreadBufferDimAccessOrder =
typename conditional<MeanInvStdSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type;
using ThreadClusterArrangeOrder = DYThreadBufferDimAccessOrder;
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, KThreadSliceSize>;
using ThreadBufferLengths_M = Sequence<MThreadSliceSize>;
static constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
static constexpr auto thread_buffer_desc_m =
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
using PassThroughOp = tensor_operation::element_wise::PassThrough;
using BlockwiseSumReduce = PartitionedBlockwiseReduction<ComputeDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
reduce::Add,
true>;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
__device__ static void Run(const GridDesc_M_K& dy_grid_desc_m_k,
const GridDesc_M_K& x_grid_desc_m_k,
const GridDesc_M_K& mean_grid_desc_m_k,
const GridDesc_M_K& inv_std_grid_desc_m_k,
const GridDesc_M& dgamma_grid_desc_m,
const GridDesc_M& dbeta_grid_desc_m,
index_t num_k_block_tile_iteration,
const DYDataType* const __restrict__ p_dy_global,
const XDataType* const __restrict__ p_x_global,
const MeanInvStdDataType* const __restrict__ p_mean_global,
const MeanInvStdDataType* const __restrict__ p_inv_std_global,
DGammaDataType* const __restrict__ p_dgamma_global,
DBetaDataType* const __restrict__ p_dbeta_global)
{
// LDS
__shared__ ComputeDataType p_reduce_work_buffer[BlockSize];
auto reduce_work_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
// Global
const auto dy_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_dy_global, dy_grid_desc_m_k.GetElementSpaceSize());
const auto x_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_x_global, x_grid_desc_m_k.GetElementSpaceSize());
const auto mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_mean_global, mean_grid_desc_m_k.GetElementSpaceSize());
const auto inv_std_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_inv_std_global, inv_std_grid_desc_m_k.GetElementSpaceSize());
auto dgamma_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_dgamma_global, dgamma_grid_desc_m.GetElementSpaceSize());
auto dbeta_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_dbeta_global, dbeta_grid_desc_m.GetElementSpaceSize());
// VGPR
auto dy_thread_buf = StaticBuffer<AddressSpaceEnum::Vgpr,
ComputeDataType,
MThreadSliceSize * KThreadSliceSize,
true>{};
auto x_thread_buf = StaticBuffer<AddressSpaceEnum::Vgpr,
ComputeDataType,
MThreadSliceSize * KThreadSliceSize,
true>{};
auto mean_thread_buf = StaticBuffer<AddressSpaceEnum::Vgpr,
ComputeDataType,
MThreadSliceSize * KThreadSliceSize,
true>{};
auto inv_std_thread_buf = StaticBuffer<AddressSpaceEnum::Vgpr,
ComputeDataType,
MThreadSliceSize * KThreadSliceSize,
true>{};
auto dgamma_thread_buf =
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>{};
auto dbeta_thread_buf =
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>{};
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id();
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1];
// IO
auto threadwise_dy_load = ThreadwiseTensorSliceTransfer_v2<DYDataType,
ComputeDataType,
GridDesc_M_K,
decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K,
DYThreadBufferDimAccessOrder,
DYSrcVectorDim,
DYSrcVectorSize,
1,
true>(
dy_grid_desc_m_k,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
ComputeDataType,
GridDesc_M_K,
decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K,
XThreadBufferDimAccessOrder,
XSrcVectorDim,
XSrcVectorSize,
1,
true>(
x_grid_desc_m_k,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
auto threadwise_mean_load =
ThreadwiseTensorSliceTransfer_v2<MeanInvStdDataType,
ComputeDataType,
GridDesc_M_K,
decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K,
MeanInvStdThreadBufferDimAccessOrder,
MeanInvStdSrcVectorDim,
MeanInvStdSrcVectorSize,
1,
true>(
mean_grid_desc_m_k,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
auto threadwise_inv_std_load =
ThreadwiseTensorSliceTransfer_v2<MeanInvStdDataType,
ComputeDataType,
GridDesc_M_K,
decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K,
MeanInvStdThreadBufferDimAccessOrder,
MeanInvStdSrcVectorDim,
MeanInvStdSrcVectorSize,
1,
true>(
inv_std_grid_desc_m_k,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
auto threadwise_dgamma_store =
ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
DGammaDataType,
decltype(thread_buffer_desc_m),
GridDesc_M,
PassThroughOp,
ThreadBufferLengths_M,
Sequence<0>,
0,
DGammaDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
dgamma_grid_desc_m,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
auto threadwise_dbeta_store =
ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
DBetaDataType,
decltype(thread_buffer_desc_m),
GridDesc_M,
PassThroughOp,
ThreadBufferLengths_M,
Sequence<0>,
0,
DBetaDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
dbeta_grid_desc_m,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
dgamma_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
dbeta_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
});
constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileSize);
for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
{
threadwise_dy_load.Run(dy_grid_desc_m_k,
dy_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
dy_thread_buf);
threadwise_x_load.Run(x_grid_desc_m_k,
x_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
x_thread_buf);
threadwise_mean_load.Run(mean_grid_desc_m_k,
mean_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
mean_thread_buf);
threadwise_inv_std_load.Run(inv_std_grid_desc_m_k,
inv_std_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
inv_std_thread_buf);
threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, thread_copy_fwd_step_m_k);
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
threadwise_mean_load.MoveSrcSliceWindow(mean_grid_desc_m_k, thread_copy_fwd_step_m_k);
threadwise_inv_std_load.MoveSrcSliceWindow(inv_std_grid_desc_m_k,
thread_copy_fwd_step_m_k);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
constexpr auto offset_m =
Number<thread_buffer_desc_m.CalculateOffset(make_tuple(iM))>{};
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset_m_k =
Number<thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK))>{};
dgamma_thread_buf(offset_m) +=
dy_thread_buf[offset_m_k] * inv_std_thread_buf[offset_m_k] *
(x_thread_buf[offset_m_k] - mean_thread_buf[offset_m_k]);
dbeta_thread_buf(offset_m) += dy_thread_buf[offset_m_k];
});
});
}
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if constexpr(I > 0)
block_sync_lds();
BlockwiseSumReduce::Reduce(reduce_work_buf, dbeta_thread_buf(I));
block_sync_lds();
BlockwiseSumReduce::Reduce(reduce_work_buf, dgamma_thread_buf(I));
});
if(thread_k_cluster_id == 0)
{
threadwise_dgamma_store.Run(thread_buffer_desc_m,
make_tuple(I0),
dgamma_thread_buf,
dgamma_grid_desc_m,
dgamma_global_val_buf);
threadwise_dbeta_store.Run(thread_buffer_desc_m,
make_tuple(I0),
dbeta_thread_buf,
dbeta_grid_desc_m,
dbeta_global_val_buf);
}
}
};
} // namespace ck
......@@ -944,4 +944,41 @@ amd_buffer_atomic_max(const typename vector_type_maker<T, N>::type::type src_thr
#endif
}
// Direct loads from global to LDS.
__device__ void
llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc,
__attribute__((address_space(3))) uint32_t* lds_ptr,
index_t size,
index_t voffset,
index_t soffset,
index_t offset,
index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds");
template <typename T, index_t NumElemsPerThread>
__device__ void amd_direct_load_global_to_lds(const T* global_base_ptr,
const index_t global_offset,
T* lds_base_ptr,
const index_t lds_offset,
const bool is_valid,
const index_t src_element_space_size)
{
// Direct loads require that each thread reads and writes exactly a single DWORD.
constexpr auto dword_bytes = 4;
constexpr auto bytes_per_thread = sizeof(T) * NumElemsPerThread;
static_assert(bytes_per_thread == dword_bytes);
const uint32_t* global_ptr =
reinterpret_cast<uint32_t*>(reinterpret_cast<uintptr_t>(global_base_ptr));
const int32x4_t src_resource = make_wave_buffer_resource(global_ptr, src_element_space_size);
const index_t global_offset_bytes = is_valid ? global_offset * sizeof(T) : 0x80000000;
// LDS pointer must be attributed with the LDS address space.
__attribute__((address_space(3))) uint32_t* lds_ptr =
reinterpret_cast<__attribute__((address_space(3))) uint32_t*>(
reinterpret_cast<uintptr_t>(lds_base_ptr + lds_offset));
llvm_amdgcn_raw_buffer_load_lds(
src_resource, lds_ptr, sizeof(uint32_t), global_offset_bytes, 0, 0, 0);
}
} // namespace ck
......@@ -173,6 +173,26 @@ struct DynamicBuffer
}
}
template <typename DstBuffer, index_t NumElemsPerThread>
__host__ __device__ void DirectCopyToLds(DstBuffer& dst_buf,
index_t src_offset,
index_t dst_offset,
bool is_valid_element) const
{
// Copy data from global to LDS memory using direct loads.
static_assert(GetAddressSpace() == AddressSpaceEnum::Global,
"Source data must come from a global memory buffer.");
static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds,
"Destination data must be stored in an LDS memory buffer.");
amd_direct_load_global_to_lds<T, NumElemsPerThread>(p_data_,
src_offset,
dst_buf.p_data_,
dst_offset,
is_valid_element,
element_space_size_);
}
template <typename X,
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
typename scalar_type<remove_cvref_t<T>>::type>::value,
......
......@@ -19,6 +19,15 @@ __device__ void block_sync_lds()
#endif
}
__device__ void block_sync_lds_direct_load()
{
asm volatile("\
s_waitcnt vmcnt(0) \n \
s_waitcnt lgkmcnt(0) \n \
s_barrier \
" ::);
}
__device__ void s_nop()
{
#if 1
......
......@@ -95,9 +95,113 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_
return type_convert<bhalf_t>(x_fp32);
}
// convert fp32 to fp8
// Declare a template function for fp8 conversion using SR
template <typename Y, typename X>
__host__ __device__ constexpr Y f8_convert_sr(X x);
// convert fp32 to fp8 with stochastic rounding
template <>
inline __host__ __device__ f8_t type_convert<f8_t, float>(float x)
inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
{
constexpr int seed = 42;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
float max_fp8 = 240.0f;
x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x);
union
{
float fval;
uint32_t i32val;
uint8_t i8val[4]; // not endian independent
} val;
val.fval = x;
uint32_t ival = 0;
ival = __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0); // 0 pos
val.i32val = ival;
return val.i8val[0]; // little endian
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
return utils::
cast_to_f8<float, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x,
rng);
#endif
}
// convert fp16 to fp8 with stochastic rounding
template <>
inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion
return f8_convert_sr<f8_t>(type_convert<float>(x));
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 42;
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
return utils::
cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
#endif
}
// convert fp32 to bf8 with stochastic rounding
template <>
inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x)
{
constexpr int seed = 42;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
union
{
float fval;
uint32_t i32val;
uint8_t i8val[4]; // not endian independent
} val;
val.fval = x;
uint32_t ival = 0;
ival = __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); // 0 pos
val.i32val = ival;
return val.i8val[0]; // little endian
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
return utils::
cast_to_f8<float, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
#endif
}
// convert fp16 to bf8 with stochastic rounding
template <>
inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion
return f8_convert_sr<f8_t>(type_convert<float>(x));
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 42;
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
return utils::
cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
#endif
}
// Declare a template function for fp8 conversion using RNE
template <typename Y, typename X>
__host__ __device__ constexpr Y f8_convert_rne(X x);
// convert fp32 to fp8 with rounding to nearest even
template <>
inline __host__ __device__ f8_t f8_convert_rne<f8_t, float>(float x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
float max_fp8 = 240.0f;
......@@ -124,6 +228,80 @@ inline __host__ __device__ f8_t type_convert<f8_t, float>(float x)
#endif
}
// convert fp16 to fp8 with rounding to nearest even
template <>
inline __host__ __device__ f8_t f8_convert_rne<f8_t, half_t>(half_t x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion
return f8_convert_rne<f8_t>(type_convert<float>(x));
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0;
return utils::
cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
#endif
}
// convert fp32 to bf8 with rounding to nearest even
template <>
inline __host__ __device__ bf8_t f8_convert_rne<bf8_t, float>(float x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
union
{
float fval;
uint32_t i32val;
uint8_t i8val[4]; // not endian independent
} val;
val.fval = x;
uint32_t ival = 0;
ival = __builtin_amdgcn_cvt_pk_bf8_f32(val.fval, val.fval, ival, false); // false -> WORD0
val.i32val = ival;
return val.i8val[0];
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0;
return utils::
cast_to_f8<float, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
#endif
}
// convert fp16 to bf8 with rounding to nearest even
template <>
inline __host__ __device__ bf8_t f8_convert_rne<bf8_t, half_t>(half_t x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion
return f8_convert_rne<bf8_t>(type_convert<float>(x));
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0;
return utils::
cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
#endif
}
// convert fp32 to fp8
template <>
inline __host__ __device__ f8_t type_convert<f8_t, float>(float x)
{
#if defined CK_USE_SR_F8_CONVERSION
return f8_convert_sr<f8_t>(x);
#else
return f8_convert_rne<f8_t>(x);
#endif
}
// convert fp8 to fp32
template <>
inline __host__ __device__ float type_convert<float, f8_t>(f8_t x)
......@@ -174,17 +352,10 @@ inline __host__ __device__ half2_t type_convert<half2_t, float2_t>(float2_t x)
template <>
inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion
return type_convert<f8_t>(type_convert<float>(x));
#if defined CK_USE_SR_F8_CONVERSION
return f8_convert_sr<f8_t>(x);
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0;
return utils::
cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
return f8_convert_nre<f8_t>(x);
#endif
}
......@@ -205,26 +376,10 @@ inline __host__ __device__ half_t type_convert<half_t, f8_t>(f8_t x)
template <>
inline __host__ __device__ bf8_t type_convert<bf8_t, float>(float x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
union
{
float fval;
uint32_t i32val;
uint8_t i8val[4]; // not endian independent
} val;
val.fval = x;
uint32_t ival = 0;
ival = __builtin_amdgcn_cvt_pk_bf8_f32(val.fval, val.fval, ival, false); // false -> WORD0
val.i32val = ival;
return val.i8val[0];
#if defined CK_USE_SR_F8_CONVERSION
return f8_convert_sr<bf8_t>(x);
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0;
return utils::
cast_to_f8<float, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
return f8_convert_rne<bf8_t>(x);
#endif
}
......@@ -248,17 +403,10 @@ inline __host__ __device__ float type_convert<float, bf8_t>(bf8_t x)
template <>
inline __host__ __device__ bf8_t type_convert<bf8_t, half_t>(half_t x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion
return type_convert<bf8_t>(type_convert<float>(x));
#if defined CK_USE_SR_F8_CONVERSION
return f8_convert_sr<bf8_t>(x);
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0;
return utils::
cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
return f8_convert_rne<bf8_t>(x);
#endif
}
......@@ -331,104 +479,4 @@ inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, half_t>(h
return bf16_convert_rtn<bhalf_t>(x_fp32);
}
// Declare a template function for fp8 conversion using SR
template <typename Y, typename X>
__host__ __device__ constexpr Y f8_convert_sr(X x);
// convert fp32 to fp8 with stochastic rounding
template <>
inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
{
constexpr int seed = 42;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
union
{
float fval;
uint32_t i32val;
uint8_t i8val[4]; // not endian independent
} val;
val.fval = x;
uint32_t ival = 0;
ival = __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0); // 0 pos
val.i32val = ival;
return val.i8val[0]; // little endian
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
return utils::
cast_to_f8<float, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x,
rng);
#endif
}
// convert fp16 to fp8 with stochastic rounding
template <>
inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion
return f8_convert_sr<f8_t>(type_convert<float>(x));
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 42;
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
return utils::
cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
#endif
}
// convert fp32 to bf8 with stochastic rounding
template <>
inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x)
{
constexpr int seed = 42;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
union
{
float fval;
uint32_t i32val;
uint8_t i8val[4]; // not endian independent
} val;
val.fval = x;
uint32_t ival = 0;
ival = __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); // 0 pos
val.i32val = ival;
return val.i8val[0]; // little endian
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
return utils::
cast_to_f8<float, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
#endif
}
// convert fp16 to bf8 with stochastic rounding
template <>
inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion
return f8_convert_sr<f8_t>(type_convert<float>(x));
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 42;
// as thread id is not available on host, use 0 for prn generation
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
return utils::
cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
#endif
}
} // namespace ck
......@@ -3,12 +3,23 @@
#pragma once
#include <iostream>
#include <cmath>
#include <cstdlib>
#include <numeric>
#include <type_traits>
#include <sstream>
#include <vector>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
namespace ck {
namespace tensor_operation {
......@@ -22,6 +33,7 @@ namespace host {
// Supports both GNCHW/NGCHW as well as GNHWC/NHWGC physical layout
// as long as dimensions in tensor descriptor is in GNCHW order
//
// @tparam NDimSpatial Number of spatial dimensions.
// @tparam InDataType Input tensor data type.
// @tparam WeiDataType Weights tensor data type.
// @tparam OutDataType Output tensor data type.
......@@ -29,7 +41,9 @@ namespace host {
// operation.
// @tparam WeiElementwiseOperation Functor for weights tensor elementwise
// operation.
// @tparam NDimSpatial Number of spatial dimensions.
// @tparam NumAElementwiseTensor Number of A elementwise tensors.
// @tparam NumBElementwiseTensor Number of B elementwise tensors.
// @tparam NumDElementwiseTensor Number of D elementwise tensors.
//
// input descriptor in [G, N, C, Do, Ho, Wo] order
// weight descriptor in [G, K, C, Z, Y, X] order
......@@ -42,28 +56,35 @@ template <ck::index_t NDimSpatial,
typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation,
ck::index_t NumDTensor = 0,
ck::index_t NumAElementwiseTensor = 0,
ck::index_t NumBElementwiseTensor = 0,
ck::index_t NumDElementwiseTensor = 0,
typename std::enable_if<NDimSpatial >= 1 && NDimSpatial <= 3, bool>::type = false>
struct ReferenceConvFwd : public device::BaseOperator
{
// Argument
struct Argument : public device::BaseArgument
{
Argument(const Tensor<InDataType>& input,
const Tensor<WeiDataType>& weight,
Tensor<OutDataType>& output,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op,
const std::array<Tensor<OutDataType>, NumDTensor>& d_tensors)
Argument(
const Tensor<InDataType>& input,
const Tensor<WeiDataType>& weight,
Tensor<OutDataType>& output,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op,
const std::array<Tensor<InDataType>, NumAElementwiseTensor>& elementwise_a_tensors,
const std::array<Tensor<WeiDataType>, NumBElementwiseTensor>& elementwise_b_tensors,
const std::array<Tensor<OutDataType>, NumDElementwiseTensor>& elementwise_d_tensors)
: input_{input},
weight_{weight},
output_{output},
d_tensors_{d_tensors},
elementwise_a_tensors_{elementwise_a_tensors},
elementwise_b_tensors_{elementwise_b_tensors},
elementwise_d_tensors_{elementwise_d_tensors},
conv_strides_{conv_filter_strides},
conv_dilations_{conv_filter_dilations},
in_left_pads_{input_left_pads},
......@@ -78,7 +99,9 @@ struct ReferenceConvFwd : public device::BaseOperator
const Tensor<WeiDataType>& weight_;
Tensor<OutDataType>& output_;
const std::array<Tensor<OutDataType>, NumDTensor>& d_tensors_;
const std::array<Tensor<InDataType>, NumAElementwiseTensor>& elementwise_a_tensors_;
const std::array<Tensor<WeiDataType>, NumBElementwiseTensor>& elementwise_b_tensors_;
const std::array<Tensor<OutDataType>, NumDElementwiseTensor>& elementwise_d_tensors_;
std::vector<index_t> conv_strides_;
std::vector<index_t> conv_dilations_;
......@@ -119,42 +142,43 @@ struct ReferenceConvFwd : public device::BaseOperator
if(wi >= 0 &&
ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[3])
{
float v_in;
float v_wei;
arg.in_element_op_(
v_in, ck::type_convert<float>(arg.input_(g, n, c, wi)));
arg.wei_element_op_(
v_wei, ck::type_convert<float>(arg.weight_(g, k, c, x)));
v_acc += v_in * v_wei;
InDataType v_in;
WeiDataType v_wei;
ExecuteElementwiseOp(arg.in_element_op_,
arg.elementwise_a_tensors_,
Number<NumAElementwiseTensor>{},
v_in,
arg.input_(g, n, c, wi),
g,
n,
c,
wi);
ExecuteElementwiseOp(arg.wei_element_op_,
arg.elementwise_b_tensors_,
Number<NumBElementwiseTensor>{},
v_wei,
arg.weight_(g, k, c, x),
g,
k,
c,
x);
v_acc +=
ck::type_convert<float>(v_in) * ck::type_convert<float>(v_wei);
}
}
}
OutDataType v_out;
OutDataType v_acc_converted = ck::type_convert<OutDataType>(v_acc);
if constexpr(NumDTensor == 0)
{
arg.out_element_op_(v_out, v_acc_converted);
}
else if constexpr(NumDTensor == 1)
{
arg.out_element_op_(v_out, v_acc_converted, arg.d_tensors_[0](g, n, k, wo));
}
else if constexpr(NumDTensor == 2)
{
arg.out_element_op_(v_out,
v_acc_converted,
arg.d_tensors_[0](g, n, k, wo),
arg.d_tensors_[1](g, n, k, wo));
}
else
{
throw std::runtime_error("Output ElementOp not supported in reference.");
}
arg.output_(g, n, k, wo) = v_out;
OutDataType& v_out = arg.output_(g, n, k, wo);
ExecuteElementwiseOp(arg.out_element_op_,
arg.elementwise_d_tensors_,
Number<NumDElementwiseTensor>{},
v_out,
v_acc_converted,
g,
n,
k,
wo);
};
make_ParallelTensorFunctor(func,
......@@ -191,44 +215,47 @@ struct ReferenceConvFwd : public device::BaseOperator
wi >= 0 &&
ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[4])
{
float v_in;
float v_wei;
arg.in_element_op_(
v_in, ck::type_convert<float>(arg.input_(g, n, c, hi, wi)));
arg.wei_element_op_(
v_wei, ck::type_convert<float>(arg.weight_(g, k, c, y, x)));
v_acc += v_in * v_wei;
InDataType v_in;
WeiDataType v_wei;
ExecuteElementwiseOp(arg.in_element_op_,
arg.elementwise_a_tensors_,
Number<NumAElementwiseTensor>{},
v_in,
arg.input_(g, n, c, hi, wi),
g,
n,
c,
hi,
wi);
ExecuteElementwiseOp(arg.wei_element_op_,
arg.elementwise_b_tensors_,
Number<NumBElementwiseTensor>{},
v_wei,
arg.weight_(g, k, c, y, x),
g,
k,
c,
y,
x);
v_acc += ck::type_convert<float>(v_in) *
ck::type_convert<float>(v_wei);
}
}
}
}
OutDataType v_out;
OutDataType v_acc_converted = ck::type_convert<OutDataType>(v_acc);
if constexpr(NumDTensor == 0)
{
arg.out_element_op_(v_out, v_acc_converted);
}
else if constexpr(NumDTensor == 1)
{
arg.out_element_op_(
v_out, v_acc_converted, arg.d_tensors_[0](g, n, k, ho, wo));
}
else if constexpr(NumDTensor == 2)
{
arg.out_element_op_(v_out,
v_acc_converted,
arg.d_tensors_[0](g, n, k, ho, wo),
arg.d_tensors_[1](g, n, k, ho, wo));
}
else
{
throw std::runtime_error("Output ElementOp not supported in reference.");
}
arg.output_(g, n, k, ho, wo) = v_out;
OutDataType& v_out = arg.output_(g, n, k, ho, wo);
ExecuteElementwiseOp(arg.out_element_op_,
arg.elementwise_d_tensors_,
Number<NumDElementwiseTensor>{},
v_out,
v_acc_converted,
g,
n,
k,
ho,
wo);
};
make_ParallelTensorFunctor(func,
......@@ -275,47 +302,51 @@ struct ReferenceConvFwd : public device::BaseOperator
ck::type_convert<std::size_t>(wi) <
arg.input_.GetLengths()[5])
{
float v_in;
float v_wei;
arg.in_element_op_(v_in,
ck::type_convert<float>(
arg.input_(g, n, c, di, hi, wi)));
arg.wei_element_op_(
v_wei,
ck::type_convert<float>(arg.weight_(g, k, c, z, y, x)));
v_acc += v_in * v_wei;
InDataType v_in;
WeiDataType v_wei;
ExecuteElementwiseOp(arg.in_element_op_,
arg.elementwise_a_tensors_,
Number<NumAElementwiseTensor>{},
v_in,
arg.input_(g, n, c, di, hi, wi),
g,
n,
c,
di,
hi,
wi);
ExecuteElementwiseOp(arg.wei_element_op_,
arg.elementwise_b_tensors_,
Number<NumBElementwiseTensor>{},
v_wei,
arg.weight_(g, k, c, z, y, x),
g,
k,
c,
z,
y,
x);
v_acc += ck::type_convert<float>(v_in) *
ck::type_convert<float>(v_wei);
}
}
}
}
}
OutDataType v_out;
OutDataType v_acc_converted = ck::type_convert<OutDataType>(v_acc);
if constexpr(NumDTensor == 0)
{
arg.out_element_op_(v_out, v_acc_converted);
}
else if constexpr(NumDTensor == 1)
{
arg.out_element_op_(
v_out, v_acc_converted, arg.d_tensors_[0](g, n, k, d_o, ho, wo));
}
else if constexpr(NumDTensor == 2)
{
arg.out_element_op_(v_out,
v_acc_converted,
arg.d_tensors_[0](g, n, k, d_o, ho, wo),
arg.d_tensors_[1](g, n, k, d_o, ho, wo));
}
else
{
throw std::runtime_error("Output ElementOp not supported in reference.");
}
arg.output_(g, n, k, d_o, ho, wo) = v_out;
OutDataType& v_out = arg.output_(g, n, k, d_o, ho, wo);
ExecuteElementwiseOp(arg.out_element_op_,
arg.elementwise_d_tensors_,
Number<NumDElementwiseTensor>{},
v_out,
v_acc_converted,
g,
n,
k,
d_o,
ho,
wo);
};
make_ParallelTensorFunctor(func,
......@@ -338,6 +369,36 @@ struct ReferenceConvFwd : public device::BaseOperator
}
};
template <typename... Args,
typename ElementwiseOp,
typename ElementwiseTensor,
typename NumTensor,
typename T>
static void ExecuteElementwiseOp(ElementwiseOp& elementwise_op,
ElementwiseTensor& elementwise_tensors,
NumTensor,
T& y,
const T& x,
Args... dims)
{
if constexpr(NumTensor::value == 0)
{
elementwise_op(y, x);
}
else if constexpr(NumTensor::value == 1)
{
elementwise_op(y, x, elementwise_tensors[0](dims...));
}
else if constexpr(NumTensor::value == 2)
{
elementwise_op(y, x, elementwise_tensors[0](dims...), elementwise_tensors[1](dims...));
}
else
{
throw std::runtime_error("ElementOp not supported in reference.");
}
}
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
......@@ -349,17 +410,20 @@ struct ReferenceConvFwd : public device::BaseOperator
return NDimSpatial >= 1 && NDimSpatial <= 3;
}
static auto MakeArgument(const Tensor<InDataType>& input,
const Tensor<WeiDataType>& weight,
Tensor<OutDataType>& output,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op,
const std::array<Tensor<OutDataType>, NumDTensor>& d_tensors = {})
static auto MakeArgument(
const Tensor<InDataType>& input,
const Tensor<WeiDataType>& weight,
Tensor<OutDataType>& output,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op,
const std::array<Tensor<InDataType>, NumAElementwiseTensor>& elementwise_a_tensors = {},
const std::array<Tensor<WeiDataType>, NumBElementwiseTensor>& elementwise_b_tensors = {},
const std::array<Tensor<OutDataType>, NumDElementwiseTensor>& elementwise_d_tensors = {})
{
return Argument{input,
weight,
......@@ -371,7 +435,9 @@ struct ReferenceConvFwd : public device::BaseOperator
in_element_op,
wei_element_op,
out_element_op,
d_tensors};
elementwise_a_tensors,
elementwise_b_tensors,
elementwise_d_tensors};
}
static auto MakeInvoker() { return Invoker{}; }
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include <vector>
#include <algorithm>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
template <typename DYDataType,
typename XDataType,
typename GammaDataType,
typename MeanInvStdDataType,
typename DGammaDataType,
typename DBetaDataType,
typename DXDataType,
typename ComputeDataType>
struct ReferenceGroupnormBwd : public device::BaseOperator
{
// Argument
struct Argument : public device::BaseArgument
{
Argument(const Tensor<DYDataType>& dy_nhwgc,
const Tensor<XDataType>& x_nhwgc,
const Tensor<GammaDataType>& gamma_gc,
const Tensor<MeanInvStdDataType>& mean_ng,
const Tensor<MeanInvStdDataType>& inv_std_ng,
Tensor<DGammaDataType>& dgamma_gc,
Tensor<DBetaDataType>& dbeta_gc,
Tensor<DXDataType>& dx_nhwgc,
const std::vector<index_t> lengths)
: dy_nhwgc_(dy_nhwgc),
x_nhwgc_(x_nhwgc),
gamma_gc_(gamma_gc),
mean_ng_(mean_ng),
inv_std_ng_(inv_std_ng),
dgamma_gc_(dgamma_gc),
dbeta_gc_(dbeta_gc),
dx_nhwgc_(dx_nhwgc),
lengths_(lengths)
{
}
const Tensor<DYDataType>& dy_nhwgc_;
const Tensor<XDataType>& x_nhwgc_;
const Tensor<GammaDataType>& gamma_gc_;
const Tensor<MeanInvStdDataType>& mean_ng_;
const Tensor<MeanInvStdDataType>& inv_std_ng_;
Tensor<DGammaDataType>& dgamma_gc_;
Tensor<DBetaDataType>& dbeta_gc_;
Tensor<DXDataType>& dx_nhwgc_;
std::vector<index_t> lengths_;
};
// Invoker
struct Invoker : public device::BaseInvoker
{
float Run(const Argument& arg)
{
int N = arg.lengths_[0];
int H = arg.lengths_[1];
int W = arg.lengths_[2];
int G = arg.lengths_[3];
int C = arg.lengths_[4];
// Calculate dgamma and dbeta
for(int g = 0; g < G; ++g)
for(int c = 0; c < C; ++c)
{
ComputeDataType dgamma = 0;
ComputeDataType dbeta = 0;
for(int n = 0; n < N; ++n)
for(int h = 0; h < H; ++h)
for(int w = 0; w < W; ++w)
{
ComputeDataType dy =
ck::type_convert<ComputeDataType>(arg.dy_nhwgc_(n, h, w, g, c));
ComputeDataType x =
ck::type_convert<ComputeDataType>(arg.x_nhwgc_(n, h, w, g, c));
ComputeDataType mean =
ck::type_convert<ComputeDataType>(arg.mean_ng_(n, g));
ComputeDataType rstd =
ck::type_convert<ComputeDataType>(arg.inv_std_ng_(n, g));
dgamma += dy * rstd * (x - mean);
dbeta += dy;
}
arg.dgamma_gc_(g, c) = ck::type_convert<DGammaDataType>(dgamma);
arg.dbeta_gc_(g, c) = ck::type_convert<DBetaDataType>(dbeta);
}
// Calculate dx
int reduce_size = H * W * C;
for(int n = 0; n < N; ++n)
for(int g = 0; g < G; ++g)
{
ComputeDataType ds = 0;
ComputeDataType db = 0;
ComputeDataType mean = ck::type_convert<ComputeDataType>(arg.mean_ng_(n, g));
ComputeDataType rstd = ck::type_convert<ComputeDataType>(arg.inv_std_ng_(n, g));
for(int h = 0; h < H; ++h)
for(int w = 0; w < W; ++w)
for(int c = 0; c < C; ++c)
{
ComputeDataType dy =
ck::type_convert<ComputeDataType>(arg.dy_nhwgc_(n, h, w, g, c));
ComputeDataType x =
ck::type_convert<ComputeDataType>(arg.x_nhwgc_(n, h, w, g, c));
ComputeDataType gamma =
ck::type_convert<ComputeDataType>(arg.gamma_gc_(g, c));
ds += dy * gamma * x;
db += dy * gamma;
}
for(int h = 0; h < H; ++h)
for(int w = 0; w < W; ++w)
for(int c = 0; c < C; ++c)
{
ComputeDataType dy =
ck::type_convert<ComputeDataType>(arg.dy_nhwgc_(n, h, w, g, c));
ComputeDataType x =
ck::type_convert<ComputeDataType>(arg.x_nhwgc_(n, h, w, g, c));
ComputeDataType gamma =
ck::type_convert<ComputeDataType>(arg.gamma_gc_(g, c));
ComputeDataType b =
(db * mean - ds) * rstd * rstd * rstd / reduce_size;
ComputeDataType c1 = -b * mean - db * rstd / reduce_size;
arg.dx_nhwgc_(n, h, w, g, c) =
ck::type_convert<DXDataType>(dy * gamma * rstd + b * x + c1);
}
}
return 0;
}
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
static auto MakeArgument(const Tensor<DYDataType>& dy_nhwgc,
const Tensor<XDataType>& x_nhwgc,
const Tensor<GammaDataType>& gamma_gc,
const Tensor<MeanInvStdDataType>& mean_ng,
const Tensor<MeanInvStdDataType>& inv_std_ng,
Tensor<DGammaDataType>& dgamma_gc,
Tensor<DBetaDataType>& dbeta_gc,
Tensor<DXDataType>& dx_nhwgc,
const std::vector<index_t> lengths)
{
return Argument{dy_nhwgc,
x_nhwgc,
gamma_gc,
mean_ng,
inv_std_ng,
dgamma_gc,
dbeta_gc,
dx_nhwgc,
lengths};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceGroupnormBwd"
<< std::endl;
// clang-format on
return str.str();
}
};
} // namespace host
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include <vector>
#include <algorithm>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
template <typename DYDataType,
typename XDataType,
typename GammaDataType,
typename MeanInvStdDataType,
typename DGammaDataType,
typename DBetaDataType,
typename DXDataType,
typename ComputeDataType>
struct ReferenceLayernormBwd : public device::BaseOperator
{
// Argument
struct Argument : public device::BaseArgument
{
Argument(const Tensor<DYDataType>& dy_m_n,
const Tensor<XDataType>& x_m_n,
const Tensor<GammaDataType>& gamma_n,
const Tensor<MeanInvStdDataType>& mean_m,
const Tensor<MeanInvStdDataType>& inv_std_m,
Tensor<DGammaDataType>& dgamma_n,
Tensor<DBetaDataType>& dbeta_n,
Tensor<DXDataType>& dx_m_n,
const std::vector<index_t> lengths)
: dy_m_n_(dy_m_n),
x_m_n_(x_m_n),
gamma_n_(gamma_n),
mean_m_(mean_m),
inv_std_m_(inv_std_m),
dgamma_n_(dgamma_n),
dbeta_n_(dbeta_n),
dx_m_n_(dx_m_n),
lengths_(lengths)
{
}
const Tensor<DYDataType>& dy_m_n_;
const Tensor<XDataType>& x_m_n_;
const Tensor<GammaDataType>& gamma_n_;
const Tensor<MeanInvStdDataType>& mean_m_;
const Tensor<MeanInvStdDataType>& inv_std_m_;
Tensor<DGammaDataType>& dgamma_n_;
Tensor<DBetaDataType>& dbeta_n_;
Tensor<DXDataType>& dx_m_n_;
std::vector<index_t> lengths_;
};
// Invoker
struct Invoker : public device::BaseInvoker
{
float Run(const Argument& arg)
{
int M = arg.lengths_[0];
int N = arg.lengths_[1];
// Calculate dgamma and dbeta
for(int n = 0; n < N; ++n)
{
ComputeDataType dgamma = 0;
ComputeDataType dbeta = 0;
for(int m = 0; m < M; ++m)
{
ComputeDataType dy = ck::type_convert<ComputeDataType>(arg.dy_m_n_(m, n));
ComputeDataType x = ck::type_convert<ComputeDataType>(arg.x_m_n_(m, n));
ComputeDataType mean = ck::type_convert<ComputeDataType>(arg.mean_m_(m));
ComputeDataType rstd = ck::type_convert<ComputeDataType>(arg.inv_std_m_(m));
dgamma += dy * rstd * (x - mean);
dbeta += dy;
}
arg.dgamma_n_(n) = ck::type_convert<DGammaDataType>(dgamma);
arg.dbeta_n_(n) = ck::type_convert<DBetaDataType>(dbeta);
}
// Calculate dx
for(int m = 0; m < M; ++m)
{
ComputeDataType ds = 0;
ComputeDataType db = 0;
ComputeDataType mean = ck::type_convert<ComputeDataType>(arg.mean_m_(m));
ComputeDataType rstd = ck::type_convert<ComputeDataType>(arg.inv_std_m_(m));
for(int n = 0; n < N; ++n)
{
ComputeDataType dy = ck::type_convert<ComputeDataType>(arg.dy_m_n_(m, n));
ComputeDataType x = ck::type_convert<ComputeDataType>(arg.x_m_n_(m, n));
ComputeDataType gamma = ck::type_convert<ComputeDataType>(arg.gamma_n_(n));
ds += dy * gamma * x;
db += dy * gamma;
}
for(int n = 0; n < N; ++n)
{
ComputeDataType dy = ck::type_convert<ComputeDataType>(arg.dy_m_n_(m, n));
ComputeDataType x = ck::type_convert<ComputeDataType>(arg.x_m_n_(m, n));
ComputeDataType gamma = ck::type_convert<ComputeDataType>(arg.gamma_n_(n));
ComputeDataType b = (db * mean - ds) * rstd * rstd * rstd / N;
ComputeDataType c = -b * mean - db * rstd / N;
arg.dx_m_n_(m, n) = ck::type_convert<DXDataType>(dy * gamma * rstd + b * x + c);
}
}
return 0;
}
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
static auto MakeArgument(const Tensor<DYDataType>& dy_m_n,
const Tensor<XDataType>& x_m_n,
const Tensor<GammaDataType>& gamma_n,
const Tensor<MeanInvStdDataType>& mean_m,
const Tensor<MeanInvStdDataType>& inv_std_m,
Tensor<DGammaDataType>& dgamma_n,
Tensor<DBetaDataType>& dbeta_n,
Tensor<DXDataType>& dx_m_n,
const std::vector<index_t> lengths)
{
return Argument{
dy_m_n, x_m_n, gamma_n, mean_m, inv_std_m, dgamma_n, dbeta_n, dx_m_n, lengths};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceLayernormBwd"
<< std::endl;
// clang-format on
return str.str();
}
};
} // namespace host
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment