Commit b89a88b5 authored by Adam Osewski's avatar Adam Osewski
Browse files

Merge branch 'develop' into wavelet_model

parents 41d5fca7 43c898f6
...@@ -2,38 +2,286 @@ ...@@ -2,38 +2,286 @@
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include <iostream> #include <iostream>
#include <vector> #include <sstream>
#include "ck/utility/math.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise_base.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "device_base.hpp" #include "ck/host_utility/kernel_launch.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
template <ck::index_t NumInputTensor, template <typename InDataTypeTuple,
ck::index_t NumOutputTensor, typename OutDataTypeTuple,
index_t NDim, typename ElementwiseOperation,
typename ElementwiseFunctor> index_t NumDim,
struct DeviceElementwise : public BaseOperator index_t MPerThread,
typename InScalarPerVectorSeq,
typename OutScalarPerVectorSeq>
struct DeviceElementwise
: public DeviceElementwiseBase<InDataTypeTuple, OutDataTypeTuple, ElementwiseOperation, NumDim>
{ {
virtual std::unique_ptr<BaseArgument> static constexpr int NumInput = InDataTypeTuple::Size();
MakeArgumentPointer(std::array<const void*, NumInputTensor> p_inputs, static constexpr int NumOutput = OutDataTypeTuple::Size();
std::array<void*, NumOutputTensor> p_outputs,
std::vector<index_t> lengths, static_assert(NumInput == InScalarPerVectorSeq::Size() &&
std::vector<std::vector<index_t>> input_strides, NumOutput == OutScalarPerVectorSeq::Size(),
std::vector<std::vector<index_t>> output_strides, "Tuple size is inconsistent with the number of in/out!");
ElementwiseFunctor functor) = 0;
static auto GenerateInDataTypePointerTuple()
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; {
}; return generate_tuple(
[&](auto I) {
template <ck::index_t NumInputTensor, using DataType = remove_cvref_t<decltype(InDataTypeTuple{}[I])>;
ck::index_t NumOutputTensor,
index_t NDim, return static_cast<const DataType*>(nullptr);
typename ElementwiseFunctor> },
using DeviceElementwisePtr = Number<NumInput>{});
std::unique_ptr<DeviceElementwise<NumInputTensor, NumOutputTensor, NDim, ElementwiseFunctor>>; };
static auto GenerateOutDataTypePointerTuple()
{
return generate_tuple(
[&](auto I) {
using DataType = remove_cvref_t<decltype(OutDataTypeTuple{}[I])>;
return static_cast<DataType*>(nullptr);
},
Number<NumOutput>{});
};
using InDataTypePointerTuple = decltype(GenerateInDataTypePointerTuple());
using OutDataTypePointerTuple = decltype(GenerateOutDataTypePointerTuple());
template <typename Desc_M>
static auto PadDescriptor_M_1d(Desc_M desc_m, index_t gridSize, index_t blockSize)
{
constexpr auto I0 = Number<0>{};
const auto m = desc_m.GetLength(I0);
const index_t loop_step = gridSize * blockSize * MPerThread;
const auto pad = math::integer_least_multiple(m, loop_step) - m;
const auto desc_m_pad =
transform_tensor_descriptor(desc_m,
make_tuple(make_right_pad_transform(m, pad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
return desc_m_pad;
}
static auto MakeDescriptor_M(const std::array<index_t, NumDim>& lengths,
const std::array<index_t, NumDim>& stride,
index_t gridSize,
index_t blockSize)
{
auto tupleOfShape = generate_tuple([&](auto I) { return lengths[I]; }, Number<NumDim>{});
auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number<NumDim>{});
// nd desc - [s0, s1, s2, ...]
const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride);
// merge nd to 1d desc - [s0 * s1 * ...]
if constexpr(NumDim > 1)
{
const auto desc_m = transform_tensor_descriptor(
desc,
make_tuple(make_merge_transform(tupleOfShape)),
make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number<NumDim>{})),
make_tuple(Sequence<0>{}));
return PadDescriptor_M_1d(desc_m, gridSize, blockSize);
}
else
return PadDescriptor_M_1d(desc, gridSize, blockSize);
}
template <index_t TupleSize>
static auto GenerateInOutGrid1dDescTuple(Number<TupleSize>)
{
return generate_tuple(
[&](auto) {
if constexpr(NumDim > 1)
{
return MakeDescriptor_M({1, 1}, {1, 1}, 1, 1);
}
else
{
return MakeDescriptor_M({1}, {1}, 1, 1);
};
},
Number<TupleSize>{});
};
using InGrid1dDescTuple = decltype(GenerateInOutGrid1dDescTuple(Number<NumInput>{}));
using OutGrid1dDescTuple = decltype(GenerateInOutGrid1dDescTuple(Number<NumOutput>{}));
using GridwiseElementwise = GridwiseElementwise_1D<InGrid1dDescTuple,
OutGrid1dDescTuple,
InDataTypePointerTuple,
OutDataTypePointerTuple,
ElementwiseOperation,
MPerThread,
InScalarPerVectorSeq,
OutScalarPerVectorSeq>;
struct Argument : public BaseArgument
{
Argument(const std::array<index_t, NumDim> lengths,
const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
const std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray,
const std::array<const void*, NumInput> in_dev_buffers,
const std::array<void*, NumOutput> out_dev_buffers,
ElementwiseOperation elementwise_op)
: lengths_(lengths),
inStridesArray_(inStridesArray),
outStridesArray_(outStridesArray),
elementwise_op_(elementwise_op),
blockSize_(256),
gridSize_(120) // FIXME - Calculate the grid size by number of CU in the future
{
in_dev_buffers_ = generate_tuple(
[&](auto I) {
using DataType = remove_cvref_t<decltype(InDataTypeTuple{}[I])>;
return static_cast<const DataType*>(in_dev_buffers[I.value]);
},
Number<NumInput>{});
out_dev_buffers_ = generate_tuple(
[&](auto I) {
using DataType = remove_cvref_t<decltype(OutDataTypeTuple{}[I])>;
return static_cast<DataType*>(out_dev_buffers[I.value]);
},
Number<NumOutput>{});
in_grid_1d_desc_tuple_ = generate_tuple(
[&](auto I) {
return MakeDescriptor_M(
lengths, inStridesArray[I.value], gridSize_, blockSize_);
},
Number<NumInput>{});
out_grid_1d_desc_tuple_ = generate_tuple(
[&](auto I) {
return MakeDescriptor_M(
lengths, outStridesArray[I.value], gridSize_, blockSize_);
},
Number<NumOutput>{});
}
InDataTypePointerTuple in_dev_buffers_;
OutDataTypePointerTuple out_dev_buffers_;
InGrid1dDescTuple in_grid_1d_desc_tuple_;
OutGrid1dDescTuple out_grid_1d_desc_tuple_;
std::array<index_t, NumDim> lengths_;
std::array<std::array<index_t, NumDim>, NumInput> inStridesArray_;
std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray_;
ElementwiseOperation elementwise_op_;
index_t blockSize_;
index_t gridSize_;
};
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const auto kernel = kernel_elementwise_1d<GridwiseElementwise,
InGrid1dDescTuple,
OutGrid1dDescTuple,
InDataTypePointerTuple,
OutDataTypePointerTuple,
ElementwiseOperation>;
float elapsed_time = launch_and_time_kernel(stream_config,
kernel,
dim3(arg.gridSize_),
dim3(arg.blockSize_),
0,
arg.in_grid_1d_desc_tuple_,
arg.out_grid_1d_desc_tuple_,
arg.in_dev_buffers_,
arg.out_dev_buffers_,
arg.elementwise_op_);
return elapsed_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
if(pArg == nullptr)
return false;
if(pArg->lengths_.back() % MPerThread != 0)
return false;
auto IsScalarPerVectorValid = [&](const std::array<index_t, NumDim>& lengths,
const std::array<index_t, NumDim>& strides,
index_t scalarPerVector) {
if(strides.back() == 1 && lengths.back() % scalarPerVector == 0)
return true;
if(strides.back() != 1 && scalarPerVector == 1)
return true;
return false;
};
bool valid = true;
static_for<0, NumInput, 1>{}([&](auto I) {
if(!IsScalarPerVectorValid(
pArg->lengths_, pArg->inStridesArray_[I.value], InScalarPerVectorSeq::At(I)))
valid = false;
});
static_for<0, NumOutput, 1>{}([&](auto I) {
if(!IsScalarPerVectorValid(
pArg->lengths_, pArg->outStridesArray_[I.value], OutScalarPerVectorSeq::At(I)))
valid = false;
});
return valid;
};
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::array<index_t, NumDim> lengths,
const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
const std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray,
const std::array<const void*, NumInput> in_dev_buffers,
const std::array<void*, NumOutput> out_dev_buffers,
ElementwiseOperation elementwise_op) override
{
return std::make_unique<Argument>(lengths,
inStridesArray,
outStridesArray,
in_dev_buffers,
out_dev_buffers,
elementwise_op);
}
static auto MakeInvoker() { return Invoker{}; }
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>();
};
}; // namespace device
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <memory>
#include <array>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename InDataTypeTuple,
typename OutDataTypeTuple,
typename ElementwiseOperation,
index_t NumDim>
struct DeviceElementwiseBase : public BaseOperator
{
static constexpr int NumInput = InDataTypeTuple::Size();
static constexpr int NumOutput = OutDataTypeTuple::Size();
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::array<index_t, NumDim> lengths,
const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
const std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray,
const std::array<const void*, NumInput> in_dev_buffers,
const std::array<void*, NumOutput> out_dev_buffers,
ElementwiseOperation elementwise_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; // namespace device
template <typename InDataTypeTuple,
typename OutDataTypeTuple,
typename ElementwiseOperation,
index_t NumDim>
using DeviceElementwiseBasePtr = std::unique_ptr<
DeviceElementwiseBase<InDataTypeTuple, OutDataTypeTuple, ElementwiseOperation, NumDim>>;
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -3,9 +3,6 @@ ...@@ -3,9 +3,6 @@
#pragma once #pragma once
#include <iostream>
#include <vector>
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck { namespace ck {
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
#include <array> #include <array>
#include "device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
......
...@@ -3,15 +3,27 @@ ...@@ -3,15 +3,27 @@
#pragma once #pragma once
#include <iostream> #include <array>
#include "device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
// FIXME: DeviceGemmReduce type need to well define the problem // FIXME: DeviceGemmReduce type need to well define the problem
// GEMM:
// input : A[AK0, M, AK1]
// input : B[AK0, N, AK1]
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// output : R0[M], R1[M], ...
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Q0 = reduce0(q_op0(E)), Q1 = reduce1(q_op0(E)), ...
// R0 = r_op0(Q0), R1 = r_op1(Q1), ...
// Assume:
// D0, D1, ... and E have the same layout
template <typename ALayout, template <typename ALayout,
typename BLayout, typename BLayout,
typename DELayout, typename DELayout,
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_multiple_r.hpp" #include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_multiple_r.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
...@@ -192,7 +193,10 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle ...@@ -192,7 +193,10 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA) static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA)
{ {
const auto a_grid_desc_mraw_kraw = [&]() { const auto a_grid_desc_mraw_kraw = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>) if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
...@@ -207,95 +211,10 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle ...@@ -207,95 +211,10 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
} }
}(); }();
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
const auto MPad = M - MRaw;
const auto KPad = K - KRaw;
if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad both M and K
assert(K % AK1 == 0);
const auto AK0 = K / AK1;
const auto a_grid_desc_m_k =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_right_pad_transform(MRaw, MPad),
make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MNPadding)
{
// pad M, but not K
assert(KRaw % AK1 == 0);
const auto AK0 = KRaw / AK1;
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_right_pad_transform(MRaw, MPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
// pad K, but not M
assert(K % AK1 == 0);
const auto AK0 = K / AK1;
const auto a_grid_desc_m_k = transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(MRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else
{
// not pad M or K
assert(KRaw % AK1 == 0);
const auto AK0 = KRaw / AK1;
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(MRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
} }
static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB) static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
{ {
const auto b_grid_desc_nraw_kraw = [&]() { const auto b_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
...@@ -310,92 +229,7 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle ...@@ -310,92 +229,7 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
} }
}(); }();
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
const auto NPad = N - NRaw;
const auto KPad = K - KRaw;
if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad both N and K
assert(K % BK1 == 0);
const auto BK0 = K / BK1;
const auto b_grid_desc_n_k =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_right_pad_transform(NRaw, NPad),
make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::MNPadding)
{
// pad N, but not K
assert(KRaw % BK1 == 0);
const auto BK0 = KRaw / BK1;
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
// pad K, but not N
assert(K % BK1 == 0);
const auto BK0 = K / BK1;
const auto b_grid_desc_n_k = transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(NRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else
{
// not pad N or K
assert(KRaw % BK1 == 0);
const auto BK0 = KRaw / BK1;
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(NRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
} }
static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE) static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
...@@ -413,47 +247,7 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle ...@@ -413,47 +247,7 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
} }
}(); }();
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
const auto MPad = M - MRaw;
const auto NPad = N - NRaw;
if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad M and N
return transform_tensor_descriptor(e_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(MRaw, MPad),
make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
// pad M, but not N
return transform_tensor_descriptor(
e_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
// pad N, but not M
return transform_tensor_descriptor(
e_grid_desc_mraw_nraw,
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
// not pad M or N
return e_grid_desc_mraw_nraw;
}
} }
// assume D is packed tensor // assume D is packed tensor
...@@ -482,10 +276,10 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle ...@@ -482,10 +276,10 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
} }
} }
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1));
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(1, 1, 1)); using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(1, 1, 1));
using RGridDesc_M = decltype(MakeRGridDescriptor_M(1)); using RGridDesc_M = decltype(MakeRGridDescriptor_M(1));
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1< using GridwiseGemm = GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1<
...@@ -504,8 +298,8 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle ...@@ -504,8 +298,8 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
ThreadReduceOperations, ThreadReduceOperations,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
RsGlobalMemoryDataOperation, RsGlobalMemoryDataOperation,
AGridDesc_AK0_M_AK1, AGridDesc_M_K,
BGridDesc_BK0_N_BK1, BGridDesc_N_K,
EGridDesc_M_N, EGridDesc_M_N,
RGridDesc_M, RGridDesc_M,
NumGemmKPrefetchStage, NumGemmKPrefetchStage,
...@@ -542,6 +336,13 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle ...@@ -542,6 +336,13 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
RThreadTransferDstScalarPerVector_MPerBlock, RThreadTransferDstScalarPerVector_MPerBlock,
LoopSched>; LoopSched>;
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap;
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
...@@ -567,12 +368,16 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle ...@@ -567,12 +368,16 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
p_ds_grid_{}, // FIXME p_ds_grid_{}, // FIXME
p_e_grid_{static_cast<EDataType*>(p_e_grid)}, p_e_grid_{static_cast<EDataType*>(p_e_grid)},
p_rs_grid_{}, // FIXME p_rs_grid_{}, // FIXME
a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)}, a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(MRaw, KRaw, StrideA)},
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)}, b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(KRaw, NRaw, StrideB)},
ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(MRaw, NRaw, StrideE)}, e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(MRaw, NRaw, StrideE)},
e_grid_desc_mblock_mperblock_nblock_nperblock_{},
r_grid_desc_m_{DeviceOp::MakeRGridDescriptor_M(MRaw)}, r_grid_desc_m_{DeviceOp::MakeRGridDescriptor_M(MRaw)},
a_grid_desc_ak0_m_ak1_{
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)},
b_grid_desc_bk0_n_bk1_{
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)},
ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
e_grid_desc_mblock_mperblock_nblock_nperblock_{},
rs_grid_desc_mblock_mperblock_{}, rs_grid_desc_mblock_mperblock_{},
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
...@@ -581,8 +386,8 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle ...@@ -581,8 +386,8 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
qs_element_op_{qs_element_op}, qs_element_op_{qs_element_op},
rs_element_op_{rs_element_op} rs_element_op_{rs_element_op}
{ {
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_,
b_grid_desc_bk0_n_bk1_, b_grid_desc_n_k_,
e_grid_desc_m_n_, e_grid_desc_m_n_,
r_grid_desc_m_, r_grid_desc_m_,
block_2_etile_map_)) block_2_etile_map_))
...@@ -624,6 +429,12 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle ...@@ -624,6 +429,12 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
typename GridwiseGemm::RsGridPointer p_rs_grid_; typename GridwiseGemm::RsGridPointer p_rs_grid_;
// tensor descriptors // tensor descriptors
AGridDesc_M_K a_grid_desc_m_k_;
BGridDesc_N_K b_grid_desc_n_k_;
EGridDesc_M_N e_grid_desc_m_n_;
RGridDesc_M r_grid_desc_m_;
// tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
StaticallyIndexedArray< StaticallyIndexedArray<
...@@ -631,16 +442,14 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle ...@@ -631,16 +442,14 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
NumDTensor> NumDTensor>
ds_grid_desc_mblock_mperblock_nblock_nperblock_; // FIXME: Ds desc may be of different ds_grid_desc_mblock_mperblock_nblock_nperblock_; // FIXME: Ds desc may be of different
// type from E // type from E
EGridDesc_M_N e_grid_desc_m_n_;
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_; e_grid_desc_mblock_mperblock_nblock_nperblock_;
RGridDesc_M r_grid_desc_m_;
StaticallyIndexedArray<typename GridwiseGemm::RGridDescriptor_MBlock_MPerBlock, NumRTensor> StaticallyIndexedArray<typename GridwiseGemm::RGridDescriptor_MBlock_MPerBlock, NumRTensor>
rs_grid_desc_mblock_mperblock_; rs_grid_desc_mblock_mperblock_;
// block-to-e-tile map // block-to-e-tile map
typename GridwiseGemm::DefaultBlock2ETileMap block_2_etile_map_; Block2ETileMap block_2_etile_map_;
// element-wise op // element-wise op
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
...@@ -657,8 +466,8 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle ...@@ -657,8 +466,8 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_n_k_,
arg.e_grid_desc_m_n_, arg.e_grid_desc_m_n_,
arg.r_grid_desc_m_, arg.r_grid_desc_m_,
arg.block_2_etile_map_)) arg.block_2_etile_map_))
...@@ -750,8 +559,8 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle ...@@ -750,8 +559,8 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
return false; return false;
} }
return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_n_k_,
arg.e_grid_desc_m_n_, arg.e_grid_desc_m_n_,
arg.r_grid_desc_m_, arg.r_grid_desc_m_,
arg.block_2_etile_map_); arg.block_2_etile_map_);
......
...@@ -292,8 +292,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -292,8 +292,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype( using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>; GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap;
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
...@@ -391,7 +389,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -391,7 +389,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
e_grid_desc_mblock_mperblock_nblock_nperblock_; e_grid_desc_mblock_mperblock_nblock_nperblock_;
// block-to-e-tile map // block-to-e-tile map
Block2ETileMap block_2_etile_map_; typename GridwiseGemm::DefaultBlock2ETileMap block_2_etile_map_;
// element-wise op // element-wise op
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_splitk.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename ADataType,
typename BDataType,
typename CDataType,
typename AccDataType,
typename ALayout,
typename BLayout,
typename CLayout,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
GemmSpecialization GemmSpec,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
ck::index_t K0PerBlock,
ck::index_t K1,
ck::index_t MPerXDL,
ck::index_t NPerXDL,
ck::index_t MXdlPerWave,
ck::index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_K0_M_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
ck::index_t ABlockTransferSrcVectorDim,
ck::index_t ABlockTransferSrcScalarPerVector,
ck::index_t ABlockTransferDstScalarPerVector_K1,
bool ABlockLdsAddExtraM,
typename BBlockTransferThreadClusterLengths_K0_N_K1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
ck::index_t BBlockTransferSrcVectorDim,
ck::index_t BBlockTransferSrcScalarPerVector,
ck::index_t BBlockTransferDstScalarPerVector_K1,
bool BBlockLdsAddExtraN,
ck::index_t CThreadTransferSrcDstVectorDim,
ck::index_t CThreadTransferDstScalarPerVector>
struct DeviceGemmXdlSplitK : public DeviceGemmSplitK<ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto K1Number = Number<K1>{};
static auto
MakeAGridDescriptor_KBatch_K0_M_K1(index_t M, index_t K, index_t StrideA, int KBatch, int KPad)
{
assert(KPad % (K1 * KBatch) == 0);
const index_t K0 = KPad / (K1 * KBatch);
const auto a_grid_desc_m_k = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
}
}();
const auto a_grid_desc_m_kpad = transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(M)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
return transform_tensor_descriptor(
a_grid_desc_m_kpad,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1Number)),
make_right_pad_transform(M, PadM)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
}
else
{
return transform_tensor_descriptor(
a_grid_desc_m_kpad,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1Number)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
}
}
static auto
MakeBGridDescriptor_KBatch_K0_N_K1(index_t K, index_t N, index_t StrideB, int KBatch, int KPad)
{
assert(KPad % (K1 * KBatch) == 0);
const index_t K0 = KPad / (K1 * KBatch);
const auto b_grid_desc_k_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
}
}();
const auto b_grid_desc_kpad_n = transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor(
b_grid_desc_kpad_n,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1Number)),
make_right_pad_transform(N, PadN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
}
else
{
return transform_tensor_descriptor(
b_grid_desc_kpad_n,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1Number)),
make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
}
}
static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC)
{
const auto c_grid_desc_m_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
}
}();
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}
static auto GetKPad(index_t K, index_t KBatch)
{
const index_t K0 = math::integer_divide_ceil(K, K1 * K0PerBlock * KBatch) * K0PerBlock;
const index_t KPad = KBatch * K0 * K1;
return KPad;
}
using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_KBatch_K0_M_K1(1, 1, 1, 1, 1));
using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_KBatch_K0_N_K1(1, 1, 1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
// GridwiseGemm
using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4<
BlockSize,
ADataType, // TODO: distinguish A/B datatype
AccDataType,
CDataType,
InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
MPerBlock,
NPerBlock,
K0PerBlock,
MPerXDL,
NPerXDL,
K1,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
false, // AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsAddExtraM,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
false, // BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN,
Sequence<0, 2, 4, 5, 6, 1, 3, 7>, // CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector>;
// GridwiseGemm
using GridwiseGemmAtomicAdd = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4<
BlockSize,
ADataType, // TODO: distinguish A/B datatype
AccDataType,
CDataType,
InMemoryDataOperationEnum::AtomicAdd,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
MPerBlock,
NPerBlock,
K0PerBlock,
MPerXDL,
NPerXDL,
K1,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
false, // AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsAddExtraM,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
false, // BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN,
Sequence<0, 2, 4, 5, 6, 1, 3, 7>, // CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector>;
using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
decltype(GridwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(CGridDesc_M_N{}));
using Block2CTileMap =
decltype(GridwiseGemm::MakeCBlockClusterAdaptor(CGridDesc_M_N{}, 1, 1, 1));
// Argument
struct Argument : public BaseArgument
{
Argument(const ADataType* p_a_grid,
const BDataType* p_b_grid,
CDataType* p_c_grid,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
index_t M01,
index_t N01,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
index_t k_batch)
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid},
a_grid_desc_kbatch_k0_m_k1_{},
b_grid_desc_kbatch_k0_n_k1_{},
c_grid_desc_m_n_{},
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{},
block_2_ctile_map_{},
M01_{M01},
N01_{N01},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op},
k_batch_{k_batch}
{
int KPad = DeviceGemmXdlSplitK::GetKPad(K, k_batch_);
a_grid_desc_kbatch_k0_m_k1_ = DeviceGemmXdlSplitK::MakeAGridDescriptor_KBatch_K0_M_K1(
M, K, StrideA, k_batch_, KPad);
b_grid_desc_kbatch_k0_n_k1_ = DeviceGemmXdlSplitK::MakeBGridDescriptor_KBatch_K0_N_K1(
K, N, StrideB, k_batch_, KPad);
c_grid_desc_m_n_ = DeviceGemmXdlSplitK::MakeCGridDescriptor_M_N(M, N, StrideC);
block_2_ctile_map_ =
GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_);
if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_,
b_grid_desc_kbatch_k0_n_k1_,
c_grid_desc_m_n_,
block_2_ctile_map_))
{
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ =
GridwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_grid_desc_m_n_);
}
}
// private:
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
CDataType* p_c_grid_;
AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_;
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_;
Block2CTileMap block_2_ctile_map_;
index_t M01_;
index_t N01_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
index_t k_batch_;
};
// Invoker
struct Invoker : public BaseInvoker
{
using Argument = DeviceGemmXdlSplitK::Argument;
void ShowInfo(const Argument& arg)
{
std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{"
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", "
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", "
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << ", "
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I3) << "}" << std::endl;
std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{"
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", "
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", "
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << ", "
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}" << std::endl;
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
ShowInfo(arg);
const auto kbatch = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0);
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_))
{
throw std::runtime_error(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting");
}
const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
const auto K0 = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1);
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
float ave_time = 0;
const auto Run = [&](const auto& kernel) {
// FIXME: this should be moved outside of DeviceOp
hipGetErrorString(
hipMemset(arg.p_c_grid_,
0,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_.GetElementSpaceSize() *
sizeof(CDataType)));
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.block_2_ctile_map_);
};
if(has_main_k0_block_loop)
{
if(kbatch == 1)
{
const auto kernel = kernel_gemm_xdlops_v2r4<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceGemmXdlSplitK::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdlSplitK::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmXdlSplitK::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
remove_reference_t<DeviceGemmXdlSplitK::Block2CTileMap>,
true>;
Run(kernel);
}
else
{
const auto kernel = kernel_gemm_xdlops_v2r4<
GridwiseGemmAtomicAdd,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceGemmXdlSplitK::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdlSplitK::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmXdlSplitK::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
remove_reference_t<DeviceGemmXdlSplitK::Block2CTileMap>,
true>;
Run(kernel);
}
}
else
{
if(kbatch == 1)
{
const auto kernel = kernel_gemm_xdlops_v2r4<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceGemmXdlSplitK::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdlSplitK::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmXdlSplitK::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
remove_reference_t<DeviceGemmXdlSplitK::Block2CTileMap>,
false>;
Run(kernel);
}
else
{
const auto kernel = kernel_gemm_xdlops_v2r4<
GridwiseGemmAtomicAdd,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceGemmXdlSplitK::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdlSplitK::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmXdlSplitK::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
remove_reference_t<DeviceGemmXdlSplitK::Block2CTileMap>,
false>;
Run(kernel);
}
}
return ave_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
static bool IsSupportedArgument(const Argument& arg)
{
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a"))
{
return false;
}
return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_);
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b,
CDataType* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
index_t KBatch)
{
return Argument{p_a,
p_b,
p_c,
M,
N,
K,
StrideA,
StrideB,
StrideC,
1,
1,
a_element_op,
b_element_op,
c_element_op,
KBatch};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
ck::index_t KBatch = 1) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c),
M,
N,
K,
StrideA,
StrideB,
StrideC,
1,
1,
a_element_op,
b_element_op,
c_element_op,
KBatch);
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceGemmXdlSplitK"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< K0PerBlock
<< ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -95,7 +95,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout, ...@@ -95,7 +95,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
const auto a_grid_desc_m_kpad = transform_tensor_descriptor( const auto a_grid_desc_m_kpad = transform_tensor_descriptor(
a_grid_desc_m_k, a_grid_desc_m_k,
make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(M)), make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#pragma once #pragma once
#include <vector> #include <array>
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
// Grouped Convolution Forward:
// input : input image A[G, N, C, Hi, Wi],
// input : weight B[G, K, C, Y, X],
// input : D0[G, N, K, Ho, Wo], D1[G, N, K, Ho, Wo], ...
// output : output image E[G, N, K, Ho, Wo]
// output : R0[G, N, Ho, Wo], R1[G, N, Ho, Wo], ...
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Q0 = reduce0(q_op0(E)), Q1 = reduce1(q_op0(E)), ...
// R0 = r_op0(Q0), R1 = r_op1(Q1), ...
// Assume:
// D0, D1, ... and E have the same layout
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DELayout,
typename RLayout,
typename ADataType,
typename BDataType,
typename DsDataType,
typename EDataType,
typename RsDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
typename QsElementwiseOperation,
typename RsElementwiseOperation>
struct DeviceGroupedConvFwdMultipleDMultipleR : public BaseOperator
{
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr index_t NumRTensor = RsDataType::Size();
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
const void* p_a,
const void* p_b,
const std::array<const void*, NumDTensor>& p_ds,
void* p_e,
std::array<void*, NumRTensor> p_rs,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_lengths,
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial + 2>& r_g_n_wos_lengths,
const std::array<index_t, NDimSpatial + 2>& r_g_n_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op,
const QsElementwiseOperation& qs_element_op,
const RsElementwiseOperation& rs_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <functional>
#include <iostream>
#include <iterator>
#include <numeric>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d_multiple_r.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/io.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace {
template <index_t NumDTensor, index_t NumRTensor>
struct ComputePtrOffsetOfStridedBatch
{
ComputePtrOffsetOfStridedBatch() = default;
ComputePtrOffsetOfStridedBatch(index_t BatchStrideA,
index_t BatchStrideB,
Array<ck::index_t, NumDTensor> BatchStrideDs,
index_t BatchStrideE,
Array<ck::index_t, NumRTensor> BatchStrideRs)
: BatchStrideA_(BatchStrideA),
BatchStrideB_(BatchStrideB),
BatchStrideDs_(BatchStrideDs),
BatchStrideE_(BatchStrideE),
BatchStrideRs_(BatchStrideRs)
{
}
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideA_);
}
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideB_);
}
__host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const
{
Array<long_index_t, NumDTensor> ds_offset;
static_for<0, NumDTensor, 1>{}(
[&](auto i) { ds_offset(i) = g_idx * static_cast<long_index_t>(BatchStrideDs_[i]); });
return ds_offset;
}
__host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideE_);
}
__host__ __device__ constexpr auto GetRsPtrOffset(index_t g_idx) const
{
Array<long_index_t, NumRTensor> rs_offset;
static_for<0, NumRTensor, 1>{}(
[&](auto i) { rs_offset(i) = g_idx * static_cast<long_index_t>(BatchStrideRs_[i]); });
return rs_offset;
}
index_t BatchStrideA_;
index_t BatchStrideB_;
Array<ck::index_t, NumDTensor> BatchStrideDs_;
index_t BatchStrideE_;
Array<ck::index_t, NumRTensor> BatchStrideRs_;
};
/*
* \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM.
*
* \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix
* given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly
* strided batched, but we can easily extend to other layouts. The returned offset can be either \p
* index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB
* limitations.
*
* \tparam Block2ETileMap Block2ETileMap::CalculateBottomIndex() takes in id of a workgroup and
* returns the 2D index of the tile that it computes. \see
* GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run().
*
* \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2
* tiles from different matrices. Keep in mind that these 2 matrices can share the same grid
* descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link
* device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for \link
* DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the computing of
* pointer offset into \p ComputePtrOffsetOfStridedBatch.
*
* \note \p Block2ETileMap allows customized mapping between a workgroup and the C-tile it computes.
* Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to
* realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion).
*
*/
template <typename GridwiseGemm,
typename ABDataType,
typename DsPointer,
typename EDataType,
typename RsPointer,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
typename QsElementwiseOperation,
typename RsElementwiseOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename RsGridDescriptor_MBlock_MPerBlock,
typename Block2ETileMap,
typename ComputePtrOffsetOfBatch,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_batch_gemm_multiple_d_xdl_cshuffle(
const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
RsPointer p_rs_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op,
const QsElementwiseOperation qs_element_op,
const RsElementwiseOperation rs_element_op,
const index_t batch_count,
const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1,
const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_,
const RsGridDescriptor_MBlock_MPerBlock rs_grid_desc_mblock_mperblock,
const Block2ETileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
const auto rs_batch_offset = compute_ptr_offset_of_batch.GetRsPtrOffset(g_idx);
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
DsPointer p_ds_grid_grp;
static constexpr index_t NumDTensor =
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size();
static_for<0, NumDTensor, 1>{}(
[&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; });
RsPointer p_rs_grid_grp;
static constexpr index_t NumRTensor = RsGridDescriptor_MBlock_MPerBlock::Size();
static_for<0, NumRTensor, 1>{}(
[&](auto i) { p_rs_grid_grp(i) = p_rs_grid[i] + rs_batch_offset[i]; });
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
p_ds_grid_grp,
p_e_grid + e_batch_offset,
p_rs_grid_grp,
p_shared,
a_element_op,
b_element_op,
cde_element_op,
qs_element_op,
rs_element_op,
a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock_,
rs_grid_desc_mblock_mperblock,
block_2_ctile_map);
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_ds_grid;
ignore = p_e_grid;
ignore = p_rs_grid;
ignore = batch_count;
ignore = a_grid_desc_k0_m_k1;
ignore = b_grid_desc_k0_n_k1;
ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = e_grid_desc_mblock_mperblock_nblock_nperblock_;
ignore = rs_grid_desc_mblock_mperblock;
ignore = a_element_op;
ignore = b_element_op;
ignore = cde_element_op;
ignore = qs_element_op;
ignore = rs_element_op;
ignore = compute_ptr_offset_of_batch;
ignore = block_2_ctile_map;
#endif
}
} // namespace
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DELayout,
typename RLayout,
typename ADataType,
typename BDataType,
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
typename EDataType,
typename ReduceAccDataType,
typename RsDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
typename QsElementwiseOperation,
typename RsElementwiseOperation,
typename ThreadReduceOperations,
typename RsGlobalMemoryDataOperation,
ConvolutionForwardSpecialization ConvForwardSpecialization,
GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t AK1,
index_t BK1,
index_t MPerXDL,
index_t NPerXDL,
index_t MXdlPerWave,
index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
index_t BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CDRThreadTransferClusterLengths_MPerBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock,
index_t RThreadTransferDstScalarPerVector_MPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
: public DeviceGroupedConvFwdMultipleDMultipleR<NDimSpatial,
ALayout,
BLayout,
DELayout,
RLayout,
ADataType,
BDataType,
DsDataType,
EDataType,
RsDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
RsElementwiseOperation,
QsElementwiseOperation>
{
using DeviceOp = DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle;
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr index_t NumRTensor = RsDataType::Size();
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto conv_to_gemm_transformer =
TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>{};
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
template <typename ALay>
static auto
MakeAGridDescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads)
{
const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>(a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads);
const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
return in_gemmm_gemmk_desc;
}
template <typename BLay>
static auto
MakeBGridDescriptor_N_K(const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides)
{
const auto wei_gemmnraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>(b_g_k_c_xs_lengths,
b_g_k_c_xs_strides);
const auto wei_gemmn_gemmk_desc =
matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc);
return wei_gemmn_gemmk_desc;
}
template <typename ELay>
static auto
MakeEGridDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides)
{
const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(e_g_n_k_wos_lengths,
e_g_n_k_wos_strides);
const auto out_gemmm_gemmn_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
return out_gemmm_gemmn_desc;
}
template <typename Descriptor>
static auto GetPaddedRGridDescriptor(Descriptor descriptor, index_t MRaw)
{
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto MPad = M - MRaw;
if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MNPadding ||
GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad M
return transform_tensor_descriptor(
descriptor,
make_tuple(make_right_pad_transform(descriptor, MPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
}
else
{
// not pad M
return descriptor;
}
}
template <typename RLay,
typename std::enable_if<is_same_v<RLay, tensor_layout::convolution::GNW> ||
is_same_v<RLay, tensor_layout::convolution::GNHW> ||
is_same_v<RLay, tensor_layout::convolution::GNDHW>,
bool>::type = false>
static auto
MakeRGridDescriptor_M(const std::array<index_t, NDimSpatial + 2>& r_g_n_wos_lengths,
const std::array<index_t, NDimSpatial + 2>& /* r_g_n_wos_strides */)
{
const index_t N = r_g_n_wos_lengths[1];
const index_t NHoWo = N * std::accumulate(r_g_n_wos_lengths.begin() + 2,
r_g_n_wos_lengths.begin() + 2 + NDimSpatial,
index_t{1},
std::multiplies<index_t>());
const auto r_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(NHoWo));
return GetPaddedRGridDescriptor(r_grid_desc_mraw, NHoWo);
}
template <typename RLay,
typename std::enable_if<is_same_v<RLay, tensor_layout::convolution::G_NW> ||
is_same_v<RLay, tensor_layout::convolution::G_NHW> ||
is_same_v<RLay, tensor_layout::convolution::G_NDHW> ||
is_same_v<RLay, tensor_layout::convolution::NWG> ||
is_same_v<RLay, tensor_layout::convolution::NHWG> ||
is_same_v<RLay, tensor_layout::convolution::NDHWG>,
bool>::type = false>
static auto MakeRGridDescriptor_M(const std::array<index_t, NDimSpatial + 2>& r_g_n_wos_lengths,
const std::array<index_t, NDimSpatial + 2>& r_g_n_wos_strides)
{
const index_t N = r_g_n_wos_lengths[1];
const index_t WoStride = r_g_n_wos_strides[NDimSpatial + 2];
const index_t NHoWo = N * std::accumulate(r_g_n_wos_lengths.begin() + 2,
r_g_n_wos_lengths.begin() + 2 + NDimSpatial,
index_t{1},
std::multiplies<index_t>());
const auto r_grid_desc_mraw =
make_naive_tensor_descriptor(make_tuple(NHoWo), make_tuple(WoStride));
return GetPaddedRGridDescriptor(r_grid_desc_mraw, NHoWo);
}
using AGridDesc_M_K = remove_cvref_t<decltype(
MakeAGridDescriptor_M_K<ALayout>({}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>;
using BGridDesc_N_K = remove_cvref_t<decltype(MakeBGridDescriptor_N_K<BLayout>({}, {}))>;
using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<DELayout>({}, {}))>;
using RGridDesc_M = remove_cvref_t<decltype(MakeRGridDescriptor_M<RLayout>({}, {}))>;
// GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1<
ADataType, // TODO: distinguish A/B datatype
AccDataType,
CShuffleDataType,
DsDataType,
EDataType,
ReduceAccDataType,
RsDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
QsElementwiseOperation,
RsElementwiseOperation,
ThreadReduceOperations,
InMemoryDataOperationEnum::Set,
RsGlobalMemoryDataOperation,
AGridDesc_M_K,
BGridDesc_N_K,
EGridDesc_M_N,
RGridDesc_M,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CDRThreadTransferClusterLengths_MPerBlock_NPerBlock,
CDEBlockTransferScalarPerVector_NPerBlock,
RThreadTransferDstScalarPerVector_MPerBlock,
LoopSched>;
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap;
// Argument
struct Argument : public BaseArgument
{
Argument(const void* p_a,
const void* p_b,
const std::array<const void*, NumDTensor>& p_ds,
void* p_e,
std::array<void*, NumRTensor> p_rs,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_lengths,
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial + 2>& r_g_n_wos_lengths,
const std::array<index_t, NDimSpatial + 2>& r_g_n_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op,
const QsElementwiseOperation& qs_element_op,
const RsElementwiseOperation& rs_element_op)
: p_a_grid_{static_cast<const ADataType*>(p_a)},
p_b_grid_{static_cast<const BDataType*>(p_b)},
p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e)},
p_rs_grid_{}, // FIXME
a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K<ALayout>(a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads)},
b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K<BLayout>(b_g_k_c_xs_lengths,
b_g_k_c_xs_strides)},
ds_grid_desc_m_n_{},
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<DELayout>(e_g_n_k_wos_lengths,
e_g_n_k_wos_strides)},
r_grid_desc_m_{
DeviceOp::MakeRGridDescriptor_M<RLayout>(r_g_n_wos_lengths, r_g_n_wos_strides)},
a_grid_desc_ak0_m_ak1_{
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)},
b_grid_desc_bk0_n_bk1_{
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)},
ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
e_grid_desc_mblock_mperblock_nblock_nperblock_{},
rs_grid_desc_mblock_mperblock_{},
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
compute_ptr_offset_of_batch_{},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op},
qs_element_op_{qs_element_op},
rs_element_op_{rs_element_op},
a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths},
a_g_n_c_wis_strides_{a_g_n_c_wis_strides},
b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
b_g_k_c_xs_strides_{b_g_k_c_xs_strides},
ds_g_n_k_wos_lengths_{ds_g_n_k_wos_lengths},
ds_g_n_k_wos_strides_{ds_g_n_k_wos_strides},
e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths},
e_g_n_k_wos_strides_{e_g_n_k_wos_strides},
conv_filter_strides_{conv_filter_strides},
conv_filter_dilations_{conv_filter_dilations},
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads}
{
// A/B/E Batch Stride
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_c_wis_strides[0];
compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0];
compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_k_wos_strides[0];
// populate desc for Ds/E
if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_,
b_grid_desc_n_k_,
e_grid_desc_m_n_,
r_grid_desc_m_,
block_2_etile_map_))
{
e_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n_);
// populate pointer, batch stride, desc for Ds
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
// D pointer
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds[i]);
// D batch stride
compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0];
// D desc
ds_grid_desc_m_n_(i) = DeviceOp::MakeEGridDescriptor_M_N<DELayout>(
ds_g_n_k_wos_lengths[i], ds_g_n_k_wos_strides[i]);
ds_grid_desc_mblock_mperblock_nblock_nperblock_(i) =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n_(i));
});
// populate pointer for Rs
static_for<0, NumRTensor, 1>{}([&](auto i) {
using RDataType = remove_cvref_t<tuple_element_t<i.value, RsDataType>>;
// R pointer
p_rs_grid_(i) = static_cast<RDataType*>(p_rs[i]);
rs_grid_desc_mblock_mperblock_(i) =
GridwiseGemm::MakeRGridDescriptor_MBlock_MPerBlock(r_grid_desc_m_);
});
}
}
void Print() const
{
std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl;
std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl;
static_for<0, NumDTensor, 1>{}(
[&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; });
std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl;
}
// private:
// pointers
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
typename GridwiseGemm::DsGridPointer p_ds_grid_;
EDataType* p_e_grid_;
typename GridwiseGemm::RsGridPointer p_rs_grid_;
// tensor descriptors for problem definiton
AGridDesc_M_K a_grid_desc_m_k_;
BGridDesc_N_K b_grid_desc_n_k_;
EGridDesc_M_N ds_grid_desc_m_n_;
EGridDesc_M_N e_grid_desc_m_n_;
RGridDesc_M r_grid_desc_m_;
// tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
StaticallyIndexedArray<
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
NumDTensor>
ds_grid_desc_mblock_mperblock_nblock_nperblock_; // FIXME: Ds desc may be of different
// type from E
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_;
StaticallyIndexedArray<typename GridwiseGemm::RGridDescriptor_MBlock_MPerBlock, NumRTensor>
rs_grid_desc_mblock_mperblock_;
// block-to-e-tile map
Block2ETileMap block_2_etile_map_;
ComputePtrOffsetOfStridedBatch<NumDTensor, NumRTensor> compute_ptr_offset_of_batch_;
// element-wise op
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
QsElementwiseOperation qs_element_op_;
RsElementwiseOperation rs_element_op_;
// for checking IsSupportedArgument()
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_;
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_;
std::array<index_t, NDimSpatial> conv_filter_strides_;
std::array<index_t, NDimSpatial> conv_filter_dilations_;
std::array<index_t, NDimSpatial> input_left_pads_;
std::array<index_t, NDimSpatial> input_right_pads_;
};
// Invoker
struct Invoker : public BaseInvoker
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
arg.b_grid_desc_n_k_,
arg.e_grid_desc_m_n_,
arg.r_grid_desc_m_,
arg.block_2_etile_map_))
{
throw std::runtime_error(
"wrong! GridwiseGemmMultipleD_xdl_cshuffle has invalid setting");
}
const index_t grid_size =
arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) *
arg.a_g_n_c_wis_lengths_[0]; // Group count
const auto K =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
auto launch_kernel = [&](auto has_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value;
const auto kernel = kernel_batch_gemm_multiple_d_xdl_cshuffle<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
typename GridwiseGemm::DsGridPointer,
EDataType,
typename GridwiseGemm::RsGridPointer,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
QsElementwiseOperation,
RsElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
ck::StaticallyIndexedArray<
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
NumDTensor>,
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
ck::StaticallyIndexedArray<
typename GridwiseGemm::RGridDescriptor_MBlock_MPerBlock,
NumRTensor>,
Block2ETileMap,
ComputePtrOffsetOfStridedBatch<NumDTensor, NumRTensor>,
has_main_loop>;
return launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_ds_grid_,
arg.p_e_grid_,
arg.p_rs_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
arg.qs_element_op_,
arg.rs_element_op_,
arg.a_g_n_c_wis_lengths_[0], // Group count
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.rs_grid_desc_mblock_mperblock_,
arg.block_2_etile_map_,
arg.compute_ptr_offset_of_batch_);
};
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
return launch_kernel(integral_constant<bool, true>{});
}
else
{
return launch_kernel(integral_constant<bool, false>{});
}
}
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static bool IsSupportedArgument(const Argument& arg)
{
namespace ctc = tensor_layout::convolution;
// check device
if(get_device_name() == "gfx908")
{
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, float> ||
is_same_v<AccDataType, int32_t>))
{
return false;
}
}
else if(get_device_name() == "gfx90a")
{
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, float> ||
is_same_v<AccDataType, int32_t> || is_same_v<AccDataType, double>))
{
return false;
}
}
else
{
return false;
}
// check ConvolutionForwardSpecialization
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{
// check if it's 1x1, stride=1 conv
for(index_t i = 0; i < NDimSpatial; ++i)
{
const index_t X = arg.b_g_k_c_xs_lengths_[i + 2];
const index_t ConvStride = arg.conv_filter_strides_[i];
const index_t LeftPad = arg.input_left_pads_[i];
const index_t RightPad = arg.input_right_pads_[i];
if(!(X == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0))
{
return false;
}
}
}
else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization::Filter1x1Pad0)
{
// check if it's 1x1 conv
for(index_t i = 0; i < NDimSpatial; ++i)
{
const index_t X = arg.b_g_k_c_xs_lengths_[i + 2];
const index_t LeftPad = arg.input_left_pads_[i];
const index_t RightPad = arg.input_right_pads_[i];
if(!(X == 1 && LeftPad == 0 && RightPad == 0))
{
return false;
}
}
}
// check vector access of A
// FIXME: layout
if constexpr(is_same_v<ALayout, ctc::G_NW_C> || is_same_v<ALayout, ctc::G_NHW_C> ||
is_same_v<ALayout, ctc::G_NDHW_C> || is_same_v<ALayout, ctc::GNWC> ||
is_same_v<ALayout, ctc::GNHWC> || is_same_v<ALayout, ctc::GNDHWC> ||
is_same_v<ALayout, ctc::NWGC> || is_same_v<ALayout, ctc::NHWGC> ||
is_same_v<ALayout, ctc::NDHWGC>)
{
const index_t C = arg.a_g_n_c_wis_lengths_[2];
if(!(ABlockTransferSrcVectorDim == 2 && C % ABlockTransferSrcScalarPerVector == 0))
{
return false;
}
}
else
{
return false;
}
// check vector access of B
// FIXME: layout
if constexpr(is_same_v<BLayout, ctc::G_K_X_C> || is_same_v<BLayout, ctc::G_K_YX_C> ||
is_same_v<BLayout, ctc::G_K_ZYX_C> || is_same_v<BLayout, ctc::GKXC> ||
is_same_v<BLayout, ctc::GKYXC> || is_same_v<BLayout, ctc::GKZYXC> ||
is_same_v<BLayout, ctc::KXGC> || is_same_v<BLayout, ctc::KYXGC> ||
is_same_v<BLayout, ctc::KZYXGC>)
{
const index_t C = arg.b_g_k_c_xs_lengths_[2];
if(!(BBlockTransferSrcVectorDim == 2 && C % BBlockTransferSrcScalarPerVector == 0))
{
return false;
}
}
else
{
return false;
}
// check vector access of Ds
bool valid = true;
static_for<0, NumDTensor, 1>{}([&](auto i) {
// FIXME: layout
if constexpr(is_same_v<DELayout, ctc::G_NW_K> || is_same_v<DELayout, ctc::G_NHW_K> ||
is_same_v<DELayout, ctc::G_NDHW_K> || is_same_v<DELayout, ctc::GNWK> ||
is_same_v<DELayout, ctc::GNHWK> || is_same_v<DELayout, ctc::GNDHWK> ||
is_same_v<DELayout, ctc::NWGK> || is_same_v<DELayout, ctc::NHWGK> ||
is_same_v<DELayout, ctc::NDHWGK>)
{
const index_t K = arg.ds_g_n_k_wos_lengths_[i][2];
if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0))
{
valid = false;
}
}
else
{
valid = false;
}
});
if(!valid)
{
return false;
}
// check vector access of E
if constexpr(is_same_v<DELayout, ctc::G_NW_K> || is_same_v<DELayout, ctc::G_NHW_K> ||
is_same_v<DELayout, ctc::G_NDHW_K> || is_same_v<DELayout, ctc::GNWK> ||
is_same_v<DELayout, ctc::GNHWK> || is_same_v<DELayout, ctc::GNDHWK> ||
is_same_v<DELayout, ctc::NWGK> || is_same_v<DELayout, ctc::NHWGK> ||
is_same_v<DELayout, ctc::NDHWGK>)
{
const index_t K = arg.e_g_n_k_wos_lengths_[2];
if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0))
{
return false;
}
}
else
{
return false;
}
// check vector access of R
if constexpr(!(is_same_v<RLayout, ctc::G_NW> || is_same_v<RLayout, ctc::G_NHW> ||
is_same_v<RLayout, ctc::G_NDHW> || is_same_v<RLayout, ctc::GNW> ||
is_same_v<RLayout, ctc::GNHW> || is_same_v<RLayout, ctc::GNDHW> ||
is_same_v<RLayout, ctc::NWG> || is_same_v<RLayout, ctc::NHWG> ||
is_same_v<RLayout, ctc::NDHWG>))
{
return false;
}
// check Gridwise GEMM
return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
arg.b_grid_desc_n_k_,
arg.e_grid_desc_m_n_,
arg.r_grid_desc_m_,
arg.block_2_etile_map_);
}
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(
const void* p_a,
const void* p_b,
const std::array<const void*, NumDTensor>& p_ds,
void* p_e,
std::array<void*, NumRTensor> p_rs,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_lengths,
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial + 2>& r_g_n_wos_lengths,
const std::array<index_t, NDimSpatial + 2>& r_g_n_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op,
const QsElementwiseOperation& qs_element_op,
const RsElementwiseOperation& rs_element_op)
{
return Argument{p_a,
p_b,
p_ds,
p_e,
p_rs,
a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
ds_g_n_k_wos_lengths,
ds_g_n_k_wos_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
r_g_n_wos_lengths,
r_g_n_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
a_element_op,
b_element_op,
cde_element_op,
qs_element_op,
rs_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
std::unique_ptr<BaseArgument> MakeArgumentPointer(
const void* p_a,
const void* p_b,
const std::array<const void*, NumDTensor>& p_ds,
void* p_e,
std::array<void*, NumRTensor> p_rs,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_lengths,
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial + 2>& r_g_n_wos_lengths,
const std::array<index_t, NDimSpatial + 2>& r_g_n_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op,
const QsElementwiseOperation& qs_element_op,
const RsElementwiseOperation& rs_element_op) override
{
return std::make_unique<Argument>(p_a,
p_b,
p_ds,
p_e,
p_rs,
a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
ds_g_n_k_wos_lengths,
ds_g_n_k_wos_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
r_g_n_wos_lengths,
r_g_n_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
a_element_op,
b_element_op,
cde_element_op,
qs_element_op,
rs_element_op);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceGroupedConvFwdMultipleD_Xdl_CShuffle"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< KPerBlock << ", "
<< getConvForwardSpecializationString(ConvForwardSpecialization)
<< ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -13,8 +13,9 @@ ...@@ -13,8 +13,9 @@
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" #include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
...@@ -296,922 +297,71 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle ...@@ -296,922 +297,71 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
static constexpr auto conv_to_gemm_transformer =
TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>{};
static constexpr auto matrix_padder = static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock}; MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
template <typename ALay, template <typename ALay>
typename std::enable_if<NDimSpatial == 1 &&
is_same_v<ALay, tensor_layout::convolution::GNWC>,
bool>::type = false>
static auto
MakeAGridDescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& /* a_g_n_c_wis_strides */,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& /* b_g_k_c_xs_strides */,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& /* e_g_n_k_wos_strides */,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads)
{
const index_t N = a_g_n_c_wis_lengths[1];
const index_t C = a_g_n_c_wis_lengths[2];
const index_t Wi = a_g_n_c_wis_lengths[3];
const index_t Wo = e_g_n_k_wos_lengths[3];
const index_t ConvStrideW = conv_filter_strides[0];
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{
const index_t NWo = N * std::accumulate(e_g_n_k_wos_lengths.begin() + 3,
e_g_n_k_wos_lengths.begin() + 3 + NDimSpatial,
index_t{1},
std::multiplies<index_t>());
const auto in_gemmmraw_gemmk_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(NWo, C));
const auto in_gemmm_gemmk_grid_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmk_grid_desc);
return in_gemmm_gemmk_grid_desc;
}
else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization::Filter1x1Pad0)
{
const auto in_n_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C));
const auto in_n_wo_c_grid_desc = transform_tensor_descriptor(
in_n_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto in_gemmmraw_gemmkraw_grid_desc = transform_tensor_descriptor(
in_n_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, Wo)), make_pass_through_transform(C)),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmm_gemmk_grid_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_grid_desc);
return in_gemmm_gemmk_grid_desc;
}
else
{
const index_t X = b_g_k_c_xs_lengths[3];
const index_t ConvDilationW = conv_filter_dilations[0];
const index_t InLeftPadW = input_left_pads[0];
const index_t InRightPadW = input_right_pads[0];
const auto in_n_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C));
const auto in_n_wip_c_grid_desc = transform_tensor_descriptor(
in_n_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto in_n_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_wip_c_grid_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
const auto in_gemmmraw_gemmk_grid_desc =
transform_tensor_descriptor(in_n_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, Wo)),
make_merge_transform(make_tuple(X, C))),
make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmm_gemmk_grid_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmk_grid_desc);
return in_gemmm_gemmk_grid_desc;
}
}
template <typename ALay,
typename std::enable_if<NDimSpatial == 2 &&
is_same_v<ALay, tensor_layout::convolution::GNHWC>,
bool>::type = false>
static auto
MakeAGridDescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& /* a_g_n_c_wis_strides */,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& /* b_g_k_c_xs_strides */,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& /* e_g_n_k_wos_strides */,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads)
{
const index_t N = a_g_n_c_wis_lengths[1];
const index_t C = a_g_n_c_wis_lengths[2];
const index_t Hi = a_g_n_c_wis_lengths[3];
const index_t Wi = a_g_n_c_wis_lengths[4];
const index_t Ho = e_g_n_k_wos_lengths[3];
const index_t Wo = e_g_n_k_wos_lengths[4];
const index_t ConvStrideH = conv_filter_strides[0];
const index_t ConvStrideW = conv_filter_strides[1];
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{
const index_t NHoWo = N * std::accumulate(e_g_n_k_wos_lengths.begin() + 3,
e_g_n_k_wos_lengths.begin() + 3 + NDimSpatial,
index_t{1},
std::multiplies<index_t>());
const auto in_gemmmraw_gemmkraw_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(NHoWo, C));
const auto in_gemmm_gemmk_grid_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_grid_desc);
return in_gemmm_gemmk_grid_desc;
}
else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization::Filter1x1Pad0)
{
const auto in_n_hi_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
const auto in_n_ho_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_gemmmraw_gemmk_grid_desc =
transform_tensor_descriptor(in_n_ho_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)),
make_pass_through_transform(C)),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmm_gemmk_grid_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmk_grid_desc);
return in_gemmm_gemmk_grid_desc;
}
else
{
const index_t Y = b_g_k_c_xs_lengths[3];
const index_t X = b_g_k_c_xs_lengths[4];
const index_t ConvDilationH = conv_filter_dilations[0];
const index_t ConvDilationW = conv_filter_dilations[1];
const index_t InLeftPadH = input_left_pads[0];
const index_t InLeftPadW = input_left_pads[1];
const index_t InRightPadH = input_right_pads[0];
const index_t InRightPadW = input_right_pads[1];
const auto in_n_hi_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_gemmmraw_gemmk_grid_desc =
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)),
make_merge_transform(make_tuple(Y, X, C))),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmm_gemmk_grid_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmk_grid_desc);
return in_gemmm_gemmk_grid_desc;
}
}
template <typename ALay,
typename std::enable_if<NDimSpatial == 3 &&
is_same_v<ALay, tensor_layout::convolution::GNDHWC>,
bool>::type = false>
static auto
MakeAGridDescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& /* a_g_n_c_wis_strides */,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& /* b_g_k_c_xs_strides */,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& /* e_g_n_k_wos_strides */,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads)
{
const index_t N = a_g_n_c_wis_lengths[1];
const index_t C = a_g_n_c_wis_lengths[2];
const index_t Di = a_g_n_c_wis_lengths[3];
const index_t Hi = a_g_n_c_wis_lengths[4];
const index_t Wi = a_g_n_c_wis_lengths[5];
const index_t Do = e_g_n_k_wos_lengths[3];
const index_t Ho = e_g_n_k_wos_lengths[4];
const index_t Wo = e_g_n_k_wos_lengths[5];
const index_t ConvStrideD = conv_filter_strides[0];
const index_t ConvStrideH = conv_filter_strides[1];
const index_t ConvStrideW = conv_filter_strides[2];
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{
const index_t NDoHoWo =
N * std::accumulate(e_g_n_k_wos_lengths.begin() + 3,
e_g_n_k_wos_lengths.begin() + 3 + NDimSpatial,
index_t{1},
std::multiplies<index_t>());
const auto in_gemmmraw_gemmkraw_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(NDoHoWo, C));
const auto in_gemmm_gemmk_grid_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_grid_desc);
return in_gemmm_gemmk_grid_desc;
}
else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization::Filter1x1Pad0)
{
const auto in_n_di_hi_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C));
const auto in_n_do_ho_wo_c_grid_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Do), make_tuple(ConvStrideD)),
make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto in_gemmmraw_gemmkraw_grid_desc = transform_tensor_descriptor(
in_n_do_ho_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)),
make_pass_through_transform(C)),
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmm_gemmk_grid_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_grid_desc);
return in_gemmm_gemmk_grid_desc;
}
else
{
const index_t Z = b_g_k_c_xs_lengths[3];
const index_t Y = b_g_k_c_xs_lengths[4];
const index_t X = b_g_k_c_xs_lengths[5];
const index_t ConvDilationD = conv_filter_dilations[0];
const index_t ConvDilationH = conv_filter_dilations[1];
const index_t ConvDilationW = conv_filter_dilations[2];
const index_t InLeftPadD = input_left_pads[0];
const index_t InLeftPadH = input_left_pads[1];
const index_t InLeftPadW = input_left_pads[2];
const index_t InRightPadD = input_right_pads[0];
const index_t InRightPadH = input_right_pads[1];
const index_t InRightPadW = input_right_pads[2];
const auto in_n_di_hi_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C));
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Di, InLeftPadD, InRightPadD),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{},
Sequence<1, 2>{},
Sequence<3, 4>{},
Sequence<5, 6>{},
Sequence<7>{}));
const auto in_gemmmraw_gemmkraw_grid_desc = transform_tensor_descriptor(
in_n_z_do_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)),
make_merge_transform(make_tuple(Z, Y, X, C))),
make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmm_gemmk_grid_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_grid_desc);
return in_gemmm_gemmk_grid_desc;
}
}
// TODO: implement ck::tensor_layout::convolution that describe packed/strided dimemsion as
// properties
template <typename ALay,
typename std::enable_if<NDimSpatial == 1 &&
(is_same_v<ALay, tensor_layout::convolution::G_NW_C> ||
is_same_v<ALay, tensor_layout::convolution::NWGC>),
bool>::type = false>
static auto static auto
MakeAGridDescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, MakeAGridDescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& /* b_g_k_c_xs_strides */, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& /* e_g_n_k_wos_strides */, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides, const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations, const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads, const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads) const std::array<index_t, NDimSpatial>& input_right_pads)
{ {
const index_t N = a_g_n_c_wis_lengths[1]; const auto in_gemmmraw_gemmkraw_desc =
const index_t C = a_g_n_c_wis_lengths[2]; conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>(a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
const index_t Wi = a_g_n_c_wis_lengths[3]; b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
const index_t Wo = e_g_n_k_wos_lengths[3]; e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
const index_t ConvStrideW = conv_filter_strides[0]; conv_filter_strides,
conv_filter_dilations,
if constexpr(ConvForwardSpecialization == input_left_pads,
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) input_right_pads);
{
const index_t NHoWo = N * std::accumulate(e_g_n_k_wos_lengths.begin() + 3, const auto in_gemmm_gemmk_desc =
e_g_n_k_wos_lengths.begin() + 3 + NDimSpatial, matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
index_t{1},
std::multiplies<index_t>()); return in_gemmm_gemmk_desc;
// This is different
const index_t WiStride = a_g_n_c_wis_strides[2 + NDimSpatial];
const auto CStride = I1;
const auto in_gemmmraw_gemmk_grid_desc =
make_naive_tensor_descriptor(make_tuple(NHoWo, C), make_tuple(WiStride, CStride));
const auto in_gemmm_gemmk_grid_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmk_grid_desc);
return in_gemmm_gemmk_grid_desc;
}
else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization::Filter1x1Pad0)
{
// This is different
const index_t NStride = a_g_n_c_wis_strides[1];
const index_t WiStride = a_g_n_c_wis_strides[3];
const auto CStride = I1;
const auto in_n_wi_c_grid_desc = make_naive_tensor_descriptor(
make_tuple(N, Wi, C), make_tuple(NStride, WiStride, CStride));
const auto in_n_wo_c_grid_desc = transform_tensor_descriptor(
in_n_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto in_gemmmraw_gemmkraw_grid_desc = transform_tensor_descriptor(
in_n_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, Wo)), make_pass_through_transform(C)),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmm_gemmk_grid_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_grid_desc);
return in_gemmm_gemmk_grid_desc;
}
else
{
const index_t X = b_g_k_c_xs_lengths[3];
const index_t ConvDilationW = conv_filter_dilations[0];
const index_t InLeftPadW = input_left_pads[0];
const index_t InRightPadW = input_right_pads[0];
// This is different
const index_t NStride = a_g_n_c_wis_strides[1];
const index_t WiStride = a_g_n_c_wis_strides[3];
const auto CStride = I1;
const auto in_n_wi_c_grid_desc = make_naive_tensor_descriptor(
make_tuple(N, Wi, C), make_tuple(NStride, WiStride, CStride));
const auto in_n_wip_c_grid_desc = transform_tensor_descriptor(
in_n_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto in_n_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_wip_c_grid_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
const auto in_gemmmraw_gemmk_grid_desc =
transform_tensor_descriptor(in_n_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, Wo)),
make_merge_transform(make_tuple(X, C))),
make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmm_gemmk_grid_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmk_grid_desc);
return in_gemmm_gemmk_grid_desc;
}
}
template <typename ALay,
typename std::enable_if<NDimSpatial == 2 &&
(is_same_v<ALay, tensor_layout::convolution::G_NHW_C> ||
is_same_v<ALay, tensor_layout::convolution::NHWGC>),
bool>::type = false>
static auto
MakeAGridDescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& /* b_g_k_c_xs_strides */,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& /* e_g_n_k_wos_strides */,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads)
{
const index_t N = a_g_n_c_wis_lengths[1];
const index_t C = a_g_n_c_wis_lengths[2];
const index_t Hi = a_g_n_c_wis_lengths[3];
const index_t Wi = a_g_n_c_wis_lengths[4];
const index_t Ho = e_g_n_k_wos_lengths[3];
const index_t Wo = e_g_n_k_wos_lengths[4];
const index_t ConvStrideH = conv_filter_strides[0];
const index_t ConvStrideW = conv_filter_strides[1];
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{
const index_t NHoWo = N * std::accumulate(e_g_n_k_wos_lengths.begin() + 3,
e_g_n_k_wos_lengths.begin() + 3 + NDimSpatial,
index_t{1},
std::multiplies<index_t>());
// This is different
const index_t WiStride = a_g_n_c_wis_strides[2 + NDimSpatial];
const auto CStride = I1;
const auto in_gemmmraw_gemmkraw_grid_desc =
make_naive_tensor_descriptor(make_tuple(NHoWo, C), make_tuple(WiStride, CStride));
const auto in_gemmm_gemmk_grid_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_grid_desc);
return in_gemmm_gemmk_grid_desc;
}
else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization::Filter1x1Pad0)
{
// This is different
const index_t NStride = a_g_n_c_wis_strides[1];
const index_t HiStride = a_g_n_c_wis_strides[3];
const index_t WiStride = a_g_n_c_wis_strides[4];
const auto CStride = I1;
const auto in_n_hi_wi_c_grid_desc = make_naive_tensor_descriptor(
make_tuple(N, Hi, Wi, C), make_tuple(NStride, HiStride, WiStride, CStride));
const auto in_n_ho_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_gemmmraw_gemmk_grid_desc =
transform_tensor_descriptor(in_n_ho_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)),
make_pass_through_transform(C)),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmm_gemmk_grid_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmk_grid_desc);
return in_gemmm_gemmk_grid_desc;
}
else
{
const index_t Y = b_g_k_c_xs_lengths[3];
const index_t X = b_g_k_c_xs_lengths[4];
const index_t ConvDilationH = conv_filter_dilations[0];
const index_t ConvDilationW = conv_filter_dilations[1];
const index_t InLeftPadH = input_left_pads[0];
const index_t InLeftPadW = input_left_pads[1];
const index_t InRightPadH = input_right_pads[0];
const index_t InRightPadW = input_right_pads[1];
// This is different
const index_t NStride = a_g_n_c_wis_strides[1];
const index_t HiStride = a_g_n_c_wis_strides[3];
const index_t WiStride = a_g_n_c_wis_strides[4];
const auto CStride = I1;
const auto in_n_hi_wi_c_grid_desc = make_naive_tensor_descriptor(
make_tuple(N, Hi, Wi, C), make_tuple(NStride, HiStride, WiStride, CStride));
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_gemmmraw_gemmk_grid_desc =
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)),
make_merge_transform(make_tuple(Y, X, C))),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmm_gemmk_grid_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmk_grid_desc);
return in_gemmm_gemmk_grid_desc;
}
}
template <typename ALay,
typename std::enable_if<NDimSpatial == 3 &&
(is_same_v<ALay, tensor_layout::convolution::G_NDHW_C> ||
is_same_v<ALay, tensor_layout::convolution::NDHWGC>),
bool>::type = false>
static auto
MakeAGridDescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& /* b_g_k_c_xs_strides */,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& /* e_g_n_k_wos_strides */,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads)
{
const index_t N = a_g_n_c_wis_lengths[1];
const index_t C = a_g_n_c_wis_lengths[2];
const index_t Di = a_g_n_c_wis_lengths[3];
const index_t Hi = a_g_n_c_wis_lengths[4];
const index_t Wi = a_g_n_c_wis_lengths[5];
const index_t Do = e_g_n_k_wos_lengths[3];
const index_t Ho = e_g_n_k_wos_lengths[4];
const index_t Wo = e_g_n_k_wos_lengths[5];
const index_t ConvStrideD = conv_filter_strides[0];
const index_t ConvStrideH = conv_filter_strides[1];
const index_t ConvStrideW = conv_filter_strides[2];
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{
const index_t NDoHoWo =
N * std::accumulate(e_g_n_k_wos_lengths.begin() + 3,
e_g_n_k_wos_lengths.begin() + 3 + NDimSpatial,
index_t{1},
std::multiplies<index_t>());
// This is different
const index_t WiStride = a_g_n_c_wis_strides[2 + NDimSpatial];
const auto CStride = I1;
const auto in_gemmmraw_gemmkraw_grid_desc =
make_naive_tensor_descriptor(make_tuple(NDoHoWo, C), make_tuple(WiStride, CStride));
const auto in_gemmm_gemmk_grid_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_grid_desc);
return in_gemmm_gemmk_grid_desc;
}
else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization::Filter1x1Pad0)
{
// This is different
const index_t NStride = a_g_n_c_wis_strides[1];
const index_t DiStride = a_g_n_c_wis_strides[3];
const index_t HiStride = a_g_n_c_wis_strides[4];
const index_t WiStride = a_g_n_c_wis_strides[5];
const auto CStride = I1;
const auto in_n_di_hi_wi_c_grid_desc = make_naive_tensor_descriptor(
make_tuple(N, Di, Hi, Wi, C),
make_tuple(NStride, DiStride, HiStride, WiStride, CStride));
const auto in_n_do_ho_wo_c_grid_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Do), make_tuple(ConvStrideD)),
make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto in_gemmmraw_gemmkraw_grid_desc = transform_tensor_descriptor(
in_n_do_ho_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)),
make_pass_through_transform(C)),
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmm_gemmk_grid_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_grid_desc);
return in_gemmm_gemmk_grid_desc;
}
else
{
const index_t Z = b_g_k_c_xs_lengths[3];
const index_t Y = b_g_k_c_xs_lengths[4];
const index_t X = b_g_k_c_xs_lengths[5];
const index_t ConvDilationD = conv_filter_dilations[0];
const index_t ConvDilationH = conv_filter_dilations[1];
const index_t ConvDilationW = conv_filter_dilations[2];
const index_t InLeftPadD = input_left_pads[0];
const index_t InLeftPadH = input_left_pads[1];
const index_t InLeftPadW = input_left_pads[2];
const index_t InRightPadD = input_right_pads[0];
const index_t InRightPadH = input_right_pads[1];
const index_t InRightPadW = input_right_pads[2];
// This is different
const index_t NStride = a_g_n_c_wis_strides[1];
const index_t DiStride = a_g_n_c_wis_strides[3];
const index_t HiStride = a_g_n_c_wis_strides[4];
const index_t WiStride = a_g_n_c_wis_strides[5];
const auto CStride = I1;
const auto in_n_di_hi_wi_c_grid_desc = make_naive_tensor_descriptor(
make_tuple(N, Di, Hi, Wi, C),
make_tuple(NStride, DiStride, HiStride, WiStride, CStride));
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Di, InLeftPadD, InRightPadD),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{},
Sequence<1, 2>{},
Sequence<3, 4>{},
Sequence<5, 6>{},
Sequence<7>{}));
const auto in_gemmmraw_gemmkraw_grid_desc = transform_tensor_descriptor(
in_n_z_do_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)),
make_merge_transform(make_tuple(Z, Y, X, C))),
make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmm_gemmk_grid_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_grid_desc);
return in_gemmm_gemmk_grid_desc;
}
}
template <typename BLay,
typename std::enable_if<is_same_v<BLay, tensor_layout::convolution::GKXC> ||
is_same_v<BLay, tensor_layout::convolution::GKYXC> ||
is_same_v<BLay, tensor_layout::convolution::GKZYXC>,
bool>::type = false>
static auto
MakeBGridDescriptor_N_K(const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& /* b_g_k_c_xs_strides */)
{
const index_t K = b_g_k_c_xs_lengths[1];
const index_t C = b_g_k_c_xs_lengths[2];
const index_t YX = std::accumulate(b_g_k_c_xs_lengths.begin() + 3,
b_g_k_c_xs_lengths.begin() + 3 + NDimSpatial,
index_t{1},
std::multiplies<index_t>());
const auto wei_k_yxc_grid_desc = make_naive_tensor_descriptor_packed(make_tuple(K, YX * C));
const auto wei_gemmn_gemmk_grid_desc =
matrix_padder.PadBDescriptor_N_K(wei_k_yxc_grid_desc);
return wei_gemmn_gemmk_grid_desc;
} }
template <typename BLay, template <typename BLay>
typename std::enable_if<is_same_v<BLay, tensor_layout::convolution::G_K_X_C> ||
is_same_v<BLay, tensor_layout::convolution::G_K_YX_C> ||
is_same_v<BLay, tensor_layout::convolution::G_K_ZYX_C> ||
is_same_v<BLay, tensor_layout::convolution::KXGC> ||
is_same_v<BLay, tensor_layout::convolution::KYXGC> ||
is_same_v<BLay, tensor_layout::convolution::KZYXGC>,
bool>::type = false>
static auto static auto
MakeBGridDescriptor_N_K(const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, MakeBGridDescriptor_N_K(const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides) const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides)
{ {
const index_t K = b_g_k_c_xs_lengths[1]; const auto wei_gemmnraw_gemmkraw_desc =
const index_t C = b_g_k_c_xs_lengths[2]; conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>(b_g_k_c_xs_lengths,
b_g_k_c_xs_strides);
const index_t YX = std::accumulate(b_g_k_c_xs_lengths.begin() + 3,
b_g_k_c_xs_lengths.begin() + 3 + NDimSpatial,
index_t{1},
std::multiplies<index_t>());
const index_t KStride = b_g_k_c_xs_strides[1];
const index_t XStride = b_g_k_c_xs_strides[2 + NDimSpatial];
const auto CStride = I1;
const auto wei_k_yx_c_grid_desc = make_naive_tensor_descriptor(
make_tuple(K, YX, C), make_tuple(KStride, XStride, CStride));
const auto wei_gemmnraw_gemmkraw_grid_desc = transform_tensor_descriptor( const auto wei_gemmn_gemmk_desc =
wei_k_yx_c_grid_desc, matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc);
make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(YX, C))),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto wei_gemmn_gemmk_grid_desc = return wei_gemmn_gemmk_desc;
matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_grid_desc);
return wei_gemmn_gemmk_grid_desc;
} }
template <typename ELay, template <typename ELay>
typename std::enable_if<is_same_v<ELay, tensor_layout::convolution::GNWK> ||
is_same_v<ELay, tensor_layout::convolution::GNHWK> ||
is_same_v<ELay, tensor_layout::convolution::GNDHWK>,
bool>::type = false>
static auto
MakeEGridDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& /* e_g_n_k_wos_strides */)
{
const index_t N = e_g_n_k_wos_lengths[1];
const index_t K = e_g_n_k_wos_lengths[2];
const index_t NHoWo = N * std::accumulate(e_g_n_k_wos_lengths.begin() + 3,
e_g_n_k_wos_lengths.begin() + 3 + NDimSpatial,
index_t{1},
std::multiplies<index_t>());
const auto out_gemmmraw_gemmnraw_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(NHoWo, K));
const auto out_gemmm_gemmn_grid_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_grid_desc);
return out_gemmm_gemmn_grid_desc;
}
template <typename ELay,
typename std::enable_if<is_same_v<ELay, tensor_layout::convolution::G_NW_K> ||
is_same_v<ELay, tensor_layout::convolution::G_NHW_K> ||
is_same_v<ELay, tensor_layout::convolution::G_NDHW_K> ||
is_same_v<ELay, tensor_layout::convolution::NWGK> ||
is_same_v<ELay, tensor_layout::convolution::NHWGK> ||
is_same_v<ELay, tensor_layout::convolution::NDHWGK>,
bool>::type = false>
static auto static auto
MakeEGridDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, MakeEGridDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides) const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides)
{ {
const index_t N = e_g_n_k_wos_lengths[1]; const auto out_gemmmraw_gemmnraw_desc =
const index_t K = e_g_n_k_wos_lengths[2]; conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(e_g_n_k_wos_lengths,
e_g_n_k_wos_strides);
const auto KStride = I1;
const index_t WoStride = e_g_n_k_wos_strides[NDimSpatial + 2];
const index_t NHoWo = N * std::accumulate(e_g_n_k_wos_lengths.begin() + 3,
e_g_n_k_wos_lengths.begin() + 3 + NDimSpatial,
index_t{1},
std::multiplies<index_t>());
const auto out_gemmmraw_gemmnraw_grid_desc = const auto out_gemmm_gemmn_desc =
make_naive_tensor_descriptor(make_tuple(NHoWo, K), make_tuple(WoStride, KStride)); matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
const auto out_gemmm_gemmn_grid_desc = return out_gemmm_gemmn_desc;
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_grid_desc);
return out_gemmm_gemmn_grid_desc;
} }
static auto MakeDsGridDescriptor_M_N( static auto MakeDsGridDescriptor_M_N(
...@@ -1456,9 +606,11 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle ...@@ -1456,9 +606,11 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
#if 1 if(stream_config.log_level_ > 0)
arg.Print(); {
#endif arg.Print();
}
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
arg.b_grid_desc_n_k_, arg.b_grid_desc_n_k_,
arg.ds_grid_desc_m_n_, arg.ds_grid_desc_m_n_,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include <memory>
#include <array>
#include <iostream>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/utility/reduction_enums.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <index_t Rank,
index_t NumReduceDim,
index_t NumReduction,
typename InElementwiseOperationTuple,
typename AccElementwiseOperationTuple>
struct DeviceMultipleReduce : public BaseOperator
{
static constexpr index_t NumInputDim = Rank;
static constexpr index_t NumOutputDim = (Rank - NumReduceDim > 1) ? Rank - NumReduceDim : 1;
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
const std::array<index_t, NumInputDim> inLengths,
const std::array<index_t, NumInputDim> inStrides,
const std::array<index_t, NumOutputDim> outLengths,
const std::array<std::array<index_t, NumOutputDim>, NumReduction> outStrides,
const std::array<int, NumReduceDim> reduceDims,
const std::array<const void*, NumReduction> alphas,
const std::array<const void*, NumReduction> betas,
const void* in_dev,
const std::array<void*, NumReduction> out_dev_buffers,
const InElementwiseOperationTuple in_elementwise_op_tuple,
const AccElementwiseOperationTuple acc_elementwise_op_tuple) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <index_t Rank,
index_t NumReduceDim,
index_t NumReduction,
typename InElementwiseOperationTuple,
typename AccElementwiseOperationTuple>
using DeviceMultipleReducePtr = std::unique_ptr<DeviceMultipleReduce<Rank,
NumReduceDim,
NumReduction,
InElementwiseOperationTuple,
AccElementwiseOperationTuple>>;
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/sequence.hpp"
#include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_multiple_reduce.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_multiblock.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_set_multiple_buffer_value.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <index_t NumReduction,
typename InDataType,
typename AccDataType,
typename OutDataTypeTuple,
index_t Rank,
index_t NumReduceDim,
typename ReduceOperation,
typename InElementwiseOperationTuple,
typename AccElementwiseOperationTuple,
InMemoryDataOperationEnum OutMemoryDataOperation,
bool PropagateNan,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t InSrcVectorDim,
index_t InSrcVectorSize,
typename OutDstVectorSizeSeq>
struct DeviceMultipleReduceMultiBlock : public DeviceMultipleReduce<Rank,
NumReduceDim,
NumReduction,
InElementwiseOperationTuple,
AccElementwiseOperationTuple>
{
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
"Invalid thread cluster size assignments!");
static_assert((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static_assert(NumReduction == OutDataTypeTuple::Size() &&
NumReduction == InElementwiseOperationTuple::Size() &&
NumReduction == AccElementwiseOperationTuple::Size() &&
NumReduction == OutDstVectorSizeSeq::Size(),
"All tuple should have the same size as the number of Reductions!");
static_assert(sequence_all_of(OutDstVectorSizeSeq{},
[](auto vectorSize) {
return (MThreadSliceSize % vectorSize == 0);
}),
"The OutDstVectorSize should completely divide the MThreadSliceSize!");
static constexpr bool CheckDataTypeTuple()
{
bool flag = true;
static_for<0, NumReduction, 1>{}([&](auto I) {
using OutDataType = remove_cvref_t<decltype(OutDataTypeTuple{}[I])>;
flag =
flag && ck::reduce::InMemoryDataOperatonSupportedOnDataType<OutMemoryDataOperation,
OutDataType>::value;
});
return flag;
};
static_assert(CheckDataTypeTuple(),
"The OutDataType must support the specified OutMemoryDataOperation!");
static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
static constexpr index_t NumInputDim = Rank;
static constexpr index_t NumOutputDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
static constexpr bool reduceAllDim = (NumInvariantDim == 0);
// So far, only AtomicAdd is considered, other Atomic Operation like AtomicMax can be added
// later
static constexpr bool use_multiblock =
(OutMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd);
static_assert(
ReduceOperation::IsCompatibleInMemoryDataOperation(OutMemoryDataOperation),
"The reduction accumulation operation must be compatible with the OutMemoryDataOperation!");
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
static auto GenerateOutDataTypePointerTuple()
{
return generate_tuple(
[&](auto I) {
using DataType = remove_cvref_t<decltype(OutDataTypeTuple{}[I])>;
return static_cast<DataType*>(nullptr);
},
Number<NumReduction>{});
};
using OutDataTypePointerTuple = decltype(GenerateOutDataTypePointerTuple());
static auto MakeSrc2dDescriptor(const std::array<index_t, NumInputDim>& inLengths,
const std::array<index_t, NumInputDim>& inStrides,
int blkGroupSize,
int numBlockTileIteration)
{
const auto tupleSrcLengths =
generate_tuple([&](auto I) { return inLengths[I]; }, Number<NumInputDim>{});
const auto tupleSrcStrides =
generate_tuple([&](auto I) { return inStrides[I]; }, Number<NumInputDim>{});
const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
const auto in_grid_desc_m_k = [&]() {
if constexpr(reduceAllDim)
{
const auto one_dim_inDesc = transform_tensor_descriptor(
inDesc,
make_tuple(make_merge_transform(tupleSrcLengths)),
make_tuple(typename arithmetic_sequence_gen<0, NumInputDim, 1>::type{}),
make_tuple(Sequence<0>{}));
return transform_tensor_descriptor(one_dim_inDesc,
make_tuple(make_unmerge_transform(make_tuple(
1, one_dim_inDesc.GetLength(Number<0>{})))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1>{}));
}
else
{
using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
const auto reduceDimLengths = generate_tuple(
[&](auto I) { return inLengths[NumInvariantDim + I]; }, Number<NumReduceDim>{});
const auto invariantDimLengths =
generate_tuple([&](auto I) { return inLengths[I]; }, Number<NumInvariantDim>{});
return transform_tensor_descriptor(
inDesc,
make_tuple(make_merge_transform(invariantDimLengths),
make_merge_transform(reduceDimLengths)),
make_tuple(InvariantDims{}, ReduceDims{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}();
const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{});
const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
const int reduceSizePerBlock = K_BlockTileSize * numBlockTileIteration;
const auto inPad_M =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
const auto inPad_K = reduceSizePerBlock * blkGroupSize - reduceLength;
auto in_grid_desc_m_k_padded = transform_tensor_descriptor(
in_grid_desc_m_k,
make_tuple(make_right_pad_transform(invariantLength, inPad_M),
make_right_pad_transform(reduceLength, inPad_K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return (in_grid_desc_m_k_padded);
};
static auto MakeDst1dDescriptor(const std::array<index_t, NumOutputDim>& outLengths,
const std::array<index_t, NumOutputDim>& outStrides)
{
const auto tupleDstLengths =
generate_tuple([&](auto I) { return outLengths[I]; }, Number<NumOutputDim>{});
const auto tupleDstStrides =
generate_tuple([&](auto I) { return outStrides[I]; }, Number<NumOutputDim>{});
auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
auto out_grid_desc_m = transform_tensor_descriptor(
outDesc,
make_tuple(make_merge_transform(tupleDstLengths)),
make_tuple(typename arithmetic_sequence_gen<0, NumOutputDim, 1>::type{}),
make_tuple(Sequence<0>{}));
const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{});
const auto outPad =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
auto out_grid_desc_m_padded = transform_tensor_descriptor(
out_grid_desc_m,
make_tuple(make_right_pad_transform(invariantLength, outPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
return (out_grid_desc_m_padded);
};
static auto GenerateOutGrid1dDescTuple()
{
return generate_tuple(
[&](auto I) {
(void)I;
return MakeDst1dDescriptor(std::array<index_t, NumOutputDim>{},
std::array<index_t, NumOutputDim>{});
},
Number<NumReduction>{});
};
using InGridDesc_M_K = decltype(MakeSrc2dDescriptor(
std::array<index_t, NumInputDim>{}, std::array<index_t, NumInputDim>{}, 1, 1));
using OutGridDesc_M_Tuple = decltype(GenerateOutGrid1dDescTuple());
static auto MakeDst1dDescriptorForBufferSet(const std::array<index_t, NumOutputDim>& outLengths,
const std::array<index_t, NumOutputDim>& outStrides)
{
const auto tupleDstLengths =
generate_tuple([&](auto I) { return outLengths[I]; }, Number<NumOutputDim>{});
const auto tupleDstStrides =
generate_tuple([&](auto I) { return outStrides[I]; }, Number<NumOutputDim>{});
auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
auto out_grid_desc_m = transform_tensor_descriptor(
outDesc,
make_tuple(make_merge_transform(tupleDstLengths)),
make_tuple(typename arithmetic_sequence_gen<0, NumOutputDim, 1>::type{}),
make_tuple(Sequence<0>{}));
const auto length = out_grid_desc_m.GetLength(Number<0>{});
const auto pad = math::integer_least_multiple(length, BlockSize) - length;
auto out_grid_desc_m_padded =
transform_tensor_descriptor(out_grid_desc_m,
make_tuple(make_right_pad_transform(length, pad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
return (out_grid_desc_m_padded);
};
static auto GenerateOutGrid1dDescTuple_2()
{
return generate_tuple(
[&](auto I) {
(void)I;
return MakeDst1dDescriptorForBufferSet(std::array<index_t, NumOutputDim>{},
std::array<index_t, NumOutputDim>{});
},
Number<NumReduction>{});
};
using OutGridDesc_M_Tuple_2 = decltype(GenerateOutGrid1dDescTuple_2());
struct Argument : public BaseArgument
{
Argument(const std::array<index_t, NumInputDim>& inLengths,
const std::array<index_t, NumInputDim>& inStrides,
const std::array<index_t, NumOutputDim>& outLengths,
const std::array<std::array<index_t, NumOutputDim>, NumReduction>& outStridesArray,
const std::array<int, NumReduceDim>& reduceDims,
const std::array<const void*, NumReduction>& alphas,
const std::array<const void*, NumReduction>& betas,
const void* in_dev,
const std::array<void*, NumReduction>& out_dev_buffers,
const InElementwiseOperationTuple in_elementwise_op_tuple,
const AccElementwiseOperationTuple acc_elementwise_op_tuple)
: outLengths_{outLengths},
outStridesArray_{outStridesArray},
in_elementwise_op_tuple_{in_elementwise_op_tuple},
acc_elementwise_op_tuple_{acc_elementwise_op_tuple}
{
inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims);
inStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inStrides, reduceDims);
for(size_t i = 0; i < NumReduction; i++)
{
alpha_values_(i) = *static_cast<const AccDataType*>(alphas[i]);
beta_values_(i) = *static_cast<const AccDataType*>(betas[i]);
};
in_dev_ = static_cast<const InDataType*>(in_dev);
out_dev_buffers_ = generate_tuple(
[&](auto iR) {
using OutDataTypePointer =
remove_cvref_t<decltype(OutDataTypePointerTuple{}[iR])>;
using OutDataType = remove_cvref_t<remove_pointer_t<OutDataTypePointer>>;
return static_cast<OutDataType*>(out_dev_buffers[iR]);
},
Number<NumReduction>{});
std::tie(invariant_total_length, reduce_total_length) =
get_2d_lengths<Rank, NumReduceDim>(inLengths_);
if constexpr(use_multiblock)
{
int iterations = 1;
while(true)
{
int testBlkGroupSize =
(reduce_total_length + (K_BlockTileSize * iterations) - 1) /
(K_BlockTileSize * iterations);
// we want the blkGroupSize be not more than 128
if(testBlkGroupSize <= 128)
break;
iterations++;
};
blkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) /
(K_BlockTileSize * iterations);
numBlockTileIteration = iterations;
}
else
{
blkGroupSize = 1;
numBlockTileIteration =
(reduce_total_length + K_BlockTileSize - 1) / K_BlockTileSize;
};
in_grid_desc_m_k =
MakeSrc2dDescriptor(inLengths_, inStrides_, blkGroupSize, numBlockTileIteration);
out_grid_desc_m_tuple = generate_tuple(
[&](auto I) { return MakeDst1dDescriptor(outLengths, outStridesArray[I]); },
Number<NumReduction>{});
out_grid_desc_m_tuple_2 = generate_tuple(
[&](auto I) {
return MakeDst1dDescriptorForBufferSet(outLengths, outStridesArray[I]);
},
Number<NumReduction>{});
gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) /
M_BlockTileSize * blkGroupSize;
gridSize_pre =
math::integer_least_multiple(invariant_total_length, BlockSize) / BlockSize;
}
std::array<index_t, NumInputDim> inLengths_;
std::array<index_t, NumInputDim> inStrides_;
std::array<index_t, NumOutputDim> outLengths_;
std::array<std::array<index_t, NumOutputDim>, NumReduction> outStridesArray_;
Array<AccDataType, NumReduction> alpha_values_;
Array<AccDataType, NumReduction> beta_values_;
const InDataType* in_dev_;
OutDataTypePointerTuple out_dev_buffers_;
InGridDesc_M_K in_grid_desc_m_k;
OutGridDesc_M_Tuple out_grid_desc_m_tuple;
OutGridDesc_M_Tuple_2 out_grid_desc_m_tuple_2;
InElementwiseOperationTuple in_elementwise_op_tuple_;
AccElementwiseOperationTuple acc_elementwise_op_tuple_;
long_index_t invariant_total_length;
long_index_t reduce_total_length;
int blkGroupSize;
int numBlockTileIteration;
size_t gridSize;
size_t gridSize_pre;
};
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
using GridwiseMultipleReduce =
GridwiseMultipleReduction_mk_to_m_multiblock<NumReduction,
InDataType,
OutDataTypePointerTuple,
AccDataType,
InGridDesc_M_K,
OutGridDesc_M_Tuple,
ReduceOperation,
InElementwiseOperationTuple,
AccElementwiseOperationTuple,
OutMemoryDataOperation,
PropagateNan,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
InSrcVectorDim,
InSrcVectorSize,
OutDstVectorSizeSeq>;
const auto kernel_main =
kernel_multiple_reduce_multiblock<GridwiseMultipleReduce,
NumReduction,
InDataType,
OutDataTypePointerTuple,
AccDataType,
InGridDesc_M_K,
OutGridDesc_M_Tuple,
InElementwiseOperationTuple,
AccElementwiseOperationTuple>;
float avg_time = 0;
if constexpr(use_multiblock)
{
auto identity_values = generate_tuple(
[&](auto iR) {
using OutDataType = remove_cvref_t<decltype(OutDataTypeTuple{}[iR])>;
return ck::reduce::GetIdentityValueForInMemoryDataOperation<OutDataType>(
OutMemoryDataOperation);
},
Number<NumReduction>{});
const auto kernel_pre = kernel_multiple_buffer_set_value<OutGridDesc_M_Tuple_2,
NumReduction,
BlockSize,
OutDataTypePointerTuple,
OutDataTypeTuple>;
avg_time += launch_and_time_kernel(stream_config,
kernel_pre,
dim3(arg.gridSize_pre),
dim3(BlockSize),
0,
arg.out_grid_desc_m_tuple_2,
arg.out_dev_buffers_,
identity_values);
};
avg_time += launch_and_time_kernel(stream_config,
kernel_main,
dim3(arg.gridSize),
dim3(BlockSize),
0,
arg.in_grid_desc_m_k,
arg.out_grid_desc_m_tuple,
arg.in_elementwise_op_tuple_,
arg.acc_elementwise_op_tuple_,
arg.blkGroupSize,
arg.numBlockTileIteration,
arg.alpha_values_,
arg.in_dev_,
arg.beta_values_,
arg.out_dev_buffers_);
return (avg_time);
};
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
};
};
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
if constexpr(use_multiblock)
{
for(size_t i = 0; i < pArg->beta_values_.Size(); i++)
if(pArg->beta_values_[i] != 0.0f)
return (false);
};
if constexpr(InSrcVectorDim == 0)
{
if constexpr(NumInvariantDim == 0)
{
return (false);
}
else
{
if(pArg->inStrides_[NumInvariantDim - 1] != 1 && InSrcVectorSize != 1)
return (false);
if(pArg->inLengths_[NumInvariantDim - 1] % InSrcVectorSize != 0)
return (false);
};
}
else
{
if(pArg->inStrides_[Rank - 1] != 1 && InSrcVectorSize != 1)
return (false);
if(pArg->inLengths_[Rank - 1] % InSrcVectorSize != 0)
return (false);
};
// To improve
bool valid = true;
static_for<0, NumReduction, 1>{}([&](auto I) {
if(pArg->outStridesArray_[I.value][NumOutputDim - 1] != 1 &&
OutDstVectorSizeSeq::At(I) != 1)
valid = false;
if(pArg->outLengths_[NumOutputDim - 1] % OutDstVectorSizeSeq::At(I) != 0)
valid = false;
});
if(!valid)
return (false);
if constexpr(use_multiblock)
{
// blkGroupSize of 1 should be handled by Blockwise path using
// InMemoryDataOperationEnum::Set
if(pArg->blkGroupSize == 1)
return (false);
// This is very strong restriction, but needed to avoid some failure
if(pArg->outLengths_[NumOutputDim - 1] % M_BlockTileSize != 0)
return (false);
}
else
{
// cases with very small reduce_total_length should be handled by ThreadWise kernel
if(pArg->reduce_total_length / KThreadSliceSize < 2)
return (false);
};
return (true);
};
std::unique_ptr<BaseArgument> MakeArgumentPointer(
const std::array<index_t, NumInputDim> inLengths,
const std::array<index_t, NumInputDim> inStrides,
const std::array<index_t, NumOutputDim> outLengths,
const std::array<std::array<index_t, NumOutputDim>, NumReduction> outStridesArray,
const std::array<int, NumReduceDim> reduceDims,
const std::array<const void*, NumReduction> alphas,
const std::array<const void*, NumReduction> betas,
const void* in_dev,
const std::array<void*, NumReduction> out_dev_buffers,
const InElementwiseOperationTuple in_elementwise_op_tuple,
const AccElementwiseOperationTuple acc_elementwise_op_tuple) override
{
return std::make_unique<Argument>(inLengths,
inStrides,
outLengths,
outStridesArray,
reduceDims,
alphas,
betas,
in_dev,
out_dev_buffers,
in_elementwise_op_tuple,
acc_elementwise_op_tuple);
};
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>();
};
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << (OutMemoryDataOperation == InMemoryDataOperationEnum::Set? "DeviceMultipleReduceBlockWise<" : "DeviceMultipleReduceMultiBlock<") << BlockSize << ",";
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << ",";
str << "OutDstVectorSize";
static_for<0, OutDstVectorSizeSeq::Size(), 1>{}([&](auto I) {str << "_" << OutDstVectorSizeSeq::At(I); });
str << ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/sequence.hpp"
#include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_multiple_reduce.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_threadwise.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <index_t NumReduction,
typename InDataType,
typename AccDataType,
typename OutDataTypeTuple,
index_t Rank,
index_t NumReduceDim,
typename ReduceOperation,
typename InElementwiseOperationTuple,
typename AccElementwiseOperationTuple,
bool PropagateNan,
index_t BlockSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t InSrcVectorDim,
index_t InSrcVectorSize,
typename OutDstVectorSizeSeq>
struct DeviceMultipleReduceThreadWise : public DeviceMultipleReduce<Rank,
NumReduceDim,
NumReduction,
InElementwiseOperationTuple,
AccElementwiseOperationTuple>
{
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
static_assert((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static_assert(NumReduction == OutDataTypeTuple::Size() &&
NumReduction == InElementwiseOperationTuple::Size() &&
NumReduction == AccElementwiseOperationTuple::Size() &&
NumReduction == OutDstVectorSizeSeq::Size(),
"All tuple should have the same size as the number of Reductions!");
static_assert(sequence_all_of(OutDstVectorSizeSeq{},
[](auto vectorSize) {
return (MThreadSliceSize % vectorSize == 0);
}),
"The OutDstVectorSize should completely divide the MThreadSliceSize!");
static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
static constexpr index_t NumInputDim = Rank;
static constexpr index_t NumOutputDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
static constexpr bool reduceAllDim = (NumInvariantDim == 0);
static constexpr index_t M_BlockTileSize = BlockSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = 1 * KThreadSliceSize;
static auto GenerateOutDataTypePointerTuple()
{
return generate_tuple(
[&](auto I) {
using DataType = remove_cvref_t<decltype(OutDataTypeTuple{}[I])>;
return static_cast<DataType*>(nullptr);
},
Number<NumReduction>{});
};
using OutDataTypePointerTuple = decltype(GenerateOutDataTypePointerTuple());
static auto MakeSrc2dDescriptor(const std::array<index_t, NumInputDim>& inLengths,
const std::array<index_t, NumInputDim>& inStrides)
{
const auto tupleSrcLengths =
generate_tuple([&](auto I) { return inLengths[I]; }, Number<NumInputDim>{});
const auto tupleSrcStrides =
generate_tuple([&](auto I) { return inStrides[I]; }, Number<NumInputDim>{});
const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
const auto in_grid_desc_m_k = [&]() {
if constexpr(reduceAllDim)
{
const auto one_dim_inDesc = transform_tensor_descriptor(
inDesc,
make_tuple(make_merge_transform(tupleSrcLengths)),
make_tuple(typename arithmetic_sequence_gen<0, NumInputDim, 1>::type{}),
make_tuple(Sequence<0>{}));
return transform_tensor_descriptor(one_dim_inDesc,
make_tuple(make_unmerge_transform(make_tuple(
1, one_dim_inDesc.GetLength(Number<0>{})))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1>{}));
}
else
{
using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
const auto reduceDimLengths = generate_tuple(
[&](auto I) { return inLengths[NumInvariantDim + I]; }, Number<NumReduceDim>{});
const auto invariantDimLengths =
generate_tuple([&](auto I) { return inLengths[I]; }, Number<NumInvariantDim>{});
return transform_tensor_descriptor(
inDesc,
make_tuple(make_merge_transform(invariantDimLengths),
make_merge_transform(reduceDimLengths)),
make_tuple(InvariantDims{}, ReduceDims{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}();
const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{});
const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
const auto inPad_M =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
const auto inPad_K =
math::integer_least_multiple(reduceLength, K_BlockTileSize) - reduceLength;
auto in_grid_desc_m_k_padded = transform_tensor_descriptor(
in_grid_desc_m_k,
make_tuple(make_right_pad_transform(invariantLength, inPad_M),
make_right_pad_transform(reduceLength, inPad_K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return (in_grid_desc_m_k_padded);
};
static auto MakeDst1dDescriptor(const std::array<index_t, NumOutputDim>& outLengths,
const std::array<index_t, NumOutputDim>& outStrides)
{
const auto tupleDstLengths =
generate_tuple([&](auto I) { return outLengths[I]; }, Number<NumOutputDim>{});
const auto tupleDstStrides =
generate_tuple([&](auto I) { return outStrides[I]; }, Number<NumOutputDim>{});
auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
auto out_grid_desc_m = transform_tensor_descriptor(
outDesc,
make_tuple(make_merge_transform(tupleDstLengths)),
make_tuple(typename arithmetic_sequence_gen<0, NumOutputDim, 1>::type{}),
make_tuple(Sequence<0>{}));
const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{});
const auto outPad =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
auto out_grid_desc_m_padded = transform_tensor_descriptor(
out_grid_desc_m,
make_tuple(make_right_pad_transform(invariantLength, outPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
return (out_grid_desc_m_padded);
};
static auto GenerateOutGrid1dDescTuple()
{
return generate_tuple(
[&](auto I) {
(void)I;
return MakeDst1dDescriptor(std::array<index_t, NumOutputDim>{},
std::array<index_t, NumOutputDim>{});
},
Number<NumReduction>{});
};
using InGridDesc_M_K = decltype(MakeSrc2dDescriptor(std::array<index_t, NumInputDim>{},
std::array<index_t, NumInputDim>{}));
using OutGridDesc_M_Tuple = decltype(GenerateOutGrid1dDescTuple());
struct Argument : public BaseArgument
{
Argument(const std::array<index_t, NumInputDim>& inLengths,
const std::array<index_t, NumInputDim>& inStrides,
const std::array<index_t, NumOutputDim>& outLengths,
const std::array<std::array<index_t, NumOutputDim>, NumReduction>& outStridesArray,
const std::array<int, NumReduceDim>& reduceDims,
const std::array<const void*, NumReduction>& alphas,
const std::array<const void*, NumReduction>& betas,
const void* in_dev,
const std::array<void*, NumReduction>& out_dev_buffers,
const InElementwiseOperationTuple in_elementwise_op_tuple,
const AccElementwiseOperationTuple acc_elementwise_op_tuple)
: outLengths_{outLengths},
outStridesArray_{outStridesArray},
in_elementwise_op_tuple_{in_elementwise_op_tuple},
acc_elementwise_op_tuple_{acc_elementwise_op_tuple}
{
inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims);
inStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inStrides, reduceDims);
for(size_t i = 0; i < NumReduction; i++)
{
alpha_values_(i) = *static_cast<const AccDataType*>(alphas[i]);
beta_values_(i) = *static_cast<const AccDataType*>(betas[i]);
};
in_dev_ = static_cast<const InDataType*>(in_dev);
out_dev_buffers_ = generate_tuple(
[&](auto iR) {
using OutDataTypePointer =
remove_cvref_t<decltype(OutDataTypePointerTuple{}[iR])>;
using OutDataType = remove_cvref_t<remove_pointer_t<OutDataTypePointer>>;
return static_cast<OutDataType*>(out_dev_buffers[iR]);
},
Number<NumReduction>{});
std::tie(invariant_total_length, reduce_total_length) =
get_2d_lengths<Rank, NumReduceDim>(inLengths_);
in_grid_desc_m_k = MakeSrc2dDescriptor(inLengths_, inStrides_);
out_grid_desc_m_tuple = generate_tuple(
[&](auto I) { return MakeDst1dDescriptor(outLengths, outStridesArray[I]); },
Number<NumReduction>{});
gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) /
M_BlockTileSize;
}
std::array<index_t, NumInputDim> inLengths_;
std::array<index_t, NumInputDim> inStrides_;
std::array<index_t, NumOutputDim> outLengths_;
std::array<std::array<index_t, NumOutputDim>, NumReduction> outStridesArray_;
Array<AccDataType, NumReduction> alpha_values_;
Array<AccDataType, NumReduction> beta_values_;
const InDataType* in_dev_;
OutDataTypePointerTuple out_dev_buffers_;
InGridDesc_M_K in_grid_desc_m_k;
OutGridDesc_M_Tuple out_grid_desc_m_tuple;
InElementwiseOperationTuple in_elementwise_op_tuple_;
AccElementwiseOperationTuple acc_elementwise_op_tuple_;
long_index_t invariant_total_length;
long_index_t reduce_total_length;
size_t gridSize;
};
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
using GridwiseMultipleReduce =
GridwiseMultipleReduction_mk_to_m_threadwise<NumReduction,
InDataType,
OutDataTypePointerTuple,
AccDataType,
InGridDesc_M_K,
OutGridDesc_M_Tuple,
ReduceOperation,
InElementwiseOperationTuple,
AccElementwiseOperationTuple,
InMemoryDataOperationEnum::Set,
PropagateNan,
BlockSize,
MThreadSliceSize,
KThreadSliceSize,
InSrcVectorDim,
InSrcVectorSize,
OutDstVectorSizeSeq>;
const auto kernel_main =
kernel_multiple_reduce_threadwise<GridwiseMultipleReduce,
NumReduction,
InDataType,
OutDataTypePointerTuple,
AccDataType,
InGridDesc_M_K,
OutGridDesc_M_Tuple,
InElementwiseOperationTuple,
AccElementwiseOperationTuple>;
float avg_time = 0;
avg_time += launch_and_time_kernel(stream_config,
kernel_main,
dim3(arg.gridSize),
dim3(BlockSize),
0,
arg.in_grid_desc_m_k,
arg.out_grid_desc_m_tuple,
arg.in_elementwise_op_tuple_,
arg.acc_elementwise_op_tuple_,
arg.alpha_values_,
arg.in_dev_,
arg.beta_values_,
arg.out_dev_buffers_);
return (avg_time);
};
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
};
};
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
if constexpr(InSrcVectorDim == 0)
{
if constexpr(NumInvariantDim == 0)
{
return (false);
}
else
{
if(pArg->inStrides_[NumInvariantDim - 1] != 1 && InSrcVectorSize != 1)
return (false);
if(pArg->inLengths_[NumInvariantDim - 1] % InSrcVectorSize != 0)
return (false);
};
}
else
{
if(pArg->inStrides_[Rank - 1] != 1 && InSrcVectorSize != 1)
return (false);
if(pArg->inLengths_[Rank - 1] % InSrcVectorSize != 0)
return (false);
};
// To improve
bool valid = true;
static_for<0, NumReduction, 1>{}([&](auto I) {
if(pArg->outStridesArray_[I.value][NumOutputDim - 1] != 1 &&
OutDstVectorSizeSeq::At(I) != 1)
valid = false;
if(pArg->outLengths_[NumOutputDim - 1] % OutDstVectorSizeSeq::At(I) != 0)
valid = false;
});
if(!valid)
return (false);
return (true);
};
std::unique_ptr<BaseArgument> MakeArgumentPointer(
const std::array<index_t, NumInputDim> inLengths,
const std::array<index_t, NumInputDim> inStrides,
const std::array<index_t, NumOutputDim> outLengths,
const std::array<std::array<index_t, NumOutputDim>, NumReduction> outStridesArray,
const std::array<int, NumReduceDim> reduceDims,
const std::array<const void*, NumReduction> alphas,
const std::array<const void*, NumReduction> betas,
const void* in_dev,
const std::array<void*, NumReduction> out_dev_buffers,
const InElementwiseOperationTuple in_elementwise_op_tuple,
const AccElementwiseOperationTuple acc_elementwise_op_tuple) override
{
return std::make_unique<Argument>(inLengths,
inStrides,
outLengths,
outStridesArray,
reduceDims,
alphas,
betas,
in_dev,
out_dev_buffers,
in_elementwise_op_tuple,
acc_elementwise_op_tuple);
};
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>();
};
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceMultipleReduceThreadwise<" << BlockSize << ",";
str << "M_C" << BlockSize << "_S" << MThreadSliceSize << ",";
str << "K_C" << 1 << "_S" << KThreadSliceSize << ",";
str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << ",";
str << "OutDstVectorSize";
static_for<0, OutDstVectorSizeSeq::Size(), 1>{}([&](auto I) {str << "_" << OutDstVectorSizeSeq::At(I); });
str << ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -35,6 +35,25 @@ std::pair<long_index_t, long_index_t> get_2d_lengths(const std::vector<index_t>& ...@@ -35,6 +35,25 @@ std::pair<long_index_t, long_index_t> get_2d_lengths(const std::vector<index_t>&
return std::make_pair(invariant_total_length, reduce_total_length); return std::make_pair(invariant_total_length, reduce_total_length);
}; };
template <index_t Rank, int NumReduceDim>
std::pair<long_index_t, long_index_t> get_2d_lengths(const std::array<index_t, Rank>& inLengths)
{
static_assert(Rank <= 6, "bigger Rank size not supported!");
long_index_t invariant_total_length = 1;
long_index_t reduce_total_length = 1;
constexpr int NumInvariantDim = Rank - NumReduceDim;
for(int i = NumInvariantDim; i < Rank; i++)
reduce_total_length *= inLengths[i];
for(int i = 0; i < NumInvariantDim; i++)
invariant_total_length *= inLengths[i];
return std::make_pair(invariant_total_length, reduce_total_length);
};
// helper functions using variadic template arguments // helper functions using variadic template arguments
template <index_t... Ns> template <index_t... Ns>
auto make_tuple_from_array_and_index_seq(const std::vector<index_t>& lengths, Sequence<Ns...>) auto make_tuple_from_array_and_index_seq(const std::vector<index_t>& lengths, Sequence<Ns...>)
...@@ -85,6 +104,39 @@ std::vector<index_t> shuffle_tensor_dimensions(const std::vector<index_t>& origL ...@@ -85,6 +104,39 @@ std::vector<index_t> shuffle_tensor_dimensions(const std::vector<index_t>& origL
return newLengthsStrides; return newLengthsStrides;
}; };
template <index_t Rank, index_t NumReduceDim>
std::array<index_t, Rank>
shuffle_tensor_dimensions(const std::array<index_t, Rank>& origLengthsStrides,
const std::array<int, NumReduceDim>& reduceDims)
{
std::array<index_t, Rank> newLengthsStrides;
int reduceFlag = 0;
// flag the bits for the reduceDims
for(int i = 0; i < NumReduceDim; i++)
{
reduceFlag |= 1 << reduceDims[i];
};
// collect invariant dimensions
int pos = 0;
for(int i = 0; i < Rank; i++)
if((reduceFlag & (1 << i)) == 0)
{
newLengthsStrides[pos++] = origLengthsStrides[i];
};
// collect reduce dimensions
for(int i = 0; i < Rank; i++)
if((reduceFlag & (1 << i)) > 0)
{
newLengthsStrides[pos++] = origLengthsStrides[i];
};
return newLengthsStrides;
};
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -3,19 +3,10 @@ ...@@ -3,19 +3,10 @@
#pragma once #pragma once
#include <iostream> #include <memory>
#include <sstream> #include <vector>
#include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp"
#include "ck/tensor_operation/gpu/device/device_normalization.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce_multiblock.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_softmax.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -24,227 +15,54 @@ namespace device { ...@@ -24,227 +15,54 @@ namespace device {
template <typename InDataType, template <typename InDataType,
typename AccDataType, typename AccDataType,
typename OutDataType, typename OutDataType,
index_t Rank, typename InElementwiseOp,
index_t NumReduceDim, typename AccElementwiseOp,
index_t BlockSize, index_t Rank>
index_t MThreadClusterSize, struct DeviceSoftmax : public BaseOperator
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t InSrcVectorDim,
index_t InSrcVectorSize,
index_t OutDstVectorSize>
struct DeviceSoftmax : public DeviceNormalization
{ {
static constexpr index_t kRank = Rank; //
static constexpr index_t kNumReduceDim = NumReduceDim; // @brief Makes a pointer to Argument class.
//
virtual index_t GetRank() const override { return kRank; } // @param[in] inLengths Input tensor extent(s) from high to low dimension
// @param[in] inStrides Input tensor stride(s) from high to low dimension
virtual index_t GetNumReduceDim() const override { return kNumReduceDim; } // @param[in] reduceDims The dimension(s) the normalization operation is applied
// @param[in] alpha Typeless pointer in host memory storing the alpha scaling
using PassThrough = tensor_operation::element_wise::PassThrough; // value as type AccDataType
// @param[in] beta Typeless pointer in host memory storing the beta scaling
// Used for freeloading of some handy functions from DeviceReduceMultiBlock // value as type AccDataType
using Reduction = DeviceReduceMultiBlock<InDataType, // @param[in] in_dev Typeless const pointer in device memory storing the input
AccDataType, // tensor
OutDataType, // @param out_dev Typeless pointer in device memory storing the output tensor
Rank, // @param[in] in_elementwise_op The input elementwise operation.
NumReduceDim, // @param[in] acc_elementwise_op The accumulation elementwise operation.
reduce::Add, //
PassThrough, // InElementwiseOperation // @return Unique pointer to the Argument class.
PassThrough, // AccElementwiseOperation //
InMemoryDataOperationEnum::Set, virtual std::unique_ptr<BaseArgument>
false, // PropagateNan MakeArgumentPointer(const std::vector<index_t> inLengths,
false, // OutputIndex const std::vector<index_t> inStrides,
false, // HaveIndexInputIfOutputIndex const std::vector<int> reduceDims,
BlockSize, const void* alpha,
MThreadClusterSize, const void* beta,
KThreadClusterSize, const void* in_dev,
MThreadSliceSize, void* out_dev,
KThreadSliceSize, InElementwiseOp in_elementwise_op,
InSrcVectorDim, AccElementwiseOp acc_elementwise_op) = 0;
InSrcVectorSize,
1>; // OutDstVectorSize virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
virtual index_t GetRank() const = 0;
using GridDesc_M_K = decltype(Reduction::MakeSrc2dDescriptor({1}, {1}, 1, 1)); virtual index_t GetNumReduceDim() const = 0;
using GridwiseSoftmaxGeneric = GridwiseSoftmax_mk_to_mk<InDataType,
OutDataType,
AccDataType,
GridDesc_M_K,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
InSrcVectorDim,
InSrcVectorSize,
OutDstVectorSize,
false>;
using GridwiseSoftmaxSweepOnce = GridwiseSoftmax_mk_to_mk<InDataType,
OutDataType,
AccDataType,
GridDesc_M_K,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
InSrcVectorDim,
InSrcVectorSize,
OutDstVectorSize,
true>;
struct Argument : public Reduction::Argument
{
Argument(const std::vector<index_t> inLengths,
const std::vector<index_t> inStrides,
const std::vector<index_t> reduceDims,
AccDataType alpha,
AccDataType beta,
const InDataType* in_dev,
OutDataType* out_dev)
: Reduction::Argument(inLengths,
inStrides,
{},
{},
reduceDims,
0.0f, // alpha
0.0f, // beta
in_dev,
nullptr,
out_dev,
nullptr,
PassThrough{},
PassThrough{}),
// FIXME: The base class DeviceReduceMultiBlock::Argument only supports alpha/beta of
// float32 precision. Make it support any data type so the fields can be removed.
alpha_(alpha),
beta_(beta)
{
// std::cout << "blkGroupSize= " << this->blkGroupSize
// << ", numBlockTileIteration= " << this->numBlockTileIteration
// << ", gridSize=" << this->gridSize
// << ", invariant_total_length=" << this->invariant_total_length <<
// std::endl;
}
AccDataType alpha_;
AccDataType beta_;
};
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const auto in_grid_desc_m_k = Reduction::MakeSrc2dDescriptor(
arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration);
const auto out_grid_desc_m_k = Reduction::MakeSrc2dDescriptor(
arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration);
bool sweep_once =
in_grid_desc_m_k.GetLength(Number<1>{}) <= KThreadClusterSize * KThreadSliceSize;
const auto kernel_main = sweep_once ? kernel_softmax<GridwiseSoftmaxSweepOnce,
InDataType,
OutDataType,
AccDataType,
GridDesc_M_K>
: kernel_softmax<GridwiseSoftmaxGeneric,
InDataType,
OutDataType,
AccDataType,
GridDesc_M_K>;
float avg_time = 0;
avg_time += launch_and_time_kernel(stream_config,
kernel_main,
dim3(arg.gridSize),
dim3(BlockSize),
0,
in_grid_desc_m_k,
out_grid_desc_m_k,
arg.blkGroupSize,
arg.numBlockTileIteration,
arg.alpha_,
arg.in_dev_,
arg.beta_,
arg.out_dev_);
return (avg_time);
};
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
};
};
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
const Argument* p_arg_ = dynamic_cast<const Argument*>(p_arg);
if(!Reduction::IsSupportedArgument(p_arg_))
{
return false;
}
if(p_arg_->inLengths_[Rank - 1] % OutDstVectorSize != 0)
{
return false;
}
return true;
};
// inLengths: input tensor extent(s) from high to low dimension
// inStrides: input tensor stride(s) from high to low dimension
// reduceDims: the dimension(s) the softmax normalization operate on
// alpha: typeless pointer in host memory storing the alpha scaling value as type AccDataType
// beta: typeless pointer in host memory storing the beta scaling value as type AccDataType
// in_dev: typeless const pointer in device memory storing the input tensor
// out_dev: typeless pointer in device memory storing the output tensor
std::unique_ptr<BaseArgument> MakeArgumentPointer(const std::vector<index_t> inLengths,
const std::vector<index_t> inStrides,
const std::vector<int> reduceDims,
const void* alpha,
const void* beta,
const void* in_dev,
void* out_dev) override
{
return std::make_unique<Argument>(inLengths,
inStrides,
reduceDims,
*static_cast<const AccDataType*>(alpha),
*static_cast<const AccDataType*>(beta),
static_cast<const InDataType*>(in_dev),
static_cast<OutDataType*>(out_dev));
};
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>();
};
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceReduceSoftmax<" << BlockSize << ",";
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << "_OutDstVectorSize_" << OutDstVectorSize << ">";
// clang-format on
return str.str();
}
}; };
template <typename InDataType,
typename AccDataType,
typename OutDataType,
typename InElementwiseOp,
typename AccElementwiseOp,
index_t Rank>
using DeviceSoftmaxPtr = std::unique_ptr<
DeviceSoftmax<InDataType, AccDataType, OutDataType, InElementwiseOp, AccElementwiseOp, Rank>>;
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_sparse_embedding3_forward_layernorm.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename EmbType,
typename IndexType,
typename GammaDataType,
typename BetaDataType,
typename AccDataType,
typename OutType,
ck::index_t BlockSize,
ck::index_t DimClusterSize,
ck::index_t RowClusterSize,
ck::index_t DimPerBlock,
ck::index_t RowPerBlock,
ck::index_t DimThreadSize,
ck::index_t RowVectorSize>
struct DeviceSparseEmbedding3ForwardLayernorm : public BaseOperator
{
static auto MakeOutputDescriptor(const index_t index_length, const index_t rows)
{
return make_naive_tensor_descriptor_packed(make_tuple(index_length, rows));
}
struct Argument : public BaseArgument
{
Argument(OutType* p_out,
const EmbType* p_emb_a,
const EmbType* p_emb_b,
const EmbType* p_emb_c,
const IndexType* p_index_a,
const IndexType* p_index_b,
const IndexType* p_index_c,
const GammaDataType* p_gamma,
const BetaDataType* p_beta,
const ck::index_t NumRows,
const ck::index_t EmbeddingDim,
const ck::index_t IndexLength,
const AccDataType epsilon)
: p_out_(p_out),
p_emb_a_(p_emb_a),
p_emb_b_(p_emb_b),
p_emb_c_(p_emb_c),
p_index_a_(p_index_a),
p_index_b_(p_index_b),
p_index_c_(p_index_c),
p_gamma_(p_gamma),
p_beta_(p_beta),
NumRows_(NumRows),
EmbeddingDim_(EmbeddingDim),
IndexLength_(IndexLength),
epsilon_(epsilon)
{
grid_size_ = (IndexLength + DimClusterSize - 1) / DimClusterSize;
}
OutType* p_out_;
const EmbType* p_emb_a_;
const EmbType* p_emb_b_;
const EmbType* p_emb_c_;
const IndexType* p_index_a_;
const IndexType* p_index_b_;
const IndexType* p_index_c_;
const GammaDataType* p_gamma_;
const BetaDataType* p_beta_;
ck::index_t NumRows_;
ck::index_t EmbeddingDim_;
ck::index_t IndexLength_;
AccDataType epsilon_;
size_t grid_size_;
};
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(void* p_out,
const void* p_emb_a,
const void* p_emb_b,
const void* p_emb_c,
const void* p_index_a,
const void* p_index_b,
const void* p_index_c,
const void* p_gamma,
const void* p_beta,
ck::index_t NumRows,
ck::index_t EmbeddingDim,
ck::index_t IndexLength,
const AccDataType epsilon)
{
return std::make_unique<Argument>(reinterpret_cast<OutType*>(p_out),
reinterpret_cast<const EmbType*>(p_emb_a),
reinterpret_cast<const EmbType*>(p_emb_b),
reinterpret_cast<const EmbType*>(p_emb_c),
reinterpret_cast<const IndexType*>(p_index_a),
reinterpret_cast<const IndexType*>(p_index_b),
reinterpret_cast<const IndexType*>(p_index_c),
reinterpret_cast<const GammaDataType*>(p_gamma),
reinterpret_cast<const BetaDataType*>(p_beta),
NumRows,
EmbeddingDim,
IndexLength,
epsilon);
}
using GridwiseSparseEmbedding =
GridwiseSparseEmbedding3ForwardLayernorm<EmbType,
IndexType,
GammaDataType,
BetaDataType,
AccDataType,
OutType,
decltype(MakeOutputDescriptor(1, 1)),
BlockSize,
DimClusterSize,
RowClusterSize,
DimPerBlock,
RowPerBlock,
DimThreadSize,
RowVectorSize>;
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
auto out_desc = MakeOutputDescriptor(arg.IndexLength_, arg.EmbeddingDim_);
const auto kernel_main =
kernel_sparse_embedding3_forward_layernorm<GridwiseSparseEmbedding,
EmbType,
IndexType,
GammaDataType,
BetaDataType,
AccDataType,
OutType,
decltype(out_desc)>;
float avg_time = 0;
avg_time += launch_and_time_kernel(stream_config,
kernel_main,
dim3(arg.grid_size_),
dim3(BlockSize),
0,
arg.p_out_,
arg.p_emb_a_,
arg.p_emb_b_,
arg.p_emb_c_,
arg.p_index_a_,
arg.p_index_b_,
arg.p_index_c_,
arg.p_gamma_,
arg.p_beta_,
out_desc,
arg.epsilon_);
return (avg_time);
}
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
};
};
static bool IsSupportedArgument(const Argument* p_arg)
{
return (RowPerBlock == p_arg->EmbeddingDim_) && (p_arg->NumRows_ % DimPerBlock == 0);
}
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(dynamic_cast<const Argument*>(p_arg));
}
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>();
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceSparseEmbedding3ForwardLayernorm_"<< BlockSize << "_" <<
DimClusterSize << "x" << RowClusterSize << "_" <<
DimPerBlock << "x" << RowPerBlock << "_" <<
DimThreadSize << "x" << RowVectorSize;
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_unary_elementwise_1d.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename ADataType,
typename BDataType,
typename ElementwiseFunctor,
index_t Dim,
index_t ScalarPerVector>
struct DeviceUnaryElementwise : public BaseOperator
{
static constexpr auto I0 = Number<0>{};
template <typename Desc_M0>
static auto PadDescriptor_M0_1d(Desc_M0 desc_m0, index_t gridSize, index_t blockSize)
{
const auto m0 = desc_m0.GetLength(I0);
const index_t loop_step = gridSize * blockSize * ScalarPerVector;
const auto pad = math::integer_least_multiple(m0, loop_step) - m0;
const auto desc_m0_pad =
transform_tensor_descriptor(desc_m0,
make_tuple(make_right_pad_transform(m0, pad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
return desc_m0_pad;
}
static auto MakeDescriptor_M0(const std::vector<index_t>& shape,
const std::vector<index_t>& stride,
index_t gridSize,
index_t blockSize)
{
auto tupleOfShape = generate_tuple([&](auto I) { return shape[I]; }, Number<Dim>{});
auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number<Dim>{});
// nd desc - [s0, s1, s2, ...]
const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride);
// merge nd to 1d desc - [s0 * s1 * ...]
if constexpr(Dim > 1)
{
const auto desc_m0 = transform_tensor_descriptor(
desc,
make_tuple(make_merge_transform(tupleOfShape)),
make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number<Dim>{})),
make_tuple(Sequence<0>{}));
return PadDescriptor_M0_1d(desc_m0, gridSize, blockSize);
}
else
return PadDescriptor_M0_1d(desc, gridSize, blockSize);
}
using GridDesc_M0 = decltype(MakeDescriptor_M0({1, 1}, {1, 1}, 1, 1));
using GridwiseUEltwise = GridwiseUnaryElementwise_1D<ADataType,
BDataType,
GridDesc_M0,
ElementwiseFunctor,
ScalarPerVector>;
struct Argument : public BaseArgument
{
Argument(const ADataType* p_a,
BDataType* p_b,
const std::vector<index_t>& shape,
const std::vector<index_t>& stride_a,
const std::vector<index_t>& stride_b,
ElementwiseFunctor functor)
: p_a_(p_a),
p_b_(p_b),
shape_(shape),
functor_(functor),
blockSize_(256) // FIXME - Calculate the grid size by number of CU in the future
{
index_t tensor_size =
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>{});
gridSize_ = GridwiseUEltwise::CalculateGridSize(tensor_size);
a_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_a, gridSize_, blockSize_);
b_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_b, gridSize_, blockSize_);
}
const ADataType* p_a_;
BDataType* p_b_;
std::vector<int> shape_;
GridDesc_M0 a_grid_desc_m0_;
GridDesc_M0 b_grid_desc_m0_;
ElementwiseFunctor functor_;
index_t blockSize_;
index_t gridSize_;
};
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const auto kernel = kernel_unary_elementwise_1d<GridwiseUEltwise,
ADataType,
BDataType,
GridDesc_M0,
ElementwiseFunctor>;
float elapsed_time = launch_and_time_kernel(stream_config,
kernel,
dim3(arg.gridSize_),
dim3(arg.blockSize_),
0,
arg.p_a_,
arg.p_b_,
arg.a_grid_desc_m0_,
arg.b_grid_desc_m0_,
arg.functor_);
return elapsed_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
if(pArg == nullptr)
return false;
if(pArg->shape_.back() % ScalarPerVector != 0)
return false;
return true;
};
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
void* p_b,
std::vector<index_t> shape,
std::vector<index_t> stride_a,
std::vector<index_t> stride_b,
ElementwiseFunctor functor)
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<BDataType*>(p_b),
shape,
stride_a,
stride_b,
functor);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() { return std::make_unique<Invoker>(); }
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceBinaryElementwise"
<< "<"
<< "ScalarPerVector = " << ScalarPerVector
<< ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment