Commit b89a88b5 authored by Adam Osewski's avatar Adam Osewski
Browse files

Merge branch 'develop' into wavelet_model

parents 41d5fca7 43c898f6
...@@ -10,4 +10,5 @@ struct StreamConfig ...@@ -10,4 +10,5 @@ struct StreamConfig
{ {
hipStream_t stream_id_ = nullptr; hipStream_t stream_id_ = nullptr;
bool time_kernel_ = false; bool time_kernel_ = false;
int log_level_ = 0;
}; };
...@@ -701,9 +701,7 @@ struct BlockwiseGemmXdlops_v2 ...@@ -701,9 +701,7 @@ struct BlockwiseGemmXdlops_v2
const auto waveId_m = wave_idx[I0]; const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1]; const auto waveId_n = wave_idx[I1];
const auto tmp = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i); const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i);
const auto blk_idx =
TransposeC ? make_multi_index(tmp[I1], tmp[I0]) : make_multi_index(tmp[I0], tmp[I1]);
constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor( constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))), make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))),
...@@ -922,13 +920,13 @@ struct BlockwiseGemmXdlops_v2 ...@@ -922,13 +920,13 @@ struct BlockwiseGemmXdlops_v2
} }
protected: protected:
// A[M0, M1, M2, KPerThread] // A[M0, M1, M2, KPack]
static constexpr auto a_thread_desc_ = static constexpr auto a_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number<KPerThread>{})); make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number<KPack>{}));
// B[N0, N1, N2, KPerThread] // B[N0, N1, N2, KPack]
static constexpr auto b_thread_desc_ = static constexpr auto b_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number<KPerThread>{})); make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number<KPack>{}));
// C[M, N, NumRegXdlops] // C[M, N, NumRegXdlops]
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
......
...@@ -16,7 +16,8 @@ template <index_t BlockSize, ...@@ -16,7 +16,8 @@ template <index_t BlockSize,
typename AccDataType, typename AccDataType,
typename ThreadMap_M_K, // thread_id to m_k typename ThreadMap_M_K, // thread_id to m_k
typename ThreadClusterDesc_M_K, typename ThreadClusterDesc_M_K,
typename ThreadSliceDesc_M_K> typename ThreadSliceDesc_M_K,
bool IgnoreNaN = false>
struct BlockwiseSoftmax struct BlockwiseSoftmax
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -27,11 +28,33 @@ struct BlockwiseSoftmax ...@@ -27,11 +28,33 @@ struct BlockwiseSoftmax
using ThreadSliceDesc_M = decltype( using ThreadSliceDesc_M = decltype(
make_naive_tensor_descriptor_packed(make_tuple(ThreadSliceDesc_M_K{}.GetLength(I0)))); make_naive_tensor_descriptor_packed(make_tuple(ThreadSliceDesc_M_K{}.GetLength(I0))));
using ThreadwiseMaxReduce = ThreadwiseReduction<AccDataType, using ThreadwiseMaxReduce = typename conditional<
ThreadSliceDesc_M_K, IgnoreNaN,
ThreadSliceDesc_M, ThreadwiseReduction<AccDataType,
reduce::Max, ThreadSliceDesc_M_K,
false>; ThreadSliceDesc_M,
reduce::Max,
false,
detail::AccumulateWithNanIgnore<reduce::Max, AccDataType>>,
ThreadwiseReduction<AccDataType,
ThreadSliceDesc_M_K,
ThreadSliceDesc_M,
reduce::Max,
false>>::type;
using ThreadwiseSumReduce = typename conditional<
IgnoreNaN,
ThreadwiseReduction<AccDataType,
ThreadSliceDesc_M_K,
ThreadSliceDesc_M,
reduce::Add,
false,
detail::AccumulateWithNanIgnore<reduce::Add, AccDataType>>,
ThreadwiseReduction<AccDataType,
ThreadSliceDesc_M_K,
ThreadSliceDesc_M,
reduce::Add,
false>>::type;
using ThreadClusterLengths_M_K = decltype(ThreadClusterDesc_M_K{}.GetLengths()); using ThreadClusterLengths_M_K = decltype(ThreadClusterDesc_M_K{}.GetLengths());
...@@ -49,12 +72,6 @@ struct BlockwiseSoftmax ...@@ -49,12 +72,6 @@ struct BlockwiseSoftmax
reduce::Add, reduce::Add,
false>; false>;
using ThreadwiseSumReduce = ThreadwiseReduction<AccDataType,
ThreadSliceDesc_M_K,
ThreadSliceDesc_M,
reduce::Add,
false>;
using BufferType = StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MRepeat, true>; using BufferType = StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MRepeat, true>;
template <typename CThreadBuffer, typename WorkspaceBuffer> template <typename CThreadBuffer, typename WorkspaceBuffer>
...@@ -74,7 +91,9 @@ struct BlockwiseSoftmax ...@@ -74,7 +91,9 @@ struct BlockwiseSoftmax
static_for<0, MRepeat, 1>{}([&](auto iM) { static_for<0, MRepeat, 1>{}([&](auto iM) {
static_for<0, KRepeat, 1>{}([&](auto iK) { static_for<0, KRepeat, 1>{}([&](auto iK) {
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{}; auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
in_thread_buf(offset) = math::exp(in_thread_buf[offset] - max_value_buf(iM)); in_thread_buf(offset) = IgnoreNaN && ck::math::isnan(in_thread_buf[offset])
? 0
: math::exp(in_thread_buf[offset] - max_value_buf(iM));
}); });
}); });
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include <vector>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_5ary_Elementwise_1d.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename ADataType,
typename BDataType,
typename CDataType,
typename DDataType,
typename EDataType,
typename FDataType,
typename ComputeDataType,
typename ElementwiseFunctor,
index_t NDim,
index_t MPerThread,
index_t AScalarPerVector,
index_t BScalarPerVector,
index_t CScalarPerVector,
index_t DScalarPerVector,
index_t EScalarPerVector,
index_t FScalarPerVector>
struct Device5AryElementwise : public DeviceElementwise<5, 1, NDim, ElementwiseFunctor>
{
static constexpr auto I0 = Number<0>{};
template <typename Desc_M>
static auto PadDescriptor_M_1d(Desc_M desc_m, index_t gridSize, index_t blockSize)
{
const auto m = desc_m.GetLength(I0);
const index_t loop_step = gridSize * blockSize * MPerThread;
const auto pad = math::integer_least_multiple(m, loop_step) - m;
const auto desc_m_pad =
transform_tensor_descriptor(desc_m,
make_tuple(make_right_pad_transform(m, pad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
return desc_m_pad;
}
static auto MakeDescriptor_M(const std::vector<index_t>& lengths,
const std::vector<index_t>& stride,
index_t gridSize,
index_t blockSize)
{
auto tupleOfShape = generate_tuple([&](auto I) { return lengths[I]; }, Number<NDim>{});
auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number<NDim>{});
// nd desc - [s0, s1, s2, ...]
const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride);
// merge nd to 1d desc - [s0 * s1 * ...]
if constexpr(NDim > 1)
{
const auto desc_m = transform_tensor_descriptor(
desc,
make_tuple(make_merge_transform(tupleOfShape)),
make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number<NDim>{})),
make_tuple(Sequence<0>{}));
return PadDescriptor_M_1d(desc_m, gridSize, blockSize);
}
else
return PadDescriptor_M_1d(desc, gridSize, blockSize);
}
using AGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1));
using BGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1));
using CGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1));
using DGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1));
using EGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1));
using FGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1));
using Gridwise5AryEltwise = Gridwise5AryElementwise_1D<ADataType,
BDataType,
CDataType,
DDataType,
EDataType,
FDataType,
ComputeDataType,
AGridDesc_M,
BGridDesc_M,
CGridDesc_M,
DGridDesc_M,
EGridDesc_M,
FGridDesc_M,
ElementwiseFunctor,
MPerThread,
AScalarPerVector,
BScalarPerVector,
CScalarPerVector,
DScalarPerVector,
EScalarPerVector,
FScalarPerVector>;
struct Argument : public BaseArgument
{
Argument(const ADataType* p_a,
const BDataType* p_b,
const CDataType* p_c,
const DDataType* p_d,
const EDataType* p_e,
FDataType* p_f,
const std::vector<index_t>& lengths,
const std::vector<index_t>& a_strides,
const std::vector<index_t>& b_strides,
const std::vector<index_t>& c_strides,
const std::vector<index_t>& d_strides,
const std::vector<index_t>& e_strides,
const std::vector<index_t>& f_strides,
ElementwiseFunctor functor)
: p_a_(p_a),
p_b_(p_b),
p_c_(p_c),
p_d_(p_d),
p_e_(p_e),
p_f_(p_f),
lengths_(lengths),
a_strides_(a_strides),
b_strides_(b_strides),
c_strides_(c_strides),
d_strides_(d_strides),
e_strides_(e_strides),
f_strides_(f_strides),
functor_(functor),
blockSize_(256),
gridSize_(120) // FIXME - Calculate the grid size by number of CU in the future
{
a_grid_desc_m_ = MakeDescriptor_M(lengths, a_strides, gridSize_, blockSize_);
b_grid_desc_m_ = MakeDescriptor_M(lengths, b_strides, gridSize_, blockSize_);
c_grid_desc_m_ = MakeDescriptor_M(lengths, c_strides, gridSize_, blockSize_);
d_grid_desc_m_ = MakeDescriptor_M(lengths, d_strides, gridSize_, blockSize_);
e_grid_desc_m_ = MakeDescriptor_M(lengths, e_strides, gridSize_, blockSize_);
f_grid_desc_m_ = MakeDescriptor_M(lengths, f_strides, gridSize_, blockSize_);
}
const ADataType* p_a_;
const BDataType* p_b_;
const CDataType* p_c_;
const DDataType* p_d_;
const EDataType* p_e_;
FDataType* p_f_;
std::vector<index_t> lengths_;
AGridDesc_M a_grid_desc_m_;
BGridDesc_M b_grid_desc_m_;
CGridDesc_M c_grid_desc_m_;
DGridDesc_M d_grid_desc_m_;
EGridDesc_M e_grid_desc_m_;
FGridDesc_M f_grid_desc_m_;
std::vector<index_t> a_strides_;
std::vector<index_t> b_strides_;
std::vector<index_t> c_strides_;
std::vector<index_t> d_strides_;
std::vector<index_t> e_strides_;
std::vector<index_t> f_strides_;
ElementwiseFunctor functor_;
index_t blockSize_;
index_t gridSize_;
};
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const auto kernel = kernel_5ary_elementwise_1d<Gridwise5AryEltwise,
ADataType,
BDataType,
CDataType,
DDataType,
EDataType,
FDataType,
AGridDesc_M,
BGridDesc_M,
CGridDesc_M,
DGridDesc_M,
EGridDesc_M,
FGridDesc_M,
ElementwiseFunctor>;
float elapsed_time = launch_and_time_kernel(stream_config,
kernel,
dim3(arg.gridSize_),
dim3(arg.blockSize_),
0,
arg.p_a_,
arg.p_b_,
arg.p_c_,
arg.p_d_,
arg.p_e_,
arg.p_f_,
arg.a_grid_desc_m_,
arg.b_grid_desc_m_,
arg.c_grid_desc_m_,
arg.d_grid_desc_m_,
arg.e_grid_desc_m_,
arg.f_grid_desc_m_,
arg.functor_);
return elapsed_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
bool IsSupportedArgument(const BaseArgument& p_arg) { return IsSupportedArgument(&p_arg); }
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
if(pArg == nullptr)
return false;
if(pArg->lengths_.size() != NDim)
return false;
if(pArg->lengths_.back() % MPerThread != 0)
return false;
auto IsScalarPerVectorValid = [](bool isLastDimensionCoalesced, int scalarPerVector) {
bool ret = true;
if(!isLastDimensionCoalesced)
ret = scalarPerVector == 1;
else
ret = MPerThread % scalarPerVector == 0;
return ret;
};
if(!IsScalarPerVectorValid(pArg->a_strides_.back() == 1, AScalarPerVector))
return false;
if(!IsScalarPerVectorValid(pArg->b_strides_.back() == 1, BScalarPerVector))
return false;
if(!IsScalarPerVectorValid(pArg->c_strides_.back() == 1, CScalarPerVector))
return false;
if(!IsScalarPerVectorValid(pArg->d_strides_.back() == 1, DScalarPerVector))
return false;
if(!IsScalarPerVectorValid(pArg->e_strides_.back() == 1, EScalarPerVector))
return false;
if(!IsScalarPerVectorValid(pArg->f_strides_.back() == 1, FScalarPerVector))
return false;
return true;
};
static auto MakeArgument(std::array<const void*, 5> p_inputs,
std::array<void*, 1> p_outputs,
std::vector<index_t> lengths,
std::vector<index_t> a_strides,
std::vector<index_t> b_strides,
std::vector<index_t> c_strides,
std::vector<index_t> d_strides,
std::vector<index_t> e_strides,
std::vector<index_t> f_strides,
ElementwiseFunctor functor)
{
return Argument{static_cast<const ADataType*>(p_inputs[0]),
static_cast<const BDataType*>(p_inputs[1]),
static_cast<const CDataType*>(p_inputs[2]),
static_cast<const DDataType*>(p_inputs[3]),
static_cast<const EDataType*>(p_inputs[4]),
static_cast<FDataType*>(p_outputs[0]),
lengths,
a_strides,
b_strides,
c_strides,
d_strides,
e_strides,
f_strides,
functor};
}
std::unique_ptr<BaseArgument>
MakeArgumentPointer(std::array<const void*, 5> p_inputs,
std::array<void*, 1> p_outputs,
std::vector<index_t> lengths,
std::vector<std::vector<index_t>> input_strides,
std::vector<std::vector<index_t>> output_strides,
ElementwiseFunctor functor) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_inputs[0]),
static_cast<const BDataType*>(p_inputs[1]),
static_cast<const CDataType*>(p_inputs[2]),
static_cast<const DDataType*>(p_inputs[3]),
static_cast<const EDataType*>(p_inputs[4]),
static_cast<FDataType*>(p_outputs[0]),
lengths,
input_strides[0],
input_strides[1],
input_strides[2],
input_strides[3],
input_strides[4],
output_strides[0],
functor);
}
static auto MakeInvoker() { return Invoker{}; }
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>();
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "Device5aryElementwise"
<< "<"
<< "NDim = " << NDim
<< "MPerThread = " << MPerThread
<< "AScalarPerVector = " << AScalarPerVector
<< "BScalarPerVector = " << BScalarPerVector
<< "CScalarPerVector = " << CScalarPerVector
<< "DScalarPerVector = " << DScalarPerVector
<< "EScalarPerVector = " << EScalarPerVector
<< "FScalarPerVector = " << FScalarPerVector
<< ">";
// clang-format on
return str.str();
}
}; // namespace device
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -129,6 +129,25 @@ namespace device { ...@@ -129,6 +129,25 @@ namespace device {
// B[G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...] // B[G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...]
// D[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...] // D[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...]
// E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...] // E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...]
// FIXME: TensorSpecialization::Packed specialization does not cover all packed tensor cases, it
// merely degenerates into TensorSpecialization::Default with NumDimG/M/N/K = 1
//
// Detail- Packed tensor satisfies
// stride_0 = 1
// stride_i = stride_{i - 1} * extent_{i - 1}
// So tensor
// [G0, G1, G2, M, N]
// transposed into tensor
// [G0, G2, G1, M, N]
// with strides
// [G2 * G1 * M * N, G1 * M * N, M * N, N, 1]
// is again a packed tensor. MakeGridDescriptor() currently just merges dimensions and ignores some
// strides from input tensor extents so finer dimension information is lost. Merging dimensions is
// essentially a degenerated case of TensorSpecialization::Default with NumDimG/M/N/K = 1.
//
// Might need to expose dimension order to the interface to fully support
// TensorSpecialization::Packed.
template <index_t NumDimG, template <index_t NumDimG,
index_t NumDimM, index_t NumDimM,
index_t NumDimN, index_t NumDimN,
......
...@@ -54,33 +54,6 @@ struct DeviceBatchedGemmGemm : public BaseOperator ...@@ -54,33 +54,6 @@ struct DeviceBatchedGemmGemm : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
template <typename ALayout,
typename B0Layout,
typename B1Layout,
typename CLayout,
typename ADataType,
typename B0DataType,
typename B1DataType,
typename CDataType,
typename AElementwiseOperation,
typename B0ElementwiseOperation,
typename Acc0ElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation>
using DeviceBatchedGemmGemmPtr = std::unique_ptr<DeviceBatchedGemmGemm<ALayout,
B0Layout,
B1Layout,
CLayout,
ADataType,
B0DataType,
B1DataType,
CDataType,
AElementwiseOperation,
B0ElementwiseOperation,
Acc0ElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation>>;
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -12,9 +12,11 @@ ...@@ -12,9 +12,11 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/io.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -188,6 +190,10 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout ...@@ -188,6 +190,10 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto matrix_padder =
GemmGemmPadder<GemmSpec, index_t, index_t, index_t, index_t>{
MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock};
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA) static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
{ {
const auto a_grid_desc_mraw_kraw = [&]() { const auto a_grid_desc_mraw_kraw = [&]() {
...@@ -203,92 +209,18 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout ...@@ -203,92 +209,18 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
} }
}(); }();
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; const auto a_grid_desc_m_k = matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
const auto MPad = M - MRaw;
const auto KPad = K - KRaw;
if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad both M and K
assert(K % AK1 == 0);
const auto AK0 = K / AK1;
const auto a_grid_desc_m_k =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_right_pad_transform(MRaw, MPad),
make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MNPadding)
{
// pad M, but not K
assert(KRaw % AK1 == 0);
const auto AK0 = KRaw / AK1;
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_right_pad_transform(MRaw, MPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
// pad K, but not M
assert(K % AK1 == 0);
const auto AK0 = K / AK1;
const auto a_grid_desc_m_k = transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto a_grid_desc_ak0_m_ak1 = const auto M = a_grid_desc_m_k.GetLength(I0);
transform_tensor_descriptor(a_grid_desc_m_k, const auto K = a_grid_desc_m_k.GetLength(I1);
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(MRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1; const auto AK0 = K / AK1;
}
else
{
// not pad M or K
assert(KRaw % AK1 == 0);
const auto AK0 = KRaw / AK1;
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(MRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1; return transform_tensor_descriptor(a_grid_desc_m_k,
} make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
} }
static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB) static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
...@@ -306,84 +238,18 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout ...@@ -306,84 +238,18 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
} }
}(); }();
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; const auto b_grid_desc_n_k = matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
const auto NPad = N - NRaw;
const auto KPad = K - KRaw;
if constexpr(GemmSpec == GemmSpecialization::NKPadding || const auto N = b_grid_desc_n_k.GetLength(I0);
GemmSpec == GemmSpecialization::MNKPadding) const auto K = b_grid_desc_n_k.GetLength(I1);
{
// pad both N and K
const auto BK0 = K / BK1;
const auto b_grid_desc_n_k =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_right_pad_transform(NRaw, NPad),
make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::MNPadding)
{
// pad N, but not K
const auto BK0 = KRaw / BK1;
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
// pad K, but not N
const auto BK0 = K / BK1;
const auto b_grid_desc_n_k = transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(NRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else
{
// not pad N or K
const auto BK0 = KRaw / BK1;
const auto b_grid_desc_bk0_n_bk1 = const auto BK0 = K / BK1;
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(NRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1; return transform_tensor_descriptor(b_grid_desc_n_k,
} make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
} }
// Args: Gemm1KRaw, Gemm1NRaw, StrideB1 // Args: Gemm1KRaw, Gemm1NRaw, StrideB1
...@@ -402,47 +268,19 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout ...@@ -402,47 +268,19 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
} }
}(); }();
const auto N = math::integer_divide_ceil(NRaw, Gemm1NPerBlock) * Gemm1NPerBlock; const auto b1_grid_desc_n_k = matrix_padder.PadB1Descriptor_N_K(b1_grid_desc_nraw_kraw);
const auto K = math::integer_divide_ceil(KRaw, Gemm1KPerBlock) * Gemm1KPerBlock;
const auto NPad = N - NRaw; const auto N = b1_grid_desc_n_k.GetLength(I0);
const auto KPad = K - KRaw; const auto K = b1_grid_desc_n_k.GetLength(I1);
// TODO: implement finer-grained padding const auto B1K0 = K / B1K1;
if constexpr(GemmSpec == GemmSpecialization::Default)
{
const auto B1K0 = KRaw / B1K1;
const auto b1_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( return transform_tensor_descriptor(
b1_grid_desc_nraw_kraw, b1_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)), make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)),
make_pass_through_transform(NRaw)), make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b1_grid_desc_bk0_n_bk1;
}
else
{
// pad both B1N and B1K
const auto B1K0 = K / B1K1;
const auto b1_grid_desc_n_k =
transform_tensor_descriptor(b1_grid_desc_nraw_kraw,
make_tuple(make_right_pad_transform(NRaw, NPad),
make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto b1_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b1_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b1_grid_desc_bk0_n_bk1;
}
} }
static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC) static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC)
...@@ -460,47 +298,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout ...@@ -460,47 +298,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
} }
}(); }();
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw);
const auto N = math::integer_divide_ceil(NRaw, Gemm1NPerBlock) * Gemm1NPerBlock;
const auto MPad = M - MRaw;
const auto NPad = N - NRaw;
if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad M and N
return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(MRaw, MPad),
make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
// pad M, but not N
return transform_tensor_descriptor(
c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
// pad N, but not M
return transform_tensor_descriptor(
c_grid_desc_mraw_nraw,
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
// not pad M or N
return c_grid_desc_mraw_nraw;
}
} }
struct ComputeBasePtrOfStridedBatch struct ComputeBasePtrOfStridedBatch
...@@ -651,7 +449,8 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout ...@@ -651,7 +449,8 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
b1_element_op_{b1_element_op}, b1_element_op_{b1_element_op},
c_element_op_{c_element_op}, c_element_op_{c_element_op},
batch_count_(Batch), batch_count_(Batch),
compute_base_ptr_of_batch_{BatchStrideA, BatchStrideB, BatchStrideB1, BatchStrideC} compute_base_ptr_of_batch_{BatchStrideA, BatchStrideB, BatchStrideB1, BatchStrideC},
raw_lengths_m_n_k_o_{MRaw, NRaw, KRaw, Gemm1NRaw}
{ {
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
b_grid_desc_bk0_n_bk1_, b_grid_desc_bk0_n_bk1_,
...@@ -665,6 +464,14 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout ...@@ -665,6 +464,14 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
} }
} }
void Print() const
{
std::cout << "A[AK0, M, AK1]: " << a_grid_desc_ak0_m_ak1_ << std::endl;
std::cout << "B0[BK0, N, BK1]: " << b_grid_desc_bk0_n_bk1_ << std::endl;
std::cout << "B1[BK0, N, BK1]: " << b1_grid_desc_bk0_n_bk1_ << std::endl;
std::cout << "C[M, N]: " << c_grid_desc_m_n_ << std::endl;
}
// private: // private:
const ADataType* p_a_grid_; const ADataType* p_a_grid_;
const BDataType* p_b_grid_; const BDataType* p_b_grid_;
...@@ -684,6 +491,9 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout ...@@ -684,6 +491,9 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
CElementwiseOperation c_element_op_; CElementwiseOperation c_element_op_;
index_t batch_count_; index_t batch_count_;
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_; ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
// For robust IsSupportedArgument() check
std::vector<index_t> raw_lengths_m_n_k_o_;
}; };
// Invoker // Invoker
...@@ -787,6 +597,31 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout ...@@ -787,6 +597,31 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
return false; return false;
} }
// Note: we need raw lengths since threadwise copy can not handle vector load when part of
// vector is out of bounds
const auto MRaw = arg.raw_lengths_m_n_k_o_[0];
const auto NRaw = arg.raw_lengths_m_n_k_o_[1];
const auto KRaw = arg.raw_lengths_m_n_k_o_[2];
const auto Gemm1NRaw = arg.raw_lengths_m_n_k_o_[3];
// Check scalar per vector requirement
const auto a_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, ALayout> ? KRaw : MRaw;
const auto b_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, BLayout> ? NRaw : KRaw;
const auto b1_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, B1Layout> ? Gemm1NRaw : NRaw;
const auto c_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, CLayout> ? Gemm1NRaw : MRaw;
if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
b_extent_lowest % BBlockTransferSrcScalarPerVector == 0 &&
b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
{
return false;
}
return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.b1_grid_desc_bk0_n_bk1_, arg.b1_grid_desc_bk0_n_bk1_,
...@@ -903,7 +738,8 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout ...@@ -903,7 +738,8 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
<< MPerBlock << ", " << MPerBlock << ", "
<< Gemm1NPerBlock << ", " << Gemm1NPerBlock << ", "
<< Gemm1KPerBlock << ", " << Gemm1KPerBlock << ", "
<< B1K1 << ">"; << B1K1 << ", "
<< getGemmSpecializationString(GemmSpec) << ">";
// clang-format on // clang-format on
return str.str(); return str.str();
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/io.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename A0Layout,
typename B0Layout,
typename D0sLayout,
typename B1Layout,
typename D1sLayout,
typename E1Layout,
typename A0DataType,
typename B0DataType,
typename D0sDataType,
typename B1DataType,
typename D1sDataType,
typename E1DataType,
typename A0ElementwiseOperation,
typename B0ElementwiseOperation,
typename CDE0ElementwiseOperation,
typename B1ElementwiseOperation,
typename CDE1ElementwiseOperation>
struct DeviceBatchedGemmMultipleDGemmMultipleD : public BaseOperator
{
static constexpr index_t NumD0Tensor = D0sDataType::Size();
static constexpr index_t NumD1Tensor = D1sDataType::Size();
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a0,
const void* p_b0,
std::array<const void*, NumD0Tensor> p_d0s,
const void* p_b1,
std::array<const void*, NumD1Tensor> p_d1s,
void* p_e1,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t O,
ck::index_t Batch,
ck::index_t StrideA0,
ck::index_t StrideB0,
std::array<ck::index_t, NumD0Tensor> StrideD0s,
ck::index_t StrideB1,
std::array<ck::index_t, NumD1Tensor> StrideD1s,
ck::index_t StrideE1,
ck::index_t BatchStrideA0,
ck::index_t BatchStrideB0,
std::array<ck::index_t, NumD0Tensor> BatchStrideD0s,
ck::index_t BatchStrideB1,
std::array<ck::index_t, NumD1Tensor> BatchStrideD1s,
ck::index_t BatchStrideE1,
A0ElementwiseOperation a0_element_op,
B0ElementwiseOperation b0_element_op,
CDE0ElementwiseOperation cde0_element_op,
B1ElementwiseOperation b1_element_op,
CDE1ElementwiseOperation cde1_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename GridwiseGemm,
typename A0B0B1DataType,
typename D0sPointer,
typename D1sPointer,
typename E1DataType,
typename A0ElementwiseOperation,
typename B0ElementwiseOperation,
typename CDE0ElementwiseOperation,
typename B1ElementwiseOperation,
typename CDE1ElementwiseOperation,
typename A0GridDesc_AK0_M_AK1,
typename B0GridDesc_BK0_N_BK1,
typename D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5,
typename B1GridDesc_BK0_N_BK1,
typename D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename Block2E1TileMap,
typename ComputeBasePtrOfStridedBatch,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_batched_gemm_gemm_xdl_cshuffle_v1(
const A0B0B1DataType* __restrict__ p_a0_grid,
const A0B0B1DataType* __restrict__ p_b0_grid,
D0sPointer p_d0s_grid,
const A0B0B1DataType* __restrict__ p_b1_grid,
D1sPointer p_d1s_grid,
E1DataType* __restrict__ p_e1_grid,
const A0ElementwiseOperation a0_element_op,
const B0ElementwiseOperation b0_element_op,
const CDE0ElementwiseOperation cde0_element_op,
const B1ElementwiseOperation b1_element_op,
const CDE1ElementwiseOperation cde1_element_op,
const A0GridDesc_AK0_M_AK1 a0_grid_desc_ak0_m_ak1,
const B0GridDesc_BK0_N_BK1 b0_grid_desc_bk0_n_bk1,
const D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
const D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
d1s_grid_desc_mblock_mperblock_nblock_nperblock,
const E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e1_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2E1TileMap block_2_e1tile_map,
const index_t batch_count,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetABasePtr(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetBBasePtr(g_idx)));
const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetB1BasePtr(g_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetCBasePtr(g_idx)));
static_for<0, p_d0s_grid.Size(), 1>{}([&](auto In) {
const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetD0BasePtr(g_idx, In)));
p_d0s_grid(In) = p_d0s_grid(In) + d0_batch_offset;
});
static_for<0, p_d1s_grid.Size(), 1>{}([&](auto In) {
const long_index_t d1_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetD1BasePtr(g_idx, In)));
p_d1s_grid(In) = p_d1s_grid(In) + d1_batch_offset;
});
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a0_grid + a_batch_offset,
p_b0_grid + b_batch_offset,
p_d0s_grid,
p_b1_grid + b1_batch_offset,
p_d1s_grid,
p_e1_grid + c_batch_offset,
p_shared,
a0_element_op,
b0_element_op,
cde0_element_op,
b1_element_op,
cde1_element_op,
a0_grid_desc_ak0_m_ak1,
b0_grid_desc_bk0_n_bk1,
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
b1_grid_desc_bk0_n_bk1,
d1s_grid_desc_mblock_mperblock_nblock_nperblock,
e1_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_e1tile_map);
#else
ignore = p_a0_grid;
ignore = p_b0_grid;
ignore = p_d0s_grid;
ignore = p_b1_grid;
ignore = p_d1s_grid;
ignore = p_e1_grid;
ignore = a0_element_op;
ignore = b0_element_op;
ignore = cde0_element_op;
ignore = b1_element_op;
ignore = cde1_element_op;
ignore = a0_grid_desc_ak0_m_ak1;
ignore = b0_grid_desc_bk0_n_bk1;
ignore = d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5;
ignore = b1_grid_desc_bk0_n_bk1;
ignore = d1s_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = e1_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = block_2_e1tile_map;
ignore = batch_count;
ignore = compute_base_ptr_of_batch;
#endif
}
// Computes C = A * B0 * B1
// ^^^^^^ (Acc0)
// ^^^^^^^^^^^ (Acc1)
template <typename A0Layout,
typename B0Layout, // B0Layout
typename D0sLayout,
typename B1Layout,
typename D1sLayout,
typename E1Layout,
typename A0DataType,
typename B0DataType,
typename Acc0DataType,
typename D0sDataType,
typename B1DataType,
typename Acc1DataType,
typename C1ShuffleDataType,
typename D1sDataType,
typename E1DataType,
typename A0ElementwiseOperation,
typename B0ElementwiseOperation,
typename CDE0ElementwiseOperation,
typename B1ElementwiseOperation,
typename CDE1ElementwiseOperation,
bool PadGemm0M,
bool PadGemm0N,
bool PadGemm0K,
bool PadGemm1N,
bool PadGemm1K,
index_t NumGemm0KPrefetchStage,
index_t BlockSize,
index_t Gemm0MPerBlock,
index_t Gemm0NPerBlock,
index_t Gemm0KPerBlock,
index_t Gemm1NPerBlock,
index_t Gemm1KPerBlock,
index_t A0K1,
index_t B0K1,
index_t B1K1,
index_t Gemm0MPerXdl,
index_t Gemm0NPerXdl,
index_t Gemm0MXdlPerWave,
index_t Gemm0NXdlPerWave,
index_t Gemm1NXdlPerWave,
typename A0BlockTransferThreadClusterLengths_AK0_M_AK1,
typename A0BlockTransferThreadClusterArrangeOrder,
typename A0BlockTransferSrcAccessOrder,
index_t A0BlockTransferSrcVectorDim,
index_t A0BlockTransferSrcScalarPerVector,
index_t A0BlockTransferDstScalarPerVector_AK1,
bool A0BlockLdsExtraM,
typename B0BlockTransferThreadClusterLengths_BK0_N_BK1,
typename B0BlockTransferThreadClusterArrangeOrder,
typename B0BlockTransferSrcAccessOrder,
index_t B0BlockTransferSrcVectorDim,
index_t B0BlockTransferSrcScalarPerVector,
index_t B0BlockTransferDstScalarPerVector_BK1,
bool B0BlockLdsExtraN,
typename B1BlockTransferThreadClusterLengths_BK0_N_BK1,
typename B1BlockTransferThreadClusterArrangeOrder,
typename B1BlockTransferSrcAccessOrder,
index_t B1BlockTransferSrcVectorDim,
index_t B1BlockTransferSrcScalarPerVector,
index_t B1BlockTransferDstScalarPerVector_BK1,
bool B1BlockLdsExtraN,
index_t C1ShuffleMXdlPerWavePerShuffle,
index_t C1ShuffleGemm0NXdlPerWavePerShuffle,
typename CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDE1ShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
: public DeviceBatchedGemmMultipleDGemmMultipleD<A0Layout,
B0Layout,
D0sLayout,
B1Layout,
D1sLayout,
E1Layout,
A0DataType,
B0DataType,
D0sDataType,
B1DataType,
D1sDataType,
E1DataType,
A0ElementwiseOperation,
B0ElementwiseOperation,
CDE0ElementwiseOperation,
B1ElementwiseOperation,
CDE1ElementwiseOperation>
{
using DeviceOp = DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle;
static constexpr index_t NumD0Tensor = D0sDataType::Size();
static constexpr index_t NumD1Tensor = D1sDataType::Size();
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
static constexpr auto I8 = Number<8>{};
static constexpr auto I9 = Number<9>{};
static constexpr auto gemm0_padder =
GemmPadder_v2<PadGemm0M, PadGemm0N, PadGemm0K, index_t, index_t, index_t>{
Gemm0MPerBlock, Gemm0NPerBlock, Gemm0KPerBlock};
static constexpr auto gemm1_padder =
GemmPadder_v2<PadGemm0M, PadGemm1N, PadGemm1K, index_t, index_t, index_t>{
Gemm0MPerBlock, Gemm1NPerBlock, Gemm1KPerBlock};
// for Gemm0
static auto MakeA0GridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA0)
{
const auto a0_grid_desc_mraw_kraw = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, A0Layout>)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(StrideA0, I1));
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, A0Layout>)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(I1, StrideA0));
}
}();
return gemm0_padder.PadADescriptor_M_K(a0_grid_desc_mraw_kraw);
}
// for Gemm0
static auto MakeB0GridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
{
const auto b0_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, B0Layout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(I1, StrideB));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, B0Layout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(StrideB, I1));
}
}();
return gemm0_padder.PadBDescriptor_N_K(b0_grid_desc_nraw_kraw);
}
// for Gemm0
template <typename DLay>
static auto MakeD0GridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideD0)
{
const auto d0_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, DLay>::value)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(StrideD0, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, DLay>::value)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(I1, StrideD0));
}
}();
return gemm0_padder.PadCDescriptor_M_N(d0_grid_desc_mraw_nraw);
}
// for Gemm1
static auto MakeB1GridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
{
const auto b1_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, B1Layout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(I1, StrideB));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, B1Layout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(StrideB, I1));
}
}();
return gemm1_padder.PadBDescriptor_N_K(b1_grid_desc_nraw_kraw);
}
// for Gemm1
template <typename ELay>
static auto MakeE1GridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE1)
{
const auto e1_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ELay>::value)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(StrideE1, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ELay>::value)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(I1, StrideE1));
}
}();
return gemm1_padder.PadCDescriptor_M_N(e1_grid_desc_mraw_nraw);
}
static auto MakeD0sGridDescriptor_M_N(const std::array<index_t, NumD1Tensor>& MRaws,
const std::array<index_t, NumD1Tensor>& NRaws,
const std::array<index_t, NumD1Tensor>& DsStride)
{
return generate_tuple(
[&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, D0sLayout>>;
return DeviceOp::MakeD0GridDescriptor_M_N<DLayout>(MRaws[i], NRaws[i], DsStride[i]);
},
Number<NumD0Tensor>{});
}
static auto MakeD1sGridDescriptor_M_N(const std::array<index_t, NumD1Tensor>& MRaws,
const std::array<index_t, NumD1Tensor>& NRaws,
const std::array<index_t, NumD1Tensor>& DsStride)
{
return generate_tuple(
[&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, D1sLayout>>;
return DeviceOp::MakeE1GridDescriptor_M_N<DLayout>(MRaws[i], NRaws[i], DsStride[i]);
},
Number<NumD1Tensor>{});
}
struct ComputeBasePtrOfStridedBatch
{
ComputeBasePtrOfStridedBatch(index_t BatchStrideA0,
index_t BatchStrideB0,
std::array<index_t, NumD0Tensor> BatchStrideD0s,
index_t BatchStrideB1,
std::array<index_t, NumD1Tensor> BatchStrideD1s,
index_t BatchStrideE1)
: BatchStrideA0_(BatchStrideA0),
BatchStrideB0_(BatchStrideB0),
BatchStrideD0s_(BatchStrideD0s),
BatchStrideB1_(BatchStrideB1),
BatchStrideD1s_(BatchStrideD1s),
BatchStrideE1_(BatchStrideE1)
{
}
__host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideA0_);
}
__host__ __device__ constexpr long_index_t GetBBasePtr(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideB0_);
}
template <index_t I>
__host__ __device__ constexpr long_index_t GetD0BasePtr(index_t g_idx,
Number<I> d1_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideD0s_[d1_idx]);
}
__host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideB1_);
}
__host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideE1_);
}
template <index_t I>
__host__ __device__ constexpr auto GetD1BasePtr(index_t g_idx, Number<I> d1_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideD1s_[d1_idx]);
}
private:
index_t BatchStrideA0_;
index_t BatchStrideB0_;
std::array<index_t, NumD0Tensor> BatchStrideD0s_;
index_t BatchStrideB1_;
std::array<index_t, NumD1Tensor> BatchStrideD1s_;
index_t BatchStrideE1_;
};
using A0GridDesc_M_K = decltype(MakeA0GridDescriptor_M_K(1, 1, 1));
using B0GridDesc_N_K = decltype(MakeB0GridDescriptor_N_K(1, 1, 1));
using D0sGridDesc_M_N = remove_cvref_t<decltype(MakeD0sGridDescriptor_M_N({}, {}, {}))>;
using B1GridDesc_N_K = decltype(MakeB1GridDescriptor_N_K(1, 1, 1));
using D1sGridDesc_M_N = remove_cvref_t<decltype(MakeD1sGridDescriptor_M_N({}, {}, {}))>;
using E1GridDesc_M_N = decltype(MakeE1GridDescriptor_M_N<E1Layout>(1, 1, 1));
// GridwiseGemm
using GridwiseGemm = GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle<
A0DataType, // TODO: distinguish A/B datatype
Acc0DataType,
D0sDataType,
Acc1DataType,
C1ShuffleDataType,
D1sDataType,
E1DataType,
A0ElementwiseOperation,
B0ElementwiseOperation,
CDE0ElementwiseOperation,
B1ElementwiseOperation,
CDE1ElementwiseOperation,
InMemoryDataOperationEnum::Set,
A0GridDesc_M_K,
B0GridDesc_N_K,
D0sGridDesc_M_N,
B1GridDesc_N_K,
D1sGridDesc_M_N,
E1GridDesc_M_N,
NumGemm0KPrefetchStage,
BlockSize,
Gemm0MPerBlock,
Gemm0NPerBlock,
Gemm0KPerBlock,
Gemm1NPerBlock,
Gemm1KPerBlock,
A0K1,
B0K1,
B1K1,
Gemm0MPerXdl,
Gemm0NPerXdl,
Gemm0MXdlPerWave,
Gemm0NXdlPerWave,
Gemm1NXdlPerWave,
A0BlockTransferThreadClusterLengths_AK0_M_AK1,
A0BlockTransferThreadClusterArrangeOrder,
A0BlockTransferSrcAccessOrder,
A0BlockTransferSrcVectorDim,
A0BlockTransferSrcScalarPerVector,
A0BlockTransferDstScalarPerVector_AK1,
true,
A0BlockLdsExtraM,
B0BlockTransferThreadClusterLengths_BK0_N_BK1,
B0BlockTransferThreadClusterArrangeOrder,
B0BlockTransferSrcAccessOrder,
B0BlockTransferSrcVectorDim,
B0BlockTransferSrcScalarPerVector,
B0BlockTransferDstScalarPerVector_BK1,
true,
B0BlockLdsExtraN,
B1BlockTransferThreadClusterLengths_BK0_N_BK1,
B1BlockTransferThreadClusterArrangeOrder,
B1BlockTransferSrcAccessOrder,
B1BlockTransferSrcVectorDim,
B1BlockTransferSrcScalarPerVector,
B1BlockTransferDstScalarPerVector_BK1,
false,
B1BlockLdsExtraN,
C1ShuffleMXdlPerWavePerShuffle,
C1ShuffleGemm0NXdlPerWavePerShuffle,
CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDE1ShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched>;
using A0GridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeDefaultA0GridDescriptor_AK0_M_AK1(A0GridDesc_M_K{}))>;
using B0GridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeDefaultB0GridDescriptor_BK0_N_BK1(B0GridDesc_N_K{}))>;
using B1GridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeDefaultB1GridDescriptor_BK0_N_BK1(B1GridDesc_N_K{}))>;
// Argument
struct Argument : public BaseArgument
{
Argument(const A0DataType* p_a0_grid,
const B0DataType* p_b0_grid,
std::array<const void*, NumD0Tensor> p_d0s_grid,
const B1DataType* p_b1_grid,
std::array<const void*, NumD1Tensor> p_d1s_grid,
E1DataType* p_e1_grid,
index_t MRaw,
index_t NRaw,
index_t KRaw,
index_t Gemm1NRaw, // = ORaw
index_t Batch,
index_t StrideA0,
index_t StrideB0,
std::array<index_t, NumD0Tensor> StrideD0s,
index_t StrideB1,
std::array<index_t, NumD1Tensor> StrideD1s,
index_t StrideE1,
index_t BatchStrideA0,
index_t BatchStrideB0,
std::array<index_t, NumD0Tensor> BatchStrideD0s,
index_t BatchStrideB1,
std::array<index_t, NumD1Tensor> BatchStrideD1s,
index_t BatchStrideE1,
A0ElementwiseOperation a0_element_op,
B0ElementwiseOperation b0_element_op,
CDE0ElementwiseOperation cde0_element_op,
B1ElementwiseOperation b1_element_op,
CDE1ElementwiseOperation cde1_element_op)
: p_a0_grid_{p_a0_grid},
p_b0_grid_{p_b0_grid},
p_d0s_grid_{},
p_b1_grid_{p_b1_grid},
p_d1s_grid_{},
p_e1_grid_{p_e1_grid},
a0_grid_desc_m_k_{DeviceOp::MakeA0GridDescriptor_M_K(MRaw, KRaw, StrideA0)},
b0_grid_desc_n_k_{DeviceOp::MakeB0GridDescriptor_N_K(KRaw, NRaw, StrideB0)},
d0s_grid_desc_m_n_{},
b1_grid_desc_n_k_{DeviceOp::MakeB1GridDescriptor_N_K(NRaw, Gemm1NRaw, StrideB1)},
d1s_grid_desc_m_n_{},
e1_grid_desc_m_n_{
DeviceOp::MakeE1GridDescriptor_M_N<E1Layout>(MRaw, Gemm1NRaw, StrideE1)},
a0_grid_desc_ak0_m_ak1_{
GridwiseGemm::MakeDefaultA0GridDescriptor_AK0_M_AK1(a0_grid_desc_m_k_)},
b0_grid_desc_bk0_n_bk1_{
GridwiseGemm::MakeDefaultB0GridDescriptor_BK0_N_BK1(b0_grid_desc_n_k_)},
d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_{},
b1_grid_desc_bk0_n_bk1_{
GridwiseGemm::MakeDefaultB1GridDescriptor_BK0_N_BK1(b1_grid_desc_n_k_)},
d1s_grid_desc_mblock_mperblock_nblock_nperblock_{},
e1_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_e1tile_map_{GridwiseGemm::MakeDefaultBlock2E1TileMap(e1_grid_desc_m_n_)},
a0_element_op_{a0_element_op},
b0_element_op_{b0_element_op},
cde0_element_op_{cde0_element_op},
b1_element_op_{b1_element_op},
cde1_element_op_{cde1_element_op},
batch_count_(Batch),
compute_base_ptr_of_batch_{BatchStrideA0,
BatchStrideB0,
BatchStrideD0s,
BatchStrideB1,
BatchStrideD1s,
BatchStrideE1}
{
std::cout << "a0_grid_desc_m_k_{" << a0_grid_desc_m_k_.GetLength(I0) << ", "
<< a0_grid_desc_m_k_.GetLength(I1) << "}" << std::endl;
std::cout << "b0_grid_desc_n_k_{" << b0_grid_desc_n_k_.GetLength(I0) << ", "
<< b0_grid_desc_n_k_.GetLength(I1) << "}" << std::endl;
std::cout << "d0s_grid_desc_m_n_[I0]{" << d0s_grid_desc_m_n_[I0].GetLength(I0) << ", "
<< d0s_grid_desc_m_n_[I0].GetLength(I1) << "}" << std::endl;
std::cout << "b1_grid_desc_n_k_{" << b1_grid_desc_n_k_.GetLength(I0) << ", "
<< b1_grid_desc_n_k_.GetLength(I1) << "}" << std::endl;
std::cout << "d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_{"
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I0) << ", "
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I1) << ", "
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I2) << ", "
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I3) << ", "
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I4) << ", "
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I5) << ", "
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I6) << ", "
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I7) << ", "
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I8) << ", "
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I9) << "}"
<< std::endl;
std::cout << "e1_grid_desc_m_n_{" << e1_grid_desc_m_n_.GetLength(I0) << ", "
<< e1_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
using D0Layout = remove_cvref_t<tuple_element_t<i.value, D0sLayout>>;
using D0DataType = remove_cvref_t<tuple_element_t<i.value, D0sDataType>>;
// D0 pointer
p_d0s_grid_(i) = static_cast<const D0DataType*>(p_d0s_grid[i]);
// D0 desc
d0s_grid_desc_m_n_(i) =
DeviceOp::MakeD0GridDescriptor_M_N<D0Layout>(MRaw, NRaw, StrideD0s[i]);
});
static_for<0, NumD1Tensor, 1>{}([&](auto i) {
using D1Layout = remove_cvref_t<tuple_element_t<i.value, D1sLayout>>;
using D1DataType = remove_cvref_t<tuple_element_t<i.value, D1sDataType>>;
// D1 pointer
p_d1s_grid_(i) = static_cast<const D1DataType*>(p_d1s_grid[i]);
// D1 desc
d1s_grid_desc_m_n_(i) =
DeviceOp::MakeE1GridDescriptor_M_N<D1Layout>(MRaw, Gemm1NRaw, StrideD1s[i]);
});
if(GridwiseGemm::CheckValidity(a0_grid_desc_m_k_,
b0_grid_desc_n_k_,
b1_grid_desc_n_k_,
e1_grid_desc_m_n_,
block_2_e1tile_map_))
{
e1_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e1_grid_desc_m_n_);
d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_ =
GridwiseGemm::MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(
d0s_grid_desc_m_n_);
d1s_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeD1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
d1s_grid_desc_m_n_);
}
}
// private:
// pointers
const A0DataType* p_a0_grid_;
const B0DataType* p_b0_grid_;
typename GridwiseGemm::D0sGridPointer p_d0s_grid_;
const B1DataType* p_b1_grid_;
typename GridwiseGemm::D1sGridPointer p_d1s_grid_;
E1DataType* p_e1_grid_;
// tensor descriptors for problem definiton
A0GridDesc_M_K a0_grid_desc_m_k_;
B0GridDesc_N_K b0_grid_desc_n_k_;
D0sGridDesc_M_N d0s_grid_desc_m_n_;
B1GridDesc_N_K b1_grid_desc_n_k_;
D1sGridDesc_M_N d1s_grid_desc_m_n_;
E1GridDesc_M_N e1_grid_desc_m_n_;
// tensor descriptors for block/thread-wise copy
A0GridDesc_AK0_M_AK1 a0_grid_desc_ak0_m_ak1_;
B0GridDesc_BK0_N_BK1 b0_grid_desc_bk0_n_bk1_;
typename GridwiseGemm::D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_;
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_;
typename GridwiseGemm::D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
d1s_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e1_grid_desc_mblock_mperblock_nblock_nperblock_;
// block-to-e1-tile map
typename GridwiseGemm::DefaultBlock2E1TileMap block_2_e1tile_map_;
// element-wise op
A0ElementwiseOperation a0_element_op_;
B0ElementwiseOperation b0_element_op_;
CDE0ElementwiseOperation cde0_element_op_;
B1ElementwiseOperation b1_element_op_;
CDE1ElementwiseOperation cde1_element_op_;
// batch
index_t batch_count_;
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
};
// Invoker
struct Invoker : public BaseInvoker
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(!GridwiseGemm::CheckValidity(arg.a0_grid_desc_m_k_,
arg.b0_grid_desc_n_k_,
arg.b1_grid_desc_n_k_,
arg.e1_grid_desc_m_n_,
arg.block_2_e1tile_map_))
{
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
}
const index_t grid_size =
arg.block_2_e1tile_map_.CalculateGridSize(arg.e1_grid_desc_m_n_) * arg.batch_count_;
// Gemm0_K
const auto K = arg.a0_grid_desc_m_k_.GetLength(I1);
auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_batched_gemm_gemm_xdl_cshuffle_v1<
GridwiseGemm,
A0DataType, // TODO: distiguish A/B datatype
typename GridwiseGemm::D0sGridPointer,
typename GridwiseGemm::D1sGridPointer,
E1DataType,
A0ElementwiseOperation,
B0ElementwiseOperation,
CDE0ElementwiseOperation,
B1ElementwiseOperation,
CDE1ElementwiseOperation,
DeviceOp::A0GridDesc_AK0_M_AK1,
DeviceOp::B0GridDesc_BK0_N_BK1,
typename GridwiseGemm::D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5,
DeviceOp::B1GridDesc_BK0_N_BK1,
typename GridwiseGemm::D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DefaultBlock2E1TileMap,
ComputeBasePtrOfStridedBatch,
has_main_k_block_loop_>;
return launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a0_grid_,
arg.p_b0_grid_,
arg.p_d0s_grid_,
arg.p_b1_grid_,
arg.p_d1s_grid_,
arg.p_e1_grid_,
arg.a0_element_op_,
arg.b0_element_op_,
arg.cde0_element_op_,
arg.b1_element_op_,
arg.cde1_element_op_,
arg.a0_grid_desc_ak0_m_ak1_,
arg.b0_grid_desc_bk0_n_bk1_,
arg.d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg.b1_grid_desc_bk0_n_bk1_,
arg.d1s_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.e1_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_e1tile_map_,
arg.batch_count_,
arg.compute_base_ptr_of_batch_);
};
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
// to concern Gemm0's loop
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
return launch_kernel(integral_constant<bool, true>{});
}
else
{
return launch_kernel(integral_constant<bool, false>{});
}
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
static bool IsSupportedArgument(const Argument& arg)
{
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a"))
{
return false;
}
return GridwiseGemm::CheckValidity(arg.a0_grid_desc_m_k_,
arg.b0_grid_desc_n_k_,
arg.b1_grid_desc_n_k_,
arg.e1_grid_desc_m_n_,
arg.block_2_e1tile_map_);
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const A0DataType* p_a0,
const B0DataType* p_b0,
std::array<const void*, NumD0Tensor> p_d0s,
const B1DataType* p_b1,
std::array<const void*, NumD1Tensor> p_d1s,
E1DataType* p_e1,
index_t MRaw,
index_t NRaw,
index_t KRaw,
index_t Gemm1NRaw,
index_t Batch,
index_t StrideA0,
index_t StrideB0,
std::array<index_t, NumD0Tensor> StrideD0s,
index_t StrideB1,
std::array<index_t, NumD1Tensor> StrideD1s,
index_t StrideE1,
index_t BatchStrideA0,
index_t BatchStrideB0,
std::array<index_t, NumD0Tensor> BatchStrideD0s,
index_t BatchStrideB1,
std::array<index_t, NumD1Tensor> BatchStrideD1s,
index_t BatchStrideE1,
A0ElementwiseOperation a0_element_op,
B0ElementwiseOperation b0_element_op,
CDE0ElementwiseOperation cde0_element_op,
B1ElementwiseOperation b1_element_op,
CDE1ElementwiseOperation cde1_element_op)
{
return Argument{p_a0, p_b0,
p_d0s, p_b1,
p_d1s, p_e1,
MRaw, NRaw,
KRaw, Gemm1NRaw,
Batch, StrideA0,
StrideB0, StrideD0s,
StrideB1, StrideD1s,
StrideE1, BatchStrideA0,
BatchStrideB0, BatchStrideD0s,
BatchStrideB1, BatchStrideD1s,
BatchStrideE1, a0_element_op,
b0_element_op, cde0_element_op,
b1_element_op, cde1_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a0,
const void* p_b0,
std::array<const void*, NumD0Tensor> p_d0s,
const void* p_b1,
std::array<const void*, NumD1Tensor> p_d1s,
void* p_e1,
index_t MRaw,
index_t NRaw,
index_t KRaw,
index_t Gemm1NRaw,
index_t Batch,
index_t StrideA0,
index_t StrideB0,
std::array<ck::index_t, NumD0Tensor> StrideD0s,
index_t StrideB1,
std::array<ck::index_t, NumD1Tensor> StrideD1s,
index_t StrideE1,
index_t BatchStrideA0,
index_t BatchStrideB0,
std::array<ck::index_t, NumD0Tensor> BatchStrideD0s,
index_t BatchStrideB1,
std::array<ck::index_t, NumD1Tensor> BatchStrideD1s,
index_t BatchStrideE1,
A0ElementwiseOperation a0_element_op,
B0ElementwiseOperation b0_element_op,
CDE0ElementwiseOperation cde0_element_op,
B1ElementwiseOperation b1_element_op,
CDE1ElementwiseOperation cde1_element_op) override
{
return std::make_unique<Argument>(static_cast<const A0DataType*>(p_a0),
static_cast<const B0DataType*>(p_b0),
p_d0s,
static_cast<const B1DataType*>(p_b1),
p_d1s,
static_cast<E1DataType*>(p_e1),
MRaw,
NRaw,
KRaw,
Gemm1NRaw,
Batch,
StrideA0,
StrideB0,
StrideD0s,
StrideB1,
StrideD1s,
StrideE1,
BatchStrideA0,
BatchStrideB0,
BatchStrideD0s,
BatchStrideB1,
BatchStrideD1s,
BatchStrideE1,
a0_element_op,
b0_element_op,
cde0_element_op,
b1_element_op,
cde1_element_op);
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle"
<< "<"
<< BlockSize << ", "
<< Gemm0MPerBlock << ", "
<< Gemm0NPerBlock << ", "
<< Gemm0KPerBlock << ", "
<< A0K1 << ", "
<< B0K1 << ", "
<< B1K1 << ", "
<< Gemm0MPerXdl << ", "
<< Gemm0NPerXdl << ", "
<< Gemm0MXdlPerWave << ", "
<< Gemm0NXdlPerWave << ", "
<< Gemm1NXdlPerWave << "> ";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -54,34 +54,6 @@ struct DeviceBatchedGemmSoftmaxGemm : public BaseOperator ...@@ -54,34 +54,6 @@ struct DeviceBatchedGemmSoftmaxGemm : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
template <typename ALayout,
typename B0Layout,
typename B1Layout,
typename CLayout,
typename ADataType,
typename B0DataType,
typename B1DataType,
typename CDataType,
typename AElementwiseOperation,
typename B0ElementwiseOperation,
typename Acc0ElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation>
using DeviceBatchedGemmSoftmaxGemmPtr =
std::unique_ptr<DeviceBatchedGemmSoftmaxGemm<ALayout,
B0Layout,
B1Layout,
CLayout,
ADataType,
B0DataType,
B1DataType,
CDataType,
AElementwiseOperation,
B0ElementwiseOperation,
Acc0ElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation>>;
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename ALayout,
typename B0Layout,
typename B1Layout,
typename CPermuteNumDims_G_M_Gemm1N, // Sequence<>
typename ADataType,
typename B0DataType,
typename B1DataType,
typename CDataType,
typename AElementwiseOperation,
typename B0ElementwiseOperation,
typename Acc0ElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation>
struct DeviceBatchedGemmSoftmaxGemmPermute : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b0,
const void* p_b1,
void* p_c,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t O,
ck::index_t Batch,
std::vector<index_t> c_gs_ms_os_lengths,
std::vector<index_t> c_gs_ms_os_strides,
ck::index_t StrideA,
ck::index_t StrideB0,
ck::index_t StrideB1,
ck::index_t BatchStrideA,
ck::index_t BatchStrideB0,
ck::index_t BatchStrideB1,
AElementwiseOperation a_element_op,
B0ElementwiseOperation b0_element_op,
Acc0ElementwiseOperation acc0_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename AccElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename B1GridDesc_BK0_N_BK1,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename Block2CTileMap,
typename ComputeBasePtrOfStridedBatch,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
const FloatAB* __restrict__ p_b1_grid,
FloatC* __restrict__ p_c_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const AccElementwiseOperation acc_element_op,
const B1ElementwiseOperation b1_element_op,
const CElementwiseOperation c_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2CTileMap block_2_ctile_map,
const index_t batch_count,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetABasePtr(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetBBasePtr(g_idx)));
const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetB1BasePtr(g_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetCBasePtr(g_idx)));
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
p_b1_grid + b1_batch_offset,
p_c_grid + c_batch_offset,
p_shared,
a_element_op,
b_element_op,
acc_element_op,
b1_element_op,
c_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_ctile_map);
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_b1_grid;
ignore = p_c_grid;
ignore = a_element_op;
ignore = b_element_op;
ignore = acc_element_op;
ignore = b1_element_op;
ignore = c_element_op;
ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1;
ignore = b1_grid_desc_bk0_n_bk1;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = block_2_ctile_map;
ignore = batch_count;
ignore = compute_base_ptr_of_batch;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
// Computes C = A * B0 * B1
// ^^^^^^ (Acc0)
// ^^^^^^^^^^^ (Acc1)
template <typename ALayout,
typename BLayout, // B0Layout
typename B1Layout,
typename CPermuteNumDims_G_M_Gemm1N, // Sequence<NumDimG, NumDimM, NumDimGemm1N>
typename ADataType,
typename BDataType,
typename B1DataType,
typename CDataType,
typename GemmAccDataType,
typename CShuffleDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename AccElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation,
GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock, // Gemm0NPerBlock
index_t KPerBlock, // Gemm0KPerBlock
index_t Gemm1NPerBlock,
index_t Gemm1KPerBlock,
index_t AK1,
index_t BK1,
index_t B1K1,
index_t MPerXDL,
index_t NPerXDL,
index_t MXdlPerWave,
index_t NXdlPerWave,
index_t Gemm1NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BBlockLdsExtraN,
typename B1BlockTransferThreadClusterLengths_BK0_N_BK1,
typename B1BlockTransferThreadClusterArrangeOrder,
typename B1BlockTransferSrcAccessOrder,
index_t B1BlockTransferSrcVectorDim,
index_t B1BlockTransferSrcScalarPerVector,
index_t B1BlockTransferDstScalarPerVector_BK1,
bool B1BlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
: public DeviceBatchedGemmSoftmaxGemmPermute<ALayout,
BLayout,
B1Layout,
CPermuteNumDims_G_M_Gemm1N,
ADataType,
BDataType,
B1DataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation>
{
using DeviceOp = DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto matrix_padder =
GemmGemmPadder<GemmSpec, index_t, index_t, index_t, index_t>{
MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock};
// FIXME: pad K
static_assert(!matrix_padder.PadK, "KPadding is currently not supported");
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
{
const auto a_grid_desc_mraw_kraw = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(StrideA, I1));
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(I1, StrideA));
}
}();
const auto a_grid_desc_m_k = matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
const auto M = a_grid_desc_m_k.GetLength(I0);
const auto K = a_grid_desc_m_k.GetLength(I1);
const auto AK0 = K / AK1;
return transform_tensor_descriptor(a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
{
const auto b_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(I1, StrideB));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(StrideB, I1));
}
}();
const auto b_grid_desc_n_k = matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = b_grid_desc_n_k.GetLength(I1);
const auto BK0 = K / BK1;
return transform_tensor_descriptor(b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
// Args: Gemm1KRaw, Gemm1NRaw, StrideB1
static auto MakeB1GridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
{
const auto b1_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, B1Layout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(I1, StrideB));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, B1Layout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(StrideB, I1));
}
}();
const auto b1_grid_desc_n_k = matrix_padder.PadB1Descriptor_N_K(b1_grid_desc_nraw_kraw);
const auto N = b1_grid_desc_n_k.GetLength(I0);
const auto K = b1_grid_desc_n_k.GetLength(I1);
const auto B1K0 = K / B1K1;
return transform_tensor_descriptor(
b1_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
// assume C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
static auto MakeCGridDescriptor_M_N(const std::vector<index_t>& c_gs_ms_ns_lengths_vec,
const std::vector<index_t>& c_gs_ms_ns_strides_vec)
{
constexpr index_t NumDimG = CPermuteNumDims_G_M_Gemm1N::At(I0);
constexpr index_t NumDimM = CPermuteNumDims_G_M_Gemm1N::At(I1);
constexpr index_t NumDimN = CPermuteNumDims_G_M_Gemm1N::At(I2); // NumDimGemm1N
assert(c_gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN &&
c_gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN);
const auto to_tuple = [&](auto& vec, auto start, auto end) {
return generate_tuple([&](auto i) { return vec[start + i]; }, Number<end - start>{});
};
const auto c_ms_ns_lengths = to_tuple(
c_gs_ms_ns_lengths_vec, Number<NumDimG>{}, Number<NumDimG + NumDimM + NumDimN>{});
const auto c_ms_ns_strides = to_tuple(
c_gs_ms_ns_strides_vec, Number<NumDimG>{}, Number<NumDimG + NumDimM + NumDimN>{});
// dimension Ids for M0, M1, ...
constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{};
// dimension Ids for N0, N1, ...
constexpr auto nDimIds =
typename arithmetic_sequence_gen<NumDimM, NumDimM + NumDimN, 1>::type{};
// lengths for M0, M1, ...
const auto mLengths = get_container_subset(c_ms_ns_lengths, mDimIds);
// lengths for K0, K1, ...
const auto nLengths = get_container_subset(c_ms_ns_lengths, nDimIds);
// naive tensor C[M0, M1, M2, ..., N0, N1, N2...]
const auto c_grid_desc_ms_ns =
make_naive_tensor_descriptor(c_ms_ns_lengths, c_ms_ns_strides);
// transformed tensor C[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * N2 * ...]
const auto c_grid_desc_mraw_nraw = transform_tensor_descriptor(
c_grid_desc_ms_ns,
make_tuple(make_merge_transform(mLengths), make_merge_transform(nLengths)),
make_tuple(mDimIds, nDimIds),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw);
}
// assume C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
static auto MakeCGridDescriptor_G_M_N(const std::vector<index_t>& c_gs_ms_ns_lengths_vec,
const std::vector<index_t>& c_gs_ms_ns_strides_vec)
{
constexpr index_t NumDimG = CPermuteNumDims_G_M_Gemm1N::At(I0);
constexpr index_t NumDimM = CPermuteNumDims_G_M_Gemm1N::At(I1);
constexpr index_t NumDimN = CPermuteNumDims_G_M_Gemm1N::At(I2); // NumDimGemm1N
assert(c_gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN &&
c_gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN);
const auto to_tuple = [&](auto& vec, auto start, auto end) {
return generate_tuple([&](auto i) { return vec[start + i]; }, Number<end - start>{});
};
const auto c_gs_ms_ns_lengths =
to_tuple(c_gs_ms_ns_lengths_vec, Number<0>{}, Number<NumDimG + NumDimM + NumDimN>{});
const auto c_gs_ms_ns_strides =
to_tuple(c_gs_ms_ns_strides_vec, Number<0>{}, Number<NumDimG + NumDimM + NumDimN>{});
// dimension Ids for G0, G1, ...
constexpr auto gDimIds = typename arithmetic_sequence_gen<0, NumDimG, 1>::type{};
// dimension Ids for M0, M1, ...
constexpr auto mDimIds =
typename arithmetic_sequence_gen<NumDimG, NumDimG + NumDimM, 1>::type{};
// dimension Ids for N0, N1, ...
constexpr auto nDimIds = typename arithmetic_sequence_gen<NumDimG + NumDimM,
NumDimG + NumDimM + NumDimN,
1>::type{};
// lengths for G0, G1, ...
const auto gLengths = get_container_subset(c_gs_ms_ns_lengths, gDimIds);
// lengths for M0, M1, ...
const auto mLengths = get_container_subset(c_gs_ms_ns_lengths, mDimIds);
// lengths for K0, K1, ...
const auto nLengths = get_container_subset(c_gs_ms_ns_lengths, nDimIds);
// naive tensor C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
const auto c_grid_desc_gs_ms_ns =
make_naive_tensor_descriptor(c_gs_ms_ns_lengths, c_gs_ms_ns_strides);
// transformed tensor C[G = G0 * G1 * ..., MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 *
// N2 * ...]
const auto c_grid_desc_g_mraw_nraw =
transform_tensor_descriptor(c_grid_desc_gs_ms_ns,
make_tuple(make_merge_transform(gLengths),
make_merge_transform(mLengths),
make_merge_transform(nLengths)),
make_tuple(gDimIds, mDimIds, nDimIds),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// this desc is only for calculating batch offset so no padding needed
return c_grid_desc_g_mraw_nraw;
}
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1));
using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N({}, {}));
using CGridDesc_G_M_N = decltype(MakeCGridDescriptor_G_M_N({}, {}));
struct ComputeBasePtrOfStridedBatch
{
ComputeBasePtrOfStridedBatch(index_t BatchStrideA,
index_t BatchStrideB,
index_t BatchStrideB1,
CGridDesc_G_M_N c_grid_desc_g_m_n)
: BatchStrideA_(BatchStrideA),
BatchStrideB_(BatchStrideB),
BatchStrideB1_(BatchStrideB1),
c_grid_desc_g_m_n_(c_grid_desc_g_m_n)
{
}
__host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideA_);
}
__host__ __device__ constexpr long_index_t GetBBasePtr(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideB_);
}
__host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideB1_);
}
__host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const
{
return c_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0));
}
private:
index_t BatchStrideA_;
index_t BatchStrideB_;
index_t BatchStrideB1_;
CGridDesc_G_M_N c_grid_desc_g_m_n_;
};
// GridwiseGemm
using GridwiseGemm = GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle<
ADataType, // TODO: distinguish A/B datatype
GemmAccDataType,
CShuffleDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
InMemoryDataOperationEnum::Set,
AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
B1GridDesc_BK0_N_BK1,
CGridDesc_M_N,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
Gemm1NPerBlock,
Gemm1KPerBlock,
AK1,
BK1,
B1K1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
Gemm1NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
true,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
true,
BBlockLdsExtraN,
B1BlockTransferThreadClusterLengths_BK0_N_BK1,
B1BlockTransferThreadClusterArrangeOrder,
B1BlockTransferSrcAccessOrder,
B1BlockTransferSrcVectorDim,
B1BlockTransferSrcScalarPerVector,
B1BlockTransferDstScalarPerVector_BK1,
false,
B1BlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched,
matrix_padder.PadN>;
// Argument
// FIXME: constness
struct Argument : public BaseArgument
{
Argument(const ADataType* p_a_grid,
const BDataType* p_b_grid,
const B1DataType* p_b1_grid,
CDataType* p_c_grid,
index_t MRaw,
index_t NRaw,
index_t KRaw,
index_t Gemm1NRaw, // = ORaw
index_t Batch,
std::vector<index_t> c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
std::vector<index_t> c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
index_t StrideA,
index_t StrideB,
index_t StrideB1,
index_t BatchStrideA,
index_t BatchStrideB,
index_t BatchStrideB1,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op)
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
p_b1_grid_{p_b1_grid},
p_c_grid_{p_c_grid},
a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)},
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)},
b1_grid_desc_bk0_n_bk1_{
DeviceOp::MakeB1GridDescriptor_BK0_N_BK1(NRaw, Gemm1NRaw, StrideB1)},
c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths,
c_gs_ms_gemm1ns_strides)},
c_grid_desc_g_m_n_{DeviceOp::MakeCGridDescriptor_G_M_N(c_gs_ms_gemm1ns_lengths,
c_gs_ms_gemm1ns_strides)},
c_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
acc_element_op_{acc_element_op},
b1_element_op_{b1_element_op},
c_element_op_{c_element_op},
batch_count_(Batch),
compute_base_ptr_of_batch_{
BatchStrideA, BatchStrideB, BatchStrideB1, c_grid_desc_g_m_n_},
raw_lengths_m_n_k_o_{MRaw, NRaw, KRaw, Gemm1NRaw},
c_extent_lowest_{c_gs_ms_gemm1ns_lengths.back()},
c_stride_lowest_{c_gs_ms_gemm1ns_strides.back()}
{
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
b_grid_desc_bk0_n_bk1_,
b1_grid_desc_bk0_n_bk1_,
c_grid_desc_m_n_,
block_2_ctile_map_))
{
c_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n_);
}
}
// private:
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
const B1DataType* p_b1_grid_;
CDataType* p_c_grid_;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_;
CGridDesc_M_N c_grid_desc_m_n_;
CGridDesc_G_M_N c_grid_desc_g_m_n_;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
AccElementwiseOperation acc_element_op_;
B1ElementwiseOperation b1_element_op_;
CElementwiseOperation c_element_op_;
index_t batch_count_;
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
// For robust IsSupportedArgument() check
std::vector<index_t> raw_lengths_m_n_k_o_;
index_t c_extent_lowest_;
index_t c_stride_lowest_;
};
// Invoker
struct Invoker : public BaseInvoker
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.b1_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_))
{
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
}
const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.batch_count_;
// Gemm0_K
const auto K =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
DeviceOp::B1GridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DefaultBlock2CTileMap,
ComputeBasePtrOfStridedBatch,
has_main_k_block_loop_>;
return launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_b1_grid_,
arg.p_c_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.acc_element_op_,
arg.b1_element_op_,
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.b1_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_,
arg.batch_count_,
arg.compute_base_ptr_of_batch_);
};
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
// to concern Gemm0's loop
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
ave_time = launch_kernel(integral_constant<bool, true>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, false>{});
}
return ave_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
static bool IsSupportedArgument(const Argument& arg)
{
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a"))
{
return false;
}
// Check if C permute dimension matches GEMM + GEMM shape
const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded
const index_t c_m = arg.c_grid_desc_m_n_.GetLength(I0);
const index_t c_gemm1n = arg.c_grid_desc_m_n_.GetLength(I1);
const index_t a_m = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
const index_t b1_gemm1n = arg.b1_grid_desc_bk0_n_bk1_.GetLength(I1);
if(!(c_g == arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n))
{
return false;
}
// Note: we need raw lengths since threadwise copy can not handle vector load when part of
// vector is out of bounds
const auto MRaw = arg.raw_lengths_m_n_k_o_[0];
const auto NRaw = arg.raw_lengths_m_n_k_o_[1];
const auto KRaw = arg.raw_lengths_m_n_k_o_[2];
const auto Gemm1NRaw = arg.raw_lengths_m_n_k_o_[3];
// Check scalar per vector requirement
const auto a_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, ALayout> ? KRaw : MRaw;
const auto b_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, BLayout> ? NRaw : KRaw;
const auto b1_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, B1Layout> ? Gemm1NRaw : NRaw;
const auto c_extent_lowest = arg.c_extent_lowest_;
if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
b_extent_lowest % BBlockTransferSrcScalarPerVector == 0 &&
b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
{
return false;
}
// Check vector store requirement; assumes last dimension in N to be contiguous
if(arg.c_stride_lowest_ != 1)
{
return false;
}
return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.b1_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_);
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b,
const B1DataType* p_b1,
CDataType* p_c,
index_t MRaw,
index_t NRaw,
index_t KRaw,
index_t Gemm1NRaw,
index_t Batch,
std::vector<index_t> c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
std::vector<index_t> c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
index_t StrideA,
index_t StrideB,
index_t StrideB1,
index_t BatchStrideA,
index_t BatchStrideB,
index_t BatchStrideB1,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op)
{
return Argument{p_a,
p_b,
p_b1,
p_c,
MRaw,
NRaw,
KRaw,
Gemm1NRaw,
Batch,
c_gs_ms_gemm1ns_lengths,
c_gs_ms_gemm1ns_strides,
StrideA,
StrideB,
StrideB1,
BatchStrideA,
BatchStrideB,
BatchStrideB1,
a_element_op,
b_element_op,
acc_element_op,
b1_element_op,
c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
// FIXME: constness
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
const void* p_b1,
void* p_c,
index_t MRaw,
index_t NRaw,
index_t KRaw,
index_t Gemm1NRaw,
index_t Batch,
std::vector<index_t> c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
std::vector<index_t> c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
index_t StrideA,
index_t StrideB,
index_t StrideB1,
index_t BatchStrideA,
index_t BatchStrideB,
index_t BatchStrideB1,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<const B1DataType*>(p_b1),
static_cast<CDataType*>(p_c),
MRaw,
NRaw,
KRaw,
Gemm1NRaw,
Batch,
c_gs_ms_gemm1ns_lengths,
c_gs_ms_gemm1ns_strides,
StrideA,
StrideB,
StrideB1,
BatchStrideA,
BatchStrideB,
BatchStrideB1,
a_element_op,
b_element_op,
acc_element_op,
b1_element_op,
c_element_op);
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< KPerBlock << ", "
<< AK1 << ", "
<< BK1 << ", "
<< MPerBlock << ", "
<< Gemm1NPerBlock << ", "
<< Gemm1KPerBlock << ", "
<< B1K1 << ", "
<< getGemmSpecializationString(GemmSpec) << ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
...@@ -111,6 +112,15 @@ __global__ void ...@@ -111,6 +112,15 @@ __global__ void
// Computes C = A * B0 * B1 // Computes C = A * B0 * B1
// ^^^^^^ (Acc0) // ^^^^^^ (Acc0)
// ^^^^^^^^^^^ (Acc1) // ^^^^^^^^^^^ (Acc1)
// When using NPadding as GemmSpecialization, AccElementwiseOperation should be set to
// ScaleAndResetNaNToMinusInfinity.
// if !isNan(AccElement)
// AccElement *= scale
// else
// AccElement = -INFINITY
// Otherwise, result may be wrong.
template <typename ALayout, template <typename ALayout,
typename BLayout, // B0Layout typename BLayout, // B0Layout
typename B1Layout, typename B1Layout,
...@@ -189,6 +199,13 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -189,6 +199,13 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto matrix_padder =
GemmGemmPadder<GemmSpec, index_t, index_t, index_t, index_t>{
MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock};
// FIXME: pad K
static_assert(!matrix_padder.PadK, "KPadding is currently not supported");
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA) static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
{ {
const auto a_grid_desc_mraw_kraw = [&]() { const auto a_grid_desc_mraw_kraw = [&]() {
...@@ -204,92 +221,18 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -204,92 +221,18 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
} }
}(); }();
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; const auto a_grid_desc_m_k = matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
const auto MPad = M - MRaw;
const auto KPad = K - KRaw;
if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad both M and K
assert(K % AK1 == 0);
const auto AK0 = K / AK1;
const auto a_grid_desc_m_k =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_right_pad_transform(MRaw, MPad),
make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MNPadding)
{
// pad M, but not K
assert(KRaw % AK1 == 0);
const auto AK0 = KRaw / AK1;
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_right_pad_transform(MRaw, MPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
// pad K, but not M
assert(K % AK1 == 0);
const auto AK0 = K / AK1;
const auto a_grid_desc_m_k = transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(MRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else
{
// not pad M or K
assert(KRaw % AK1 == 0);
const auto AK0 = KRaw / AK1; const auto M = a_grid_desc_m_k.GetLength(I0);
const auto K = a_grid_desc_m_k.GetLength(I1);
const auto a_grid_desc_ak0_m_ak1 = const auto AK0 = K / AK1;
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(MRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1; return transform_tensor_descriptor(a_grid_desc_m_k,
} make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
} }
static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB) static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
...@@ -307,84 +250,18 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -307,84 +250,18 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
} }
}(); }();
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; const auto b_grid_desc_n_k = matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
const auto NPad = N - NRaw;
const auto KPad = K - KRaw;
if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad both N and K
const auto BK0 = K / BK1;
const auto b_grid_desc_n_k =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_right_pad_transform(NRaw, NPad),
make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::MNPadding)
{
// pad N, but not K
const auto BK0 = KRaw / BK1;
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1; const auto N = b_grid_desc_n_k.GetLength(I0);
} const auto K = b_grid_desc_n_k.GetLength(I1);
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
// pad K, but not N
const auto BK0 = K / BK1;
const auto b_grid_desc_n_k = transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(NRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else
{
// not pad N or K
const auto BK0 = KRaw / BK1;
const auto b_grid_desc_bk0_n_bk1 = const auto BK0 = K / BK1;
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(NRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1; return transform_tensor_descriptor(b_grid_desc_n_k,
} make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
} }
// Args: Gemm1KRaw, Gemm1NRaw, StrideB1 // Args: Gemm1KRaw, Gemm1NRaw, StrideB1
...@@ -403,47 +280,19 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -403,47 +280,19 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
} }
}(); }();
const auto N = math::integer_divide_ceil(NRaw, Gemm1NPerBlock) * Gemm1NPerBlock; const auto b1_grid_desc_n_k = matrix_padder.PadB1Descriptor_N_K(b1_grid_desc_nraw_kraw);
const auto K = math::integer_divide_ceil(KRaw, Gemm1KPerBlock) * Gemm1KPerBlock;
const auto NPad = N - NRaw; const auto N = b1_grid_desc_n_k.GetLength(I0);
const auto KPad = K - KRaw; const auto K = b1_grid_desc_n_k.GetLength(I1);
// TODO: implement finer-grained padding const auto B1K0 = K / B1K1;
if constexpr(GemmSpec == GemmSpecialization::Default)
{
const auto B1K0 = KRaw / B1K1;
const auto b1_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( return transform_tensor_descriptor(
b1_grid_desc_nraw_kraw, b1_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)), make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)),
make_pass_through_transform(NRaw)), make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b1_grid_desc_bk0_n_bk1;
}
else
{
// pad both B1N and B1K
const auto B1K0 = K / B1K1;
const auto b1_grid_desc_n_k =
transform_tensor_descriptor(b1_grid_desc_nraw_kraw,
make_tuple(make_right_pad_transform(NRaw, NPad),
make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto b1_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b1_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b1_grid_desc_bk0_n_bk1;
}
} }
static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC) static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC)
...@@ -461,47 +310,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -461,47 +310,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
} }
}(); }();
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw);
const auto N = math::integer_divide_ceil(NRaw, Gemm1NPerBlock) * Gemm1NPerBlock;
const auto MPad = M - MRaw;
const auto NPad = N - NRaw;
if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad M and N
return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(MRaw, MPad),
make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
// pad M, but not N
return transform_tensor_descriptor(
c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
// pad N, but not M
return transform_tensor_descriptor(
c_grid_desc_mraw_nraw,
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
// not pad M or N
return c_grid_desc_mraw_nraw;
}
} }
struct ComputeBasePtrOfStridedBatch struct ComputeBasePtrOfStridedBatch
...@@ -608,7 +417,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -608,7 +417,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
CShuffleNXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched>; LoopSched,
matrix_padder.PadN>;
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
...@@ -652,7 +462,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -652,7 +462,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
b1_element_op_{b1_element_op}, b1_element_op_{b1_element_op},
c_element_op_{c_element_op}, c_element_op_{c_element_op},
batch_count_(Batch), batch_count_(Batch),
compute_base_ptr_of_batch_{BatchStrideA, BatchStrideB, BatchStrideB1, BatchStrideC} compute_base_ptr_of_batch_{BatchStrideA, BatchStrideB, BatchStrideB1, BatchStrideC},
raw_lengths_m_n_k_o_{MRaw, NRaw, KRaw, Gemm1NRaw}
{ {
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
b_grid_desc_bk0_n_bk1_, b_grid_desc_bk0_n_bk1_,
...@@ -685,6 +496,9 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -685,6 +496,9 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
CElementwiseOperation c_element_op_; CElementwiseOperation c_element_op_;
index_t batch_count_; index_t batch_count_;
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_; ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
// For robust IsSupportedArgument() check
std::vector<index_t> raw_lengths_m_n_k_o_;
}; };
// Invoker // Invoker
...@@ -788,6 +602,31 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -788,6 +602,31 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
return false; return false;
} }
// Note: we need raw lengths since threadwise copy can not handle vector load when part of
// vector is out of bounds
const auto MRaw = arg.raw_lengths_m_n_k_o_[0];
const auto NRaw = arg.raw_lengths_m_n_k_o_[1];
const auto KRaw = arg.raw_lengths_m_n_k_o_[2];
const auto Gemm1NRaw = arg.raw_lengths_m_n_k_o_[3];
// Check scalar per vector requirement
const auto a_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, ALayout> ? KRaw : MRaw;
const auto b_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, BLayout> ? NRaw : KRaw;
const auto b1_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, B1Layout> ? Gemm1NRaw : NRaw;
const auto c_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, CLayout> ? Gemm1NRaw : MRaw;
if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
b_extent_lowest % BBlockTransferSrcScalarPerVector == 0 &&
b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
{
return false;
}
return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.b1_grid_desc_bk0_n_bk1_, arg.b1_grid_desc_bk0_n_bk1_,
...@@ -904,7 +743,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -904,7 +743,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<< MPerBlock << ", " << MPerBlock << ", "
<< Gemm1NPerBlock << ", " << Gemm1NPerBlock << ", "
<< Gemm1KPerBlock << ", " << Gemm1KPerBlock << ", "
<< B1K1 << ">"; << B1K1 << ", "
<< getGemmSpecializationString(GemmSpec) << ">";
// clang-format on // clang-format on
return str.str(); return str.str();
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <index_t Rank, index_t NumBatchNormReduceDim>
struct DeviceBatchNormFwd : public BaseOperator
{
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
const std::array<index_t, Rank> xyLengths,
const std::array<index_t, Rank> xStrides,
const std::array<index_t, Rank> yStrides,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarStrides,
const void* p_x,
const void* bnScale,
const void* bnBias,
void* p_y,
double exponentialAverageFactor,
void* resultRunningMean,
void* resultRunningVariance,
double epsilon,
void* resultSaveMean,
void* resultSaveInvVariance) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <index_t Rank, index_t NumBatchNormReduceDim>
using DeviceBatchNormFwdPtr = std::unique_ptr<DeviceBatchNormFwd<Rank, NumBatchNormReduceDim>>;
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <index_t Rank, index_t NumBatchNormReduceDim>
struct DeviceBatchNormInfer : public BaseOperator
{
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
const std::array<index_t, Rank> xyLengths,
const std::array<index_t, Rank> xStrides,
const std::array<index_t, Rank> yStrides,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarStrides,
const void* p_x,
const void* bnScale,
const void* bnBias,
double epsilon,
const void* estimatedMean,
const void* estimatedInvVariance,
void* p_y) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <index_t Rank, index_t NumBatchNormReduceDim>
using DeviceBatchNormInferPtr = std::unique_ptr<DeviceBatchNormInfer<Rank, NumBatchNormReduceDim>>;
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename ADataType,
typename BDataType,
typename CDataType,
typename ComputeDataType,
typename ElementwiseFunctor,
index_t NDim,
index_t MPerThread,
index_t AScalarPerVector,
index_t BScalarPerVector,
index_t CScalarPerVector>
struct DeviceBinaryElementwise : public DeviceElementwise<2, 1, NDim, ElementwiseFunctor>
{
static constexpr auto I0 = Number<0>{};
template <typename Desc_M>
static auto PadDescriptor_M_1d(Desc_M desc_m, index_t gridSize, index_t blockSize)
{
const auto M = desc_m.GetLength(I0);
const index_t loop_step = gridSize * blockSize * MPerThread;
const auto pad = math::integer_least_multiple(M, loop_step) - M;
const auto desc_m_pad =
transform_tensor_descriptor(desc_m,
make_tuple(make_right_pad_transform(M, pad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
return desc_m_pad;
}
static auto MakeDescriptor_M(const std::vector<index_t>& lengths,
const std::vector<index_t>& strides,
index_t gridSize,
index_t blockSize)
{
auto tupleOfShape = generate_tuple([&](auto I) { return lengths[I]; }, Number<NDim>{});
auto tupleOfStride = generate_tuple([&](auto I) { return strides[I]; }, Number<NDim>{});
// nd desc - [s0, s1, s2, ...]
const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride);
// merge nd to 1d desc - [s0 * s1 * ...]
if constexpr(NDim > 1)
{
const auto desc_m = transform_tensor_descriptor(
desc,
make_tuple(make_merge_transform(tupleOfShape)),
make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number<NDim>{})),
make_tuple(Sequence<0>{}));
return PadDescriptor_M_1d(desc_m, gridSize, blockSize);
}
else
return PadDescriptor_M_1d(desc, gridSize, blockSize);
}
using AGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1));
using BGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1));
using CGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1));
using GridwiseBinEltwise = GridwiseBinaryElementwise_1D<ADataType,
BDataType,
CDataType,
ComputeDataType,
AGridDesc_M,
BGridDesc_M,
CGridDesc_M,
ElementwiseFunctor,
MPerThread,
AScalarPerVector,
BScalarPerVector,
CScalarPerVector>;
struct Argument : public BaseArgument
{
Argument(const ADataType* p_a,
const BDataType* p_b,
CDataType* p_c,
const std::vector<index_t>& lengths,
const std::vector<index_t>& a_strides,
const std::vector<index_t>& b_strides,
const std::vector<index_t>& c_strides,
ElementwiseFunctor functor)
: p_a_(p_a),
p_b_(p_b),
p_c_(p_c),
lengths_(lengths),
a_strides_(a_strides),
b_strides_(b_strides),
c_strides_(c_strides),
functor_(functor),
blockSize_(256),
gridSize_(120) // FIXME - Calculate the grid size by number of CU in the future
{
a_grid_desc_m_ = MakeDescriptor_M(lengths, a_strides, gridSize_, blockSize_);
b_grid_desc_m_ = MakeDescriptor_M(lengths, b_strides, gridSize_, blockSize_);
c_grid_desc_m_ = MakeDescriptor_M(lengths, c_strides, gridSize_, blockSize_);
}
const ADataType* p_a_;
const BDataType* p_b_;
CDataType* p_c_;
std::vector<int> lengths_;
AGridDesc_M a_grid_desc_m_;
BGridDesc_M b_grid_desc_m_;
CGridDesc_M c_grid_desc_m_;
std::vector<index_t> a_strides_;
std::vector<index_t> b_strides_;
std::vector<index_t> c_strides_;
ElementwiseFunctor functor_;
index_t blockSize_;
index_t gridSize_;
};
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const auto kernel = kernel_binary_elementwise_1d<GridwiseBinEltwise,
ADataType,
BDataType,
CDataType,
AGridDesc_M,
BGridDesc_M,
CGridDesc_M,
ElementwiseFunctor>;
float elapsed_time = launch_and_time_kernel(stream_config,
kernel,
dim3(arg.gridSize_),
dim3(arg.blockSize_),
0,
arg.p_a_,
arg.p_b_,
arg.p_c_,
arg.a_grid_desc_m_,
arg.b_grid_desc_m_,
arg.c_grid_desc_m_,
arg.functor_);
return elapsed_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
if(pArg == nullptr)
return false;
if(pArg->lengths_.size() != NDim)
return false;
if(pArg->lengths_.back() % MPerThread != 0)
return false;
auto IsScalarPerVectorValid = [](bool isLastDimensionCoalesced, int scalarPerVector) {
bool ret = true;
if(!isLastDimensionCoalesced)
ret = scalarPerVector == 1;
else
ret = MPerThread % scalarPerVector == 0;
return ret;
};
if(!IsScalarPerVectorValid(pArg->a_strides_.back() == 1, AScalarPerVector))
return false;
if(!IsScalarPerVectorValid(pArg->b_strides_.back() == 1, BScalarPerVector))
return false;
if(!IsScalarPerVectorValid(pArg->c_strides_.back() == 1, CScalarPerVector))
return false;
return true;
};
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(std::array<const void*, 2> p_inputs,
std::array<void*, 1> p_outputs,
std::vector<index_t> lengths,
std::vector<std::vector<index_t>> input_strides,
std::vector<std::vector<index_t>> output_strides,
ElementwiseFunctor functor) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_inputs[0]),
static_cast<const BDataType*>(p_inputs[1]),
static_cast<CDataType*>(p_outputs[0]),
lengths,
input_strides[0],
input_strides[1],
output_strides[0],
functor);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>();
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceBinaryElementwise"
<< "<"
<< "NDim = " << NDim
<< "MPerThread = " << MPerThread
<< "AScalarPerVector = " << AScalarPerVector
<< "BScalarPerVector = " << BScalarPerVector
<< "CScalarPerVector = " << CScalarPerVector
<< ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#include "ck/tensor_operation/gpu/device/device_cgemm.hpp" #include "ck/tensor_operation/gpu/device/device_cgemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
...@@ -538,48 +538,43 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -538,48 +538,43 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
float ave_time = 0; float ave_time = 0;
using Add = ck::tensor_operation::element_wise::Add; using Add = ck::tensor_operation::element_wise::Add;
using Subtract = ck::tensor_operation::element_wise::Subtract; using Subtract = ck::tensor_operation::element_wise::Subtract;
using GridwiseBinAdd = GridwiseBinaryElementwise_1D<CDataType,
CDataType, using GridwiseBinAdd =
CDataType, GridwiseElementwise_1D<Tuple<CGridDesc_M, CGridDesc_M>,
CDataType, Tuple<CGridDesc_M>,
CGridDesc_M, Tuple<const CDataType*, const CDataType*>,
CGridDesc_M, Tuple<CDataType*>,
CGridDesc_M, Add,
Add, MPerThread,
MPerThread, Sequence<AScalarPerVector, BScalarPerVector>,
AScalarPerVector, Sequence<CScalarPerVector>>;
BScalarPerVector,
CScalarPerVector>; using GridwiseBinSubtract =
using GridwiseBinSubtract = GridwiseBinaryElementwise_1D<CDataType, GridwiseElementwise_1D<Tuple<CGridDesc_M, CGridDesc_M>,
CDataType, Tuple<CGridDesc_M>,
CDataType, Tuple<const CDataType*, const CDataType*>,
CDataType, Tuple<CDataType*>,
CGridDesc_M, Subtract,
CGridDesc_M, MPerThread,
CGridDesc_M, Sequence<AScalarPerVector, BScalarPerVector>,
Subtract, Sequence<CScalarPerVector>>;
MPerThread,
AScalarPerVector, const auto add_kernel = kernel_elementwise_1d<GridwiseBinAdd,
BScalarPerVector, Tuple<CGridDesc_M, CGridDesc_M>,
CScalarPerVector>; Tuple<CGridDesc_M>,
const auto add_kernel = kernel_binary_elementwise_1d<GridwiseBinAdd, Tuple<const CDataType*, const CDataType*>,
CDataType, Tuple<CDataType*>,
CDataType, Add>;
CDataType,
CGridDesc_M, const auto subtract_kernel =
CGridDesc_M, kernel_elementwise_1d<GridwiseBinSubtract,
CGridDesc_M, Tuple<CGridDesc_M, CGridDesc_M>,
Add>; Tuple<CGridDesc_M>,
const auto subtract_kernel = kernel_binary_elementwise_1d<GridwiseBinSubtract, Tuple<const CDataType*, const CDataType*>,
CDataType, Tuple<CDataType*>,
CDataType, Subtract>;
CDataType,
CGridDesc_M,
CGridDesc_M,
CGridDesc_M,
Subtract>;
if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{ {
...@@ -631,18 +626,18 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -631,18 +626,18 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
// c_real = aux - aux_2 // c_real = aux - aux_2
ave_time += launch_and_time_kernel(stream_config, ave_time += launch_and_time_kernel(
subtract_kernel, stream_config,
dim3(grid_size), subtract_kernel,
dim3(BlockSize), dim3(grid_size),
0, dim3(BlockSize),
arg.p_aux_grid_, 0,
arg.p_aux_2_grid_, make_tuple(arg.c_grid_desc_m_, arg.c_grid_desc_m_),
arg.p_c_grid_real_, make_tuple(arg.c_grid_desc_m_),
arg.c_grid_desc_m_, make_tuple(const_cast<const CDataType*>(arg.p_aux_grid_),
arg.c_grid_desc_m_, const_cast<const CDataType*>(arg.p_aux_2_grid_)),
arg.c_grid_desc_m_, make_tuple(arg.p_c_grid_real_),
Subtract{}); Subtract{});
ave_time += ave_time +=
launch_and_time_kernel(stream_config, launch_and_time_kernel(stream_config,
...@@ -679,18 +674,18 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -679,18 +674,18 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
// c_imag = aux + aux_2 // c_imag = aux + aux_2
ave_time += launch_and_time_kernel(stream_config, ave_time += launch_and_time_kernel(
add_kernel, stream_config,
dim3(grid_size), add_kernel,
dim3(BlockSize), dim3(grid_size),
0, dim3(BlockSize),
arg.p_aux_grid_, 0,
arg.p_aux_2_grid_, make_tuple(arg.c_grid_desc_m_, arg.c_grid_desc_m_),
arg.p_c_grid_imag_, make_tuple(arg.c_grid_desc_m_),
arg.c_grid_desc_m_, make_tuple(const_cast<const CDataType*>(arg.p_aux_grid_),
arg.c_grid_desc_m_, const_cast<const CDataType*>(arg.p_aux_2_grid_)),
arg.c_grid_desc_m_, make_tuple(arg.p_c_grid_imag_),
Add{}); Add{});
} }
else else
{ {
...@@ -742,18 +737,18 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -742,18 +737,18 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
// c_real = aux - aux_2 // c_real = aux - aux_2
ave_time += launch_and_time_kernel(stream_config, ave_time += launch_and_time_kernel(
subtract_kernel, stream_config,
dim3(grid_size), subtract_kernel,
dim3(BlockSize), dim3(grid_size),
0, dim3(BlockSize),
arg.p_aux_grid_, 0,
arg.p_aux_2_grid_, make_tuple(arg.c_grid_desc_m_, arg.c_grid_desc_m_),
arg.p_c_grid_real_, make_tuple(arg.c_grid_desc_m_),
arg.c_grid_desc_m_, make_tuple(const_cast<const CDataType*>(arg.p_aux_grid_),
arg.c_grid_desc_m_, const_cast<const CDataType*>(arg.p_aux_2_grid_)),
arg.c_grid_desc_m_, make_tuple(arg.p_c_grid_real_),
Subtract{}); Subtract{});
ave_time += ave_time +=
launch_and_time_kernel(stream_config, launch_and_time_kernel(stream_config,
...@@ -790,18 +785,18 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -790,18 +785,18 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
// c_imag = aux + aux_2 // c_imag = aux + aux_2
ave_time += launch_and_time_kernel(stream_config, ave_time += launch_and_time_kernel(
add_kernel, stream_config,
dim3(grid_size), add_kernel,
dim3(BlockSize), dim3(grid_size),
0, dim3(BlockSize),
arg.p_aux_grid_, 0,
arg.p_aux_2_grid_, make_tuple(arg.c_grid_desc_m_, arg.c_grid_desc_m_),
arg.p_c_grid_imag_, make_tuple(arg.c_grid_desc_m_),
arg.c_grid_desc_m_, make_tuple(const_cast<const CDataType*>(arg.p_aux_grid_),
arg.c_grid_desc_m_, const_cast<const CDataType*>(arg.p_aux_2_grid_)),
arg.c_grid_desc_m_, make_tuple(arg.p_c_grid_imag_),
Add{}); Add{});
} }
return ave_time; return ave_time;
......
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
#include "ck/tensor_operation/gpu/device/device_conv_bwd_weight.hpp" #include "ck/tensor_operation/gpu/device/device_conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" #include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_unary_elementwise_1d.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
......
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
#include "ck/tensor_operation/gpu/device/device_conv_bwd_weight.hpp" #include "ck/tensor_operation/gpu/device/device_conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" #include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_unary_elementwise_1d.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment