Commit d0b49a14 authored by Qianfeng Zhang's avatar Qianfeng Zhang
Browse files

Merge branch 'develop' into bnorm_bwd_pr

parents 29026b0e 87fd1152
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/utility/reduction_enums.hpp" #include "ck/utility/reduction_enums.hpp"
#include "ck/tensor_operation/gpu/device/device_layernorm_impl.hpp" #include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" #include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/library/utility/fill.hpp" #include "ck/library/utility/fill.hpp"
...@@ -47,34 +47,34 @@ struct YElementOp ...@@ -47,34 +47,34 @@ struct YElementOp
}; };
using DeviceInstance = using DeviceInstance =
ck::tensor_operation::device::DeviceLayernormImpl<XDataType, ck::tensor_operation::device::DeviceNormalizationImpl<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
AccDataType, AccDataType,
YDataType, YDataType,
YElementOp, YElementOp,
Rank, Rank,
NumReduceDim, NumReduceDim,
256, // BlockSize 1024, // BlockSize
8, // ClusterM 1, // ClusterM
32, // ClusterK 1024, // ClusterK
1, // SliceM 1, // SliceM
8, // SliceK 32, // SliceK
1, // SrcVecDim (0=M, 1=K) 1, // SrcVecDim (0=M, 1=K)
8, // SrcScalarPerVector 2, // SrcScalarPerVector
1, // GammaVecDim (0=M, 1=K) 1, // GammaVecDim (0=M, 1=K)
8, // GammaScalarPerVector 2, // GammaScalarPerVector
1, // BetaVecDim (0=M, 1=K) 1, // BetaVecDim (0=M, 1=K)
8, // BetaScalarPerVector 2, // BetaScalarPerVector
8>; // OutScalarPerVector 2>; // OutScalarPerVector
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
ck::index_t N = 128; ck::index_t N = 2;
ck::index_t H = 16; ck::index_t H = 32;
ck::index_t W = 16; ck::index_t W = 32;
ck::index_t G = 32; ck::index_t G = 32;
ck::index_t C = 40; ck::index_t C = 30;
if(argc == 1) if(argc == 1)
{ {
......
add_example_executable(example_splitk_gemm_bias_e_permute_xdl_fp16 splitk_gemm_bias_e_permute_xdl_fp16.cpp)
add_example_executable(example_splitk_gemm_bias_e_permute_xdl_fp32 splitk_gemm_bias_e_permute_xdl_fp32.cpp)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Add = ck::tensor_operation::element_wise::Add;
using ADataType = F16;
using BDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F16;
using DDataType = F16;
using DsDataType = ck::Tuple<DDataType>;
using EDataType = F16;
static constexpr ck::index_t NumDimG = 2;
static constexpr ck::index_t NumDimM = 2;
static constexpr ck::index_t NumDimN = 2;
static constexpr ck::index_t NumDimK = 1;
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CDEElementOp = ck::tensor_operation::element_wise::Add;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto ABSpec = ck::tensor_operation::device::TensorSpecialization::Packed;
static constexpr auto DESpec = ck::tensor_operation::device::TensorSpecialization::Default;
// clang-format off
using DeviceOpInstanceKKNN = ck::tensor_operation::device::
//############################################| NumDimG| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| Gemm| A| B| DE| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle|CBlockTransferClusterLengths| CBlockTransfer|
//############################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Spacialization| Spacialization| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//############################################| | | | | | | | | | | Operation| Operation| Operation| | | | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceSplitKContractionMultipleD_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, F16, F16, F32, F16, DsDataType, F16, AElementOp, BElementOp, CDEElementOp, GemmSpec, ABSpec, ABSpec, DESpec, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>;
// clang-format on
using DeviceOpInstance = DeviceOpInstanceKKNN;
// hardcoded for NumDimM == NumDimN == NumDimK == 2
template <ck::index_t NumDimG,
ck::index_t NumDimM,
ck::index_t NumDimN,
ck::index_t NumDimK,
typename ADataType,
typename BDataType,
typename EDataType,
typename AccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
ck::enable_if_t<NumDimG == 2 && NumDimM == 2 && NumDimN == 2 && NumDimK == 1, bool> =
false>
struct ReferenceContraction_G2_M2_N2_K1 : public ck::tensor_operation::device::BaseOperator
{
// Argument
struct Argument : public ck::tensor_operation::device::BaseArgument
{
Argument(const Tensor<ADataType>& a_gs_ms_ks,
const Tensor<BDataType>& b_gs_ns_ks,
Tensor<EDataType>& e_gs_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
: a_gs_ms_ks_{a_gs_ms_ks},
b_gs_ns_ks_{b_gs_ns_ks},
e_gs_ms_ns_{e_gs_ms_ns},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op}
{
}
const Tensor<ADataType>& a_gs_ms_ks_;
const Tensor<BDataType>& b_gs_ns_ks_;
Tensor<EDataType>& e_gs_ms_ns_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
};
// Invoker
struct Invoker : public ck::tensor_operation::device::BaseInvoker
{
using Argument = ReferenceContraction_G2_M2_N2_K1::Argument;
float Run(const Argument& arg)
{
auto f_ms_ns = [&](auto g0, auto g1, auto m0, auto m1, auto n0, auto n1) {
const int K0 = arg.a_gs_ms_ks_.mDesc.GetLengths()[4];
AccDataType v_acc = 0;
for(int k0 = 0; k0 < K0; ++k0)
{
AccDataType v_a;
AccDataType v_b;
arg.a_element_op_(
v_a,
ck::type_convert<const AccDataType>(arg.a_gs_ms_ks_(g0, g1, m0, m1, k0)));
arg.b_element_op_(
v_b,
ck::type_convert<const AccDataType>(arg.b_gs_ns_ks_(g0, g1, n0, n1, k0)));
v_acc += v_a * v_b;
}
AccDataType v_c;
arg.cde_element_op_(v_c, v_acc);
arg.e_gs_ms_ns_(g0, g1, m0, m1, n0, n1) = v_c;
};
make_ParallelTensorFunctor(f_ms_ns,
arg.e_gs_ms_ns_.mDesc.GetLengths()[0],
arg.e_gs_ms_ns_.mDesc.GetLengths()[1],
arg.e_gs_ms_ns_.mDesc.GetLengths()[2],
arg.e_gs_ms_ns_.mDesc.GetLengths()[3],
arg.e_gs_ms_ns_.mDesc.GetLengths()[4],
arg.e_gs_ms_ns_.mDesc.GetLengths()[5])(
std::thread::hardware_concurrency());
return 0;
}
float Run(const ck::tensor_operation::device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override
{
return true;
}
static auto MakeArgument(const Tensor<ADataType>& a_gs_ms_ks,
const Tensor<BDataType>& b_gs_ns_ks,
Tensor<EDataType>& e_gs_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
{
return Argument{
a_gs_ms_ks, b_gs_ns_ks, e_gs_ms_ns, a_element_op, b_element_op, cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<ck::tensor_operation::device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceContraction_G2_M2_N2_K1"
<< std::endl;
// clang-format on
return str.str();
}
};
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
int split_k = 1;
ck::index_t G0 = 1;
ck::index_t G1 = 2;
ck::index_t M0 = 4;
ck::index_t M1 = 256;
ck::index_t N0 = 16;
ck::index_t N1 = 128;
ck::index_t K0 = 64 * 2;
// A[G0, G1, M0, M1, K0]
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M0, M1, K0};
std::vector<ck::index_t> a_gs_ms_ks_strides{G1 * M0 * M1 * K0, M0 * M1 * K0, M1 * K0, K0, 1};
// B[G0, G1, N0, N1, K0]
std::vector<ck::index_t> b_gs_ns_ks_lengths{G0, G1, N0, N1, K0};
std::vector<ck::index_t> b_gs_ns_ks_strides{G1 * N0 * N1 * K0, N0 * N1 * K0, N1 * K0, K0, 1};
// D[G0, G1, M0, N0, M1, N1]
std::vector<ck::index_t> d_gs_ms_ns_lengths{G0, G1, M0, M1, N0, N1};
std::vector<ck::index_t> d_gs_ms_ns_strides{G1 * N0 * N1, N0 * N1, 0, 0, N1, 1};
// E[G0, G1, M0, N0, M1, N1]
std::vector<ck::index_t> e_gs_ms_ns_lengths{G0, G1, M0, M1, N0, N1};
std::vector<ck::index_t> e_gs_ms_ns_strides{
G1 * M0 * N0 * M1 * N1, M0 * N0 * M1 * N1, N0 * M1 * N1, N1, M1 * N1, 1};
if(argc == 1)
{
// use default case
}
else if(argc == 5)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
split_k = std::stoi(argv[4]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
exit(0);
}
Tensor<ADataType> a_gs_ms_ks(
std::vector<std::size_t>(a_gs_ms_ks_lengths.begin(), a_gs_ms_ks_lengths.end()),
std::vector<std::size_t>(a_gs_ms_ks_strides.begin(), a_gs_ms_ks_strides.end()));
Tensor<BDataType> b_gs_ns_ks(
std::vector<std::size_t>(b_gs_ns_ks_lengths.begin(), b_gs_ns_ks_lengths.end()),
std::vector<std::size_t>(b_gs_ns_ks_strides.begin(), b_gs_ns_ks_strides.end()));
Tensor<DDataType> d_gs_ms_ns(
std::vector<std::size_t>(d_gs_ms_ns_lengths.begin(), d_gs_ms_ns_lengths.end()),
std::vector<std::size_t>(d_gs_ms_ns_strides.begin(), d_gs_ms_ns_strides.end()));
Tensor<EDataType> e_gs_ms_ns_host_result(
std::vector<std::size_t>(e_gs_ms_ns_lengths.begin(), e_gs_ms_ns_lengths.end()),
std::vector<std::size_t>(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end()));
Tensor<EDataType> e_gs_ms_ns_device_result(
std::vector<std::size_t>(e_gs_ms_ns_lengths.begin(), e_gs_ms_ns_lengths.end()),
std::vector<std::size_t>(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end()));
std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl;
std::cout << "b_gs_ns_ks: " << b_gs_ns_ks.mDesc << std::endl;
std::cout << "d_gs_ms_ns: " << d_gs_ms_ns.mDesc << std::endl;
std::cout << "e_gs_ms_ns: " << e_gs_ms_ns_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
break;
case 2:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break;
default:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
b_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
break;
}
DeviceMem a_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_gs_ns_ks.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf(sizeof(DDataType) * d_gs_ms_ns.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) *
e_gs_ms_ns_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_gs_ms_ks.mData.data());
b_device_buf.ToDevice(b_gs_ns_ks.mData.data());
d_device_buf.ToDevice(d_gs_ms_ns.mData.data());
// set zero
e_device_buf.SetZero();
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
// device operation
auto op = DeviceOpInstance{};
auto invoker = op.MakeInvoker();
auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
std::array<const void*, 1>{d_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(),
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b_gs_ns_ks_lengths,
b_gs_ns_ks_strides,
std::array<std::vector<ck::index_t>, 1>{d_gs_ms_ns_lengths},
std::array<std::vector<ck::index_t>, 1>{d_gs_ms_ns_strides},
e_gs_ms_ns_lengths,
e_gs_ms_ns_strides,
a_element_op,
b_element_op,
cde_element_op,
split_k);
if(!op.IsSupportedArgument(argument))
{
std::cout << op.GetTypeString() << " does not support this problem" << std::endl;
return 0;
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
ck::index_t G = std::accumulate(e_gs_ms_ns_lengths.begin(),
e_gs_ms_ns_lengths.begin() + NumDimG,
ck::index_t{1},
std::multiplies<ck::index_t>{});
ck::index_t M = std::accumulate(e_gs_ms_ns_lengths.begin() + NumDimG,
e_gs_ms_ns_lengths.begin() + NumDimG + NumDimM,
ck::index_t{1},
std::multiplies<ck::index_t>{});
ck::index_t N = std::accumulate(e_gs_ms_ns_lengths.begin() + NumDimG + NumDimM,
e_gs_ms_ns_lengths.begin() + NumDimG + NumDimM + NumDimN,
ck::index_t{1},
std::multiplies<ck::index_t>{});
ck::index_t K = std::accumulate(a_gs_ms_ks_lengths.begin() + NumDimG + NumDimM,
a_gs_ms_ks_lengths.begin() + NumDimG + NumDimM + NumDimK,
ck::index_t{1},
std::multiplies<ck::index_t>{});
std::size_t flop = std::size_t(2) * G * M * N * K;
std::size_t num_btype = sizeof(ADataType) * G * M * K + sizeof(BDataType) * G * K * N +
sizeof(DDataType) * G * M * N + sizeof(EDataType) * G * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< op.GetTypeString() << std::endl;
e_device_buf.FromDevice(e_gs_ms_ns_device_result.mData.data());
if(do_verification)
{
Tensor<CShuffleDataType> c_ms_ns_host_result(
std::vector<std::size_t>(e_gs_ms_ns_lengths.begin(), e_gs_ms_ns_lengths.end()),
std::vector<std::size_t>(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end()));
using ReferenceOpInstance = ReferenceContraction_G2_M2_N2_K1<NumDimG,
NumDimM,
NumDimN,
NumDimK,
ADataType,
BDataType,
CShuffleDataType,
AccDataType,
AElementOp,
BElementOp,
PassThrough>;
auto ref_gemm = ReferenceOpInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a_gs_ms_ks, b_gs_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op, PassThrough{});
ref_invoker.Run(ref_argument);
e_gs_ms_ns_host_result.ForEach([&](auto&, auto idx) {
cde_element_op(e_gs_ms_ns_host_result(idx), c_ms_ns_host_result(idx), d_gs_ms_ns(idx));
});
return ck::utils::check_err(e_gs_ms_ns_device_result.mData, e_gs_ms_ns_host_result.mData)
? 0
: 1;
}
return 0;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Add = ck::tensor_operation::element_wise::Add;
using ADataType = F32;
using BDataType = F32;
using AccDataType = F32;
using CShuffleDataType = F32;
using DDataType = F32;
using DsDataType = ck::Tuple<DDataType>;
using EDataType = F32;
static constexpr ck::index_t NumDimG = 2;
static constexpr ck::index_t NumDimM = 2;
static constexpr ck::index_t NumDimN = 2;
static constexpr ck::index_t NumDimK = 1;
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CDEElementOp = ck::tensor_operation::element_wise::Add;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto ABSpec = ck::tensor_operation::device::TensorSpecialization::Packed;
static constexpr auto DESpec = ck::tensor_operation::device::TensorSpecialization::Default;
// clang-format off
using DeviceOpInstanceKKNN = ck::tensor_operation::device::
//############################################| NumDimG| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| Gemm| A| B| DE| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle|CBlockTransferClusterLengths| CBlockTransfer|
//############################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Spacialization| Spacialization| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//############################################| | | | | | | | | | | Operation| Operation| Operation| | | | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceSplitKContractionMultipleD_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, ABSpec, ABSpec, DESpec, 1, 256, 256, 128, 32, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 1, 1, 1, S<1, 32, 1, 4>, 4>;
// clang-format on
using DeviceOpInstance = DeviceOpInstanceKKNN;
// hardcoded for NumDimM == NumDimN == NumDimK == 2
template <ck::index_t NumDimG,
ck::index_t NumDimM,
ck::index_t NumDimN,
ck::index_t NumDimK,
typename ADataType,
typename BDataType,
typename EDataType,
typename AccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
ck::enable_if_t<NumDimG == 2 && NumDimM == 2 && NumDimN == 2 && NumDimK == 1, bool> =
false>
struct ReferenceContraction_G2_M2_N2_K1 : public ck::tensor_operation::device::BaseOperator
{
// Argument
struct Argument : public ck::tensor_operation::device::BaseArgument
{
Argument(const Tensor<ADataType>& a_gs_ms_ks,
const Tensor<BDataType>& b_gs_ns_ks,
Tensor<EDataType>& e_gs_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
: a_gs_ms_ks_{a_gs_ms_ks},
b_gs_ns_ks_{b_gs_ns_ks},
e_gs_ms_ns_{e_gs_ms_ns},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op}
{
}
const Tensor<ADataType>& a_gs_ms_ks_;
const Tensor<BDataType>& b_gs_ns_ks_;
Tensor<EDataType>& e_gs_ms_ns_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
};
// Invoker
struct Invoker : public ck::tensor_operation::device::BaseInvoker
{
using Argument = ReferenceContraction_G2_M2_N2_K1::Argument;
float Run(const Argument& arg)
{
auto f_ms_ns = [&](auto g0, auto g1, auto m0, auto m1, auto n0, auto n1) {
const int K0 = arg.a_gs_ms_ks_.mDesc.GetLengths()[4];
AccDataType v_acc = 0;
for(int k0 = 0; k0 < K0; ++k0)
{
AccDataType v_a;
AccDataType v_b;
arg.a_element_op_(
v_a,
ck::type_convert<const AccDataType>(arg.a_gs_ms_ks_(g0, g1, m0, m1, k0)));
arg.b_element_op_(
v_b,
ck::type_convert<const AccDataType>(arg.b_gs_ns_ks_(g0, g1, n0, n1, k0)));
v_acc += v_a * v_b;
}
AccDataType v_c;
arg.cde_element_op_(v_c, v_acc);
arg.e_gs_ms_ns_(g0, g1, m0, m1, n0, n1) = v_c;
};
make_ParallelTensorFunctor(f_ms_ns,
arg.e_gs_ms_ns_.mDesc.GetLengths()[0],
arg.e_gs_ms_ns_.mDesc.GetLengths()[1],
arg.e_gs_ms_ns_.mDesc.GetLengths()[2],
arg.e_gs_ms_ns_.mDesc.GetLengths()[3],
arg.e_gs_ms_ns_.mDesc.GetLengths()[4],
arg.e_gs_ms_ns_.mDesc.GetLengths()[5])(
std::thread::hardware_concurrency());
return 0;
}
float Run(const ck::tensor_operation::device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override
{
return true;
}
static auto MakeArgument(const Tensor<ADataType>& a_gs_ms_ks,
const Tensor<BDataType>& b_gs_ns_ks,
Tensor<EDataType>& e_gs_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
{
return Argument{
a_gs_ms_ks, b_gs_ns_ks, e_gs_ms_ns, a_element_op, b_element_op, cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<ck::tensor_operation::device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceContraction_G2_M2_N2_K1"
<< std::endl;
// clang-format on
return str.str();
}
};
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
int split_k = 1;
ck::index_t G0 = 1;
ck::index_t G1 = 2;
ck::index_t M0 = 4;
ck::index_t M1 = 256;
ck::index_t N0 = 16;
ck::index_t N1 = 128;
ck::index_t K0 = 64 * 2;
// A[G0, G1, M0, M1, K0]
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M0, M1, K0};
std::vector<ck::index_t> a_gs_ms_ks_strides{G1 * M0 * M1 * K0, M0 * M1 * K0, M1 * K0, K0, 1};
// B[G0, G1, N0, N1, K0]
std::vector<ck::index_t> b_gs_ns_ks_lengths{G0, G1, N0, N1, K0};
std::vector<ck::index_t> b_gs_ns_ks_strides{G1 * N0 * N1 * K0, N0 * N1 * K0, N1 * K0, K0, 1};
// D[G0, G1, M0, N0, M1, N1]
std::vector<ck::index_t> d_gs_ms_ns_lengths{G0, G1, M0, M1, N0, N1};
std::vector<ck::index_t> d_gs_ms_ns_strides{G1 * N0 * N1, N0 * N1, 0, 0, N1, 1};
// E[G0, G1, M0, N0, M1, N1]
std::vector<ck::index_t> e_gs_ms_ns_lengths{G0, G1, M0, M1, N0, N1};
std::vector<ck::index_t> e_gs_ms_ns_strides{
G1 * M0 * N0 * M1 * N1, M0 * N0 * M1 * N1, N0 * M1 * N1, N1, M1 * N1, 1};
if(argc == 1)
{
// use default case
}
else if(argc == 5)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
split_k = std::stoi(argv[4]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
exit(0);
}
Tensor<ADataType> a_gs_ms_ks(
std::vector<std::size_t>(a_gs_ms_ks_lengths.begin(), a_gs_ms_ks_lengths.end()),
std::vector<std::size_t>(a_gs_ms_ks_strides.begin(), a_gs_ms_ks_strides.end()));
Tensor<BDataType> b_gs_ns_ks(
std::vector<std::size_t>(b_gs_ns_ks_lengths.begin(), b_gs_ns_ks_lengths.end()),
std::vector<std::size_t>(b_gs_ns_ks_strides.begin(), b_gs_ns_ks_strides.end()));
Tensor<DDataType> d_gs_ms_ns(
std::vector<std::size_t>(d_gs_ms_ns_lengths.begin(), d_gs_ms_ns_lengths.end()),
std::vector<std::size_t>(d_gs_ms_ns_strides.begin(), d_gs_ms_ns_strides.end()));
Tensor<EDataType> e_gs_ms_ns_host_result(
std::vector<std::size_t>(e_gs_ms_ns_lengths.begin(), e_gs_ms_ns_lengths.end()),
std::vector<std::size_t>(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end()));
Tensor<EDataType> e_gs_ms_ns_device_result(
std::vector<std::size_t>(e_gs_ms_ns_lengths.begin(), e_gs_ms_ns_lengths.end()),
std::vector<std::size_t>(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end()));
std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl;
std::cout << "b_gs_ns_ks: " << b_gs_ns_ks.mDesc << std::endl;
std::cout << "d_gs_ms_ns: " << d_gs_ms_ns.mDesc << std::endl;
std::cout << "e_gs_ms_ns: " << e_gs_ms_ns_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
break;
case 2:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break;
default:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
b_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
break;
}
DeviceMem a_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_gs_ns_ks.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf(sizeof(DDataType) * d_gs_ms_ns.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) *
e_gs_ms_ns_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_gs_ms_ks.mData.data());
b_device_buf.ToDevice(b_gs_ns_ks.mData.data());
d_device_buf.ToDevice(d_gs_ms_ns.mData.data());
// set zero
e_device_buf.SetZero();
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
// device operation
auto op = DeviceOpInstance{};
auto invoker = op.MakeInvoker();
auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
std::array<const void*, 1>{d_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(),
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b_gs_ns_ks_lengths,
b_gs_ns_ks_strides,
std::array<std::vector<ck::index_t>, 1>{d_gs_ms_ns_lengths},
std::array<std::vector<ck::index_t>, 1>{d_gs_ms_ns_strides},
e_gs_ms_ns_lengths,
e_gs_ms_ns_strides,
a_element_op,
b_element_op,
cde_element_op,
split_k);
if(!op.IsSupportedArgument(argument))
{
std::cout << op.GetTypeString() << " does not support this problem" << std::endl;
return 0;
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
ck::index_t G = std::accumulate(e_gs_ms_ns_lengths.begin(),
e_gs_ms_ns_lengths.begin() + NumDimG,
ck::index_t{1},
std::multiplies<ck::index_t>{});
ck::index_t M = std::accumulate(e_gs_ms_ns_lengths.begin() + NumDimG,
e_gs_ms_ns_lengths.begin() + NumDimG + NumDimM,
ck::index_t{1},
std::multiplies<ck::index_t>{});
ck::index_t N = std::accumulate(e_gs_ms_ns_lengths.begin() + NumDimG + NumDimM,
e_gs_ms_ns_lengths.begin() + NumDimG + NumDimM + NumDimN,
ck::index_t{1},
std::multiplies<ck::index_t>{});
ck::index_t K = std::accumulate(a_gs_ms_ks_lengths.begin() + NumDimG + NumDimM,
a_gs_ms_ks_lengths.begin() + NumDimG + NumDimM + NumDimK,
ck::index_t{1},
std::multiplies<ck::index_t>{});
std::size_t flop = std::size_t(2) * G * M * N * K;
std::size_t num_btype = sizeof(ADataType) * G * M * K + sizeof(BDataType) * G * K * N +
sizeof(DDataType) * G * M * N + sizeof(EDataType) * G * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< op.GetTypeString() << std::endl;
e_device_buf.FromDevice(e_gs_ms_ns_device_result.mData.data());
if(do_verification)
{
Tensor<CShuffleDataType> c_ms_ns_host_result(
std::vector<std::size_t>(e_gs_ms_ns_lengths.begin(), e_gs_ms_ns_lengths.end()),
std::vector<std::size_t>(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end()));
using ReferenceOpInstance = ReferenceContraction_G2_M2_N2_K1<NumDimG,
NumDimM,
NumDimN,
NumDimK,
ADataType,
BDataType,
CShuffleDataType,
AccDataType,
AElementOp,
BElementOp,
PassThrough>;
auto ref_gemm = ReferenceOpInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a_gs_ms_ks, b_gs_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op, PassThrough{});
ref_invoker.Run(ref_argument);
e_gs_ms_ns_host_result.ForEach([&](auto&, auto idx) {
cde_element_op(e_gs_ms_ns_host_result(idx), c_ms_ns_host_result(idx), d_gs_ms_ns(idx));
});
return ck::utils::check_err(e_gs_ms_ns_device_result.mData, e_gs_ms_ns_host_result.mData)
? 0
: 1;
}
return 0;
}
add_example_executable(example_elementwise_permute_4D_fp16 elementwise_permute_4D_fp16.cpp)
#include <iostream>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_elementwise.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
using F16 = ck::half_t;
using F32 = float;
using ADataType = F16;
using BDataType = F16;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using DeviceElementwisePermuteInstance =
ck::tensor_operation::device::DeviceElementwise<ck::Tuple<ADataType>,
ck::Tuple<BDataType>,
PassThrough,
4,
8,
ck::Sequence<8>,
ck::Sequence<1>>;
template <typename HostTensorA, typename HostTensorB, typename Functor>
void host_elementwise4D(HostTensorB& B_nhwc, const HostTensorA& A_nchw, Functor functor)
{
for(std::size_t n = 0; n < A_nchw.mDesc.GetLengths()[0]; ++n)
for(std::size_t c = 0; c < A_nchw.mDesc.GetLengths()[1]; ++c)
for(std::size_t h = 0; h < A_nchw.mDesc.GetLengths()[2]; ++h)
for(std::size_t w = 0; w < A_nchw.mDesc.GetLengths()[3]; ++w)
{
auto a_val = A_nchw(n, c, h, w);
functor(B_nhwc(n, h, w, c), a_val);
}
}
int main()
{
bool do_verification = true;
bool time_kernel = true;
std::vector<std::size_t> nchw = {16, 128, 32, 64};
std::vector<std::size_t> nhwc = {16, 32, 64, 128};
Tensor<ADataType> a(nchw);
Tensor<BDataType> b(nhwc);
a.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a.mData.data());
std::array<const void*, 1> input = {a_device_buf.GetDeviceBuffer()};
std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()};
std::array<ck::index_t, 4> ab_lengths;
std::array<ck::index_t, 4> a_strides = {static_cast<int>(nchw[1] * nchw[2] * nchw[3]),
static_cast<int>(nchw[2] * nchw[3]),
static_cast<int>(nchw[3]),
1};
std::array<ck::index_t, 4> b_strides = {static_cast<int>(nhwc[1] * nhwc[2] * nhwc[3]),
1,
static_cast<int>(nhwc[2] * nhwc[3]),
static_cast<int>(nhwc[3])};
std::copy(nchw.begin(), nchw.end(), ab_lengths.begin());
auto broadcastPermute = DeviceElementwisePermuteInstance{};
auto argument = broadcastPermute.MakeArgumentPointer(
ab_lengths, {a_strides}, {b_strides}, input, output, PassThrough{});
if(!broadcastPermute.IsSupportedArgument(argument.get()))
{
throw std::runtime_error(
"The runtime parameters seems not supported by the device instance, exiting!");
};
std::cout << "A (nchw): " << a.mDesc << std::endl;
std::cout << "B (nhwc): " << b.mDesc << std::endl;
auto broadcastPermute_invoker_ptr = broadcastPermute.MakeInvokerPointer();
float ave_time =
broadcastPermute_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * nchw[0] * nchw[1] * nchw[2] * nchw[3];
std::size_t num_btype = sizeof(ADataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]) +
sizeof(BDataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]);
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl;
bool pass = true;
if(do_verification)
{
b_device_buf.FromDevice(b.mData.data());
Tensor<BDataType> host_b(nhwc);
host_elementwise4D(host_b, a, PassThrough{});
pass &=
ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3);
}
return pass ? 0 : 1;
}
...@@ -159,6 +159,11 @@ ...@@ -159,6 +159,11 @@
// tuning parameter // tuning parameter
#define CK_WORKAROUND_SWDEV_325164 0 #define CK_WORKAROUND_SWDEV_325164 0
// workaround: disable broken fused attention kernel instance that does not pass validation
// issue found on mi100/#10738 combo when irregular KPerBlock attention kernel has acc0 scaling
// enabled
#define CK_WORKAROUND_DISABLE_BROKEN_ATTN_KERNEL_INSTANCE 1
namespace ck { namespace ck {
enum struct InMemoryDataOperationEnum enum struct InMemoryDataOperationEnum
......
...@@ -14,7 +14,8 @@ namespace ck { ...@@ -14,7 +14,8 @@ namespace ck {
template <typename TensorLengths, template <typename TensorLengths,
typename DimAccessOrder, typename DimAccessOrder,
typename ScalarsPerAccess> // # of scalars per access in each dimension typename ScalarsPerAccess,
bool SnakeCurved = true> // # of scalars per access in each dimension
struct SpaceFillingCurve struct SpaceFillingCurve
{ {
static constexpr index_t nDim = TensorLengths::Size(); static constexpr index_t nDim = TensorLengths::Size();
...@@ -136,9 +137,10 @@ struct SpaceFillingCurve ...@@ -136,9 +137,10 @@ struct SpaceFillingCurve
Index ordered_idx; Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto idim) { static_for<0, nDim, 1>{}([&](auto idim) {
ordered_idx(idim) = forward_sweep[idim] ? ordered_access_idx[idim] ordered_idx(idim) =
: ordered_access_lengths[idim] - 1 - !SnakeCurved || forward_sweep[idim]
ordered_access_idx[idim]; ? ordered_access_idx[idim]
: ordered_access_lengths[idim] - 1 - ordered_access_idx[idim];
}); });
return container_reorder_given_old2new(ordered_idx, dim_access_order) * return container_reorder_given_old2new(ordered_idx, dim_access_order) *
......
...@@ -151,6 +151,27 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -151,6 +151,27 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
return make_tuple(c_thread_m, c_thread_n); return make_tuple(c_thread_m, c_thread_n);
} }
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
__device__ static auto
CalculateCThreadOriginDataIndex8D(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk4D(xdlops_i, blk_i);
return make_tuple(Number<m0>{},
Number<n0>{},
waveId_m,
waveId_n,
blk_idx[I0],
blk_idx[I1],
blk_idx[I2],
blk_idx[I3]);
}
__host__ __device__ BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1() __host__ __device__ BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1()
{ {
static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() && static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() &&
...@@ -724,6 +745,21 @@ struct BlockwiseGemmXdlops_v2 ...@@ -724,6 +745,21 @@ struct BlockwiseGemmXdlops_v2
return make_tuple(c_thread_m, c_thread_n); return make_tuple(c_thread_m, c_thread_n);
} }
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
__device__ static auto
CalculateCThreadOriginDataIndex8D(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk4D(xdlops_i, blk_i);
return make_tuple(
m0, n0, waveId_m, waveId_n, blk_idx[I0], blk_idx[I1], blk_idx[I2], blk_idx[I3]);
}
using Tuple4 = decltype(CalculateAThreadOriginDataIndex()); using Tuple4 = decltype(CalculateAThreadOriginDataIndex());
__host__ __device__ BlockwiseGemmXdlops_v2(Tuple4 a_origin = CalculateAThreadOriginDataIndex(), __host__ __device__ BlockwiseGemmXdlops_v2(Tuple4 a_origin = CalculateAThreadOriginDataIndex(),
......
...@@ -24,7 +24,8 @@ template <typename ALayout, ...@@ -24,7 +24,8 @@ template <typename ALayout,
typename B0ElementwiseOperation, typename B0ElementwiseOperation,
typename Acc0ElementwiseOperation, typename Acc0ElementwiseOperation,
typename B1ElementwiseOperation, typename B1ElementwiseOperation,
typename CElementwiseOperation> typename CElementwiseOperation,
bool MaskOutUpperTriangle> // TODO: enum for mask type
struct DeviceBatchedGemmSoftmaxGemm : public BaseOperator struct DeviceBatchedGemmSoftmaxGemm : public BaseOperator
{ {
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
......
...@@ -7,49 +7,60 @@ ...@@ -7,49 +7,60 @@
#include <vector> #include <vector>
#include "device_base.hpp" #include "device_base.hpp"
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
template <typename ALayout, template <index_t NumDimG,
typename B0Layout, index_t NumDimM,
typename B1Layout, index_t NumDimN,
typename CPermuteNumDims_G_M_Gemm1N, // Sequence<> index_t NumDimK,
index_t NumDimO,
typename ADataType, typename ADataType,
typename B0DataType, typename B0DataType,
typename B1DataType, typename B1DataType,
typename CDataType, typename CDataType,
typename Acc0BiasDataType,
typename Acc1BiasDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename B0ElementwiseOperation, typename B0ElementwiseOperation,
typename Acc0ElementwiseOperation, typename Acc0ElementwiseOperation,
typename B1ElementwiseOperation, typename B1ElementwiseOperation,
typename CElementwiseOperation> typename CElementwiseOperation,
MaskingSpecialization MaskingSpec>
struct DeviceBatchedGemmSoftmaxGemmPermute : public BaseOperator struct DeviceBatchedGemmSoftmaxGemmPermute : public BaseOperator
{ {
virtual std::unique_ptr<BaseArgument> static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size();
MakeArgumentPointer(const void* p_a, static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size();
const void* p_b0,
const void* p_b1, virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
void* p_c, const void* p_a,
ck::index_t M, const void* p_b0,
ck::index_t N, const void* p_b1,
ck::index_t K, void* p_c,
ck::index_t O, const std::array<void*, NumAcc0Bias> p_acc0_biases,
ck::index_t Batch, const std::array<void*, NumAcc1Bias> p_acc1_biases,
std::vector<index_t> c_gs_ms_os_lengths, const std::vector<index_t>& a_gs_ms_ks_lengths,
std::vector<index_t> c_gs_ms_os_strides, const std::vector<index_t>& a_gs_ms_ks_strides,
ck::index_t StrideA, const std::vector<index_t>& b_gs_ns_ks_lengths,
ck::index_t StrideB0, const std::vector<index_t>& b_gs_ns_ks_strides,
ck::index_t StrideB1, const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
ck::index_t BatchStrideA, const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
ck::index_t BatchStrideB0, const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
ck::index_t BatchStrideB1, const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
AElementwiseOperation a_element_op, const std::array<std::vector<index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths,
B0ElementwiseOperation b0_element_op, const std::array<std::vector<index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides,
Acc0ElementwiseOperation acc0_element_op, const std::array<std::vector<index_t>, NumAcc1Bias>
B1ElementwiseOperation b1_element_op, acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
CElementwiseOperation c_element_op) = 0; const std::array<std::vector<index_t>, NumAcc1Bias>
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
AElementwiseOperation a_element_op,
B0ElementwiseOperation b0_element_op,
Acc0ElementwiseOperation acc0_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
......
...@@ -13,31 +13,36 @@ namespace ck { ...@@ -13,31 +13,36 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
template <index_t Rank, index_t NumBatchNormReduceDim> template <index_t Rank, index_t NumBatchNormReduceDim, typename YElementwiseOp>
struct DeviceBatchNormFwd : public BaseOperator struct DeviceBatchNormFwd : public BaseOperator
{ {
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer( virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
const std::array<index_t, Rank> xyLengths, const std::array<index_t, Rank> xyLengths,
const std::array<index_t, Rank> xStrides, const std::array<index_t, Rank> xStrides,
const std::array<index_t, Rank> yStrides, const std::array<index_t, Rank> yStrides,
const std::array<int, NumBatchNormReduceDim> reduceDims,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths, const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarStrides, const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleStrides,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnBiasStrides,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnMeanVarStrides,
const void* p_x, const void* p_x,
const void* bnScale, const void* bnScale,
const void* bnBias, const void* bnBias,
double epsilon,
const YElementwiseOp y_elementwise_op,
void* p_y, void* p_y,
void* resultSaveMean,
void* resultSaveInvVariance,
double exponentialAverageFactor, double exponentialAverageFactor,
void* resultRunningMean, void* resultRunningMean,
void* resultRunningVariance, void* resultRunningVariance) = 0;
double epsilon,
void* resultSaveMean,
void* resultSaveInvVariance) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
template <index_t Rank, index_t NumBatchNormReduceDim> template <index_t Rank, index_t NumBatchNormReduceDim, typename YElementwiseOp>
using DeviceBatchNormFwdPtr = std::unique_ptr<DeviceBatchNormFwd<Rank, NumBatchNormReduceDim>>; using DeviceBatchNormFwdPtr =
std::unique_ptr<DeviceBatchNormFwd<Rank, NumBatchNormReduceDim, YElementwiseOp>>;
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -21,7 +21,9 @@ struct DeviceBatchNormInfer : public BaseOperator ...@@ -21,7 +21,9 @@ struct DeviceBatchNormInfer : public BaseOperator
const std::array<index_t, Rank> xStrides, const std::array<index_t, Rank> xStrides,
const std::array<index_t, Rank> yStrides, const std::array<index_t, Rank> yStrides,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths, const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarStrides, const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleStrides,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnBiasStrides,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnMeanVarStrides,
const void* p_x, const void* p_x,
const void* bnScale, const void* bnScale,
const void* bnBias, const void* bnBias,
......
...@@ -7,46 +7,50 @@ ...@@ -7,46 +7,50 @@
#include <vector> #include <vector>
#include "device_base.hpp" #include "device_base.hpp"
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
template <typename ALayout, template <index_t NumDimG,
typename B0Layout, index_t NumDimM,
typename B1Layout, index_t NumDimN,
typename CPermuteNumDims_G_M_Gemm1N, // Sequence<> index_t NumDimK,
index_t NumDimO,
typename ADataType, typename ADataType,
typename B0DataType, typename B0DataType,
typename B1DataType, typename B1DataType,
typename CDataType, typename CDataType,
typename Acc0BiasDataType,
typename Acc1BiasDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename B0ElementwiseOperation, typename B0ElementwiseOperation,
typename Acc0ElementwiseOperation, typename Acc0ElementwiseOperation,
typename B1ElementwiseOperation, typename B1ElementwiseOperation,
typename CElementwiseOperation> typename CElementwiseOperation,
MaskingSpecialization MaskingSpec>
struct DeviceGroupedGemmSoftmaxGemmPermute : public BaseOperator struct DeviceGroupedGemmSoftmaxGemmPermute : public BaseOperator
{ {
struct ProblemDesc struct ProblemDesc
{ {
// Overall problem shape std::vector<index_t> a_gs_ms_ks_lengths;
index_t M; std::vector<index_t> a_gs_ms_ks_strides;
index_t N;
index_t K;
index_t O;
index_t Batch;
// Stride for A/B0/B1; layout determined by template args std::vector<index_t> b0_gs_ns_ks_lengths;
index_t StrideA; std::vector<index_t> b0_gs_ns_ks_strides;
index_t StrideB0;
index_t StrideB1; std::vector<index_t> b1_gs_os_ns_lengths;
index_t BatchStrideA; std::vector<index_t> b1_gs_os_ns_strides;
index_t BatchStrideB0;
index_t BatchStrideB1;
// Lengths and strides for output C
std::vector<index_t> c_gs_ms_os_lengths; std::vector<index_t> c_gs_ms_os_lengths;
std::vector<index_t> c_gs_ms_os_strides; std::vector<index_t> c_gs_ms_os_strides;
std::vector<std::vector<index_t>> acc0_biases_gs_ms_ns_lengths;
std::vector<std::vector<index_t>> acc0_biases_gs_ms_ns_strides;
std::vector<std::vector<index_t>> acc1_biases_gs_ms_os_lengths;
std::vector<std::vector<index_t>> acc1_biases_gs_ms_os_strides;
}; };
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
...@@ -54,6 +58,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute : public BaseOperator ...@@ -54,6 +58,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute : public BaseOperator
std::vector<const void*> p_b0_vec, std::vector<const void*> p_b0_vec,
std::vector<const void*> p_b1_vec, std::vector<const void*> p_b1_vec,
std::vector<void*> p_c_vec, std::vector<void*> p_c_vec,
std::vector<std::vector<const void*>> p_acc0_biases_vec,
std::vector<std::vector<const void*>> p_acc1_biases_vec,
std::vector<ProblemDesc> problem_desc_vec, std::vector<ProblemDesc> problem_desc_vec,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
B0ElementwiseOperation b0_element_op, B0ElementwiseOperation b0_element_op,
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
...@@ -54,9 +55,8 @@ __global__ void ...@@ -54,9 +55,8 @@ __global__ void
index_t right = group_count; index_t right = group_count;
index_t group_id = index_t((left + right) / 2); index_t group_id = index_t((left + right) / 2);
while((!(block_id >= arg_ptr[group_id].block_start_ && while(
block_id < arg_ptr[group_id].block_end_)) && (!(block_id >= arg_ptr[group_id].block_start_ && block_id < arg_ptr[group_id].block_end_)))
left <= right)
{ {
if(block_id < arg_ptr[group_id].block_start_) if(block_id < arg_ptr[group_id].block_start_)
{ {
...@@ -114,14 +114,17 @@ __global__ void ...@@ -114,14 +114,17 @@ __global__ void
// Computes C = A * B0 * B1 // Computes C = A * B0 * B1
// ^^^^^^ (Acc0) // ^^^^^^ (Acc0)
// ^^^^^^^^^^^ (Acc1) // ^^^^^^^^^^^ (Acc1)
template <typename ALayout, template <index_t NumDimG,
typename BLayout, // B0Layout index_t NumDimM,
typename B1Layout, index_t NumDimN,
typename CPermuteNumDims_G_M_Gemm1N, // Sequence<NumDimG, NumDimM, NumDimGemm1N> index_t NumDimK,
index_t NumDimO, // NumDimGemm1N
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
typename B1DataType, typename B1DataType,
typename CDataType, typename CDataType,
typename Acc0BiasDataType,
typename Acc1BiasDataType,
typename GemmAccDataType, typename GemmAccDataType,
typename CShuffleDataType, typename CShuffleDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
...@@ -130,6 +133,10 @@ template <typename ALayout, ...@@ -130,6 +133,10 @@ template <typename ALayout,
typename B1ElementwiseOperation, typename B1ElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
GemmSpecialization GemmSpec, GemmSpecialization GemmSpec,
TensorSpecialization ASpec,
TensorSpecialization BSpec,
TensorSpecialization B1Spec,
TensorSpecialization CSpec,
index_t NumGemmKPrefetchStage, index_t NumGemmKPrefetchStage,
index_t BlockSize, index_t BlockSize,
index_t MPerBlock, index_t MPerBlock,
...@@ -170,297 +177,152 @@ template <typename ALayout, ...@@ -170,297 +177,152 @@ template <typename ALayout,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
bool MaskOutUpperTriangle, MaskingSpecialization MaskingSpec,
LoopScheduler LoopSched = LoopScheduler::Default> LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
: public DeviceGroupedGemmSoftmaxGemmPermute<ALayout, : public DeviceGroupedGemmSoftmaxGemmPermute<NumDimG,
BLayout, NumDimM,
B1Layout, NumDimN,
CPermuteNumDims_G_M_Gemm1N, NumDimK,
NumDimO,
ADataType, ADataType,
BDataType, BDataType,
B1DataType, B1DataType,
CDataType, CDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
AccElementwiseOperation, AccElementwiseOperation,
B1ElementwiseOperation, B1ElementwiseOperation,
CElementwiseOperation> CElementwiseOperation,
MaskingSpec>
{ {
using DeviceOp = DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle; static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
using ProblemDesc = "Number of dimension must be greater than 0");
typename DeviceGroupedGemmSoftmaxGemmPermute<ALayout,
BLayout, static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size();
B1Layout, static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size();
CPermuteNumDims_G_M_Gemm1N,
ADataType, // TODO ANT: implement bias combination
BDataType, static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
B1DataType,
CDataType, #if 0
AElementwiseOperation, // TODO ANT: use alias
BElementwiseOperation, static constexpr index_t NumDimGemm0M = NumDimM;
AccElementwiseOperation, static constexpr index_t NumDimGemm0N = NumDimN;
B1ElementwiseOperation, static constexpr index_t NumDimGemm0K = NumDimK;
CElementwiseOperation>::ProblemDesc; static constexpr index_t NumDimGemm1M = NumDimM;
static constexpr index_t NumDimGemm1N = NumDimO;
static constexpr index_t NumDimGemm1K = NumDimN;
#endif
using DeviceOp = DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle;
using ProblemDesc = typename DeviceGroupedGemmSoftmaxGemmPermute<NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
ADataType,
BDataType,
B1DataType,
CDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
MaskingSpec>::ProblemDesc;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto matrix_padder = using Transform = TransformBatchedContractionContractionToBatchedGemmGemm<
GemmGemmPadder<GemmSpec, index_t, index_t, index_t, index_t>{ Sequence<NumDimG, NumDimM, NumDimN, NumDimK, NumDimO>,
MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock}; Sequence<MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock>,
GemmSpec,
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA) ASpec,
{ BSpec,
const auto a_grid_desc_mraw_kraw = [&]() { B1Spec,
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>) CSpec>;
{
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector<index_t>& a_gs_ms_ks_lengths_vec,
make_tuple(StrideA, I1)); const std::vector<index_t>& a_gs_ms_ks_strides_vec)
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(I1, StrideA));
}
}();
const auto a_grid_desc_m_k = matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
const auto M = a_grid_desc_m_k.GetLength(I0);
const auto K = a_grid_desc_m_k.GetLength(I1);
const auto AK0 = K / AK1;
return transform_tensor_descriptor(a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
{
const auto b_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(I1, StrideB));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(StrideB, I1));
}
}();
const auto b_grid_desc_n_k = matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = b_grid_desc_n_k.GetLength(I1);
const auto BK0 = K / BK1;
return transform_tensor_descriptor(b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
// Args: Gemm1KRaw, Gemm1NRaw, StrideB1
static auto MakeB1GridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
{ {
const auto b1_grid_desc_nraw_kraw = [&]() { return Transform::MakeAGridDescriptor_AK0_M_AK1(
if constexpr(is_same<tensor_layout::gemm::RowMajor, B1Layout>::value) Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec),
{ Number<AK1>{});
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(I1, StrideB));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, B1Layout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(StrideB, I1));
}
}();
const auto b1_grid_desc_n_k = matrix_padder.PadB1Descriptor_N_K(b1_grid_desc_nraw_kraw);
const auto N = b1_grid_desc_n_k.GetLength(I0);
const auto K = b1_grid_desc_n_k.GetLength(I1);
const auto B1K0 = K / B1K1;
return transform_tensor_descriptor(
b1_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
} }
// assume C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...] static auto MakeBGridDescriptor_BK0_N_BK1(const std::vector<index_t>& b_gs_ns_ks_lengths_vec,
static auto MakeCGridDescriptor_M_N(const std::vector<index_t>& c_gs_ms_ns_lengths_vec, const std::vector<index_t>& b_gs_ns_ks_strides_vec)
const std::vector<index_t>& c_gs_ms_ns_strides_vec)
{ {
constexpr index_t NumDimG = CPermuteNumDims_G_M_Gemm1N::At(I0); return Transform::MakeB0GridDescriptor_BK0_N_BK1(
constexpr index_t NumDimM = CPermuteNumDims_G_M_Gemm1N::At(I1); Transform::MakeB0GridDescriptor_N_K(b_gs_ns_ks_lengths_vec, b_gs_ns_ks_strides_vec),
constexpr index_t NumDimN = CPermuteNumDims_G_M_Gemm1N::At(I2); // NumDimGemm1N Number<BK1>{});
assert(c_gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN &&
c_gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN);
const auto to_tuple = [&](auto& vec, auto start, auto end) {
return generate_tuple([&](auto i) { return vec[start + i]; }, Number<end - start>{});
};
const auto c_ms_ns_lengths = to_tuple(
c_gs_ms_ns_lengths_vec, Number<NumDimG>{}, Number<NumDimG + NumDimM + NumDimN>{});
const auto c_ms_ns_strides = to_tuple(
c_gs_ms_ns_strides_vec, Number<NumDimG>{}, Number<NumDimG + NumDimM + NumDimN>{});
// dimension Ids for M0, M1, ...
constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{};
// dimension Ids for N0, N1, ...
constexpr auto nDimIds =
typename arithmetic_sequence_gen<NumDimM, NumDimM + NumDimN, 1>::type{};
// lengths for M0, M1, ...
const auto mLengths = get_container_subset(c_ms_ns_lengths, mDimIds);
// lengths for K0, K1, ...
const auto nLengths = get_container_subset(c_ms_ns_lengths, nDimIds);
// naive tensor C[M0, M1, M2, ..., N0, N1, N2...]
const auto c_grid_desc_ms_ns =
make_naive_tensor_descriptor(c_ms_ns_lengths, c_ms_ns_strides);
// transformed tensor C[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * N2 * ...]
const auto c_grid_desc_mraw_nraw = transform_tensor_descriptor(
c_grid_desc_ms_ns,
make_tuple(make_merge_transform(mLengths), make_merge_transform(nLengths)),
make_tuple(mDimIds, nDimIds),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw);
} }
// assume C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...] static auto
static auto MakeCGridDescriptor_G_M_N(const std::vector<index_t>& c_gs_ms_ns_lengths_vec, MakeB1GridDescriptor_BK0_N_BK1(const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths_vec,
const std::vector<index_t>& c_gs_ms_ns_strides_vec) const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides_vec)
{ {
constexpr index_t NumDimG = CPermuteNumDims_G_M_Gemm1N::At(I0); return Transform::MakeB1GridDescriptor_BK0_N_BK1(
constexpr index_t NumDimM = CPermuteNumDims_G_M_Gemm1N::At(I1); Transform::MakeB1GridDescriptor_N_K(b1_gs_gemm1ns_gemm1ks_lengths_vec,
constexpr index_t NumDimN = CPermuteNumDims_G_M_Gemm1N::At(I2); // NumDimGemm1N b1_gs_gemm1ns_gemm1ks_strides_vec),
Number<B1K1>{});
assert(c_gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN &&
c_gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN);
const auto to_tuple = [&](auto& vec, auto start, auto end) {
return generate_tuple([&](auto i) { return vec[start + i]; }, Number<end - start>{});
};
const auto c_gs_ms_ns_lengths =
to_tuple(c_gs_ms_ns_lengths_vec, Number<0>{}, Number<NumDimG + NumDimM + NumDimN>{});
const auto c_gs_ms_ns_strides =
to_tuple(c_gs_ms_ns_strides_vec, Number<0>{}, Number<NumDimG + NumDimM + NumDimN>{});
// dimension Ids for G0, G1, ...
constexpr auto gDimIds = typename arithmetic_sequence_gen<0, NumDimG, 1>::type{};
// dimension Ids for M0, M1, ...
constexpr auto mDimIds =
typename arithmetic_sequence_gen<NumDimG, NumDimG + NumDimM, 1>::type{};
// dimension Ids for N0, N1, ...
constexpr auto nDimIds = typename arithmetic_sequence_gen<NumDimG + NumDimM,
NumDimG + NumDimM + NumDimN,
1>::type{};
// lengths for G0, G1, ...
const auto gLengths = get_container_subset(c_gs_ms_ns_lengths, gDimIds);
// lengths for M0, M1, ...
const auto mLengths = get_container_subset(c_gs_ms_ns_lengths, mDimIds);
// lengths for K0, K1, ...
const auto nLengths = get_container_subset(c_gs_ms_ns_lengths, nDimIds);
// naive tensor C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
const auto c_grid_desc_gs_ms_ns =
make_naive_tensor_descriptor(c_gs_ms_ns_lengths, c_gs_ms_ns_strides);
// transformed tensor C[G = G0 * G1 * ..., MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 *
// N2 * ...]
const auto c_grid_desc_g_mraw_nraw =
transform_tensor_descriptor(c_grid_desc_gs_ms_ns,
make_tuple(make_merge_transform(gLengths),
make_merge_transform(mLengths),
make_merge_transform(nLengths)),
make_tuple(gDimIds, mDimIds, nDimIds),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// this desc is only for calculating batch offset so no padding needed
return c_grid_desc_g_mraw_nraw;
} }
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1(1, 1, 1)); using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1({}, {}));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N({}, {})); using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
using CGridDesc_G_M_N = decltype(MakeCGridDescriptor_G_M_N({}, {})); using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {}));
using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {}));
using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
// to track the points which need to be set to -inf on C0 constexpr static auto make_MaskOutPredicate()
// Note: no need to reset M padding value, because they will not be stored out.
struct C0MatrixMask
{ {
C0MatrixMask(index_t NRaw) : NRaw_(NRaw) {} if constexpr(MaskingSpec == MaskingSpecialization::MaskDisabled)
__host__ __device__ bool IsUpperTriangle(index_t m, index_t n) const { return n > m; }
__host__ __device__ bool IsNOutOfBound(/*index_t m, */ index_t n) const
{ {
return n >= NRaw_; return MaskDisabledPredicate{};
} }
else if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle)
__host__ __device__ bool IsMaskedElement(index_t m, index_t n) const
{ {
return IsUpperTriangle(m, n) || IsNOutOfBound(n); return MaskOutUpperTrianglePredicate{};
} }
}
private: using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>;
// index_t MRaw_;
index_t NRaw_;
};
struct ComputeBasePtrOfStridedBatch struct ComputeBasePtrOfStridedBatch
{ {
ComputeBasePtrOfStridedBatch(index_t BatchStrideA, ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k,
index_t BatchStrideB, const BGridDesc_G_N_K& b_grid_desc_g_n_k,
index_t BatchStrideB1, const B1GridDesc_G_N_K& b1_grid_desc_g_n_k,
CGridDesc_G_M_N c_grid_desc_g_m_n) const CGridDesc_G_M_N& c_grid_desc_g_m_n)
: BatchStrideA_(BatchStrideA), : a_grid_desc_g_m_k_(a_grid_desc_g_m_k),
BatchStrideB_(BatchStrideB), b_grid_desc_g_n_k_(b_grid_desc_g_n_k),
BatchStrideB1_(BatchStrideB1), b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k),
c_grid_desc_g_m_n_(c_grid_desc_g_m_n) c_grid_desc_g_m_n_(c_grid_desc_g_m_n)
{ {
} }
__host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const __host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const
{ {
return g_idx * static_cast<long_index_t>(BatchStrideA_); return a_grid_desc_g_m_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
} }
__host__ __device__ constexpr long_index_t GetBBasePtr(index_t g_idx) const __host__ __device__ constexpr long_index_t GetBBasePtr(index_t g_idx) const
{ {
return g_idx * static_cast<long_index_t>(BatchStrideB_); return b_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
} }
__host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const __host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
{ {
return g_idx * static_cast<long_index_t>(BatchStrideB1_); return b1_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
} }
__host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const __host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const
...@@ -469,9 +331,9 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -469,9 +331,9 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
} }
private: private:
index_t BatchStrideA_; AGridDesc_G_M_K a_grid_desc_g_m_k_;
index_t BatchStrideB_; BGridDesc_G_N_K b_grid_desc_g_n_k_;
index_t BatchStrideB1_; B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
CGridDesc_G_M_N c_grid_desc_g_m_n_; CGridDesc_G_M_N c_grid_desc_g_m_n_;
}; };
...@@ -535,8 +397,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -535,8 +397,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched, LoopSched,
matrix_padder.PadN, Transform::matrix_padder.PadN,
MaskOutUpperTriangle>; MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle>;
using Block2CTileMap = OffsettedBlockToCTileMap<typename GridwiseGemm::DefaultBlock2CTileMap>; using Block2CTileMap = OffsettedBlockToCTileMap<typename GridwiseGemm::DefaultBlock2CTileMap>;
...@@ -570,16 +432,16 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -570,16 +432,16 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
struct GroupDeviceArg struct GroupDeviceArg
{ {
// problem definiton // lengths for the last dimensions of overall problem for sanity check of vector load/store
index_t M; std::vector<index_t> raw_lengths_mz_nz_kz_gemm1nz_;
index_t N;
index_t K;
index_t O;
// Strides for the last dimensions of C for sanity check of vector load/store // strides for the last dimensions of each tensor for sanity check of vector load/store
index_t c_extent_lowest_; std::vector<index_t> a_mz_kz_strides_;
index_t c_stride_lowest_; std::vector<index_t> b_nz_kz_strides_;
std::vector<index_t> b1_nz_kz_strides_;
std::vector<index_t> c_mz_gemm1nz_strides_;
// for gridwise gemm check
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
}; };
...@@ -591,6 +453,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -591,6 +453,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
std::vector<const void*> p_b_vec, std::vector<const void*> p_b_vec,
std::vector<const void*> p_b1_vec, std::vector<const void*> p_b1_vec,
std::vector<void*> p_c_vec, std::vector<void*> p_c_vec,
std::vector<std::vector<const void*>> p_acc0_biases_vec,
std::vector<std::vector<const void*>> p_acc1_biases_vec,
std::vector<ProblemDesc> problem_desc_vec, std::vector<ProblemDesc> problem_desc_vec,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -603,6 +467,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -603,6 +467,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
b1_element_op_{b1_element_op}, b1_element_op_{b1_element_op},
c_element_op_{c_element_op} c_element_op_{c_element_op}
{ {
// TODO ANT: implement bias addition
group_count_ = problem_desc_vec.size(); group_count_ = problem_desc_vec.size();
if(!(group_count_ == p_a_vec.size() && group_count_ == p_b_vec.size() && if(!(group_count_ == p_a_vec.size() && group_count_ == p_b_vec.size() &&
...@@ -611,6 +476,11 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -611,6 +476,11 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
throw std::runtime_error("wrong! group_count_ != a/b/b1/c_vec.size"); throw std::runtime_error("wrong! group_count_ != a/b/b1/c_vec.size");
} }
if(!(p_acc0_biases_vec.size() == p_acc1_biases_vec.size()))
{
throw std::runtime_error("wrong! acc0_bias_vec.size != acc1_bias_vec.size");
}
grid_size_ = 0; grid_size_ = 0;
for(std::size_t i = 0; i < group_count_; i++) for(std::size_t i = 0; i < group_count_; i++)
...@@ -620,14 +490,25 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -620,14 +490,25 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
const auto p_b1_grid = static_cast<const B1DataType*>(p_b1_vec[i]); const auto p_b1_grid = static_cast<const B1DataType*>(p_b1_vec[i]);
const auto p_c_grid = static_cast<CDataType*>(p_c_vec[i]); const auto p_c_grid = static_cast<CDataType*>(p_c_vec[i]);
const auto a_grid_desc_ak0_m_ak1 = DeviceOp::MakeAGridDescriptor_AK0_M_AK1( const auto& problem_desc = problem_desc_vec[i];
problem_desc_vec[i].M, problem_desc_vec[i].K, problem_desc_vec[i].StrideA);
const auto b_grid_desc_bk0_n_bk1 = DeviceOp::MakeBGridDescriptor_BK0_N_BK1( const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
problem_desc_vec[i].K, problem_desc_vec[i].N, problem_desc_vec[i].StrideB0); problem_desc.a_gs_ms_ks_lengths, problem_desc.a_gs_ms_ks_strides);
const auto b1_grid_desc_bk0_n_bk1 = DeviceOp::MakeB1GridDescriptor_BK0_N_BK1( const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
problem_desc_vec[i].N, problem_desc_vec[i].O, problem_desc_vec[i].StrideB1); problem_desc.b0_gs_ns_ks_lengths, problem_desc.b0_gs_ns_ks_strides);
const auto c_grid_desc_m_n = DeviceOp::MakeCGridDescriptor_M_N( const auto b1_grid_desc_bk0_n_bk1 = MakeB1GridDescriptor_BK0_N_BK1(
problem_desc_vec[i].c_gs_ms_os_lengths, problem_desc_vec[i].c_gs_ms_os_strides); problem_desc.b1_gs_os_ns_lengths, problem_desc.b1_gs_os_ns_strides);
const auto c_grid_desc_m_n = Transform::MakeCGridDescriptor_M_N(
problem_desc.c_gs_ms_os_lengths, problem_desc.c_gs_ms_os_strides);
const auto a_grid_desc_g_m_k = Transform::MakeAGridDescriptor_G_M_K(
problem_desc.a_gs_ms_ks_lengths, problem_desc.a_gs_ms_ks_strides);
const auto b_grid_desc_g_n_k = Transform::MakeB0GridDescriptor_G_N_K(
problem_desc.b0_gs_ns_ks_lengths, problem_desc.b0_gs_ns_ks_strides);
const auto b1_grid_desc_g_n_k = Transform::MakeB1GridDescriptor_G_N_K(
problem_desc.b1_gs_os_ns_lengths, problem_desc.b1_gs_os_ns_strides);
const auto c_grid_desc_g_m_n = Transform::MakeCGridDescriptor_G_M_N(
problem_desc.c_gs_ms_os_lengths, problem_desc.c_gs_ms_os_strides);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock = const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
...@@ -635,25 +516,32 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -635,25 +516,32 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
const index_t BlockStart = grid_size_; const index_t BlockStart = grid_size_;
const auto block_2_ctile_map = Block2CTileMap(c_grid_desc_m_n, BlockStart); const auto block_2_ctile_map = Block2CTileMap(c_grid_desc_m_n, BlockStart);
const index_t grid_size_grp = block_2_ctile_map.CalculateGridSize(c_grid_desc_m_n) * const index_t batch_count = c_grid_desc_g_m_n.GetLength(I0);
problem_desc_vec[i].Batch; const index_t grid_size_grp =
block_2_ctile_map.CalculateGridSize(c_grid_desc_m_n) * batch_count;
const index_t BlockEnd = grid_size_ + grid_size_grp; const index_t BlockEnd = grid_size_ + grid_size_grp;
// batch stride // batch stride
// TODO ANT: only keep batch stride in tensor desc to reduce scalar cache pressure const auto compute_base_ptr_of_batch = ComputeBasePtrOfStridedBatch(
const auto c_grid_desc_g_m_n = DeviceOp::MakeCGridDescriptor_G_M_N( a_grid_desc_g_m_k, b_grid_desc_g_n_k, b1_grid_desc_g_n_k, c_grid_desc_g_m_n);
problem_desc_vec[i].c_gs_ms_os_lengths, problem_desc_vec[i].c_gs_ms_os_strides);
const auto compute_base_ptr_of_batch =
ComputeBasePtrOfStridedBatch(problem_desc_vec[i].BatchStrideA,
problem_desc_vec[i].BatchStrideB0,
problem_desc_vec[i].BatchStrideB1,
c_grid_desc_g_m_n);
// C0 mask // C0 mask
const auto c0_matrix_mask = C0MatrixMask(problem_desc_vec[i].N); const auto c0_matrix_mask = C0MatrixMask(b_grid_desc_g_n_k.GetLength(I1));
grid_size_ += grid_size_grp; grid_size_ += grid_size_grp;
// for each group, make sure acc0_biases_gs_ms_ns_lengths.size() == NumAcc0Bias and
// so on
if(!(problem_desc.acc0_biases_gs_ms_ns_lengths.size() == NumAcc0Bias &&
problem_desc.acc0_biases_gs_ms_ns_strides.size() == NumAcc0Bias &&
problem_desc.acc1_biases_gs_ms_os_lengths.size() == NumAcc1Bias &&
problem_desc.acc1_biases_gs_ms_os_strides.size() == NumAcc1Bias))
{
throw std::runtime_error(
"wrong! number of biases in function argument does not "
"match that in template argument");
}
group_kernel_args_.push_back({p_a_grid, group_kernel_args_.push_back({p_a_grid,
p_b_grid, p_b_grid,
p_b1_grid, p_b1_grid,
...@@ -669,13 +557,20 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -669,13 +557,20 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
BlockStart, BlockStart,
BlockEnd}); BlockEnd});
group_device_args_.push_back({problem_desc_vec[i].M, group_device_args_.push_back(
problem_desc_vec[i].N, {{problem_desc.a_gs_ms_ks_lengths[NumDimG + NumDimM - 1],
problem_desc_vec[i].K, problem_desc.b0_gs_ns_ks_lengths[NumDimG + NumDimN - 1],
problem_desc_vec[i].O, problem_desc.b0_gs_ns_ks_lengths[NumDimG + NumDimN + NumDimK - 1],
problem_desc_vec[i].c_gs_ms_os_lengths.back(), problem_desc.b1_gs_os_ns_lengths[NumDimG + NumDimO - 1]},
problem_desc_vec[i].c_gs_ms_os_strides.back(), {problem_desc.a_gs_ms_ks_strides[NumDimG + NumDimM - 1],
c_grid_desc_m_n}); problem_desc.a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1]},
{problem_desc.b0_gs_ns_ks_strides[NumDimG + NumDimN - 1],
problem_desc.b0_gs_ns_ks_strides[NumDimG + NumDimN + NumDimK - 1]},
{problem_desc.b1_gs_os_ns_strides[NumDimG + NumDimO - 1],
problem_desc.b1_gs_os_ns_strides[NumDimG + NumDimO + NumDimN - 1]},
{problem_desc.c_gs_ms_os_strides[NumDimG + NumDimM - 1],
problem_desc.c_gs_ms_os_strides[NumDimG + NumDimM + NumDimO - 1]},
c_grid_desc_m_n});
} }
} }
...@@ -788,6 +683,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -788,6 +683,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
return false; return false;
} }
// TODO ANT: Check if tensor specialization & strides mismatch
bool all_has_main_k_block_loop = true; bool all_has_main_k_block_loop = true;
bool some_has_main_k_block_loop = false; bool some_has_main_k_block_loop = false;
...@@ -815,19 +712,16 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -815,19 +712,16 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
// Note: we need raw lengths since threadwise copy can not handle vector load when // Note: we need raw lengths since threadwise copy can not handle vector load when
// part of vector is out of bounds // part of vector is out of bounds
const auto MRaw = device_arg.M; const auto MzRaw = device_arg.raw_lengths_mz_nz_kz_gemm1nz_[0];
const auto NRaw = device_arg.N; const auto NzRaw = device_arg.raw_lengths_mz_nz_kz_gemm1nz_[1];
const auto KRaw = device_arg.K; const auto KzRaw = device_arg.raw_lengths_mz_nz_kz_gemm1nz_[2];
const auto Gemm1NRaw = device_arg.O; const auto Gemm1NzRaw = device_arg.raw_lengths_mz_nz_kz_gemm1nz_[3];
// Check scalar per vector requirement // Check scalar per vector requirement
const auto a_extent_lowest = const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? KzRaw : MzRaw;
is_same_v<tensor_layout::gemm::RowMajor, ALayout> ? KRaw : MRaw; const auto b_extent_lowest = BBlockTransferSrcVectorDim == 2 ? KzRaw : NzRaw;
const auto b_extent_lowest = const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? NzRaw : Gemm1NzRaw;
is_same_v<tensor_layout::gemm::RowMajor, BLayout> ? NRaw : KRaw; const auto c_extent_lowest = Gemm1NzRaw;
const auto b1_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, B1Layout> ? Gemm1NRaw : NRaw;
const auto c_extent_lowest = device_arg.c_extent_lowest_;
if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 && if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
b_extent_lowest % BBlockTransferSrcScalarPerVector == 0 && b_extent_lowest % BBlockTransferSrcScalarPerVector == 0 &&
...@@ -837,8 +731,22 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -837,8 +731,22 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
return false; return false;
} }
// Check vector store requirement; assumes last dimension in N to be contiguous // Check vector load/store requirement
if(device_arg.c_stride_lowest_ != 1) const auto a_stride_lowest = ABlockTransferSrcVectorDim == 2
? device_arg.a_mz_kz_strides_[1]
: device_arg.a_mz_kz_strides_[0];
const auto b_stride_lowest = BBlockTransferSrcVectorDim == 2
? device_arg.b_nz_kz_strides_[1]
: device_arg.b_nz_kz_strides_[0];
const auto b1_stride_lowest = B1BlockTransferSrcVectorDim == 2
? device_arg.b1_nz_kz_strides_[1]
: device_arg.b1_nz_kz_strides_[0];
const auto c_stride_lowest =
device_arg.c_mz_gemm1nz_strides_[1]; // cshuffle assumes lowest dim in Gemm1Ns to be
// contiguous
if(!(a_stride_lowest == 1 || b_stride_lowest == 1 || b1_stride_lowest == 1 ||
c_stride_lowest == 1))
{ {
return false; return false;
} }
...@@ -873,6 +781,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -873,6 +781,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
std::vector<const void*> p_b_vec, std::vector<const void*> p_b_vec,
std::vector<const void*> p_b1_vec, std::vector<const void*> p_b1_vec,
std::vector<void*> p_c_vec, std::vector<void*> p_c_vec,
std::vector<std::vector<const void*>> p_acc0_biases_vec,
std::vector<std::vector<const void*>> p_acc1_biases_vec,
std::vector<ProblemDesc> problem_desc_vec, std::vector<ProblemDesc> problem_desc_vec,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -884,6 +794,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -884,6 +794,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
p_b_vec, p_b_vec,
p_b1_vec, p_b1_vec,
p_c_vec, p_c_vec,
p_acc0_biases_vec,
p_acc1_biases_vec,
problem_desc_vec, problem_desc_vec,
a_element_op, a_element_op,
b_element_op, b_element_op,
...@@ -895,21 +807,26 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -895,21 +807,26 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
// polymorphic // polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(std::vector<const void*> p_a_vec, std::unique_ptr<BaseArgument>
std::vector<const void*> p_b_vec, MakeArgumentPointer(std::vector<const void*> p_a_vec,
std::vector<const void*> p_b1_vec, std::vector<const void*> p_b_vec,
std::vector<void*> p_c_vec, std::vector<const void*> p_b1_vec,
std::vector<ProblemDesc> problem_desc_vec, std::vector<void*> p_c_vec,
AElementwiseOperation a_element_op, std::vector<std::vector<const void*>> p_acc0_biases_vec,
BElementwiseOperation b_element_op, std::vector<std::vector<const void*>> p_acc1_biases_vec,
AccElementwiseOperation acc_element_op, std::vector<ProblemDesc> problem_desc_vec,
B1ElementwiseOperation b1_element_op, AElementwiseOperation a_element_op,
CElementwiseOperation c_element_op) override BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op) override
{ {
return std::make_unique<Argument>(p_a_vec, return std::make_unique<Argument>(p_a_vec,
p_b_vec, p_b_vec,
p_b1_vec, p_b1_vec,
p_c_vec, p_c_vec,
p_acc0_biases_vec,
p_acc1_biases_vec,
problem_desc_vec, problem_desc_vec,
a_element_op, a_element_op,
b_element_op, b_element_op,
...@@ -942,7 +859,12 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -942,7 +859,12 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
<< Gemm1NPerBlock << ", " << Gemm1NPerBlock << ", "
<< Gemm1KPerBlock << ", " << Gemm1KPerBlock << ", "
<< B1K1 << ", " << B1K1 << ", "
<< getGemmSpecializationString(GemmSpec) << ">"; << getGemmSpecializationString(GemmSpec) << ", "
<< "ASpec" << getTensorSpecializationString(ASpec) << ", "
<< "B0Spec" << getTensorSpecializationString(BSpec) << ", "
<< "B1Spec" << getTensorSpecializationString(B1Spec) << ", "
<< "CSpec" << getTensorSpecializationString(CSpec) << ", "
<< getMaskingSpecializationString(MaskingSpec) << ">";
// clang-format on // clang-format on
return str.str(); return str.str();
......
...@@ -11,33 +11,6 @@ ...@@ -11,33 +11,6 @@
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
struct DeviceNormalization : public BaseOperator
{
// inLengths: input tensor extent(s) from high to low dimension
// inStrides: input tensor stride(s) from high to low dimension
// reduceDims: the dimension(s) the normalization operation is applied
// alpha: typeless pointer in host memory storing the alpha scaling value of type AccDataType
// beta: typeless pointer in host memory storing the beta scaling value of type AccDataType
// in_dev: typeless const pointer in device memory storing the input tensor
// out_dev: typeless pointer in device memory storing the output tensor
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(const std::vector<index_t> inLengths,
const std::vector<index_t> inStrides,
const std::vector<int> reduceDims,
const void* alpha,
const void* beta,
const void* in_dev,
void* out_dev) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
virtual index_t GetRank() const = 0;
virtual index_t GetNumReduceDim() const = 0;
};
using DeviceNormalizationPtr = std::unique_ptr<DeviceNormalization>;
template <typename XDataType, template <typename XDataType,
typename GammaDataType, typename GammaDataType,
typename BetaDataType, typename BetaDataType,
...@@ -46,7 +19,7 @@ template <typename XDataType, ...@@ -46,7 +19,7 @@ template <typename XDataType,
typename AccElementwiseOperation, typename AccElementwiseOperation,
index_t Rank, index_t Rank,
index_t NumReduceDim> index_t NumReduceDim>
struct DeviceLayernorm : public BaseOperator struct DeviceNormalization : public BaseOperator
{ {
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::vector<index_t> lengths, MakeArgumentPointer(const std::vector<index_t> lengths,
...@@ -73,14 +46,14 @@ template <typename XDataType, ...@@ -73,14 +46,14 @@ template <typename XDataType,
typename AccElementwiseOperation, typename AccElementwiseOperation,
index_t Rank, index_t Rank,
index_t NumReduceDim> index_t NumReduceDim>
using DeviceLayernormPtr = std::unique_ptr<DeviceLayernorm<XDataType, using DeviceNormalizationPtr = std::unique_ptr<DeviceNormalization<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
AccDataType, AccDataType,
YDataType, YDataType,
AccElementwiseOperation, AccElementwiseOperation,
Rank, Rank,
NumReduceDim>>; NumReduceDim>>;
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -3,27 +3,30 @@ ...@@ -3,27 +3,30 @@
#pragma once #pragma once
#include <vector> #include <array>
#include <memory> #include <memory>
#include <iostream>
#include "ck/utility/common_header.hpp" #include "ck/ck.hpp"
#include "ck/utility/reduction_enums.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
template <typename InElementwiseOperation, typename AccElementwiseOperation> template <index_t Rank,
index_t NumReduceDim,
typename InElementwiseOperation,
typename AccElementwiseOperation>
struct DeviceReduce : public BaseOperator struct DeviceReduce : public BaseOperator
{ {
static constexpr index_t NumOutDim = (Rank - NumReduceDim == 0) ? 1 : Rank - NumReduceDim;
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::vector<index_t> inLengths, MakeArgumentPointer(const std::array<index_t, Rank> inLengths,
const std::vector<index_t> inStrides, const std::array<index_t, Rank> inStrides,
const std::vector<index_t> outLengths, const std::array<index_t, NumOutDim> outLengths,
const std::vector<index_t> outStrides, const std::array<index_t, NumOutDim> outStrides,
const std::vector<int> reduceDims, const std::array<int, NumReduceDim> reduceDims,
float alpha, float alpha,
float beta, float beta,
const void* in_dev, const void* in_dev,
...@@ -36,9 +39,12 @@ struct DeviceReduce : public BaseOperator ...@@ -36,9 +39,12 @@ struct DeviceReduce : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
template <typename InElementwiseOperation, typename AccElementwiseOperation> template <index_t Rank,
using DeviceReducePtr = index_t NumReduceDim,
std::unique_ptr<DeviceReduce<InElementwiseOperation, AccElementwiseOperation>>; typename InElementwiseOperation,
typename AccElementwiseOperation>
using DeviceReducePtr = std::unique_ptr<
DeviceReduce<Rank, NumReduceDim, InElementwiseOperation, AccElementwiseOperation>>;
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment