Unverified Commit 9312548e authored by zjing14's avatar zjing14 Committed by GitHub
Browse files

Merge branch 'develop' into sphinx_doc

parents 09908c9f 19490ac4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/numeric.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F64 = double;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = F64;
using BDataType = F64;
using AccDataType = F64;
using CShuffleDataType = F64;
using DsDataType = ck::Tuple<>;
using EDataType = F64;
static constexpr ck::index_t NumDimM = 2;
static constexpr ck::index_t NumDimN = 2;
static constexpr ck::index_t NumDimK = 2;
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CDEElementOp = ck::tensor_operation::element_wise::Scale;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// clang-format off
using DeviceOpInstanceKKN = ck::tensor_operation::device::
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F64, F64, F64, F64, DsDataType, F64, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 16, 1, 16>, 1>;
using DeviceOpInstanceKNN = ck::tensor_operation::device::
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F64, F64, F64, F64, DsDataType, F64, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 2, 1, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>;
using DeviceOpInstanceMKN = ck::tensor_operation::device::
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F64, F64, F64, F64, DsDataType, F64, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 1, 2, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 16, 1, 16>, 1>;
using DeviceOpInstanceMNN = ck::tensor_operation::device::
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F64, F64, F64, F64, DsDataType, F64, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 1, 1, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>;
// clang-format on
using DeviceOpInstance = DeviceOpInstanceKKN;
// hardcoded for NumDimM == NumDimN == NumDimK == 2
template <ck::index_t NumDimM,
ck::index_t NumDimN,
ck::index_t NumDimK,
typename ADataType,
typename BDataType,
typename EDataType,
typename AccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
ck::enable_if_t<NumDimM == 2 && NumDimN == 2 && NumDimK == 2, bool> = false>
struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::BaseOperator
{
// Argument
struct Argument : public ck::tensor_operation::device::BaseArgument
{
Argument(const Tensor<ADataType>& a_ms_ks,
const Tensor<BDataType>& b_ns_ks,
Tensor<EDataType>& e_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
: a_ms_ks_{a_ms_ks},
b_ns_ks_{b_ns_ks},
e_ms_ns_{e_ms_ns},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op}
{
}
const Tensor<ADataType>& a_ms_ks_;
const Tensor<BDataType>& b_ns_ks_;
Tensor<EDataType>& e_ms_ns_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
};
// Invoker
struct Invoker : public ck::tensor_operation::device::BaseInvoker
{
using Argument = ReferenceContraction_M2_N2_K2::Argument;
float Run(const Argument& arg)
{
auto f_ms_ns = [&](auto m0, auto m1, auto n0, auto n1) {
const int K0 = arg.a_ms_ks_.mDesc.GetLengths()[2];
const int K1 = arg.a_ms_ks_.mDesc.GetLengths()[3];
AccDataType v_acc = 0;
for(int k0 = 0; k0 < K0; ++k0)
{
for(int k1 = 0; k1 < K1; ++k1)
{
AccDataType v_a;
AccDataType v_b;
arg.a_element_op_(
v_a, ck::type_convert<const AccDataType>(arg.a_ms_ks_(m0, m1, k0, k1)));
arg.b_element_op_(
v_b, ck::type_convert<const AccDataType>(arg.b_ns_ks_(n0, n1, k0, k1)));
v_acc += v_a * v_b;
}
}
AccDataType v_c;
arg.cde_element_op_(v_c, v_acc);
arg.e_ms_ns_(m0, m1, n0, n1) = v_c;
};
make_ParallelTensorFunctor(f_ms_ns,
arg.e_ms_ns_.mDesc.GetLengths()[0],
arg.e_ms_ns_.mDesc.GetLengths()[1],
arg.e_ms_ns_.mDesc.GetLengths()[2],
arg.e_ms_ns_.mDesc.GetLengths()[3])(
std::thread::hardware_concurrency());
return 0;
}
float Run(const ck::tensor_operation::device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override
{
return true;
}
static auto MakeArgument(const Tensor<ADataType>& a_ms_ks,
const Tensor<BDataType>& b_ns_ks,
Tensor<EDataType>& e_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
{
return Argument{a_ms_ks, b_ns_ks, e_ms_ns, a_element_op, b_element_op, cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<ck::tensor_operation::device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceContraction_M2_N2_K2"
<< std::endl;
// clang-format on
return str.str();
}
};
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
// A[M0, M1, K0, K1]
std::vector<ck::index_t> a_ms_ks_lengths{30, 128, 32, 64};
std::vector<ck::index_t> a_ms_ks_strides{524288, 4096, 128, 1};
// B[N0, N1, K0, K1]
std::vector<ck::index_t> b_ns_ks_lengths{32, 64, 32, 64};
std::vector<ck::index_t> b_ns_ks_strides{524288, 4096, 128, 1};
// E[M0, M1, N0, N1]
std::vector<ck::index_t> e_ms_ns_lengths{30, 128, 32, 64};
std::vector<ck::index_t> e_ms_ns_strides{524288, 4096, 128, 1};
float scale = 1.f;
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 23)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
const ck::index_t M0 = std::stoi(argv[4]);
const ck::index_t M1 = std::stoi(argv[5]);
const ck::index_t N0 = std::stoi(argv[6]);
const ck::index_t N1 = std::stoi(argv[7]);
const ck::index_t K0 = std::stoi(argv[8]);
const ck::index_t K1 = std::stoi(argv[9]);
a_ms_ks_lengths = {M0, M1, K0, K1};
a_ms_ks_strides = {
std::stoi(argv[10]), std::stoi(argv[11]), std::stoi(argv[12]), std::stoi(argv[13])};
b_ns_ks_lengths = {N0, N1, K0, K1};
b_ns_ks_strides = {
std::stoi(argv[14]), std::stoi(argv[15]), std::stoi(argv[16]), std::stoi(argv[17])};
e_ms_ns_lengths = {M0, M1, N0, N1};
e_ms_ns_strides = {
std::stoi(argv[18]), std::stoi(argv[19]), std::stoi(argv[20]), std::stoi(argv[21])};
scale = std::stof(argv[22]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 9: M0, M1, N0, N1, K0, K1\n");
printf("arg10 to 13: Stride_A_M0, Stride_A_M1, Stride_A_K0, Stride_A_K1\n");
printf("arg14 to 17: Stride_B_N0, Stride_B_N1, Stride_B_K0, Stride_B_K1\n");
printf("arg18 to 21: Stride_E_M0, Stride_E_M1, Stride_E_N0, Stride_E_N1\n");
printf("arg22: scale\n");
exit(0);
}
Tensor<ADataType> a_ms_ks(a_ms_ks_lengths, a_ms_ks_strides);
Tensor<BDataType> b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides);
Tensor<EDataType> e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides);
Tensor<EDataType> e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides);
std::cout << "a_ms_ks: " << a_ms_ks.mDesc << std::endl;
std::cout << "b_ns_ks: " << b_ns_ks.mDesc << std::endl;
std::cout << "e_ms_ns: " << e_ms_ns_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_ns_ks.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
break;
default:
a_ms_ks.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_ns_ks.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break;
}
DeviceMem a_device_buf(sizeof(ADataType) * a_ms_ks.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_ns_ks.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_ms_ns_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_ms_ks.mData.data());
b_device_buf.ToDevice(b_ns_ks.mData.data());
// set zero
e_device_buf.SetZero();
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{scale};
// device operation
auto op = DeviceOpInstance{};
auto invoker = op.MakeInvoker();
auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
std::array<const void*, 0>{},
e_device_buf.GetDeviceBuffer(),
a_ms_ks_lengths,
a_ms_ks_strides,
b_ns_ks_lengths,
b_ns_ks_strides,
std::array<std::vector<ck::index_t>, 0>{},
std::array<std::vector<ck::index_t>, 0>{},
e_ms_ns_lengths,
e_ms_ns_strides,
a_element_op,
b_element_op,
cde_element_op);
if(!op.IsSupportedArgument(argument))
{
std::cout << op.GetTypeString() << " does not support this problem" << std::endl;
return 0;
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
ck::index_t M =
ck::accumulate_n<ck::index_t>(e_ms_ns_lengths.begin(), NumDimM, 1, std::multiplies<>{});
ck::index_t N = ck::accumulate_n<ck::index_t>(
e_ms_ns_lengths.begin() + NumDimM, NumDimN, 1, std::multiplies<>{});
ck::index_t K = ck::accumulate_n<ck::index_t>(
a_ms_ks_lengths.begin() + NumDimM, NumDimK, 1, std::multiplies<>{});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + +sizeof(EDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< op.GetTypeString() << std::endl;
e_device_buf.FromDevice(e_ms_ns_device_result.mData.data());
if(do_verification)
{
Tensor<CShuffleDataType> c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides);
using ReferenceOpInstance = ReferenceContraction_M2_N2_K2<NumDimM,
NumDimN,
NumDimK,
ADataType,
BDataType,
CShuffleDataType,
AccDataType,
AElementOp,
BElementOp,
PassThrough>;
auto ref_gemm = ReferenceOpInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op, PassThrough{});
ref_invoker.Run(ref_argument);
for(size_t m0 = 0; m0 < e_ms_ns_host_result.mDesc.GetLengths()[0]; ++m0)
{
for(size_t m1 = 0; m1 < e_ms_ns_host_result.mDesc.GetLengths()[1]; ++m1)
{
for(size_t n0 = 0; n0 < e_ms_ns_host_result.mDesc.GetLengths()[2]; ++n0)
{
for(size_t n1 = 0; n1 < e_ms_ns_host_result.mDesc.GetLengths()[3]; ++n1)
{
cde_element_op(e_ms_ns_host_result(m0, m1, n0, n1),
c_ms_ns_host_result(m0, m1, n0, n1));
}
}
}
}
return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1;
}
return 0;
}
......@@ -20,12 +20,12 @@
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp"
using XDataType = ck::half_t;
using GammaDataType = ck::half_t;
using BetaDataType = ck::half_t;
using YDataType = ck::half_t;
using AccDataType = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using XDataType = ck::half_t;
using GammaDataType = ck::half_t;
using BetaDataType = ck::half_t;
using YDataType = ck::half_t;
using ComputeDataType = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
constexpr int Rank = 2;
constexpr int NumReduceDim = 1;
......@@ -34,7 +34,7 @@ using DeviceInstance =
ck::tensor_operation::device::DeviceNormalizationImpl<XDataType,
GammaDataType,
BetaDataType,
AccDataType,
ComputeDataType,
YDataType,
PassThrough,
Rank,
......@@ -121,7 +121,7 @@ int main()
GammaDataType,
BetaDataType,
YDataType,
AccDataType,
ComputeDataType,
PassThrough,
Rank,
NumReduceDim>;
......
add_example_executable(example_batched_gemm_bias_e_permute_xdl_fp16 batched_gemm_bias_e_permute_xdl_fp16.cpp)
if(GPU_TARGETS MATCHES "gfx1100")
add_example_executable(example_batched_gemm_bias_e_permute_wmma_fp16 batched_gemm_bias_e_permute_wmma_fp16.cpp)
endif()
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/numeric.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Add = ck::tensor_operation::element_wise::Add;
using ADataType = F16;
using BDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F16;
using DDataType = F16;
using DsDataType = ck::Tuple<DDataType>;
using EDataType = F16;
static constexpr ck::index_t NumDimG = 2;
static constexpr ck::index_t NumDimM = 2;
static constexpr ck::index_t NumDimN = 2;
static constexpr ck::index_t NumDimK = 1;
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CDEElementOp = ck::tensor_operation::element_wise::Add;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto ABSpec = ck::tensor_operation::device::TensorSpecialization::Packed;
static constexpr auto DESpec = ck::tensor_operation::device::TensorSpecialization::Default;
using DeviceOpInstanceKKNN =
ck::tensor_operation::device::DeviceBatchedContractionMultipleD_Wmma_CShuffle<NumDimG,
NumDimM,
NumDimN,
NumDimK,
ADataType,
BDataType,
DsDataType,
EDataType,
AccDataType,
CShuffleDataType,
AElementOp,
BElementOp,
CDEElementOp,
GemmSpec,
ABSpec,
ABSpec,
DESpec,
256,
128,
256,
8,
8,
16,
16,
4,
4,
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
1,
1,
S<1, 32, 1, 8>,
8>;
using DeviceOpInstance = DeviceOpInstanceKKNN;
// hardcoded for NumDimM == NumDimN == NumDimK == 2
template <ck::index_t NumDimG,
ck::index_t NumDimM,
ck::index_t NumDimN,
ck::index_t NumDimK,
typename ADataType,
typename BDataType,
typename EDataType,
typename AccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
ck::enable_if_t<NumDimG == 2 && NumDimM == 2 && NumDimN == 2 && NumDimK == 1, bool> =
false>
struct ReferenceContraction_G2_M2_N2_K1 : public ck::tensor_operation::device::BaseOperator
{
// Argument
struct Argument : public ck::tensor_operation::device::BaseArgument
{
Argument(const Tensor<ADataType>& a_gs_ms_ks,
const Tensor<BDataType>& b_gs_ns_ks,
Tensor<EDataType>& e_gs_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
: a_gs_ms_ks_{a_gs_ms_ks},
b_gs_ns_ks_{b_gs_ns_ks},
e_gs_ms_ns_{e_gs_ms_ns},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op}
{
}
const Tensor<ADataType>& a_gs_ms_ks_;
const Tensor<BDataType>& b_gs_ns_ks_;
Tensor<EDataType>& e_gs_ms_ns_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
};
// Invoker
struct Invoker : public ck::tensor_operation::device::BaseInvoker
{
using Argument = ReferenceContraction_G2_M2_N2_K1::Argument;
float Run(const Argument& arg)
{
auto f_ms_ns = [&](auto g0, auto g1, auto m0, auto m1, auto n0, auto n1) {
const int K0 = arg.a_gs_ms_ks_.mDesc.GetLengths()[4];
AccDataType v_acc = 0;
for(int k0 = 0; k0 < K0; ++k0)
{
AccDataType v_a;
AccDataType v_b;
arg.a_element_op_(
v_a,
ck::type_convert<const AccDataType>(arg.a_gs_ms_ks_(g0, g1, m0, m1, k0)));
arg.b_element_op_(
v_b,
ck::type_convert<const AccDataType>(arg.b_gs_ns_ks_(g0, g1, n0, n1, k0)));
v_acc += v_a * v_b;
}
AccDataType v_c;
arg.cde_element_op_(v_c, v_acc);
arg.e_gs_ms_ns_(g0, g1, m0, m1, n0, n1) = v_c;
};
make_ParallelTensorFunctor(f_ms_ns,
arg.e_gs_ms_ns_.mDesc.GetLengths()[0],
arg.e_gs_ms_ns_.mDesc.GetLengths()[1],
arg.e_gs_ms_ns_.mDesc.GetLengths()[2],
arg.e_gs_ms_ns_.mDesc.GetLengths()[3],
arg.e_gs_ms_ns_.mDesc.GetLengths()[4],
arg.e_gs_ms_ns_.mDesc.GetLengths()[5])(
std::thread::hardware_concurrency());
return 0;
}
float Run(const ck::tensor_operation::device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override
{
return true;
}
static auto MakeArgument(const Tensor<ADataType>& a_gs_ms_ks,
const Tensor<BDataType>& b_gs_ns_ks,
Tensor<EDataType>& e_gs_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
{
return Argument{
a_gs_ms_ks, b_gs_ns_ks, e_gs_ms_ns, a_element_op, b_element_op, cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<ck::tensor_operation::device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceContraction_G2_M2_N2_K1"
<< std::endl;
// clang-format on
return str.str();
}
};
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = true;
ck::index_t G0 = 1;
ck::index_t G1 = 2;
ck::index_t M0 = 4;
ck::index_t M1 = 128;
ck::index_t N0 = 16;
ck::index_t N1 = 256;
ck::index_t K0 = 2048;
// A[G0, G1, M0, M1, K0]
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M0, M1, K0};
std::vector<ck::index_t> a_gs_ms_ks_strides{G1 * M0 * M1 * K0, M0 * M1 * K0, M1 * K0, K0, 1};
// B[G0, G1, N0, N1, K0]
std::vector<ck::index_t> b_gs_ns_ks_lengths{G0, G1, N0, N1, K0};
std::vector<ck::index_t> b_gs_ns_ks_strides{G1 * N0 * N1 * K0, N0 * N1 * K0, N1 * K0, K0, 1};
// D[G0, G1, M0, N0, M1, N1]
std::vector<ck::index_t> d_gs_ms_ns_lengths{G0, G1, M0, M1, N0, N1};
std::vector<ck::index_t> d_gs_ms_ns_strides{G1 * N0 * N1, N0 * N1, 0, 0, N1, 1};
// E[G0, G1, M0, N0, M1, N1]
std::vector<ck::index_t> e_gs_ms_ns_lengths{G0, G1, M0, M1, N0, N1};
std::vector<ck::index_t> e_gs_ms_ns_strides{
G1 * M0 * N0 * M1 * N1, M0 * N0 * M1 * N1, N0 * M1 * N1, N1, M1 * N1, 1};
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
exit(0);
}
Tensor<ADataType> a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
Tensor<BDataType> b_gs_ns_ks(b_gs_ns_ks_lengths, b_gs_ns_ks_strides);
Tensor<DDataType> d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides);
Tensor<EDataType> e_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides);
Tensor<EDataType> e_gs_ms_ns_device_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides);
std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl;
std::cout << "b_gs_ns_ks: " << b_gs_ns_ks.mDesc << std::endl;
std::cout << "d_gs_ms_ns: " << d_gs_ms_ns.mDesc << std::endl;
std::cout << "e_gs_ms_ns: " << e_gs_ms_ns_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2<DDataType>{-5, 5});
break;
default:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_3<DDataType>{-0.5, 0.5});
break;
}
DeviceMem a_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_gs_ns_ks.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf(sizeof(DDataType) * d_gs_ms_ns.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) *
e_gs_ms_ns_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_gs_ms_ks.mData.data());
b_device_buf.ToDevice(b_gs_ns_ks.mData.data());
d_device_buf.ToDevice(d_gs_ms_ns.mData.data());
// set zero
e_device_buf.SetZero();
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
// device operation
auto op = DeviceOpInstance{};
auto invoker = op.MakeInvoker();
auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
std::array<const void*, 1>{d_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(),
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b_gs_ns_ks_lengths,
b_gs_ns_ks_strides,
std::array<std::vector<ck::index_t>, 1>{d_gs_ms_ns_lengths},
std::array<std::vector<ck::index_t>, 1>{d_gs_ms_ns_strides},
e_gs_ms_ns_lengths,
e_gs_ms_ns_strides,
a_element_op,
b_element_op,
cde_element_op);
if(!op.IsSupportedArgument(argument))
{
std::cout << op.GetTypeString() << " does not support this problem" << std::endl;
return 0;
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
ck::index_t G =
ck::accumulate_n<ck::index_t>(e_gs_ms_ns_lengths.begin(), NumDimG, 1, std::multiplies<>{});
ck::index_t M = ck::accumulate_n<ck::index_t>(
e_gs_ms_ns_lengths.begin() + NumDimG, NumDimM, 1, std::multiplies<>{});
ck::index_t N = ck::accumulate_n<ck::index_t>(
e_gs_ms_ns_lengths.begin() + NumDimG + NumDimM, NumDimN, 1, std::multiplies<>{});
ck::index_t K = ck::accumulate_n<ck::index_t>(
a_gs_ms_ks_lengths.begin() + NumDimG + NumDimM, NumDimK, 1, std::multiplies<>{});
std::cout << "GMNK=" << G << ", " << M << ", " << N << ", " << K << std::endl;
std::size_t flop = std::size_t(2) * G * M * N * K;
std::size_t num_btype = sizeof(ADataType) * G * M * K + sizeof(BDataType) * G * K * N +
sizeof(DDataType) * G * M * N + sizeof(EDataType) * G * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< op.GetTypeString() << std::endl;
e_device_buf.FromDevice(e_gs_ms_ns_device_result.mData.data());
if(do_verification)
{
Tensor<CShuffleDataType> c_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides);
using ReferenceOpInstance = ReferenceContraction_G2_M2_N2_K1<NumDimG,
NumDimM,
NumDimN,
NumDimK,
ADataType,
BDataType,
CShuffleDataType,
AccDataType,
AElementOp,
BElementOp,
PassThrough>;
auto ref_gemm = ReferenceOpInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a_gs_ms_ks, b_gs_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op, PassThrough{});
ref_invoker.Run(ref_argument);
for(size_t g0 = 0; g0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[0]; ++g0)
{
for(size_t g1 = 0; g1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[1]; ++g1)
{
for(size_t m0 = 0; m0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[2]; ++m0)
{
for(size_t m1 = 0; m1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[3]; ++m1)
{
for(size_t n0 = 0; n0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[4]; ++n0)
{
for(size_t n1 = 0; n1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[5];
++n1)
{
cde_element_op(e_gs_ms_ns_host_result(g0, g1, m0, m1, n0, n1),
c_ms_ns_host_result(g0, g1, m0, m1, n0, n1),
d_gs_ms_ns(g0, g1, m0, m1, n0, n1));
}
}
}
}
}
}
return ck::utils::check_err(e_gs_ms_ns_device_result, e_gs_ms_ns_host_result) ? 0 : 1;
}
return 0;
}
......@@ -16,6 +16,9 @@ if(USE_BITINT_EXTENSION_INT4)
add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int4)
endif() # USE_BITINT_EXTENSION_INT4
if(GPU_TARGETS MATCHES "gfx1100")
add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_fp16 grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp)
endif()
add_example_executable(example_grouped_conv_fwd_xdl_fp16 grouped_conv_fwd_xdl_fp16.cpp)
......
......@@ -137,7 +137,7 @@ inline bool parse_cmd_args(int argc,
const ck::index_t num_dim_spatial = std::stoi(argv[4]);
conv_param = ck::utils::conv::parse_conv_param(
num_dim_spatial, threshold_to_catch_partial_args, argv);
num_dim_spatial, threshold_to_catch_partial_args + 1, argv);
}
else
{
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <algorithm>
#include <array>
#include <iostream>
#include <string>
#include <type_traits>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
using BF16 = ck::bhalf_t;
using FP16 = ck::half_t;
using FP32 = float;
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
using I4 = ck::int4_t;
#endif
using I8 = std::int8_t;
using I32 = std::int32_t;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto ConvSpec =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
template <typename InputLay, typename WeightLay, typename OutputLay>
struct CommonLayoutSetting
{
using InputLayout = InputLay;
using WeightLayout = WeightLay;
using OutputLayout = OutputLay;
};
template <ck::index_t NDimSpatial>
struct CommonLayoutSettingSelector;
namespace ctl = ck::tensor_layout::convolution;
template <>
struct CommonLayoutSettingSelector<1> final
: CommonLayoutSetting<ctl::G_NW_C, ctl::G_K_X_C, ctl::G_NW_K>
{
};
template <>
struct CommonLayoutSettingSelector<2> final
: CommonLayoutSetting<ctl::G_NHW_C, ctl::G_K_YX_C, ctl::G_NHW_K>
{
};
template <>
struct CommonLayoutSettingSelector<3> final
: CommonLayoutSetting<ctl::G_NDHW_C, ctl::G_K_ZYX_C, ctl::G_NDHW_K>
{
};
template <ck::index_t NDimSpatial>
using InputLayout = typename CommonLayoutSettingSelector<NDimSpatial>::InputLayout;
template <ck::index_t NDimSpatial>
using WeightLayout = typename CommonLayoutSettingSelector<NDimSpatial>::WeightLayout;
template <ck::index_t NDimSpatial>
using OutputLayout = typename CommonLayoutSettingSelector<NDimSpatial>::OutputLayout;
struct ExecutionConfig final
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = true;
};
#define DefaultConvParam \
ck::utils::conv::ConvParam \
{ \
2, 32, 2, 256, 192, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, { 1, 1 } \
}
inline void print_help_msg()
{
std::cerr << "arg1: verification (0=no, 1=yes)\n"
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"
<< "arg3: time kernel (0=no, 1=yes)\n"
<< ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl;
}
inline bool parse_cmd_args(int argc,
char* argv[],
ExecutionConfig& config,
ck::utils::conv::ConvParam& conv_param)
{
constexpr int num_execution_config_args =
3; // arguments for do_verification, init_method, time_kernel
constexpr int num_conv_param_leading_args = 5; // arguments for num_dim_spatial_, G_, N_, K_, C_
constexpr int threshold_to_catch_partial_args = 1 + num_execution_config_args;
constexpr int threshold_to_catch_all_args =
threshold_to_catch_partial_args + num_conv_param_leading_args;
if(argc == 1)
{
// use default
}
// catch only ExecutionConfig arguments
else if(argc == threshold_to_catch_partial_args)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
}
// catch both ExecutionConfig & ConvParam arguments
else if(threshold_to_catch_all_args < argc && ((argc - threshold_to_catch_all_args) % 3 == 0))
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
const ck::index_t num_dim_spatial = std::stoi(argv[4]);
conv_param = ck::utils::conv::parse_conv_param(
num_dim_spatial, threshold_to_catch_partial_args + 1, argv);
}
else
{
print_help_msg();
return false;
}
return true;
}
inline HostTensorDescriptor make_input_descriptor(const ck::utils::conv::ConvParam& conv_param)
{
switch(conv_param.num_dim_spatial_)
{
case 1:
return HostTensorDescriptor(
{conv_param.G_, conv_param.N_, conv_param.C_, conv_param.input_spatial_lengths_[0]},
{
conv_param.C_, // g
conv_param.input_spatial_lengths_[0] * conv_param.G_ * conv_param.C_, // n
1, // c
conv_param.G_ * conv_param.C_ // wi
});
case 2:
return HostTensorDescriptor(
{conv_param.G_,
conv_param.N_,
conv_param.C_,
conv_param.input_spatial_lengths_[0],
conv_param.input_spatial_lengths_[1]},
{
conv_param.C_, // g
conv_param.input_spatial_lengths_[0] * conv_param.input_spatial_lengths_[1] *
conv_param.G_ * conv_param.C_, // n
1, // c
conv_param.input_spatial_lengths_[1] * conv_param.G_ * conv_param.C_, // hi
conv_param.G_ * conv_param.C_ // wi
});
case 3:
return HostTensorDescriptor(
{conv_param.G_,
conv_param.N_,
conv_param.C_,
conv_param.input_spatial_lengths_[0],
conv_param.input_spatial_lengths_[1],
conv_param.input_spatial_lengths_[2]},
{
conv_param.C_, // g
conv_param.input_spatial_lengths_[0] * conv_param.input_spatial_lengths_[1] *
conv_param.input_spatial_lengths_[2] * conv_param.G_ * conv_param.C_, // n
1, // c
conv_param.input_spatial_lengths_[1] * conv_param.input_spatial_lengths_[2] *
conv_param.G_ * conv_param.C_, // di
conv_param.input_spatial_lengths_[2] * conv_param.G_ * conv_param.C_, // hi
conv_param.G_ * conv_param.C_ // wi
});
}
throw std::runtime_error("unsuppored # dim spatial");
}
inline HostTensorDescriptor make_weight_descriptor(const ck::utils::conv::ConvParam& conv_param)
{
switch(conv_param.num_dim_spatial_)
{
case 1:
return HostTensorDescriptor(
{conv_param.G_, conv_param.K_, conv_param.C_, conv_param.filter_spatial_lengths_[0]},
{
conv_param.K_ * conv_param.filter_spatial_lengths_[0] * conv_param.C_, // g
conv_param.filter_spatial_lengths_[0] * conv_param.C_, // k
1, // c
conv_param.C_ // x
});
case 2:
return HostTensorDescriptor(
{conv_param.G_,
conv_param.K_,
conv_param.C_,
conv_param.filter_spatial_lengths_[0],
conv_param.filter_spatial_lengths_[1]},
{
conv_param.K_ * conv_param.filter_spatial_lengths_[0] *
conv_param.filter_spatial_lengths_[1] * conv_param.C_, // g
conv_param.filter_spatial_lengths_[0] * conv_param.filter_spatial_lengths_[1] *
conv_param.C_, // k
1, // c
conv_param.filter_spatial_lengths_[1] * conv_param.C_, // y
conv_param.C_ // x
});
case 3:
return HostTensorDescriptor(
{conv_param.G_,
conv_param.K_,
conv_param.C_,
conv_param.filter_spatial_lengths_[0],
conv_param.filter_spatial_lengths_[1],
conv_param.filter_spatial_lengths_[2]},
{
conv_param.K_ * conv_param.filter_spatial_lengths_[0] *
conv_param.filter_spatial_lengths_[1] * conv_param.filter_spatial_lengths_[2] *
conv_param.C_, // g
conv_param.filter_spatial_lengths_[0] * conv_param.filter_spatial_lengths_[1] *
conv_param.filter_spatial_lengths_[2] * conv_param.C_, // k
1, // c
conv_param.filter_spatial_lengths_[1] * conv_param.filter_spatial_lengths_[2] *
conv_param.C_, // z
conv_param.filter_spatial_lengths_[2] * conv_param.C_, // y
conv_param.C_ // x
});
}
throw std::runtime_error("unsuppored # dim spatial");
}
inline HostTensorDescriptor make_bias_descriptor(const ck::utils::conv::ConvParam& conv_param)
{
switch(conv_param.num_dim_spatial_)
{
case 1:
return HostTensorDescriptor(
{conv_param.G_, conv_param.N_, conv_param.K_, conv_param.output_spatial_lengths_[0]},
{
conv_param.K_, // g
0, // k
1, // c
0 // x
});
case 2:
return HostTensorDescriptor({conv_param.G_,
conv_param.N_,
conv_param.K_,
conv_param.output_spatial_lengths_[0],
conv_param.output_spatial_lengths_[1]},
{
conv_param.K_, // g
0, // n
1, // k
0, // ho
0 // wo
});
case 3:
return HostTensorDescriptor({conv_param.G_,
conv_param.N_,
conv_param.K_,
conv_param.output_spatial_lengths_[0],
conv_param.output_spatial_lengths_[1],
conv_param.output_spatial_lengths_[2]},
{
conv_param.K_, // g
0, // n
1, // k
0, // z
0, // y
0 // x
});
}
throw std::runtime_error("unsuppored # dim spatial");
}
inline HostTensorDescriptor make_output_descriptor(const ck::utils::conv::ConvParam& conv_param)
{
switch(conv_param.num_dim_spatial_)
{
case 1:
return HostTensorDescriptor(
{conv_param.G_, conv_param.N_, conv_param.K_, conv_param.output_spatial_lengths_[0]},
{
conv_param.K_, // g
conv_param.output_spatial_lengths_[0] * conv_param.G_ * conv_param.K_, // n
1, // k
conv_param.G_ * conv_param.K_ // wo
});
case 2:
return HostTensorDescriptor(
{conv_param.G_,
conv_param.N_,
conv_param.K_,
conv_param.output_spatial_lengths_[0],
conv_param.output_spatial_lengths_[1]},
{
conv_param.K_, // g
conv_param.output_spatial_lengths_[0] * conv_param.output_spatial_lengths_[1] *
conv_param.G_ * conv_param.K_, // n
1, // k
conv_param.output_spatial_lengths_[1] * conv_param.G_ * conv_param.K_, // ho
conv_param.G_ * conv_param.K_ // wo
});
case 3:
return HostTensorDescriptor(
{conv_param.G_,
conv_param.N_,
conv_param.K_,
conv_param.output_spatial_lengths_[0],
conv_param.output_spatial_lengths_[1],
conv_param.output_spatial_lengths_[2]},
{
conv_param.K_, // g
conv_param.output_spatial_lengths_[0] * conv_param.output_spatial_lengths_[1] *
conv_param.output_spatial_lengths_[2] * conv_param.G_ * conv_param.K_, // n
1, // k
conv_param.output_spatial_lengths_[1] * conv_param.output_spatial_lengths_[2] *
conv_param.G_ * conv_param.K_, // do
conv_param.output_spatial_lengths_[2] * conv_param.G_ * conv_param.K_, // ho
conv_param.G_ * conv_param.K_ // wo
});
}
throw std::runtime_error("unsuppored # dim spatial");
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "common_wmma.hpp"
// kernel data types
using InKernelDataType = FP16;
using WeiKernelDataType = FP16;
using AccDataType = FP32;
using CShuffleDataType = FP16;
using BiasKernelDataType = FP16;
using ResidualKernelDataType = FP16;
using OutKernelDataType = FP16;
// tensor data types
using InUserDataType = InKernelDataType;
using WeiUserDataType = WeiKernelDataType;
using OutUserDataType = OutKernelDataType;
using InElementOp = PassThrough;
using WeiElementOp = PassThrough;
using OutElementOp = ck::tensor_operation::element_wise::AddReluAdd;
#include "run_grouped_conv_fwd_bias_relu_add_wmma_example.inc"
int main(int argc, char* argv[]) { return !run_grouped_conv_fwd_bias_relu_add_example(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
template <typename BiasLay, typename ResidualLay>
struct LayoutSetting
{
using BiasLayout = BiasLay;
using ResidualLayout = ResidualLay;
};
template <ck::index_t NDimSpatial>
struct LayoutSettingSelector;
template <>
struct LayoutSettingSelector<1> final : LayoutSetting<ctl::G_K, ctl::G_NW_K>
{
};
template <>
struct LayoutSettingSelector<2> final : LayoutSetting<ctl::G_K, ctl::G_NHW_K>
{
};
template <>
struct LayoutSettingSelector<3> final : LayoutSetting<ctl::G_K, ctl::G_NDHW_K>
{
};
template <ck::index_t NDimSpatial>
using BiasLayout = typename LayoutSettingSelector<NDimSpatial>::BiasLayout;
template <ck::index_t NDimSpatial>
using ResidualLayout = typename LayoutSettingSelector<NDimSpatial>::ResidualLayout;
template <ck::index_t NDimSpatial>
using DeviceConvFwdInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<
NDimSpatial,
InputLayout<NDimSpatial>,
WeightLayout<NDimSpatial>,
ck::Tuple<BiasLayout<NDimSpatial>, ResidualLayout<NDimSpatial>>,
OutputLayout<NDimSpatial>,
InKernelDataType,
WeiKernelDataType,
ck::Tuple<BiasKernelDataType, ResidualKernelDataType>,
OutKernelDataType,
AccDataType,
CShuffleDataType,
InElementOp,
WeiElementOp,
OutElementOp,
ConvSpec, // ConvForwardSpecialization
GemmSpec, // GemmSpecialization
256, // BlockSize
128, // MPerBlock
128, // NPerBlock
4, // K0PerBlock
8, // K1
16, // MPerWMMA
16, // NPerWMMA
4, // MRepeat
2, // NRepeat
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
true, // ABlockLdsExtraM
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
true, // BBlockLdsExtraN
1,
1,
S<1, 32, 1, 8>,
8>;
template <ck::index_t NDimSpatial>
using HostConvFwdInstance = ck::tensor_operation::host::ReferenceConvFwd<NDimSpatial,
InUserDataType,
WeiUserDataType,
CShuffleDataType,
InElementOp,
WeiElementOp,
PassThrough>;
template <ck::index_t NDimSpatial>
bool run_grouped_conv_fwd_bias_relu_add(const ExecutionConfig& config,
const ck::utils::conv::ConvParam& conv_param)
{
static_assert(1 <= NDimSpatial && NDimSpatial <= 3, "Unsupported NDimSpatial");
const auto in_g_n_c_wis_desc = make_input_descriptor(conv_param);
const auto wei_g_k_c_xs_desc = make_weight_descriptor(conv_param);
const auto bias_g_n_k_wos_desc = make_bias_descriptor(conv_param);
const auto out_g_n_k_wos_desc = make_output_descriptor(conv_param);
Tensor<InUserDataType> in(in_g_n_c_wis_desc);
Tensor<WeiUserDataType> wei(wei_g_k_c_xs_desc);
Tensor<OutUserDataType> bias(bias_g_n_k_wos_desc);
Tensor<OutUserDataType> residual(bias_g_n_k_wos_desc);
Tensor<OutUserDataType> out_host(out_g_n_k_wos_desc);
Tensor<OutKernelDataType> out_device(out_g_n_k_wos_desc);
std::cout << "in: " << in.mDesc << std::endl;
std::cout << "wei: " << wei.mDesc << std::endl;
std::cout << "bias: " << bias.mDesc << std::endl;
std::cout << "residual: " << residual.mDesc << std::endl;
std::cout << "out: " << out_host.mDesc << std::endl;
switch(config.init_method)
{
case 0: break;
case 1:
in.GenerateTensorValue(GeneratorTensor_2<InUserDataType>{-5, 5});
wei.GenerateTensorValue(GeneratorTensor_2<WeiUserDataType>{-5, 5});
bias.GenerateTensorValue(GeneratorTensor_2<OutUserDataType>{-5, 5});
break;
default:
in.GenerateTensorValue(GeneratorTensor_3<InUserDataType>{0.0, 1.0});
wei.GenerateTensorValue(GeneratorTensor_3<WeiUserDataType>{-0.5, 0.5});
bias.GenerateTensorValue(GeneratorTensor_3<OutUserDataType>{-0.5, 0.5});
}
DeviceMem in_device_buf(sizeof(InKernelDataType) * in.mDesc.GetElementSpaceSize());
DeviceMem wei_device_buf(sizeof(WeiKernelDataType) * wei.mDesc.GetElementSpaceSize());
DeviceMem bias_device_buf(sizeof(OutKernelDataType) * bias.mDesc.GetElementSpaceSize());
DeviceMem residual_device_buf(sizeof(OutKernelDataType) * residual.mDesc.GetElementSpaceSize());
DeviceMem out_device_buf(sizeof(OutKernelDataType) * out_device.mDesc.GetElementSpaceSize());
#ifdef BUILD_INT4_EXAMPLE
const Tensor<InKernelDataType> in_converted(in);
const Tensor<WeiKernelDataType> wei_converted(wei);
const Tensor<OutKernelDataType> bias_converted(bias);
const Tensor<OutKernelDataType> residual_converted(residual);
in_device_buf.ToDevice(in_converted.mData.data());
wei_device_buf.ToDevice(wei_converted.mData.data());
bias_device_buf.ToDevice(bias_converted.mData.data());
residual_device_buf.ToDevice(residual_converted.mData.data());
#else
in_device_buf.ToDevice(in.mData.data());
wei_device_buf.ToDevice(wei.mData.data());
bias_device_buf.ToDevice(bias.mData.data());
residual_device_buf.ToDevice(residual.mData.data());
#endif
std::array<ck::index_t, NDimSpatial + 3> a_g_n_c_wis_lengths{};
std::array<ck::index_t, NDimSpatial + 3> a_g_n_c_wis_strides{};
std::array<ck::index_t, NDimSpatial + 3> b_g_k_c_xs_lengths{};
std::array<ck::index_t, NDimSpatial + 3> b_g_k_c_xs_strides{};
std::array<ck::index_t, NDimSpatial + 3> d0_g_n_k_wos_lengths{};
std::array<ck::index_t, NDimSpatial + 3> d0_g_n_k_wos_strides{};
std::array<ck::index_t, NDimSpatial + 3> d1_g_n_k_wos_lengths{};
std::array<ck::index_t, NDimSpatial + 3> d1_g_n_k_wos_strides{};
std::array<ck::index_t, NDimSpatial + 3> e_g_n_k_wos_lengths{};
std::array<ck::index_t, NDimSpatial + 3> e_g_n_k_wos_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_dilations{};
std::array<ck::index_t, NDimSpatial> input_left_pads{};
std::array<ck::index_t, NDimSpatial> input_right_pads{};
auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); };
copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths);
copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides);
copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths);
copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides);
copy(bias_g_n_k_wos_desc.GetLengths(), d0_g_n_k_wos_lengths);
copy(bias_g_n_k_wos_desc.GetStrides(), d0_g_n_k_wos_strides);
copy(bias_g_n_k_wos_desc.GetLengths(), d1_g_n_k_wos_lengths);
copy(bias_g_n_k_wos_desc.GetStrides(), d1_g_n_k_wos_strides);
copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths);
copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides);
copy(conv_param.conv_filter_strides_, conv_filter_strides);
copy(conv_param.conv_filter_dilations_, conv_filter_dilations);
copy(conv_param.input_left_pads_, input_left_pads);
copy(conv_param.input_right_pads_, input_right_pads);
// do Conv
auto conv = DeviceConvFwdInstance<NDimSpatial>{};
auto invoker = conv.MakeInvoker();
auto argument =
conv.MakeArgument(in_device_buf.GetDeviceBuffer(),
wei_device_buf.GetDeviceBuffer(),
std::array<const void*, 2>{bias_device_buf.GetDeviceBuffer(),
residual_device_buf.GetDeviceBuffer()},
out_device_buf.GetDeviceBuffer(),
a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
std::array<std::array<ck::index_t, NDimSpatial + 3>, 2>{
{d0_g_n_k_wos_lengths, d1_g_n_k_wos_lengths}},
std::array<std::array<ck::index_t, NDimSpatial + 3>, 2>{
{d0_g_n_k_wos_strides, d1_g_n_k_wos_strides}},
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
InElementOp{},
WeiElementOp{},
OutElementOp{});
if(!conv.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_conv with the specified compilation parameters does "
"not support this Conv problem");
}
float avg_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
std::size_t flop = conv_param.GetFlops();
std::size_t num_btype = conv_param.GetByte<InUserDataType, WeiUserDataType, OutUserDataType>();
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
float gb_per_sec = num_btype / 1.E6 / avg_time;
std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< conv.GetTypeString() << std::endl;
if(config.do_verification)
{
Tensor<CShuffleDataType> c_host(out_g_n_k_wos_desc);
auto ref_conv = HostConvFwdInstance<NDimSpatial>{};
auto ref_invoker = ref_conv.MakeInvoker();
auto ref_argument = ref_conv.MakeArgument(in,
wei,
c_host,
conv_param.conv_filter_strides_,
conv_param.conv_filter_dilations_,
conv_param.input_left_pads_,
conv_param.input_right_pads_,
InElementOp{},
WeiElementOp{},
PassThrough{});
ref_invoker.Run(ref_argument);
// TODO: implement elementwise operation for host
out_host.ForEach([&](auto&, auto idx) {
OutElementOp{}(out_host(idx), c_host(idx), bias(idx), residual(idx));
});
out_device_buf.FromDevice(out_device.mData.data());
#ifdef BUILD_INT4_EXAMPLE
const Tensor<OutUserDataType> out_device_converted(out_device);
return ck::utils::check_err(
out_device_converted, out_host, "Error: incorrect results!", 1e-5f, 1e-4f);
#else
return ck::utils::check_err(
out_device, out_host, "Error: incorrect results!", 1e-5f, 1e-4f);
#endif
}
return true;
}
bool run_grouped_conv_fwd_bias_relu_add_example(int argc, char* argv[])
{
ExecutionConfig config;
ck::utils::conv::ConvParam conv_param = DefaultConvParam;
if(!parse_cmd_args(argc, argv, config, conv_param))
{
return false;
}
switch(conv_param.num_dim_spatial_)
{
case 1: return run_grouped_conv_fwd_bias_relu_add<1>(config, conv_param);
case 2: return run_grouped_conv_fwd_bias_relu_add<2>(config, conv_param);
case 3: return run_grouped_conv_fwd_bias_relu_add<3>(config, conv_param);
}
return false;
}
......@@ -23,11 +23,11 @@
constexpr int Rank = 5;
constexpr int NumReduceDim = 3;
using XDataType = ck::half_t;
using GammaDataType = ck::half_t;
using BetaDataType = ck::half_t;
using YDataType = ck::half_t;
using AccDataType = float;
using XDataType = ck::half_t;
using GammaDataType = ck::half_t;
using BetaDataType = ck::half_t;
using YDataType = ck::half_t;
using ComputeDataType = float;
struct YElementOp
{
......@@ -50,7 +50,7 @@ using DeviceInstance =
ck::tensor_operation::device::DeviceNormalizationImpl<XDataType,
GammaDataType,
BetaDataType,
AccDataType,
ComputeDataType,
YDataType,
YElementOp,
Rank,
......@@ -157,7 +157,7 @@ int main(int argc, char* argv[])
GammaDataType,
BetaDataType,
YDataType,
AccDataType,
ComputeDataType,
YElementOp>;
ReferenceInstance ref;
......
......@@ -53,7 +53,6 @@ bool run_gemm_add_multiply(const ProblemSize& problem_size, const ExecutionConfi
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
d0_device_buf.ToDevice(d0_m_n.mData.data());
......@@ -84,8 +83,8 @@ bool run_gemm_add_multiply(const ProblemSize& problem_size, const ExecutionConfi
if(!device_op.IsSupportedArgument(argument))
{
std::cout << "wrong! this device_op instance does not support this problem" << std::endl;
return true;
std::cout << "wrong! this device_op instance does not support this problem" << std::endl;
return true;
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
......
......@@ -168,13 +168,6 @@
// tuning parameter
#define CK_WORKAROUND_SWDEV_325164 0
// workaround: a BF16 attention kernel for gfx908 is likely affected by a compiler issue
#ifdef __gfx908__
#define CK_WORKAROUND_SWDEV_XXXXXX_BF16_ATTEN_FWD_GFX908_ISSUE 1
#else // __gfx90a__, ...
#define CK_WORKAROUND_SWDEV_XXXXXX_BF16_ATTEN_FWD_GFX908_ISSUE 0
#endif // __gfx908__
// flag to enable (1) or disable (0) the debugging output in some kernels
#define DEBUG_LOG 0
......
......@@ -20,6 +20,7 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
#if CK_TIME_KERNEL
if(stream_config.time_kernel_)
{
#if DEBUG_LOG
printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n",
__func__,
grid_dim.x,
......@@ -29,15 +30,15 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
block_dim.y,
block_dim.z);
const int nrepeat = 10;
printf("Warm up 1 time\n");
#endif
// warm up
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
const int nrepeat = 10;
#if DEBUG_LOG
printf("Start running %d times...\n", nrepeat);
#endif
hipEvent_t start, stop;
hip_check_error(hipEventCreate(&start));
......
......@@ -14,9 +14,9 @@ namespace device {
template <typename XDataType,
typename GammaDataType,
typename BetaDataType,
typename AccDataType,
typename ComputeDataType,
typename YDataType,
typename AccElementwiseOperation,
typename YElementwiseOperation,
index_t Rank,
index_t NumReduceDim>
struct DeviceNormalization : public BaseOperator
......@@ -35,7 +35,7 @@ struct DeviceNormalization : public BaseOperator
void* p_y,
void* p_savedMean,
void* p_savedInvVar,
AccElementwiseOperation acc_elementwise_op) = 0;
YElementwiseOperation y_elementwise_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
......@@ -43,17 +43,17 @@ struct DeviceNormalization : public BaseOperator
template <typename XDataType,
typename GammaDataType,
typename BetaDataType,
typename AccDataType,
typename ComputeDataType,
typename YDataType,
typename AccElementwiseOperation,
typename YElementwiseOperation,
index_t Rank,
index_t NumReduceDim>
using DeviceNormalizationPtr = std::unique_ptr<DeviceNormalization<XDataType,
GammaDataType,
BetaDataType,
AccDataType,
ComputeDataType,
YDataType,
AccElementwiseOperation,
YElementwiseOperation,
Rank,
NumReduceDim>>;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
// Tensor Contraction:
// input : A
// input : B
// input : D0, D1, ...
// output : E
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// A[G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...]
// B[G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...]
// D[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...]
// E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...]
// NOTE: TensorSpecialization::Packed specialized tensor is "packed" in a sense that each inner
// dimension in a dimension group (eg [G0, G1] in Gs, [M0, M1, M2] in Ms, etc.) are contiguous and
// ordered. Not in a sense that the tensor [G0, G1, ..., M0, M1, ..., N0, N1...] can be permuted
// while still being a contiguous, unpadded tensor. In other words, it merely degenerates into
// TensorSpecialization::Default with NumDimG/M/N/K = 1
//
// Detail- Packed tensor satisfies
// stride_0 = 1
// stride_i = stride_{i - 1} * extent_{i - 1}
// So tensor
// [G0, G1, G2, M, N]
// transposed into tensor
// [G0, G2, G1, M, N]
// with strides
// [G2 * G1 * M * N, G1 * M * N, M * N, N, 1]
// is again a packed tensor. MakeGridDescriptor() currently just merges dimensions and ignores some
// strides from input tensor extents so finer dimension information is lost. Merging dimensions is
// essentially a degenerated case of TensorSpecialization::Default with NumDimG/M/N/K = 1.
//
// Might need to expose dimension order to the interface to fully support
// TensorSpecialization::Packed in a traditional sense of "packed" tensor
template <index_t NumDimG,
index_t NumDimM,
index_t NumDimN,
index_t NumDimK,
typename ADataType,
typename BDataType,
typename DsDataType,
typename EDataType,
typename AccDataType,
typename CShuffleDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
GemmSpecialization GemmSpec,
TensorSpecialization ASpec,
TensorSpecialization BSpec,
TensorSpecialization DESpec,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
ck::index_t K0PerBlock,
ck::index_t K1,
ck::index_t MPerWMMA,
ck::index_t NPerWMMA,
ck::index_t MRepeat,
ck::index_t NRepeat,
typename ABlockTransferThreadClusterLengths_K0_M_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
ck::index_t ABlockTransferSrcVectorDim,
ck::index_t ABlockTransferSrcScalarPerVector,
ck::index_t ABlockTransferDstScalarPerVector_K1,
bool ABlockLdsAddExtraM,
typename BBlockTransferThreadClusterLengths_K0_N_K1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
ck::index_t BBlockTransferSrcVectorDim,
ck::index_t BBlockTransferSrcScalarPerVector,
ck::index_t BBlockTransferDstScalarPerVector_K1,
bool BBlockLdsAddExtraN,
index_t CShuffleMRepeatPerShuffle,
index_t CShuffleNRepeatPerShuffle,
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
ck::index_t NumPrefetch = 1,
ck::LoopScheduler LoopSched = make_default_loop_scheduler(),
ck::PipelineVersion PipelineVer = ck::PipelineVersion::v1>
struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
: public DeviceBatchedContractionMultipleD<NumDimG,
NumDimM,
NumDimN,
NumDimK,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
{
using DeviceOp = DeviceBatchedContractionMultipleD_Wmma_CShuffle;
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
// K1 = Max Vector Access Pixels
static constexpr auto K1Number = Number<K1>{};
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, K0PerBlock* K1};
// Assume: A[G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...]
static auto MakeAGridDescriptor_M_K(const std::vector<index_t>& a_gs_ms_ks_lengths_vec,
const std::vector<index_t>& a_gs_ms_ks_strides_vec)
{
assert(a_gs_ms_ks_lengths_vec.size() == NumDimG + NumDimM + NumDimK &&
a_gs_ms_ks_strides_vec.size() == NumDimG + NumDimM + NumDimK);
const auto to_tuple = [&](auto& vec, auto start, auto end) {
return generate_tuple([&](auto i) { return vec[start + i]; }, Number<end - start>{});
};
const auto a_ms_ks_lengths = to_tuple(
a_gs_ms_ks_lengths_vec, Number<NumDimG>{}, Number<NumDimG + NumDimM + NumDimK>{});
const auto a_ms_ks_strides = to_tuple(
a_gs_ms_ks_strides_vec, Number<NumDimG>{}, Number<NumDimG + NumDimM + NumDimK>{});
// dimension Ids for M0, M1, ...
constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{};
// dimension Ids for K0, K1, ...
constexpr auto kDimIds =
typename arithmetic_sequence_gen<NumDimM, NumDimM + NumDimK, 1>::type{};
// lengths for M0, M1, ...
const auto mLengths = get_container_subset(a_ms_ks_lengths, mDimIds);
// lengths for K0, K1, ...
const auto kLengths = get_container_subset(a_ms_ks_lengths, kDimIds);
if constexpr(ASpec == TensorSpecialization::Packed)
{
auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{});
auto K = container_reduce(kLengths, math::multiplies{}, Number<1>{});
const auto a_grid_desc_mraw_kraw = make_naive_tensor_descriptor(
make_tuple(M, K),
make_tuple(a_ms_ks_strides[Number<NumDimM - 1>{}],
a_ms_ks_strides[Number<NumDimM + NumDimK - 1>{}]));
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
}
else
{
// naive tensor A[M0, M1, M2, ..., K0, K1, K2...]
const auto a_grid_desc_ms_ks =
make_naive_tensor_descriptor(a_ms_ks_lengths, a_ms_ks_strides);
// transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...]
const auto a_grid_desc_mraw_kraw = transform_tensor_descriptor(
a_grid_desc_ms_ks,
make_tuple(make_merge_transform(mLengths), make_merge_transform(kLengths)),
make_tuple(mDimIds, kDimIds),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
}
}
// Assume: B[G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...]
static auto MakeBGridDescriptor_N_K(const std::vector<index_t>& b_gs_ns_ks_lengths_vec,
const std::vector<index_t>& b_gs_ns_ks_strides_vec)
{
assert(b_gs_ns_ks_lengths_vec.size() == NumDimG + NumDimN + NumDimK &&
b_gs_ns_ks_strides_vec.size() == NumDimG + NumDimN + NumDimK);
const auto to_tuple = [&](auto& vec, auto start, auto end) {
return generate_tuple([&](auto i) { return vec[start + i]; }, Number<end - start>{});
};
const auto b_ns_ks_lengths = to_tuple(
b_gs_ns_ks_lengths_vec, Number<NumDimG>{}, Number<NumDimG + NumDimN + NumDimK>{});
const auto b_ns_ks_strides = to_tuple(
b_gs_ns_ks_strides_vec, Number<NumDimG>{}, Number<NumDimG + NumDimN + NumDimK>{});
// dimension Ids for N0, N1, ...
constexpr auto nDimIds = typename arithmetic_sequence_gen<0, NumDimN, 1>::type{};
// dimension Ids for K0, K1, ...
constexpr auto kDimIds =
typename arithmetic_sequence_gen<NumDimN, NumDimN + NumDimK, 1>::type{};
// lengths for K0, K1, ...
const auto kLengths = get_container_subset(b_ns_ks_lengths, kDimIds);
// lengths for N0, N1, ...
const auto nLengths = get_container_subset(b_ns_ks_lengths, nDimIds);
if constexpr(BSpec == TensorSpecialization::Packed)
{
auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{});
auto K = container_reduce(kLengths, math::multiplies{}, Number<1>{});
const auto b_grid_desc_nraw_kraw = make_naive_tensor_descriptor(
make_tuple(N, K),
make_tuple(b_ns_ks_strides[Number<NumDimN - 1>{}],
b_ns_ks_strides[Number<NumDimN + NumDimK - 1>{}]));
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
}
else
{
// naive tensor B[N0, N1, N2, ..., K0, K1, K2, ...]
const auto b_grid_desc_ns_ks =
make_naive_tensor_descriptor(b_ns_ks_lengths, b_ns_ks_strides);
// transformed tensor B[NRaw = N0 * N1 * N2 * ..., KRaw = K0 * K1 * K2 * ...]
const auto b_grid_desc_nraw_kraw = transform_tensor_descriptor(
b_grid_desc_ns_ks,
make_tuple(make_merge_transform(nLengths), make_merge_transform(kLengths)),
make_tuple(nDimIds, kDimIds),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
}
}
// assume E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
static auto MakeEGridDescriptor_M_N(const std::vector<index_t>& e_gs_ms_ns_lengths_vec,
const std::vector<index_t>& e_gs_ms_ns_strides_vec)
{
assert(e_gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN &&
e_gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN);
const auto to_tuple = [&](auto& vec, auto start, auto end) {
return generate_tuple([&](auto i) { return vec[start + i]; }, Number<end - start>{});
};
const auto e_ms_ns_lengths = to_tuple(
e_gs_ms_ns_lengths_vec, Number<NumDimG>{}, Number<NumDimG + NumDimM + NumDimN>{});
const auto e_ms_ns_strides = to_tuple(
e_gs_ms_ns_strides_vec, Number<NumDimG>{}, Number<NumDimG + NumDimM + NumDimN>{});
// dimension Ids for M0, M1, ...
constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{};
// dimension Ids for N0, N1, ...
constexpr auto nDimIds =
typename arithmetic_sequence_gen<NumDimM, NumDimM + NumDimN, 1>::type{};
// lengths for M0, M1, ...
const auto mLengths = get_container_subset(e_ms_ns_lengths, mDimIds);
// lengths for K0, K1, ...
const auto nLengths = get_container_subset(e_ms_ns_lengths, nDimIds);
if constexpr(DESpec == TensorSpecialization::Packed)
{
auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{});
auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{});
const auto e_grid_desc_mraw_nraw = make_naive_tensor_descriptor(
make_tuple(M, N),
make_tuple(e_ms_ns_strides[Number<NumDimM - 1>{}],
e_ms_ns_strides[Number<NumDimM + NumDimN - 1>{}]));
return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
}
else
{
// naive tensor E[M0, M1, M2, ..., N0, N1, N2...]
const auto e_grid_desc_ms_ns =
make_naive_tensor_descriptor(e_ms_ns_lengths, e_ms_ns_strides);
// transformed tensor E[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * N2 * ...]
const auto e_grid_desc_mraw_nraw = transform_tensor_descriptor(
e_grid_desc_ms_ns,
make_tuple(make_merge_transform(mLengths), make_merge_transform(nLengths)),
make_tuple(mDimIds, nDimIds),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
}
}
// assume E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
static auto MakeEGridDescriptor_G_M_N(const std::vector<index_t>& e_gs_ms_ns_lengths_vec,
const std::vector<index_t>& e_gs_ms_ns_strides_vec)
{
assert(e_gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN &&
e_gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN);
const auto to_tuple = [&](auto& vec, auto start, auto end) {
return generate_tuple([&](auto i) { return vec[start + i]; }, Number<end - start>{});
};
const auto e_gs_ms_ns_lengths =
to_tuple(e_gs_ms_ns_lengths_vec, Number<0>{}, Number<NumDimG + NumDimM + NumDimN>{});
const auto e_gs_ms_ns_strides =
to_tuple(e_gs_ms_ns_strides_vec, Number<0>{}, Number<NumDimG + NumDimM + NumDimN>{});
// dimension Ids for G0, G1, ...
constexpr auto gDimIds = typename arithmetic_sequence_gen<0, NumDimG, 1>::type{};
// dimension Ids for M0, M1, ...
constexpr auto mDimIds =
typename arithmetic_sequence_gen<NumDimG, NumDimG + NumDimM, 1>::type{};
// dimension Ids for N0, N1, ...
constexpr auto nDimIds = typename arithmetic_sequence_gen<NumDimG + NumDimM,
NumDimG + NumDimM + NumDimN,
1>::type{};
// lengths for G0, G1, ...
const auto gLengths = get_container_subset(e_gs_ms_ns_lengths, gDimIds);
// lengths for M0, M1, ...
const auto mLengths = get_container_subset(e_gs_ms_ns_lengths, mDimIds);
// lengths for K0, K1, ...
const auto nLengths = get_container_subset(e_gs_ms_ns_lengths, nDimIds);
if constexpr(DESpec == TensorSpecialization::Packed)
{
auto G = container_reduce(gLengths, math::multiplies{}, Number<1>{});
auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{});
auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{});
const auto e_grid_desc_g_mraw_nraw = make_naive_tensor_descriptor(
make_tuple(G, M, N),
make_tuple(e_gs_ms_ns_strides[Number<NumDimG - 1>{}],
e_gs_ms_ns_strides[Number<NumDimG + NumDimM - 1>{}],
e_gs_ms_ns_strides[Number<NumDimG + NumDimM + NumDimN - 1>{}]));
// return matrix_padder.PadCDescriptor_M_N(e_grid_desc_g_mraw_nraw);
return e_grid_desc_g_mraw_nraw;
}
else
{
// naive tensor E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
const auto e_grid_desc_gs_ms_ns =
make_naive_tensor_descriptor(e_gs_ms_ns_lengths, e_gs_ms_ns_strides);
// transformed tensor E[G = G0 * G1 * ..., MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 *
// N2 * ...]
const auto e_grid_desc_g_mraw_nraw = transform_tensor_descriptor(
e_grid_desc_gs_ms_ns,
make_tuple(make_merge_transform(gLengths),
make_merge_transform(mLengths),
make_merge_transform(nLengths)),
make_tuple(gDimIds, mDimIds, nDimIds),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// return matrix_padder.PadCDescriptor_M_N(e_grid_desc_g_mraw_nraw);
return e_grid_desc_g_mraw_nraw;
}
}
static auto MakeDsGridDescriptor_M_N(
const std::array<std::vector<index_t>, NumDTensor>& ds_gs_ms_ns_lengths_vec,
const std::array<std::vector<index_t>, NumDTensor>& ds_gs_ms_ns_strides_vec)
{
return generate_tuple(
[&](auto i) {
return DeviceOp::MakeEGridDescriptor_M_N(ds_gs_ms_ns_lengths_vec[i],
ds_gs_ms_ns_strides_vec[i]);
},
Number<NumDTensor>{});
}
static auto MakeDsGridDescriptor_G_M_N(
const std::array<std::vector<index_t>, NumDTensor>& ds_gs_ms_ns_lengths_vec,
const std::array<std::vector<index_t>, NumDTensor>& ds_gs_ms_ns_strides_vec)
{
return generate_tuple(
[&](auto i) {
return DeviceOp::MakeEGridDescriptor_G_M_N(ds_gs_ms_ns_lengths_vec[i],
ds_gs_ms_ns_strides_vec[i]);
},
Number<NumDTensor>{});
}
// Gridwise descriptor, mapping to whole given provblem.
using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K({}, {}));
using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K({}, {}));
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}))>;
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {}));
using DsGridDesc_G_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_G_M_N({}, {}))>;
using EGridDesc_G_M_N = decltype(MakeEGridDescriptor_G_M_N({}, {}));
struct ComputePtrOffsetOfStridedBatch
{
ComputePtrOffsetOfStridedBatch(index_t batch_stride_A,
index_t batch_stride_B,
DsGridDesc_G_M_N ds_grid_desc_g_m_n,
EGridDesc_G_M_N e_grid_desc_g_m_n)
: batch_stride_A_(batch_stride_A),
batch_stride_B_(batch_stride_B),
ds_grid_desc_g_m_n_(ds_grid_desc_g_m_n),
e_grid_desc_g_m_n_(e_grid_desc_g_m_n)
{
}
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
{
return static_cast<long_index_t>(g_idx) * batch_stride_A_;
}
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
{
return static_cast<long_index_t>(g_idx) * batch_stride_B_;
}
__host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const
{
std::array<long_index_t, NumDTensor> ds_offset;
static_for<0, NumDTensor, 1>{}([&](auto i) {
ds_offset[i] = static_cast<long_index_t>(g_idx) *
ds_grid_desc_g_m_n_[i].CalculateOffset(make_multi_index(1, 0, 0));
});
return ds_offset;
}
__host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const
{
return static_cast<long_index_t>(g_idx) *
e_grid_desc_g_m_n_.CalculateOffset(make_multi_index(1, 0, 0));
}
private:
index_t batch_stride_A_;
index_t batch_stride_B_;
DsGridDesc_G_M_N ds_grid_desc_g_m_n_;
EGridDesc_G_M_N e_grid_desc_g_m_n_;
};
// A desc for source in blockwise copy
template <typename AGridDesc_M_K>
__host__ __device__ static constexpr auto
MakeAGridDescriptor_K0_M_K1(const AGridDesc_M_K& a_grid_desc_m_k)
{
const auto M = a_grid_desc_m_k.GetLength(I0);
const auto K = a_grid_desc_m_k.GetLength(I1);
const auto AK0 = K / K1;
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, K1)), make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
// B desc for source in blockwise copy
template <typename BGridDesc_N_K>
__host__ __device__ static constexpr auto
MakeBGridDescriptor_K0_N_K1(const BGridDesc_N_K& b_grid_desc_n_k)
{
const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = b_grid_desc_n_k.GetLength(I1);
const auto BK0 = K / K1;
return transform_tensor_descriptor(
b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, K1)), make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
using AGridDesc_K0_M_K1 = decltype(DeviceOp::MakeAGridDescriptor_K0_M_K1(AGridDesc_M_K{}));
using BGridDesc_K0_N_K1 = decltype(DeviceOp::MakeBGridDescriptor_K0_N_K1(BGridDesc_N_K{}));
// GridwiseOp
using GridwiseOp = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle<
// DataType Family
ADataType,
BDataType,
AccDataType,
CShuffleDataType,
DsDataType,
EDataType,
// InMemory Data Descriptor
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
DsGridDesc_M_N,
EGridDesc_M_N,
// ElementwiseOp Family
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
InMemoryDataOperationEnum::Set,
// Tiling Family
MPerBlock,
NPerBlock,
K0PerBlock,
MPerWMMA,
NPerWMMA,
K1,
MRepeat,
NRepeat,
// ThreadCluster Family
BlockSize,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
false, // AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsAddExtraM,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
false, // BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEShuffleBlockTransferScalarPerVector_NPerBlock,
NumPrefetch,
LoopSched,
PipelineVer>;
// Argument
struct Argument : public BaseArgument
{
Argument(const void* p_a_grid,
const void* p_b_grid,
std::array<const void*, NumDTensor> p_ds_grid,
void* p_e_grid,
const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& b_gs_ns_ks_lengths,
const std::array<std::vector<index_t>, NumDTensor>& ds_gs_ms_ns_lengths,
const std::vector<index_t>& e_gs_ms_ns_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_strides,
const std::array<std::vector<index_t>, NumDTensor>& ds_gs_ms_ns_strides,
const std::vector<index_t>& e_gs_ms_ns_strides,
index_t M01,
index_t N01,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
: p_a_grid_{static_cast<const ADataType*>(p_a_grid)},
p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e_grid)},
a_grid_desc_m_k_{},
b_grid_desc_n_k_{},
ds_grid_desc_m_n_{},
e_grid_desc_m_n_{},
ds_grid_desc_g_m_n_{
DeviceOp::MakeDsGridDescriptor_G_M_N(ds_gs_ms_ns_lengths, ds_gs_ms_ns_strides)},
e_grid_desc_g_m_n_{
DeviceOp::MakeEGridDescriptor_G_M_N(e_gs_ms_ns_lengths, e_gs_ms_ns_strides)},
a_grid_desc_k0_m_k1_{},
b_grid_desc_k0_n_k1_{},
ds_grid_desc_mblock_mperblock_nblock_nperblock{},
e_grid_desc_mblock_mperblock_nblock_nperblock{},
block_2_ctile_map_{},
M01_{M01},
N01_{N01},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op},
a_mz_stride_{},
a_kz_stride_{},
b_nz_stride_{},
b_kz_stride_{},
ds_nz_stride_{},
e_nz_stride_{},
a_batch_stride_{a_gs_ms_ks_strides[NumDimG - 1]},
b_batch_stride_{b_gs_ns_ks_strides[NumDimG - 1]},
compute_ptr_offset_of_batch_{
a_batch_stride_, b_batch_stride_, ds_grid_desc_g_m_n_, e_grid_desc_g_m_n_}
{
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
// D pointer
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]);
});
a_grid_desc_m_k_ =
DeviceOp::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
b_grid_desc_n_k_ =
DeviceOp::MakeBGridDescriptor_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides);
ds_grid_desc_m_n_ =
DeviceOp::MakeDsGridDescriptor_M_N(ds_gs_ms_ns_lengths, ds_gs_ms_ns_strides);
e_grid_desc_m_n_ =
DeviceOp::MakeEGridDescriptor_M_N(e_gs_ms_ns_lengths, e_gs_ms_ns_strides);
a_grid_desc_k0_m_k1_ = DeviceOp::MakeAGridDescriptor_K0_M_K1(a_grid_desc_m_k_);
b_grid_desc_k0_n_k1_ = DeviceOp::MakeBGridDescriptor_K0_N_K1(b_grid_desc_n_k_);
block_2_ctile_map_ = GridwiseOp::MakeDefaultBlock2CTileMap(e_grid_desc_m_n_, M01, N01);
ds_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseOp::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n_);
e_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseOp::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n_);
// for sanity check of vector memory access
a_mz_stride_ = a_gs_ms_ks_strides[NumDimG + NumDimM - 1];
a_kz_stride_ = a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1];
b_nz_stride_ = b_gs_ns_ks_strides[NumDimG + NumDimN - 1];
b_kz_stride_ = b_gs_ns_ks_strides[NumDimG + NumDimN + NumDimK - 1];
for(index_t i = 0; i < NumDTensor; ++i)
{
ds_nz_stride_[i] = ds_gs_ms_ns_strides[i][NumDimG + NumDimM + NumDimN - 1];
}
e_nz_stride_ = e_gs_ms_ns_strides[NumDimG + NumDimM + NumDimN - 1];
}
// Pointers
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
typename GridwiseOp::DsGridPointer p_ds_grid_;
EDataType* p_e_grid_;
// Tensor Descriptors
AGridDesc_M_K a_grid_desc_m_k_;
BGridDesc_N_K b_grid_desc_n_k_;
DsGridDesc_M_N ds_grid_desc_m_n_;
EGridDesc_M_N e_grid_desc_m_n_;
DsGridDesc_G_M_N ds_grid_desc_g_m_n_;
EGridDesc_G_M_N e_grid_desc_g_m_n_;
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock;
typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock;
// Block to Tile mapping
typename GridwiseOp::DefaultBlock2CTileMap block_2_ctile_map_;
// Idle
index_t M01_;
index_t N01_;
// ElementwiseOp
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
// Strides for the last M/N/K dimensions of A/B/Ds/E
// for sanity check of vector load/store
index_t a_mz_stride_;
index_t a_kz_stride_;
index_t b_nz_stride_;
index_t b_kz_stride_;
std::array<index_t, NumDTensor> ds_nz_stride_;
index_t e_mz_stride_;
index_t e_nz_stride_;
index_t a_batch_stride_;
index_t b_batch_stride_;
// Batch Offset
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_;
};
// Invoker
struct Invoker : public BaseInvoker
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const index_t G = arg.e_grid_desc_g_m_n_.GetLength(I0);
const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * G;
const auto K =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
auto launch_kernel = [&](auto has_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value;
const auto kernel = kernel_contraction_multiple_d_wmma_cshuffle<
GridwiseOp,
ADataType,
BDataType,
typename GridwiseOp::DsGridPointer,
EDataType,
DeviceOp::AGridDesc_K0_M_K1,
DeviceOp::BGridDesc_K0_N_K1,
typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
ComputePtrOffsetOfStridedBatch,
typename GridwiseOp::DefaultBlock2CTileMap,
has_main_loop>;
return launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_ds_grid_,
arg.p_e_grid_,
G,
arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock,
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
arg.compute_ptr_offset_of_batch_,
arg.block_2_ctile_map_);
};
if(GridwiseOp::CalculateHasMainKBlockLoop(K))
{
return launch_kernel(integral_constant<bool, true>{});
}
else
{
return launch_kernel(integral_constant<bool, false>{});
}
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
static bool IsSupportedArgument(const Argument& arg)
{
if(ck::get_device_name() == "gfx1100")
{
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{
return false;
}
}
else
{
return false;
}
if(!GridwiseOp::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.ds_grid_desc_m_n_,
arg.e_grid_desc_m_n_,
arg.block_2_ctile_map_))
{
return false;
}
// check vector access
static_assert((ABlockTransferSrcVectorDim == 1 || ABlockTransferSrcVectorDim == 2) &&
(BBlockTransferSrcVectorDim == 1 || BBlockTransferSrcVectorDim == 2),
"wrong!");
// vector memory access of A: could be on M or AK1 dimension
if constexpr(ABlockTransferSrcVectorDim == 1)
{
if(!(arg.a_mz_stride_ == 1 &&
arg.a_grid_desc_k0_m_k1_.GetLength(I1) % ABlockTransferSrcScalarPerVector == 0))
{
return false;
}
}
else
{
if(!(arg.a_kz_stride_ == 1 &&
arg.a_grid_desc_k0_m_k1_.GetLength(I2) % ABlockTransferSrcScalarPerVector == 0))
{
return false;
}
}
// vector memory access of B: could be on N or BK1 dimension
if constexpr(BBlockTransferSrcVectorDim == 1)
{
if(!(arg.b_nz_stride_ == 1 &&
arg.b_grid_desc_k0_n_k1_.GetLength(I1) % BBlockTransferSrcScalarPerVector == 0))
{
return false;
}
}
else
{
if(!(arg.b_kz_stride_ == 1 &&
arg.b_grid_desc_k0_n_k1_.GetLength(I2) % BBlockTransferSrcScalarPerVector == 0))
{
return false;
}
}
// vector memory access of Ds: always on NPerBlock dimension
bool valid_d_access = true;
static_for<0, NumDTensor, 1>{}([&](auto i) {
if(!(arg.ds_nz_stride_[i] == 1 &&
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetLength(I3) %
CDEShuffleBlockTransferScalarPerVector_NPerBlock ==
0))
{
valid_d_access = false;
}
});
if(valid_d_access == false)
{
return false;
}
// vector memory access of E: always on NPerBlock dimension
if(!((arg.e_nz_stride_ == 1 &&
arg.e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I3) %
CDEShuffleBlockTransferScalarPerVector_NPerBlock ==
0) ||
CDEShuffleBlockTransferScalarPerVector_NPerBlock == 1))
{
return false;
}
return true;
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto
MakeArgument(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths,
const std::vector<index_t>& b_gs_ns_ks_strides,
const std::array<std::vector<index_t>, NumDTensor>& ds_gs_ms_ns_lengths,
const std::array<std::vector<index_t>, NumDTensor>& ds_gs_ms_ns_strides,
const std::vector<index_t>& e_gs_ms_ns_lengths,
const std::vector<index_t>& e_gs_ms_ns_strides,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
{
return Argument{p_a,
p_b,
p_ds,
p_e,
a_gs_ms_ks_lengths,
b_gs_ns_ks_lengths,
ds_gs_ms_ns_lengths,
e_gs_ms_ns_lengths,
a_gs_ms_ks_strides,
b_gs_ns_ks_strides,
ds_gs_ms_ns_strides,
e_gs_ms_ns_strides,
1,
1,
a_element_op,
b_element_op,
cde_element_op};
}
// polymorphic
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths,
const std::vector<index_t>& b_gs_ns_ks_strides,
const std::array<std::vector<index_t>, NumDTensor>& ds_gs_ms_ns_lengths,
const std::array<std::vector<index_t>, NumDTensor>& ds_gs_ms_ns_strides,
const std::vector<index_t>& e_gs_ms_ns_lengths,
const std::vector<index_t>& e_gs_ms_ns_strides,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) override
{
return std::make_unique<Argument>(p_a,
p_b,
p_ds,
p_e,
a_gs_ms_ks_lengths,
b_gs_ns_ks_lengths,
ds_gs_ms_ns_lengths,
e_gs_ms_ns_lengths,
a_gs_ms_ks_strides,
b_gs_ns_ks_strides,
ds_gs_ms_ns_strides,
e_gs_ms_ns_strides,
1,
1,
a_element_op,
b_element_op,
cde_element_op);
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
std::map<LoopScheduler, std::string> LoopSchedToString{
{LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}};
std::map<PipelineVersion, std::string> PipelineVersionToString{{PipelineVersion::v1, "v1"},
{PipelineVersion::v2, "v2"}};
// clang-format off
str << "DeviceBatchedContractionMultipleD_Wmma_CShuffle"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< K0PerBlock << ", "
<< K1 << ", "
<< MPerWMMA << ", "
<< NPerWMMA << ", "
<< MRepeat << ", "
<< NRepeat
<< ">"
<< " NumPrefetch: "
<< NumPrefetch << ", "
<< "LoopScheduler: "
<< LoopSchedToString[LoopSched] << ", "
<< "PipelineVersion: "
<< PipelineVersionToString[PipelineVer];
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename DsDataType,
typename EDataType,
typename AccDataType,
typename CShuffleDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
GemmSpecialization GemmSpec,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
ck::index_t K0PerBlock,
ck::index_t K1,
ck::index_t MPerWMMA,
ck::index_t NPerWMMA,
ck::index_t MRepeat,
ck::index_t NRepeat,
typename ABlockTransferThreadClusterLengths_K0_M_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
ck::index_t ABlockTransferSrcVectorDim,
ck::index_t ABlockTransferSrcScalarPerVector,
ck::index_t ABlockTransferDstScalarPerVector_K1,
bool ABlockLdsAddExtraM,
typename BBlockTransferThreadClusterLengths_K0_N_K1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
ck::index_t BBlockTransferSrcVectorDim,
ck::index_t BBlockTransferSrcScalarPerVector,
ck::index_t BBlockTransferDstScalarPerVector_K1,
bool BBlockLdsAddExtraN,
index_t CShuffleMRepeatPerShuffle,
index_t CShuffleNRepeatPerShuffle,
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
ck::index_t NumPrefetch = 1,
ck::LoopScheduler LoopSched = make_default_loop_scheduler(),
ck::PipelineVersion PipelineVer = ck::PipelineVersion::v1>
struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
{
using DeviceOp = DeviceGemmMultipleD_Wmma_CShuffle;
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
// K1 = Max Vector Access Pixels
static constexpr auto K1Number = Number<K1>{};
static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA)
{
assert(K % K1 == 0);
const index_t K0 = K / K1;
const auto a_grid_desc_m_k = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
}
#ifdef ENABLE_COLMAJOR
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
}
#endif
}();
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_right_pad_transform(M, PadM)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
}
static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB)
{
assert(K % K1 == 0);
const index_t K0 = K / K1;
const auto b_grid_desc_k_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
}
}();
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_right_pad_transform(N, PadN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
}
template <typename ELayout_>
static auto MakeEGridDescriptor_M_N(index_t M, index_t N, index_t StrideE)
{
const auto e_grid_desc_m_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ELayout_>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideE, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ELayout_>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideE));
}
}();
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor(
e_grid_desc_m_n,
make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
e_grid_desc_m_n,
make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}
static auto MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& Ms,
const std::array<index_t, NumDTensor>& Ns,
const std::array<index_t, NumDTensor>& DsStride)
{
return generate_tuple(
[&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(Ms[i], Ns[i], DsStride[i]);
},
Number<NumDTensor>{});
}
// Gridwise descriptor, mapping to whole given provblem.
using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1));
using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1));
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<ELayout>(1, 1, 1));
// GridwiseOp
using GridwiseOp = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle<
// DataType Family
ADataType,
BDataType,
AccDataType,
CShuffleDataType,
DsDataType,
EDataType,
// InMemory Data Descriptor
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
DsGridDesc_M_N,
EGridDesc_M_N,
// ElementwiseOp Family
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
InMemoryDataOperationEnum::Set,
// Tiling Family
MPerBlock,
NPerBlock,
K0PerBlock,
MPerWMMA,
NPerWMMA,
K1,
MRepeat,
NRepeat,
// ThreadCluster Family
BlockSize,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
false, // AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsAddExtraM,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
false, // BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEShuffleBlockTransferScalarPerVector_NPerBlock,
NumPrefetch,
LoopSched,
PipelineVer>;
// Argument
struct Argument : public BaseArgument
{
Argument(const void* p_a_grid,
const void* p_b_grid,
std::array<const void*, NumDTensor> p_ds_grid,
void* p_e_grid,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
std::array<index_t, NumDTensor> StrideDs,
index_t StrideE,
index_t M01,
index_t N01,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
: p_a_grid_{static_cast<const ADataType*>(p_a_grid)},
p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e_grid)},
a_grid_desc_k0_m_k1_{},
b_grid_desc_k0_n_k1_{},
ds_grid_desc_m_n_{},
e_grid_desc_m_n_{},
ds_grid_desc_mblock_mperblock_nblock_nperblock{},
e_grid_desc_mblock_mperblock_nblock_nperblock{},
block_2_ctile_map_{},
M01_{M01},
N01_{N01},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op}
{
a_grid_desc_k0_m_k1_ = DeviceOp::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
b_grid_desc_k0_n_k1_ = DeviceOp::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
// D pointer
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]);
// D desc
ds_grid_desc_m_n_(i) =
DeviceOp::MakeEGridDescriptor_M_N<DLayout>(M, N, StrideDs[i]);
});
e_grid_desc_m_n_ = DeviceOp::MakeEGridDescriptor_M_N<ELayout>(M, N, StrideE);
block_2_ctile_map_ = GridwiseOp::MakeDefaultBlock2CTileMap(e_grid_desc_m_n_, M01, N01);
if(GridwiseOp::CheckValidity(a_grid_desc_k0_m_k1_,
b_grid_desc_k0_n_k1_,
ds_grid_desc_m_n_,
e_grid_desc_m_n_,
block_2_ctile_map_))
{
ds_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseOp::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n_);
e_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseOp::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n_);
}
}
// Pointers
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
typename GridwiseOp::DsGridPointer p_ds_grid_;
EDataType* p_e_grid_;
// Tensor Descriptors
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
DsGridDesc_M_N ds_grid_desc_m_n_;
EGridDesc_M_N e_grid_desc_m_n_;
typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock;
typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock;
// Block to Tile mapping
typename GridwiseOp::DefaultBlock2CTileMap block_2_ctile_map_;
// Idle
index_t M01_;
index_t N01_;
// ElementwiseOp
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
};
// Invoker
struct Invoker : public BaseInvoker
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
#if 0
{
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0)
<< ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
<< arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0)
<< ", " << arg.c_grid_desc_m_n_.GetLength(I1) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I2) << "}" << std::endl;
}
#endif
if(!GridwiseOp::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.ds_grid_desc_m_n_,
arg.e_grid_desc_m_n_,
arg.block_2_ctile_map_))
{
throw std::runtime_error(
"wrong! GridwiseGemm_k0mk1_k0nk1_m0nm1_wmma_v1r1 has invalid setting");
}
const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.e_grid_desc_m_n_);
const auto K =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
float ave_time = 0;
if(GridwiseOp::CalculateHasMainKBlockLoop(K))
{
const auto kernel = kernel_gemm_mupltipe_d_wmma_cshuffle<
GridwiseOp,
ADataType,
BDataType,
typename GridwiseOp::DsGridPointer,
EDataType,
remove_reference_t<typename DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<typename DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
remove_reference_t<
typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
remove_reference_t<typename GridwiseOp::DefaultBlock2CTileMap>,
true>; // Last Option is W/O
ave_time =
launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_ds_grid_,
arg.p_e_grid_,
arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock,
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
arg.block_2_ctile_map_);
}
else
{
const auto kernel = kernel_gemm_mupltipe_d_wmma_cshuffle<
GridwiseOp,
ADataType,
BDataType,
typename GridwiseOp::DsGridPointer,
EDataType,
remove_reference_t<typename DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<typename DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
remove_reference_t<
typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
remove_reference_t<typename GridwiseOp::DefaultBlock2CTileMap>,
false>;
ave_time =
launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_ds_grid_,
arg.p_e_grid_,
arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock,
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
arg.block_2_ctile_map_);
}
return ave_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
static bool IsSupportedArgument(const Argument& arg)
{
if(ck::get_device_name() == "gfx1100")
{
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{
return false;
}
}
else
{
return false;
}
return GridwiseOp::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.ds_grid_desc_m_n_,
arg.e_grid_desc_m_n_,
arg.block_2_ctile_map_);
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
std::array<ck::index_t, NumDTensor> StrideDs,
index_t StrideE,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
{
return Argument{p_a,
p_b,
p_ds,
p_e,
M,
N,
K,
StrideA,
StrideB,
StrideDs,
StrideE,
1,
1,
a_element_op,
b_element_op,
cde_element_op};
}
// polymorphic
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
std::array<ck::index_t, NumDTensor> StrideDs,
index_t StrideE,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) override
{
return std::make_unique<Argument>(p_a,
p_b,
p_ds,
p_e,
M,
N,
K,
StrideA,
StrideB,
StrideDs,
StrideE,
1,
1,
a_element_op,
b_element_op,
cde_element_op);
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
std::map<LoopScheduler, std::string> LoopSchedToString{
{LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}};
std::map<PipelineVersion, std::string> PipelineVersionToString{{PipelineVersion::v1, "v1"},
{PipelineVersion::v2, "v2"}};
// clang-format off
str << "DeviceGemmMultipleD_Wmma_CShuffle"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< K0PerBlock << ", "
<< K1 << ", "
<< MPerWMMA << ", "
<< NPerWMMA << ", "
<< MRepeat << ", "
<< NRepeat
<< ">"
<< " NumPrefetch: "
<< NumPrefetch << ", "
<< "LoopScheduler: "
<< LoopSchedToString[LoopSched] << ", "
<< "PipelineVersion: "
<< PipelineVersionToString[PipelineVer];
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <functional>
#include <iostream>
#include <iterator>
#include <numeric>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/io.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace {
template <index_t NumDTensor>
struct ComputePtrOffsetOfStridedBatch
{
ComputePtrOffsetOfStridedBatch() = default;
ComputePtrOffsetOfStridedBatch(index_t BatchStrideA,
index_t BatchStrideB,
Array<ck::index_t, NumDTensor> BatchStrideDs,
index_t BatchStrideE)
: BatchStrideA_(BatchStrideA),
BatchStrideB_(BatchStrideB),
BatchStrideDs_(BatchStrideDs),
BatchStrideE_(BatchStrideE)
{
}
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideA_);
}
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideB_);
}
__host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const
{
Array<long_index_t, NumDTensor> ds_offset;
static_for<0, NumDTensor, 1>{}(
[&](auto i) { ds_offset(i) = g_idx * static_cast<long_index_t>(BatchStrideDs_[i]); });
return ds_offset;
}
__host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideE_);
}
index_t BatchStrideA_;
index_t BatchStrideB_;
Array<ck::index_t, NumDTensor> BatchStrideDs_;
index_t BatchStrideE_;
};
} // namespace
//
// @brief Device Convolution operation.
//
// Supports:
// @li Forward convolution with up to 3 spatial dimentions
// @li Input tensor in GNWC data format
// @li Weight tensor in GKXC data format
// @li Output tensor in GNWK data format
//
// 1D:
// out[N, Wo, K] = in[N, Wi, C] * wei[K, X, C]
// 2D:
// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
// 3D:
// out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C]
// Assume:
// AK1 == BK1
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename DsDataType,
typename EDataType,
typename AccDataType,
typename CShuffleDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
ConvolutionForwardSpecialization ConvForwardSpecialization,
GemmSpecialization GemmSpec,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
ck::index_t K0PerBlock,
ck::index_t K1,
ck::index_t MPerWMMA,
ck::index_t NPerWMMA,
ck::index_t MRepeat,
ck::index_t NRepeat,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BBlockLdsExtraN,
index_t CShuffleMRepeatPerShuffle,
index_t CShuffleNRepeatPerShuffle,
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
index_t NumGemmKPrefetchStage = 1,
LoopScheduler LoopSched = make_default_loop_scheduler(),
ck::PipelineVersion PipelineVer = ck::PipelineVersion::v1>
struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
: public DeviceGroupedConvFwdMultipleD<NDimSpatial,
ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
{
using DeviceOp = DeviceGroupedConvFwdMultipleD_Wmma_CShuffle;
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr index_t KPerBlock = K0PerBlock * K1;
static constexpr auto conv_to_gemm_transformer =
TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>{};
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
template <typename ALay>
static auto
MakeAGridDescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads)
{
const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>(a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads);
const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
return in_gemmm_gemmk_desc;
}
template <typename BLay>
static auto
MakeBGridDescriptor_N_K(const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides)
{
const auto wei_gemmnraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>(b_g_k_c_xs_lengths,
b_g_k_c_xs_strides);
const auto wei_gemmn_gemmk_desc =
matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc);
return wei_gemmn_gemmk_desc;
}
template <typename ELay>
static auto
MakeEGridDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides)
{
const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(e_g_n_k_wos_lengths,
e_g_n_k_wos_strides);
const auto out_gemmm_gemmn_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
return out_gemmm_gemmn_desc;
}
static auto MakeDsGridDescriptor_M_N(
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_lengths,
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides)
{
return generate_tuple(
[&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(ds_g_n_k_wos_lengths[i],
ds_g_n_k_wos_strides[i]);
},
Number<NumDTensor>{});
}
// desc for problem definition
using AGridDesc_M_K = remove_cvref_t<decltype(
MakeAGridDescriptor_M_K<ALayout>({}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>;
using BGridDesc_N_K = remove_cvref_t<decltype(MakeBGridDescriptor_N_K<BLayout>({}, {}))>;
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}))>;
using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>({}, {}))>;
// A desc for source in blockwise copy
template <typename AGridDesc_M_K>
__host__ __device__ static constexpr auto
MakeAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k)
{
const auto M = a_grid_desc_m_k.GetLength(I0);
const auto K = a_grid_desc_m_k.GetLength(I1);
const auto AK1 = K1;
const auto AK0 = K / AK1;
return transform_tensor_descriptor(a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
// B desc for source in blockwise copy
template <typename BGridDesc_N_K>
__host__ __device__ static constexpr auto
MakeBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k)
{
const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = b_grid_desc_n_k.GetLength(I1);
const auto BK1 = K1;
const auto BK0 = K / BK1;
return transform_tensor_descriptor(b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
using AGridDesc_AK0_M_AK1 = decltype(DeviceOp::MakeAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}));
using BGridDesc_BK0_N_BK1 = decltype(DeviceOp::MakeBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}));
// GridwiseOp
using GridwiseOp = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle<
// DataType Family
ADataType,
BDataType,
AccDataType,
CShuffleDataType,
DsDataType,
EDataType,
// InMemory Data Descriptor
AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
DsGridDesc_M_N,
EGridDesc_M_N,
// ElementwiseOp Family
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
InMemoryDataOperationEnum::Set,
// Tiling Family
MPerBlock,
NPerBlock,
K0PerBlock,
MPerWMMA,
NPerWMMA,
K1,
MRepeat,
NRepeat,
// ThreadCluster Family
BlockSize,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEShuffleBlockTransferScalarPerVector_NPerBlock,
NumGemmKPrefetchStage,
LoopSched,
PipelineVer>;
// Argument
struct Argument : public BaseArgument
{
Argument(const void* p_a,
const void* p_b,
const std::array<const void*, NumDTensor>& p_ds,
void* p_e,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_lengths,
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
index_t M01,
index_t N01,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op)
: p_a_grid_{static_cast<const ADataType*>(p_a)},
p_b_grid_{static_cast<const BDataType*>(p_b)},
p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e)},
num_group_{a_g_n_c_wis_lengths[0]},
a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K<ALayout>(a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads)},
b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K<BLayout>(b_g_k_c_xs_lengths,
b_g_k_c_xs_strides)},
ds_grid_desc_m_n_{},
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(e_g_n_k_wos_lengths,
e_g_n_k_wos_strides)},
a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)},
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)},
ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
e_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_etile_map_{GridwiseOp::MakeDefaultBlock2CTileMap(e_grid_desc_m_n_, M01, N01)},
compute_ptr_offset_of_batch_{},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op},
a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths},
a_g_n_c_wis_strides_{a_g_n_c_wis_strides},
b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
b_g_k_c_xs_strides_{b_g_k_c_xs_strides},
ds_g_n_k_wos_lengths_{ds_g_n_k_wos_lengths},
ds_g_n_k_wos_strides_{ds_g_n_k_wos_strides},
e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths},
e_g_n_k_wos_strides_{e_g_n_k_wos_strides},
conv_filter_strides_{conv_filter_strides},
conv_filter_dilations_{conv_filter_dilations},
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads}
{
// A/B/E Batch Stride
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_c_wis_strides[0];
compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0];
compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_k_wos_strides[0];
// populate pointer, batch stride, desc for Ds
static_for<0, NumDTensor, 1>{}([&](auto i) {
// using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
// D pointer
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds[i]);
// D batch stride
compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0];
});
// D desc
ds_grid_desc_m_n_ =
DeviceOp::MakeDsGridDescriptor_M_N(ds_g_n_k_wos_lengths, ds_g_n_k_wos_strides);
// populate desc for Ds/E
e_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseOp::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n_);
ds_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseOp::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n_);
}
void Print() const
{
std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl;
std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl;
static_for<0, NumDTensor, 1>{}(
[&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; });
std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl;
}
// private:
// pointers
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
typename GridwiseOp::DsGridPointer p_ds_grid_;
EDataType* p_e_grid_;
// tensor descriptors for problem definiton
index_t num_group_;
AGridDesc_M_K a_grid_desc_m_k_;
BGridDesc_N_K b_grid_desc_n_k_;
DsGridDesc_M_N ds_grid_desc_m_n_;
EGridDesc_M_N e_grid_desc_m_n_;
// tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_;
// block-to-e-tile map
typename GridwiseOp::DefaultBlock2CTileMap block_2_etile_map_;
// for computing batch offset
ComputePtrOffsetOfStridedBatch<NumDTensor> compute_ptr_offset_of_batch_;
// element-wise op
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
// for checking IsSupportedArgument()
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_;
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_;
std::array<index_t, NDimSpatial> conv_filter_strides_;
std::array<index_t, NDimSpatial> conv_filter_dilations_;
std::array<index_t, NDimSpatial> input_left_pads_;
std::array<index_t, NDimSpatial> input_right_pads_;
};
// Invoker
struct Invoker : public BaseInvoker
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(stream_config.log_level_ > 0)
{
arg.Print();
}
const index_t grid_size =
arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * arg.num_group_;
const auto K =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
auto launch_kernel = [&](auto has_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value;
const auto kernel = kernel_grouped_conv_fwd_multiple_d_wmma_cshuffle<
GridwiseOp,
ADataType,
BDataType,
typename GridwiseOp::DsGridPointer,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
remove_reference_t<typename GridwiseOp::DefaultBlock2CTileMap>,
ComputePtrOffsetOfStridedBatch<NumDTensor>,
has_main_loop>;
return launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_ds_grid_,
arg.p_e_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
arg.a_g_n_c_wis_lengths_[0], // Group count
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_etile_map_,
arg.compute_ptr_offset_of_batch_);
};
if(GridwiseOp::CalculateHasMainKBlockLoop(K))
{
return launch_kernel(integral_constant<bool, true>{});
}
else
{
return launch_kernel(integral_constant<bool, false>{});
}
}
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static bool IsSupportedArgument(const Argument& arg)
{
namespace ctc = tensor_layout::convolution;
// check device
if(get_device_name() == "gfx1100")
{
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{
return false;
}
}
else
{
return false;
}
// check ConvolutionForwardSpecialization
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{
// check if it's 1x1, stride=1 conv
for(index_t i = 0; i < NDimSpatial; ++i)
{
const index_t X = arg.b_g_k_c_xs_lengths_[i + 2];
const index_t ConvStride = arg.conv_filter_strides_[i];
const index_t LeftPad = arg.input_left_pads_[i];
const index_t RightPad = arg.input_right_pads_[i];
if(!(X == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0))
{
return false;
}
}
}
else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization::Filter1x1Pad0)
{
// check if it's 1x1 conv
for(index_t i = 0; i < NDimSpatial; ++i)
{
const index_t X = arg.b_g_k_c_xs_lengths_[i + 2];
const index_t LeftPad = arg.input_left_pads_[i];
const index_t RightPad = arg.input_right_pads_[i];
if(!(X == 1 && LeftPad == 0 && RightPad == 0))
{
return false;
}
}
}
// check vector access of A
// FIXME: layout
if constexpr(is_same_v<ALayout, ctc::G_NW_C> || is_same_v<ALayout, ctc::G_NHW_C> ||
is_same_v<ALayout, ctc::G_NDHW_C> || is_same_v<ALayout, ctc::GNWC> ||
is_same_v<ALayout, ctc::GNHWC> || is_same_v<ALayout, ctc::GNDHWC> ||
is_same_v<ALayout, ctc::NWGC> || is_same_v<ALayout, ctc::NHWGC> ||
is_same_v<ALayout, ctc::NDHWGC>)
{
const index_t C = arg.a_g_n_c_wis_lengths_[2];
if(!(ABlockTransferSrcVectorDim == 2 && C % ABlockTransferSrcScalarPerVector == 0))
{
return false;
}
}
else
{
return false;
}
// check vector access of B
// FIXME: layout
if constexpr(is_same_v<BLayout, ctc::G_K_X_C> || is_same_v<BLayout, ctc::G_K_YX_C> ||
is_same_v<BLayout, ctc::G_K_ZYX_C> || is_same_v<BLayout, ctc::GKXC> ||
is_same_v<BLayout, ctc::GKYXC> || is_same_v<BLayout, ctc::GKZYXC> ||
is_same_v<BLayout, ctc::KXGC> || is_same_v<BLayout, ctc::KYXGC> ||
is_same_v<BLayout, ctc::KZYXGC>)
{
const index_t C = arg.b_g_k_c_xs_lengths_[2];
if(!(BBlockTransferSrcVectorDim == 2 && C % BBlockTransferSrcScalarPerVector == 0))
{
return false;
}
}
else
{
return false;
}
// check vector access of Ds
bool valid = true;
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
// FIXME: layout
if constexpr(is_same_v<DLayout, ctc::G_NW_K> || is_same_v<DLayout, ctc::G_NHW_K> ||
is_same_v<DLayout, ctc::G_NDHW_K> || is_same_v<DLayout, ctc::GNWK> ||
is_same_v<DLayout, ctc::GNHWK> || is_same_v<DLayout, ctc::GNDHWK> ||
is_same_v<DLayout, ctc::NWGK> || is_same_v<DLayout, ctc::NHWGK> ||
is_same_v<DLayout, ctc::NDHWGK> || is_same_v<DLayout, ctc::GK> ||
is_same_v<DLayout, ctc::G_K>)
{
const index_t K = arg.ds_g_n_k_wos_lengths_[i][2];
if(!(K % CDEShuffleBlockTransferScalarPerVector_NPerBlock == 0))
{
valid = false;
}
}
else
{
valid = false;
}
});
if(!valid)
{
return false;
}
// check vector access of E
if constexpr(is_same_v<ELayout, ctc::G_NW_K> || is_same_v<ELayout, ctc::G_NHW_K> ||
is_same_v<ELayout, ctc::G_NDHW_K> || is_same_v<ELayout, ctc::GNWK> ||
is_same_v<ELayout, ctc::GNHWK> || is_same_v<ELayout, ctc::GNDHWK> ||
is_same_v<ELayout, ctc::NWGK> || is_same_v<ELayout, ctc::NHWGK> ||
is_same_v<ELayout, ctc::NDHWGK>)
{
const index_t K = arg.e_g_n_k_wos_lengths_[2];
if(!(K % CDEShuffleBlockTransferScalarPerVector_NPerBlock == 0))
{
return false;
}
}
else
{
return false;
}
// check Gridwise GEMM
return GridwiseOp::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.ds_grid_desc_m_n_,
arg.e_grid_desc_m_n_,
arg.block_2_etile_map_);
}
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(
const void* p_a,
const void* p_b,
const std::array<const void*, NumDTensor>& p_ds,
void* p_e,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_lengths,
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op)
{
return Argument{p_a,
p_b,
p_ds,
p_e,
a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
ds_g_n_k_wos_lengths,
ds_g_n_k_wos_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
1,
1,
a_element_op,
b_element_op,
cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
std::unique_ptr<BaseArgument> MakeArgumentPointer(
const void* p_a,
const void* p_b,
const std::array<const void*, NumDTensor>& p_ds,
void* p_e,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_lengths,
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op) override
{
return std::make_unique<Argument>(p_a,
p_b,
p_ds,
p_e,
a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
ds_g_n_k_wos_lengths,
ds_g_n_k_wos_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
1,
1,
a_element_op,
b_element_op,
cde_element_op);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceGroupedConvFwdMultipleD_Wmma_CShuffle"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< KPerBlock << ", "
<< getConvForwardSpecializationString(ConvForwardSpecialization)
<< ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -10,46 +10,11 @@
#include "ck/tensor_operation/gpu/device/device_normalization.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_normalization_welford_variance.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_normalization_selector.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
template <typename GridwiseReduction,
typename XDataType,
typename GammaDataType,
typename BetaDataType,
typename YDataType,
typename AccDataType,
typename AccElementwiseOperation,
typename GridDesc_M_K>
__global__ void kernel_normalization(const GridDesc_M_K x_grid_desc_m_k,
const GridDesc_M_K gamma_grid_desc_m_k,
const GridDesc_M_K beta_grid_desc_m_k,
const GridDesc_M_K y_grid_desc_m_k,
index_t num_k_block_tile_iteration,
AccDataType epsilon,
const XDataType* const __restrict__ p_x_global,
const GammaDataType* const __restrict__ p_gamma_global,
const BetaDataType* const __restrict__ p_beta_global,
YDataType* const __restrict__ p_y_global,
const AccElementwiseOperation acc_elementwise_op)
{
GridwiseReduction::Run(x_grid_desc_m_k,
gamma_grid_desc_m_k,
beta_grid_desc_m_k,
y_grid_desc_m_k,
num_k_block_tile_iteration,
epsilon,
p_x_global,
p_gamma_global,
p_beta_global,
p_y_global,
acc_elementwise_op);
};
} // namespace ck
namespace ck {
namespace tensor_operation {
namespace device {
......@@ -58,9 +23,9 @@ namespace device {
template <typename XDataType,
typename GammaDataType,
typename BetaDataType,
typename AccDataType,
typename ComputeDataType,
typename YDataType,
typename AccElementwiseOperation,
typename YElementwiseOperation,
index_t Rank,
index_t NumReduceDim,
index_t BlockSize,
......@@ -74,16 +39,18 @@ template <typename XDataType,
index_t GammaSrcVectorSize,
index_t BetaSrcVectorDim,
index_t BetaSrcVectorSize,
index_t YDstVectorSize>
index_t YDstVectorSize,
bool UseWelford = true>
struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
GammaDataType,
BetaDataType,
AccDataType,
ComputeDataType,
YDataType,
AccElementwiseOperation,
YElementwiseOperation,
Rank,
NumReduceDim>
{
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize);
static_assert(
((GammaSrcVectorDim == 0 && MThreadSliceSize % GammaSrcVectorSize == 0) ||
(GammaSrcVectorDim == 1 && KThreadSliceSize % GammaSrcVectorSize == 0)),
......@@ -167,51 +134,6 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
using GridDesc_M_K = decltype(MakeSrc2dDescriptor({1}, {1}, 1, 1));
using GridwiseReduceLayernormGeneric =
GridwiseNormalizationWelfordVariance_mk_to_mk<XDataType,
GammaDataType,
BetaDataType,
YDataType,
AccDataType,
AccElementwiseOperation,
GridDesc_M_K,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
XYSrcVectorDim,
XSrcVectorSize,
GammaSrcVectorDim,
GammaSrcVectorSize,
BetaSrcVectorDim,
BetaSrcVectorSize,
XYSrcVectorDim,
YDstVectorSize,
false>;
using GridwiseNormalizationSweepOnce =
GridwiseNormalizationWelfordVariance_mk_to_mk<XDataType,
GammaDataType,
BetaDataType,
YDataType,
AccDataType,
AccElementwiseOperation,
GridDesc_M_K,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
XYSrcVectorDim,
XSrcVectorSize,
GammaSrcVectorDim,
GammaSrcVectorSize,
BetaSrcVectorDim,
BetaSrcVectorSize,
XYSrcVectorDim,
YDstVectorSize,
true>;
struct Argument : public BaseArgument
{
Argument(const std::vector<index_t> lengths,
......@@ -220,7 +142,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
const std::vector<index_t> betaStrides,
const std::vector<index_t> yStrides,
const std::vector<index_t> reduceDims,
AccElementwiseOperation acc_elementwise_op,
YElementwiseOperation y_elementwise_op,
double epsilon,
const XDataType* p_x,
const GammaDataType* p_gamma,
......@@ -230,9 +152,9 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
p_gamma_(p_gamma),
p_beta_(p_beta),
p_y_(p_y),
acc_elementwise_op_(acc_elementwise_op)
y_elementwise_op_(y_elementwise_op)
{
epsilon_ = static_cast<AccDataType>(epsilon);
epsilon_ = static_cast<ComputeDataType>(epsilon);
Lengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(lengths, reduceDims);
xStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(xStrides, reduceDims);
......@@ -265,7 +187,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
x_grid_desc_m_k_.GetLength(Number<1>{}) <= KThreadClusterSize * KThreadSliceSize;
}
AccDataType epsilon_;
ComputeDataType epsilon_;
const XDataType* p_x_;
const GammaDataType* p_gamma_;
......@@ -278,7 +200,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
std::vector<index_t> betaStrides_;
std::vector<index_t> yStrides_;
AccElementwiseOperation acc_elementwise_op_;
YElementwiseOperation y_elementwise_op_;
int blkGroupSize_;
int numBlockTileIteration_;
......@@ -295,23 +217,27 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const auto kernel_main = arg.isSweeponce_
? kernel_normalization<GridwiseNormalizationSweepOnce,
XDataType,
GammaDataType,
BetaDataType,
YDataType,
AccDataType,
AccElementwiseOperation,
GridDesc_M_K>
: kernel_normalization<GridwiseReduceLayernormGeneric,
XDataType,
GammaDataType,
BetaDataType,
YDataType,
AccDataType,
AccElementwiseOperation,
GridDesc_M_K>;
auto kernel_main = NormalizationKernelSelector<XDataType,
GammaDataType,
BetaDataType,
YDataType,
ComputeDataType,
YElementwiseOperation,
GridDesc_M_K,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
XYSrcVectorDim,
XSrcVectorSize,
GammaSrcVectorDim,
GammaSrcVectorSize,
BetaSrcVectorDim,
BetaSrcVectorSize,
XYSrcVectorDim,
YDstVectorSize,
UseWelford>(arg.isSweeponce_);
float avg_time = 0;
avg_time += launch_and_time_kernel(stream_config,
......@@ -329,7 +255,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
arg.p_gamma_,
arg.p_beta_,
arg.p_y_,
arg.acc_elementwise_op_);
arg.y_elementwise_op_);
return (avg_time);
};
......@@ -429,7 +355,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
void* p_y,
void* p_saveMean,
void* p_saveInvVar,
AccElementwiseOperation acc_elementwise_op) override
YElementwiseOperation y_elementwise_op) override
{
// TODO
// Optional cache of the intermediate results (mean and InvVariance) during the
......@@ -443,7 +369,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
betaStrides,
yStrides,
reduceDims,
acc_elementwise_op,
y_elementwise_op,
epsilon,
static_cast<const XDataType*>(p_x),
static_cast<const GammaDataType*>(p_gamma),
......@@ -462,8 +388,8 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
// clang-format off
str << "DeviceNormalizationImpl<" << BlockSize << ",";
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
str << "Cluster_MK_" << MThreadClusterSize << "_" << KThreadClusterSize << ",";
str << "Slice_MK_" << MThreadSliceSize << "_" << KThreadSliceSize << ",";
str << "XYSrcVectorDim_" << XYSrcVectorDim << ",";
str << "VectorSize_X" << XSrcVectorSize << "_Gamma" << GammaSrcVectorSize << "_Beta" << BetaSrcVectorSize << "_Y" << YDstVectorSize << ">";
// clang-format on
......
......@@ -150,6 +150,13 @@ struct Bilinear
template <typename Y, typename X0, typename X1>
__host__ __device__ constexpr void operator()(Y&, const X0&, const X1&) const;
template <>
__host__ __device__ constexpr void
operator()<double, double, double>(double& y, const double& x0, const double& x1) const
{
y = alpha_ * x0 + beta_ * x1;
};
template <>
__host__ __device__ constexpr void
operator()<float, float, float>(float& y, const float& x0, const float& x1) const
......
......@@ -95,6 +95,12 @@ struct Scale
y = scale_ * x;
};
template <>
__host__ __device__ void operator()<double, double>(double& y, const double& x) const
{
y = scale_ * x;
};
float scale_;
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment