Commit bd0f0686 authored by Jing Zhang's avatar Jing Zhang
Browse files

merge develop

parents e9b1000f 63914743
......@@ -156,16 +156,16 @@ int main(int argc, char* argv[])
d1_m_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
}
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem d0_m_n_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpace());
DeviceMem d1_m_n_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpace());
DeviceMem e_m_n_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpace());
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpace());
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpace());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpace());
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
d0_m_n_device_buf.ToDevice(d0_m_n.mData.data());
d1_m_n_device_buf.ToDevice(d1_m_n.mData.data());
a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
d0_device_buf.ToDevice(d0_m_n.mData.data());
d1_device_buf.ToDevice(d1_m_n.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
......@@ -175,11 +175,11 @@ int main(int argc, char* argv[])
auto device_op = DeviceOpInstance{};
auto invoker = device_op.MakeInvoker();
auto argument =
device_op.MakeArgument(a_m_k_device_buf.GetDeviceBuffer(),
b_k_n_device_buf.GetDeviceBuffer(),
std::array<const void*, 2>{d0_m_n_device_buf.GetDeviceBuffer(),
d1_m_n_device_buf.GetDeviceBuffer()},
e_m_n_device_buf.GetDeviceBuffer(),
device_op.MakeArgument(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
std::array<const void*, 2>{d0_device_buf.GetDeviceBuffer(),
d1_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(),
M,
N,
K,
......@@ -239,7 +239,7 @@ int main(int argc, char* argv[])
}
}
e_m_n_device_buf.FromDevice(e_m_n_device_result.mData.data());
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
return ck::utils::check_err(e_m_n_device_result.mData, e_m_n_host_result.mData) ? 0 : 1;
}
......
add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_fp16 gemm_bias_relu_add_layernorm_xdl_fp16.cpp)
add_example_executable(example_gemm_layernorm_xdl_fp16 gemm_layernorm_xdl_fp16.cpp)
add_example_executable(example_gemm_xdl_layernorm_single_kernel_fp16 gemm_xdl_layernorm_single_kernel_fp16.cpp)
......@@ -166,15 +166,15 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
for(int m = 0; m < M; ++m)
for(int n = 0; n < N; ++n)
{
AccDataType acc =
static_cast<AccDataType>(c_m_n(m, n)) + static_cast<AccDataType>(bias_n(n));
AccDataType acc = ck::type_convert<AccDataType>(c_m_n(m, n)) +
ck::type_convert<AccDataType>(bias_n(n));
AccDataType c1 = static_cast<AccDataType>(c1_m_n(m, n));
AccDataType c1 = ck::type_convert<AccDataType>(c1_m_n(m, n));
c_element_op(acc, acc);
c1_element_op(c1, c1);
acc += c1;
c_m_n(m, n) = static_cast<CDataType>(acc);
c_m_n(m, n) = ck::type_convert<CDataType>(acc);
}
// reduce_mean and reduce_square_mean
......@@ -208,12 +208,12 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
{
AccDataType out_acc = 0;
layerNormInst(out_acc,
static_cast<AccDataType>(c_m_n(m, n)),
static_cast<AccDataType>(mean_m(m)),
static_cast<AccDataType>(meanSquare_m(m)),
static_cast<AccDataType>(gamma_n(n)),
static_cast<AccDataType>(beta_n(n)));
out_m_n(m, n) = static_cast<ReduceDataType>(out_acc);
ck::type_convert<AccDataType>(c_m_n(m, n)),
ck::type_convert<AccDataType>(mean_m(m)),
ck::type_convert<AccDataType>(meanSquare_m(m)),
ck::type_convert<AccDataType>(gamma_n(n)),
ck::type_convert<AccDataType>(beta_n(n)));
out_m_n(m, n) = ck::type_convert<ReduceDataType>(out_acc);
}
}
}
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include "ck/ck.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/host_tensor/device_memory.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/host_tensor/host_tensor_generator.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_xdl_layernorm_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/utility/reduction_operator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm_layernorm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
// This example demonstrate a single kernel that runs GEMM layer and laynorm in one fused kernel
//
// The GEMM + Layernorm implementation is a specialized kernel which allows fusing both layers
// together given the condition GEMM extents N of MNK is spanned by a single workgroup. For example,
// a kernel configured with NPerBlock = 128 allows to operate on all GEMM sizes if N <= 128
//
// D = Layernorm(acc_element_op(A * B + broadcast(bias)) + add) * broadcast(gamma) + broadcast(beta)
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using ADataType = F16;
using BDataType = F16;
using CDataType = F16;
using C0DataType = F16;
using AccDataType = F32;
using CShuffleDataType = F16;
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
struct Relu
{
template <typename OutT, typename InT>
__host__ __device__ void operator()(OutT& y, const InT& x) const
{
y = x > 0 ? x : 0;
}
};
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
// Elementwise operation that operates on the output of matrix multiplication
// i.e., AccElementOp(A * B + bias)
using AccElementOp = Relu;
// Elementwise operation that operates on the output of layer normalization
using CElementOp = Relu;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmLayerNorm_Xdl_CShuffle
//######| ALayout| BLayout| CLayout| AData| BData| CData| C0Data| GemmAcc| CShuffle| ReduceAcc| A| B| Acc| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadCopy|
//######| | | | Type| Type| Type| Type| DataType| DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, ADataType, BDataType, CDataType, C0DataType, AccDataType, CShuffleDataType, AccDataType, AElementOp, BElementOp, AccElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 8>, 8, S<64, 4>, 4>;
// clang-format on
using ReferenceInstance = ck::tensor_operation::host::ReferenceGemmLayernorm<ADataType,
BDataType,
CDataType,
C0DataType,
AccDataType,
AElementOp,
BElementOp,
AccElementOp,
CElementOp>;
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
// GEMM shape
ck::index_t M = 3840;
ck::index_t N = 128;
ck::index_t K = 4096;
ck::index_t StrideA = 4096;
ck::index_t StrideB = 4096;
ck::index_t StrideC = 128;
if(argc == 1)
{
// do nothing
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 10)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
StrideA = std::stoi(argv[7]);
StrideB = std::stoi(argv[8]);
StrideC = std::stoi(argv[9]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=n0, 1=yes)\n");
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n");
exit(0);
}
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({stride, 1}));
}
else
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({1, stride}));
}
};
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<AccDataType> acc_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<C0DataType> c0_n_bias(HostTensorDescriptor(std::vector<size_t>({size_t(N)})));
Tensor<C0DataType> c0_m_n_add(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<C0DataType> c0_n_gamma(HostTensorDescriptor(std::vector<size_t>({size_t(N)})));
Tensor<C0DataType> c0_n_beta(HostTensorDescriptor(std::vector<size_t>({size_t(N)})));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
std::cout << "c0_n_bias: " << c0_n_bias.mDesc << std::endl;
std::cout << "c0_m_n_add: " << c0_m_n_add.mDesc << std::endl;
std::cout << "c0_n_gamma: " << c0_n_gamma.mDesc << std::endl;
std::cout << "c0_n_beta: " << c0_n_beta.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
break;
case 2:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break;
default:
a_m_k.GenerateTensorValue(GeneratorTensor_Sequential<0>{});
b_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
}
c0_n_bias.GenerateTensorValue(GeneratorTensor_2<C0DataType>{-5, 5});
c0_m_n_add.GenerateTensorValue(GeneratorTensor_2<C0DataType>{-5, 5});
c0_n_gamma.GenerateTensorValue(GeneratorTensor_2<C0DataType>{0, 2});
c0_n_beta.GenerateTensorValue(GeneratorTensor_2<C0DataType>{0, 5});
c_m_n_host_result.GenerateTensorValue(GeneratorTensor_1<CDataType>{0});
acc_m_n_host_result.GenerateTensorValue(GeneratorTensor_1<AccDataType>{0});
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace());
DeviceMem c0_bias_buf(sizeof(C0DataType) * c0_n_bias.mDesc.GetElementSpace());
DeviceMem c0_add_buf(sizeof(C0DataType) * c0_m_n_add.mDesc.GetElementSpace());
DeviceMem c0_gamma_buf(sizeof(C0DataType) * c0_n_gamma.mDesc.GetElementSpace());
DeviceMem c0_beta_buf(sizeof(C0DataType) * c0_n_beta.mDesc.GetElementSpace());
a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
c0_bias_buf.ToDevice(c0_n_bias.mData.data());
c0_add_buf.ToDevice(c0_m_n_add.mData.data());
c0_gamma_buf.ToDevice(c0_n_gamma.mData.data());
c0_beta_buf.ToDevice(c0_n_beta.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto acc_element_op = AccElementOp{};
auto c_element_op = CElementOp{};
// do GEMM
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
static_cast<C0DataType*>(c0_add_buf.GetDeviceBuffer()),
static_cast<C0DataType*>(c0_bias_buf.GetDeviceBuffer()),
static_cast<C0DataType*>(c0_gamma_buf.GetDeviceBuffer()),
static_cast<C0DataType*>(c0_beta_buf.GetDeviceBuffer()),
M,
N,
K,
StrideA,
StrideB,
StrideC,
a_element_op,
b_element_op,
acc_element_op,
c_element_op);
if(!gemm.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
// extra 6MN flops due to: bias + add + gamma + beta + norm_sub + norm_div,
// excluding reduction steps
std::size_t flop = std::size_t(2) * M * N * K + std::size_t(6) * M * N;
// extra MN and 3N due to c0_add (MxN), bias (1xN), gamma (1xN), beta (1xN)
std::size_t bytes = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(CDataType) * 2 * M * N + sizeof(C0DataType) * 3 * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = bytes / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl;
bool pass = true;
if(do_verification)
{
c_device_buf.FromDevice(c_m_n_device_result.mData.data());
auto ref_gemm = ReferenceInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(a_m_k,
b_k_n,
c_m_n_host_result,
c0_n_bias,
c0_m_n_add,
c0_n_gamma,
c0_n_beta,
a_element_op,
b_element_op,
acc_element_op,
c_element_op);
ref_invoker.Run(ref_argument);
if constexpr(std::is_same<CShuffleDataType, F32>::value)
{
pass &= ck::utils::check_err(
c_m_n_device_result.mData, c_m_n_host_result.mData, "Error: Incorrect results c");
}
else if constexpr(std::is_same<CShuffleDataType, F16>::value)
{
pass &= ck::utils::check_err(c_m_n_device_result.mData,
c_m_n_host_result.mData,
"Error: Incorrect results c",
1e-2,
1e-2);
}
}
return pass ? 0 : 1;
}
......@@ -150,6 +150,9 @@ int main(int argc, char* argv[])
AccDataType alpha = args.scales[0];
AccDataType beta = args.scales[1];
std::cout << "in: " << in.mDesc << std::endl;
std::cout << "out: " << out.mDesc << std::endl;
std::size_t num_thread = 1;
if(args.do_verification)
......@@ -195,7 +198,7 @@ int main(int argc, char* argv[])
using ReferenceInstance =
tensor_operation::host::ReferenceSoftmax<InDataType, OutDataType, AccDataType>;
ReferenceInstance ref;
auto ref_arg = ref.MakeArgument(in, out_ref, alpha, beta, Rank, reduceDims);
auto ref_arg = ref.MakeArgument(in, out_ref, alpha, beta, reduceDims);
auto invoker = ref.MakeInvoker();
invoker.Run(ref_arg);
// LogRangeAsType<float>(std::cout << "tensor out_ref: ", out_ref.mData, ",") << std::endl;
......@@ -212,8 +215,8 @@ int main(int argc, char* argv[])
auto argument_ptr = device_instance.MakeArgumentPointer(i_inLengths,
i_inStrides,
reduceDims,
alpha,
beta,
&alpha,
&beta,
in_dev.GetDeviceBuffer(),
out_dev.GetDeviceBuffer());
......
add_example_executable(example_batched_gemm_c_permute_xdl_fp16 batched_gemm_c_permute_xdl_fp16.cpp)
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_c_permute_xdl.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/host_tensor/device_memory.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/host_tensor/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = ck::half_t;
using BDataType = ck::half_t;
using CDataType = ck::half_t;
using AccDataType = float;
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
// static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// static constexpr auto MNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
static constexpr auto MNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmCPermuteXdl
//######| ALayout| BLayout| AData| BData| CData| AccData| A| B| C| GEMM| Num| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// < Row, Col, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, MNPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>;
< Row, Col, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, MNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>;
// clang-format on
using ReferenceBatchedGemmInstance = ck::tensor_operation::host::
ReferenceBatchedGemm<ADataType, BDataType, CDataType, AElementOp, BElementOp, CElementOp>;
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
const int M = 88;
const int N = 64;
const int K = 88;
const int stride_A = K;
const int stride_B = K;
const int G0 = 1024;
const int G1 = 10;
const int batch_count = G0 * G1;
// output layout - [G0, M, G1, N]
const int stride_G0 = M * G1 * N;
const int stride_G1 = N;
const int stride_M = G1 * N;
const int stride_N = 1;
if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=n0, 1=yes)\n");
exit(0);
}
// GEMM shape
ck::tensor_operation::device::BatchedGemmCPermuteDesc batched_gemm_c_permute_desc{
G0, G1, M, N, stride_G0, stride_G1, stride_M, stride_N};
auto f_host_tensor_descriptor = [](std::size_t batch_count_,
std::size_t row,
std::size_t col,
std::size_t stride,
auto layout) {
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor(std::vector<std::size_t>({batch_count_, row, col}),
std::vector<std::size_t>({row * stride, stride, 1}));
}
else
{
return HostTensorDescriptor(std::vector<std::size_t>({batch_count_, row, col}),
std::vector<std::size_t>({col * stride, 1, stride}));
}
};
Tensor<ADataType> a_g_m_k(f_host_tensor_descriptor(batch_count, M, K, stride_A, ALayout{}));
Tensor<BDataType> b_g_k_n(f_host_tensor_descriptor(batch_count, K, N, stride_B, BLayout{}));
auto f_host_c_tensor_descriptor = [](std::size_t G0_,
std::size_t G1_,
std::size_t M_,
std::size_t N_,
std::size_t stride_G0_,
std::size_t stride_G1_,
std::size_t stride_M_,
std::size_t stride_N_) {
return HostTensorDescriptor(
std::vector<std::size_t>({G0_, G1_, M_, N_}),
std::vector<std::size_t>({stride_G0_, stride_G1_, stride_M_, stride_N_}));
};
Tensor<CDataType> c_g0_g1_m_n_host_result(
f_host_c_tensor_descriptor(G0, G1, M, N, stride_G0, stride_G1, stride_M, stride_N));
Tensor<CDataType> c_g0_g1_m_n_device_result(
f_host_c_tensor_descriptor(G0, G1, M, N, stride_G0, stride_G1, stride_M, stride_N));
std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl;
std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl;
std::cout << "c_g0_g1_m_n: " << c_g0_g1_m_n_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_g_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
break;
default:
a_g_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_g_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break;
}
DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpace());
DeviceMem c_device_buf(sizeof(CDataType) * c_g0_g1_m_n_device_result.mDesc.GetElementSpace());
a_device_buf.ToDevice(a_g_m_k.mData.data());
b_device_buf.ToDevice(b_g_k_n.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
// do GEMM
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
M,
N,
K,
stride_A,
stride_B,
batched_gemm_c_permute_desc,
a_element_op,
b_element_op,
c_element_op,
batch_count);
if(!gemm.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * batch_count * M * N * K;
std::size_t num_btype = sizeof(ADataType) * batch_count * M * K +
sizeof(BDataType) * batch_count * K * N +
sizeof(CDataType) * batch_count * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl;
bool pass = true;
if(do_verification)
{
c_device_buf.FromDevice(c_g0_g1_m_n_device_result.mData.data());
auto ref_batched_gemm = ReferenceBatchedGemmInstance{};
auto ref_invoker = ref_batched_gemm.MakeInvoker();
Tensor<CDataType> c_g_m_n_host_result = HostTensorDescriptor(
std::vector<std::size_t>({batch_count, M, N}), std::vector<std::size_t>({M * N, N, 1}));
auto ref_argument = ref_batched_gemm.MakeArgument(
a_g_m_k, b_g_k_n, c_g_m_n_host_result, a_element_op, b_element_op, c_element_op);
ref_invoker.Run(ref_argument);
for(int g0 = 0; g0 < G0; g0++)
{
for(int g1 = 0; g1 < G1; g1++)
{
for(int m = 0; m < M; m++)
{
for(int n = 0; n < N; n++)
{
int g = g0 * G1 + g1;
c_g0_g1_m_n_host_result(g0, g1, m, n) = c_g_m_n_host_result(g, m, n);
}
}
}
}
pass = ck::utils::check_err(c_g0_g1_m_n_host_result.mData,
c_g0_g1_m_n_device_result.mData,
"Error: Incorrect results c");
}
return pass ? 0 : 1;
}
add_example_executable(example_gemm_bias_c_permute_xdl_fp16 gemm_bias_c_permute_xdl_fp16.cpp)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_bias_c_permute_xdl.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/library/host_tensor/device_memory.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/host_tensor/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Add = ck::tensor_operation::element_wise::Add;
using ADataType = F16;
using BDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F32;
using DDataType = F16;
using EDataType = F16;
using ALayout = Row;
using BLayout = Col;
using DLayout = Row;
using ELayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = Add;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmBiasCPermute_Xdl
//######| ALayout| BLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1>;
// clang-format on
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
ck::index_t M0 = 4;
ck::index_t M1 = 32;
ck::index_t M2 = 128;
ck::index_t N0 = 16;
ck::index_t N1 = 256;
// GEMM shape
ck::index_t M = M0 * M1 * M2;
ck::index_t N = N0 * N1;
ck::index_t K = 128;
ck::index_t stride_A = K;
ck::index_t stride_B = K;
#if 1
// E = [M0, N0, M1, N1, M2]
ck::index_t stride_E_M0 = N0 * M1 * N1 * M2;
ck::index_t stride_E_M1 = N1 * M2;
ck::index_t stride_E_M2 = 1;
ck::index_t stride_E_N0 = M1 * N1 * M2;
ck::index_t stride_E_N1 = M2;
// D = [0, N0, 0, N1, 0]
ck::index_t stride_D_M0 = 0;
ck::index_t stride_D_M1 = 0;
ck::index_t stride_D_M2 = 0;
ck::index_t stride_D_N0 = N1;
ck::index_t stride_D_N1 = 1;
#else
// D = [0, 0, 0, N0, N1]
ck::index_t stride_D_M0 = 0;
ck::index_t stride_D_M1 = 0;
ck::index_t stride_D_M2 = 0;
ck::index_t stride_D_N0 = N1;
ck::index_t stride_D_N1 = 1;
// E = [M0, M1, M2, N0, N1]
ck::index_t stride_E_M0 = M1 * M2 * N0 * N1;
ck::index_t stride_E_M1 = M2 * N0 * N1;
ck::index_t stride_E_M2 = N0 * N1;
ck::index_t stride_E_N0 = N1;
ck::index_t stride_E_N1 = 1;
#endif
const ck::tensor_operation::device::DEGridDesc_M0_M1_M2_N0_N1 d_grid_desc{
M0, M1, M2, N0, N1, stride_D_M0, stride_D_M1, stride_D_M2, stride_D_N0, stride_D_N1};
const ck::tensor_operation::device::DEGridDesc_M0_M1_M2_N0_N1 e_grid_desc{
M0, M1, M2, N0, N1, stride_E_M0, stride_E_M1, stride_E_M2, stride_E_N0, stride_E_N1};
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
exit(0);
}
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({stride, 1}));
}
else
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({1, stride}));
}
};
auto f_host_de_tensor_descriptor =
[](ck::tensor_operation::device::DEGridDesc_M0_M1_M2_N0_N1 de_grid_desc) {
std::size_t m0 = de_grid_desc.M0_;
std::size_t m1 = de_grid_desc.M1_;
std::size_t m2 = de_grid_desc.M2_;
std::size_t n0 = de_grid_desc.N0_;
std::size_t n1 = de_grid_desc.N1_;
std::size_t stride_m0 = de_grid_desc.stride_M0_;
std::size_t stride_m1 = de_grid_desc.stride_M1_;
std::size_t stride_m2 = de_grid_desc.stride_M2_;
std::size_t stride_n0 = de_grid_desc.stride_N0_;
std::size_t stride_n1 = de_grid_desc.stride_N1_;
return HostTensorDescriptor(
std::vector<std::size_t>({m0, m1, m2, n0, n1}),
std::vector<std::size_t>({stride_m0, stride_m1, stride_m2, stride_n0, stride_n1}));
};
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, stride_A, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, stride_B, BLayout{}));
Tensor<DDataType> d_m0_m1_m2_n0_n1(f_host_de_tensor_descriptor(d_grid_desc));
Tensor<EDataType> e_m0_m1_m2_n0_n1_host_result(f_host_de_tensor_descriptor(e_grid_desc));
Tensor<EDataType> e_m0_m1_m2_n0_n1_device_result(f_host_de_tensor_descriptor(e_grid_desc));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "d_m0_m1_m2_n0_n1: " << d_m0_m1_m2_n0_n1.mDesc << std::endl;
std::cout << "e_m0_m1_m2_n0_n1: " << e_m0_m1_m2_n0_n1_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
d_m0_m1_m2_n0_n1.GenerateTensorValue(GeneratorTensor_2<DDataType>{-5, 5});
break;
default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
d_m0_m1_m2_n0_n1.GenerateTensorValue(GeneratorTensor_3<DDataType>{0.0, 1.0});
}
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem d_m0_m1_m2_n0_n1_device_buf(sizeof(DDataType) *
d_m0_m1_m2_n0_n1.mDesc.GetElementSpace());
DeviceMem e_m0_m1_m2_n0_n1_device_buf(sizeof(EDataType) *
e_m0_m1_m2_n0_n1_device_result.mDesc.GetElementSpace());
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
d_m0_m1_m2_n0_n1_device_buf.ToDevice(d_m0_m1_m2_n0_n1.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
// do GEMM
auto device_op = DeviceOpInstance{};
auto invoker = device_op.MakeInvoker();
auto argument = device_op.MakeArgument(a_m_k_device_buf.GetDeviceBuffer(),
b_k_n_device_buf.GetDeviceBuffer(),
d_m0_m1_m2_n0_n1_device_buf.GetDeviceBuffer(),
e_m0_m1_m2_n0_n1_device_buf.GetDeviceBuffer(),
M,
N,
K,
stride_A,
stride_B,
d_grid_desc,
e_grid_desc,
a_element_op,
b_element_op,
cde_element_op);
if(!device_op.IsSupportedArgument(argument))
{
throw std::runtime_error("wrong! this device_op instance does not support this problem");
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(DDataType) * N + sizeof(EDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< device_op.GetTypeString() << std::endl;
if(do_verification)
{
Tensor<AccDataType> c_m_n(HostTensorDescriptor(
std::vector<std::size_t>{static_cast<std::size_t>(M), static_cast<std::size_t>(N)}));
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
AccDataType,
AccDataType,
AElementOp,
BElementOp,
PassThrough>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument =
ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{});
ref_invoker.Run(ref_argument);
for(int m0 = 0; m0 < M0; ++m0)
for(int m1 = 0; m1 < M1; ++m1)
for(int m2 = 0; m2 < M2; ++m2)
for(int n0 = 0; n0 < N0; ++n0)
for(int n1 = 0; n1 < N1; ++n1)
{
int m = m0 * M1 * M2 + m1 * M2 + m2;
int n = n0 * N1 + n1;
cde_element_op(e_m0_m1_m2_n0_n1_host_result(m0, m1, m2, n0, n1),
ck::type_convert<EDataType>(c_m_n(m, n)),
d_m0_m1_m2_n0_n1(m0, m1, m2, n0, n1));
}
e_m0_m1_m2_n0_n1_device_buf.FromDevice(e_m0_m1_m2_n0_n1_device_result.mData.data());
return ck::utils::check_err(e_m0_m1_m2_n0_n1_device_result.mData,
e_m0_m1_m2_n0_n1_host_result.mData)
? 0
: 1;
}
return 0;
}
add_example_executable(example_contraction_bilinear_xdl_fp32 contraction_bilinear_xdl_fp32.cpp)
add_example_executable(example_contraction_scale_xdl_fp32 contraction_scale_xdl_fp32.cpp)
# Instructions for ```example_contraction_bilinear_xdl_fp32```
## Run
```bash
#arg1: verification (0=no, 1=yes)
#arg2: initialization (0=no init, 1=integer value, 2=decimal value)
#arg3: time kernel (0=no, 1=yes)
./bin/example_contraction_bilinear_xdl_fp32 1 1 1
```
Result (MI100 @ dynammic freq, 46TFlops peak FP32)
```
a_ms_ks: dim 4, lengths {30, 128, 32, 64}, strides {524288, 4096, 128, 1}
b_ks_ns: dim 4, lengths {32, 64, 32, 64}, strides {128, 1, 524288, 4096}
c_ms_ns: dim 4, lengths {30, 128, 32, 64}, strides {524288, 4096, 128, 1}
launch_and_time_kernel: grid_dim {240, 1, 1}, block_dim {256, 1, 1}
Warm up 1 time
Start running 10 times...
Perf: 0.843286 ms, 38.1985 TFlops, 94.5014 GB/s, DeviceContractionMultipleD_Xdl_CShuffle<256, 256, 128, 16, 4, 4>
```
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/host_tensor/device_memory.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/host_tensor/host_tensor_generator.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F32 = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = F32;
using BDataType = F32;
using AccDataType = F32;
using CShuffleDataType = F32;
using DDataType = F32;
using DsDataType = ck::Tuple<DDataType>;
using EDataType = F32;
static constexpr ck::index_t NumDimM = 2;
static constexpr ck::index_t NumDimN = 2;
static constexpr ck::index_t NumDimK = 2;
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CDEElementOp = ck::tensor_operation::element_wise::Bilinear;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// clang-format off
using DeviceOpInstanceKKNN = ck::tensor_operation::device::
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>;
using DeviceOpInstanceKNNN = ck::tensor_operation::device::
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 4, 1, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>;
using DeviceOpInstanceMKNN = ck::tensor_operation::device::
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 1, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>;
using DeviceOpInstanceMNNN = ck::tensor_operation::device::
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 1, 1, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>;
// clang-format on
using DeviceOpInstance = DeviceOpInstanceKKNN;
// hardcoded for NumDimM == NumDimN == NumDimK == 2
template <ck::index_t NumDimM,
ck::index_t NumDimN,
ck::index_t NumDimK,
typename ADataType,
typename BDataType,
typename EDataType,
typename AccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
ck::enable_if_t<NumDimM == 2 && NumDimN == 2 && NumDimK == 2, bool> = false>
struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::BaseOperator
{
// Argument
struct Argument : public ck::tensor_operation::device::BaseArgument
{
Argument(const Tensor<ADataType>& a_ms_ks,
const Tensor<BDataType>& b_ns_ks,
Tensor<EDataType>& e_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
: a_ms_ks_{a_ms_ks},
b_ns_ks_{b_ns_ks},
e_ms_ns_{e_ms_ns},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op}
{
}
const Tensor<ADataType>& a_ms_ks_;
const Tensor<BDataType>& b_ns_ks_;
Tensor<EDataType>& e_ms_ns_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
};
// Invoker
struct Invoker : public ck::tensor_operation::device::BaseInvoker
{
using Argument = ReferenceContraction_M2_N2_K2::Argument;
float Run(const Argument& arg)
{
auto f_ms_ns = [&](auto m0, auto m1, auto n0, auto n1) {
const int K0 = arg.a_ms_ks_.mDesc.GetLengths()[2];
const int K1 = arg.a_ms_ks_.mDesc.GetLengths()[3];
AccDataType v_acc = 0;
for(int k0 = 0; k0 < K0; ++k0)
{
for(int k1 = 0; k1 < K1; ++k1)
{
AccDataType v_a;
AccDataType v_b;
arg.a_element_op_(
v_a, ck::type_convert<const AccDataType>(arg.a_ms_ks_(m0, m1, k0, k1)));
arg.b_element_op_(
v_b, ck::type_convert<const AccDataType>(arg.b_ns_ks_(n0, n1, k0, k1)));
v_acc += v_a * v_b;
}
}
AccDataType v_c;
arg.cde_element_op_(v_c, v_acc);
arg.e_ms_ns_(m0, m1, n0, n1) = v_c;
};
make_ParallelTensorFunctor(f_ms_ns,
arg.e_ms_ns_.mDesc.GetLengths()[0],
arg.e_ms_ns_.mDesc.GetLengths()[1],
arg.e_ms_ns_.mDesc.GetLengths()[2],
arg.e_ms_ns_.mDesc.GetLengths()[3])(
std::thread::hardware_concurrency());
return 0;
}
float Run(const ck::tensor_operation::device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override
{
return true;
}
static auto MakeArgument(const Tensor<ADataType>& a_ms_ks,
const Tensor<BDataType>& b_ns_ks,
Tensor<EDataType>& e_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
{
return Argument{a_ms_ks, b_ns_ks, e_ms_ns, a_element_op, b_element_op, cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<ck::tensor_operation::device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceContraction_M2_N2_K2"
<< std::endl;
// clang-format on
return str.str();
}
};
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
// A[M0, M1, K0, K1]
std::vector<ck::index_t> a_ms_ks_lengths{30, 128, 32, 64};
std::vector<ck::index_t> a_ms_ks_strides{524288, 4096, 128, 1};
// B[N0, N1, K0, K1]
std::vector<ck::index_t> b_ns_ks_lengths{32, 64, 32, 64};
std::vector<ck::index_t> b_ns_ks_strides{524288, 4096, 128, 1};
// D[M0, M1, N0, N1]
std::vector<ck::index_t> d_ms_ns_lengths{30, 128, 32, 64};
std::vector<ck::index_t> d_ms_ns_strides{524288, 4096, 128, 1};
// E[M0, M1, N0, N1]
std::vector<ck::index_t> e_ms_ns_lengths{30, 128, 32, 64};
std::vector<ck::index_t> e_ms_ns_strides{524288, 4096, 128, 1};
float alpha = 1.f;
float beta = 1.f;
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 28)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
const ck::index_t M0 = std::stoi(argv[4]);
const ck::index_t M1 = std::stoi(argv[5]);
const ck::index_t N0 = std::stoi(argv[6]);
const ck::index_t N1 = std::stoi(argv[7]);
const ck::index_t K0 = std::stoi(argv[8]);
const ck::index_t K1 = std::stoi(argv[9]);
a_ms_ks_lengths = {M0, M1, K0, K1};
a_ms_ks_strides = {
std::stoi(argv[10]), std::stoi(argv[11]), std::stoi(argv[12]), std::stoi(argv[13])};
b_ns_ks_lengths = {N0, N1, K0, K1};
b_ns_ks_strides = {
std::stoi(argv[14]), std::stoi(argv[15]), std::stoi(argv[16]), std::stoi(argv[17])};
d_ms_ns_lengths = {M0, M1, N0, N1};
d_ms_ns_strides = {
std::stoi(argv[18]), std::stoi(argv[19]), std::stoi(argv[20]), std::stoi(argv[21])};
e_ms_ns_lengths = {M0, M1, N0, N1};
e_ms_ns_strides = {
std::stoi(argv[22]), std::stoi(argv[23]), std::stoi(argv[24]), std::stoi(argv[25])};
alpha = std::stof(argv[26]);
beta = std::stof(argv[27]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 7: M0, M1, N0, N1, K0, K1\n");
printf("arg10 to 13: Stride_A_M0, Stride_A_M1, Stride_A_K0, Stride_A_K1\n");
printf("arg14 to 17: Stride_B_N0, Stride_B_N1, Stride_B_K0, Stride_B_K1\n");
printf("arg18 to 21: Stride_D_M0, Stride_D_M1, Stride_D_N0, Stride_D_N1\n");
printf("arg22 to 25: Stride_E_M0, Stride_E_M1, Stride_E_N0, Stride_E_N1\n");
printf("arg26 to 27: alpha, beta\n");
exit(0);
}
Tensor<ADataType> a_ms_ks(
std::vector<std::size_t>(a_ms_ks_lengths.begin(), a_ms_ks_lengths.end()),
std::vector<std::size_t>(a_ms_ks_strides.begin(), a_ms_ks_strides.end()));
Tensor<BDataType> b_ns_ks(
std::vector<std::size_t>(b_ns_ks_lengths.begin(), b_ns_ks_lengths.end()),
std::vector<std::size_t>(b_ns_ks_strides.begin(), b_ns_ks_strides.end()));
Tensor<EDataType> d_ms_ns(
std::vector<std::size_t>(d_ms_ns_lengths.begin(), d_ms_ns_lengths.end()),
std::vector<std::size_t>(d_ms_ns_strides.begin(), d_ms_ns_strides.end()));
Tensor<EDataType> e_ms_ns_host_result(
std::vector<std::size_t>(e_ms_ns_lengths.begin(), e_ms_ns_lengths.end()),
std::vector<std::size_t>(e_ms_ns_strides.begin(), e_ms_ns_strides.end()));
Tensor<EDataType> e_ms_ns_device_result(
std::vector<std::size_t>(e_ms_ns_lengths.begin(), e_ms_ns_lengths.end()),
std::vector<std::size_t>(e_ms_ns_strides.begin(), e_ms_ns_strides.end()));
std::cout << "a_ms_ks: " << a_ms_ks.mDesc << std::endl;
std::cout << "b_ns_ks: " << b_ns_ks.mDesc << std::endl;
std::cout << "d_ms_ns: " << d_ms_ns.mDesc << std::endl;
std::cout << "e_ms_ns: " << e_ms_ns_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_ns_ks.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
d_ms_ns.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
break;
default:
a_ms_ks.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_ns_ks.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
d_ms_ns.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break;
}
DeviceMem a_device_buf(sizeof(ADataType) * a_ms_ks.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(BDataType) * b_ns_ks.mDesc.GetElementSpace());
DeviceMem d_device_buf(sizeof(DDataType) * d_ms_ns.mDesc.GetElementSpace());
DeviceMem e_device_buf(sizeof(EDataType) * e_ms_ns_device_result.mDesc.GetElementSpace());
a_device_buf.ToDevice(a_ms_ks.mData.data());
b_device_buf.ToDevice(b_ns_ks.mData.data());
d_device_buf.ToDevice(d_ms_ns.mData.data());
// set zero
e_device_buf.SetZero();
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{alpha, beta};
// device operation
auto op = DeviceOpInstance{};
auto invoker = op.MakeInvoker();
auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
std::array<const void*, 1>{d_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(),
a_ms_ks_lengths,
a_ms_ks_strides,
b_ns_ks_lengths,
b_ns_ks_strides,
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths},
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides},
e_ms_ns_lengths,
e_ms_ns_strides,
a_element_op,
b_element_op,
cde_element_op);
if(!op.IsSupportedArgument(argument))
{
std::cout << op.GetTypeString() << " does not support this problem" << std::endl;
return 0;
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
ck::index_t M = std::accumulate(e_ms_ns_lengths.begin(),
e_ms_ns_lengths.begin() + NumDimM,
ck::index_t{1},
std::multiplies<ck::index_t>{});
ck::index_t N = std::accumulate(e_ms_ns_lengths.begin() + NumDimM,
e_ms_ns_lengths.begin() + NumDimM + NumDimN,
ck::index_t{1},
std::multiplies<ck::index_t>{});
ck::index_t K = std::accumulate(a_ms_ks_lengths.begin() + NumDimM,
a_ms_ks_lengths.begin() + NumDimM + NumDimK,
ck::index_t{1},
std::multiplies<ck::index_t>{});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(DDataType) * M * N + sizeof(EDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< op.GetTypeString() << std::endl;
e_device_buf.FromDevice(e_ms_ns_device_result.mData.data());
if(do_verification)
{
Tensor<CShuffleDataType> c_ms_ns_host_result(
std::vector<std::size_t>(e_ms_ns_lengths.begin(), e_ms_ns_lengths.end()),
std::vector<std::size_t>(e_ms_ns_strides.begin(), e_ms_ns_strides.end()));
using ReferenceOpInstance = ReferenceContraction_M2_N2_K2<NumDimM,
NumDimN,
NumDimK,
ADataType,
BDataType,
CShuffleDataType,
AccDataType,
AElementOp,
BElementOp,
PassThrough>;
auto ref_gemm = ReferenceOpInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op, PassThrough{});
ref_invoker.Run(ref_argument);
for(size_t m0 = 0; m0 < e_ms_ns_host_result.mDesc.GetLengths()[0]; ++m0)
{
for(size_t m1 = 0; m1 < e_ms_ns_host_result.mDesc.GetLengths()[1]; ++m1)
{
for(size_t n0 = 0; n0 < e_ms_ns_host_result.mDesc.GetLengths()[2]; ++n0)
{
for(size_t n1 = 0; n1 < e_ms_ns_host_result.mDesc.GetLengths()[3]; ++n1)
{
cde_element_op(e_ms_ns_host_result(m0, m1, n0, n1),
c_ms_ns_host_result(m0, m1, n0, n1),
d_ms_ns(m0, m1, n0, n1));
}
}
}
}
return ck::utils::check_err(e_ms_ns_device_result.mData, e_ms_ns_host_result.mData) ? 0 : 1;
}
return 0;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/host_tensor/device_memory.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/host_tensor/host_tensor_generator.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F32 = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = F32;
using BDataType = F32;
using AccDataType = F32;
using CShuffleDataType = F32;
using DsDataType = ck::Tuple<>;
using EDataType = F32;
static constexpr ck::index_t NumDimM = 2;
static constexpr ck::index_t NumDimN = 2;
static constexpr ck::index_t NumDimK = 2;
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CDEElementOp = ck::tensor_operation::element_wise::Scale;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// clang-format off
using DeviceOpInstanceKKN = ck::tensor_operation::device::
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>;
using DeviceOpInstanceKNN = ck::tensor_operation::device::
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 4, 1, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>;
using DeviceOpInstanceMKN = ck::tensor_operation::device::
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 1, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>;
using DeviceOpInstanceMNN = ck::tensor_operation::device::
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 1, 1, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>;
// clang-format on
using DeviceOpInstance = DeviceOpInstanceKKN;
// hardcoded for NumDimM == NumDimN == NumDimK == 2
template <ck::index_t NumDimM,
ck::index_t NumDimN,
ck::index_t NumDimK,
typename ADataType,
typename BDataType,
typename EDataType,
typename AccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
ck::enable_if_t<NumDimM == 2 && NumDimN == 2 && NumDimK == 2, bool> = false>
struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::BaseOperator
{
// Argument
struct Argument : public ck::tensor_operation::device::BaseArgument
{
Argument(const Tensor<ADataType>& a_ms_ks,
const Tensor<BDataType>& b_ns_ks,
Tensor<EDataType>& e_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
: a_ms_ks_{a_ms_ks},
b_ns_ks_{b_ns_ks},
e_ms_ns_{e_ms_ns},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op}
{
}
const Tensor<ADataType>& a_ms_ks_;
const Tensor<BDataType>& b_ns_ks_;
Tensor<EDataType>& e_ms_ns_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
};
// Invoker
struct Invoker : public ck::tensor_operation::device::BaseInvoker
{
using Argument = ReferenceContraction_M2_N2_K2::Argument;
float Run(const Argument& arg)
{
auto f_ms_ns = [&](auto m0, auto m1, auto n0, auto n1) {
const int K0 = arg.a_ms_ks_.mDesc.GetLengths()[2];
const int K1 = arg.a_ms_ks_.mDesc.GetLengths()[3];
AccDataType v_acc = 0;
for(int k0 = 0; k0 < K0; ++k0)
{
for(int k1 = 0; k1 < K1; ++k1)
{
AccDataType v_a;
AccDataType v_b;
arg.a_element_op_(
v_a, ck::type_convert<const AccDataType>(arg.a_ms_ks_(m0, m1, k0, k1)));
arg.b_element_op_(
v_b, ck::type_convert<const AccDataType>(arg.b_ns_ks_(n0, n1, k0, k1)));
v_acc += v_a * v_b;
}
}
AccDataType v_c;
arg.cde_element_op_(v_c, v_acc);
arg.e_ms_ns_(m0, m1, n0, n1) = v_c;
};
make_ParallelTensorFunctor(f_ms_ns,
arg.e_ms_ns_.mDesc.GetLengths()[0],
arg.e_ms_ns_.mDesc.GetLengths()[1],
arg.e_ms_ns_.mDesc.GetLengths()[2],
arg.e_ms_ns_.mDesc.GetLengths()[3])(
std::thread::hardware_concurrency());
return 0;
}
float Run(const ck::tensor_operation::device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override
{
return true;
}
static auto MakeArgument(const Tensor<ADataType>& a_ms_ks,
const Tensor<BDataType>& b_ns_ks,
Tensor<EDataType>& e_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
{
return Argument{a_ms_ks, b_ns_ks, e_ms_ns, a_element_op, b_element_op, cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<ck::tensor_operation::device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceContraction_M2_N2_K2"
<< std::endl;
// clang-format on
return str.str();
}
};
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
// A[M0, M1, K0, K1]
std::vector<ck::index_t> a_ms_ks_lengths{30, 128, 32, 64};
std::vector<ck::index_t> a_ms_ks_strides{524288, 4096, 128, 1};
// B[N0, N1, K0, K1]
std::vector<ck::index_t> b_ns_ks_lengths{32, 64, 32, 64};
std::vector<ck::index_t> b_ns_ks_strides{524288, 4096, 128, 1};
// E[M0, M1, N0, N1]
std::vector<ck::index_t> e_ms_ns_lengths{30, 128, 32, 64};
std::vector<ck::index_t> e_ms_ns_strides{524288, 4096, 128, 1};
float scale = 1.f;
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 23)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
const ck::index_t M0 = std::stoi(argv[4]);
const ck::index_t M1 = std::stoi(argv[5]);
const ck::index_t N0 = std::stoi(argv[6]);
const ck::index_t N1 = std::stoi(argv[7]);
const ck::index_t K0 = std::stoi(argv[8]);
const ck::index_t K1 = std::stoi(argv[9]);
a_ms_ks_lengths = {M0, M1, K0, K1};
a_ms_ks_strides = {
std::stoi(argv[10]), std::stoi(argv[11]), std::stoi(argv[12]), std::stoi(argv[13])};
b_ns_ks_lengths = {N0, N1, K0, K1};
b_ns_ks_strides = {
std::stoi(argv[14]), std::stoi(argv[15]), std::stoi(argv[16]), std::stoi(argv[17])};
e_ms_ns_lengths = {M0, M1, N0, N1};
e_ms_ns_strides = {
std::stoi(argv[22]), std::stoi(argv[23]), std::stoi(argv[24]), std::stoi(argv[25])};
scale = std::stof(argv[26]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 7: M0, M1, N0, N1, K0, K1\n");
printf("arg10 to 13: Stride_A_M0, Stride_A_M1, Stride_A_K0, Stride_A_K1\n");
printf("arg14 to 17: Stride_B_N0, Stride_B_N1, Stride_B_K0, Stride_B_K1\n");
printf("arg18 to 21: Stride_E_M0, Stride_E_M1, Stride_E_N0, Stride_E_N1\n");
printf("arg22: scale\n");
exit(0);
}
Tensor<ADataType> a_ms_ks(
std::vector<std::size_t>(a_ms_ks_lengths.begin(), a_ms_ks_lengths.end()),
std::vector<std::size_t>(a_ms_ks_strides.begin(), a_ms_ks_strides.end()));
Tensor<BDataType> b_ns_ks(
std::vector<std::size_t>(b_ns_ks_lengths.begin(), b_ns_ks_lengths.end()),
std::vector<std::size_t>(b_ns_ks_strides.begin(), b_ns_ks_strides.end()));
Tensor<EDataType> e_ms_ns_host_result(
std::vector<std::size_t>(e_ms_ns_lengths.begin(), e_ms_ns_lengths.end()),
std::vector<std::size_t>(e_ms_ns_strides.begin(), e_ms_ns_strides.end()));
Tensor<EDataType> e_ms_ns_device_result(
std::vector<std::size_t>(e_ms_ns_lengths.begin(), e_ms_ns_lengths.end()),
std::vector<std::size_t>(e_ms_ns_strides.begin(), e_ms_ns_strides.end()));
std::cout << "a_ms_ks: " << a_ms_ks.mDesc << std::endl;
std::cout << "b_ns_ks: " << b_ns_ks.mDesc << std::endl;
std::cout << "e_ms_ns: " << e_ms_ns_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_ns_ks.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
break;
default:
a_ms_ks.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_ns_ks.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break;
}
DeviceMem a_device_buf(sizeof(ADataType) * a_ms_ks.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(BDataType) * b_ns_ks.mDesc.GetElementSpace());
DeviceMem e_device_buf(sizeof(EDataType) * e_ms_ns_device_result.mDesc.GetElementSpace());
a_device_buf.ToDevice(a_ms_ks.mData.data());
b_device_buf.ToDevice(b_ns_ks.mData.data());
// set zero
e_device_buf.SetZero();
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{scale};
// device operation
auto op = DeviceOpInstance{};
auto invoker = op.MakeInvoker();
auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
std::array<const void*, 0>{},
e_device_buf.GetDeviceBuffer(),
a_ms_ks_lengths,
a_ms_ks_strides,
b_ns_ks_lengths,
b_ns_ks_strides,
std::array<std::vector<ck::index_t>, 0>{},
std::array<std::vector<ck::index_t>, 0>{},
e_ms_ns_lengths,
e_ms_ns_strides,
a_element_op,
b_element_op,
cde_element_op);
if(!op.IsSupportedArgument(argument))
{
std::cout << op.GetTypeString() << " does not support this problem" << std::endl;
return 0;
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
ck::index_t M = std::accumulate(e_ms_ns_lengths.begin(),
e_ms_ns_lengths.begin() + NumDimM,
ck::index_t{1},
std::multiplies<ck::index_t>{});
ck::index_t N = std::accumulate(e_ms_ns_lengths.begin() + NumDimM,
e_ms_ns_lengths.begin() + NumDimM + NumDimN,
ck::index_t{1},
std::multiplies<ck::index_t>{});
ck::index_t K = std::accumulate(a_ms_ks_lengths.begin() + NumDimM,
a_ms_ks_lengths.begin() + NumDimM + NumDimK,
ck::index_t{1},
std::multiplies<ck::index_t>{});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + +sizeof(EDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< op.GetTypeString() << std::endl;
e_device_buf.FromDevice(e_ms_ns_device_result.mData.data());
if(do_verification)
{
Tensor<CShuffleDataType> c_ms_ns_host_result(
std::vector<std::size_t>(e_ms_ns_lengths.begin(), e_ms_ns_lengths.end()),
std::vector<std::size_t>(e_ms_ns_strides.begin(), e_ms_ns_strides.end()));
using ReferenceOpInstance = ReferenceContraction_M2_N2_K2<NumDimM,
NumDimN,
NumDimK,
ADataType,
BDataType,
CShuffleDataType,
AccDataType,
AElementOp,
BElementOp,
PassThrough>;
auto ref_gemm = ReferenceOpInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op, PassThrough{});
ref_invoker.Run(ref_argument);
for(size_t m0 = 0; m0 < e_ms_ns_host_result.mDesc.GetLengths()[0]; ++m0)
{
for(size_t m1 = 0; m1 < e_ms_ns_host_result.mDesc.GetLengths()[1]; ++m1)
{
for(size_t n0 = 0; n0 < e_ms_ns_host_result.mDesc.GetLengths()[2]; ++n0)
{
for(size_t n1 = 0; n1 < e_ms_ns_host_result.mDesc.GetLengths()[3]; ++n1)
{
cde_element_op(e_ms_ns_host_result(m0, m1, n0, n1),
c_ms_ns_host_result(m0, m1, n0, n1));
}
}
}
}
return ck::utils::check_err(e_ms_ns_device_result.mData, e_ms_ns_host_result.mData) ? 0 : 1;
}
return 0;
}
......@@ -22,7 +22,7 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
endfunction(add_example_executable_no_testing EXAMPLE_NAME)
add_subdirectory(01_gemm)
add_subdirectory(02_gemm_alpha_beta)
add_subdirectory(02_gemm_bilinear)
add_subdirectory(03_gemm_bias_relu)
add_subdirectory(04_gemm_add_add_fastgelu)
add_subdirectory(06_conv2d_fwd_bias_relu)
......@@ -42,3 +42,6 @@ add_subdirectory(20_convnd_bwd_weight_xdl)
add_subdirectory(21_gemm_layernorm)
add_subdirectory(22_cgemm)
add_subdirectory(23_softmax)
add_subdirectory(24_batched_gemm_c_permute)
add_subdirectory(25_gemm_bias_c_permute)
add_subdirectory(26_contraction)
......@@ -102,7 +102,12 @@
#define CK_EXPERIMENTAL_STATIC_TENSOR_DESCRIPTOR 0
// experimental feature: buffer load/store/atomic-add/ OOB trick
// This (ifndef) is a hack to use customized behavior for buffer load rather than using default
// setting. Don't use this hack unless absolutely necessary!
// FIXME: make the behavior of buffer load a configurable (template) parameter for each usage
#ifndef CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 0
#endif
#define CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK 1
#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK 1
#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK 1
......@@ -167,16 +172,6 @@ struct InMemoryDataOperationEnumSequence
}
};
#if 0
// TODO: no longer needed, remove this
enum struct ActivTypeEnum
{
None,
LeakyRelu,
Sigmoid
};
#endif
// index type
using index_t = int32_t;
using long_index_t = int64_t;
......
......@@ -18,7 +18,7 @@ enum struct ConvolutionForwardSpecialization
OddC,
};
inline std::string getConvFwdSpecializationStr(const ConvolutionForwardSpecialization& s)
inline std::string getConvForwardSpecializationString(const ConvolutionForwardSpecialization& s)
{
switch(s)
{
......
......@@ -12,33 +12,56 @@ namespace ck {
namespace tensor_operation {
namespace device {
template <typename AElementwiseOperation,
template <typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename CDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct DeviceBatchedGemm : public BaseOperator
{
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
ck::index_t Batch) = 0;
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
ck::index_t BatchStrideA,
ck::index_t BatchStrideB,
ck::index_t BatchStrideC,
ck::index_t Batch,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename AElementwiseOperation,
template <typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename CDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
using DeviceBatchedGemmPtr = std::unique_ptr<
DeviceBatchedGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>>;
using DeviceBatchedGemmPtr = std::unique_ptr<DeviceBatchedGemm<ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>>;
} // namespace device
} // namespace tensor_operation
......
#pragma once
#include <iostream>
#include <vector>
#include "device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
struct BatchedGemmCPermuteDesc
{
ck::index_t G0_, G1_, M_, N_;
ck::index_t stride_G0_, stride_G1_, stride_M_, stride_N_;
};
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct DeviceBatchedGemmCPermute : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
index_t M,
index_t N,
index_t K,
index_t stride_A,
index_t stride_B,
BatchedGemmCPermuteDesc batched_gemm_c_permute_desc,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
ck::index_t BatchCount) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
using DeviceBatchedGemmCPermutePtr = std::unique_ptr<
DeviceBatchedGemmCPermute<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>>;
} // namespace device
} // namespace tensor_operation
} // namespace ck
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_c_permute.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/device_utility/device_prop.hpp"
#include "ck/device_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
/*
* \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM.
*
* \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix
* given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly
* strided batched, but we can easily extend to other layouts. The returned offset can be either \p
* index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB
* limitations.
*
* \tparam Block2CTileMap Block2CTileMap::CalculateBottomIndex() takes in id of a workgroup and
* returns the 2D index of the tile that it computes. \see
* GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run().
*
* \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2
* tiles from different matrices. Keep in mind that these 2 matrices can share the same grid
* descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link
* device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for \link
* DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the computing of
* pointer offset into \p ComputePtrOffsetOfStridedBatch.
*
* \note \p Block2CTileMap allows customized mapping between a workgroup and the C-tile it computes.
* Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to
* realize BatchedGemmCPermute and GroupedGemm (and the corresponding GEMM fusion).
*
*/
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1,
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename ComputePtrOffsetOfBatch,
typename Block2CTileMap,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_batched_gemm_c_permute_xdl(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const index_t batch_count,
const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const Block2CTileMap block_2_ctile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)));
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(
p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
ck::Tuple<>{},
p_c_grid + c_batch_offset,
p_shared,
a_element_op,
b_element_op,
c_element_op,
a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1,
ck::StaticallyIndexedArray<
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
0>{},
c_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_ctile_map);
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_c_grid;
ignore = batch_count;
ignore = a_grid_desc_k0_m_k1;
ignore = b_grid_desc_k0_n_k1;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = a_element_op;
ignore = b_element_op;
ignore = c_element_op;
ignore = compute_ptr_offset_of_batch;
ignore = block_2_ctile_map;
#endif
}
template <typename ALayout,
typename BLayout,
typename ADataType,
typename BDataType,
typename CDataType,
typename AccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
GemmSpecialization GemmSpec,
ck::index_t NumPrefetch,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
ck::index_t KPerBlock,
ck::index_t AK1,
ck::index_t BK1,
ck::index_t MPerXDL,
ck::index_t NPerXDL,
ck::index_t MXdlPerWave,
ck::index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_K0_M_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
ck::index_t ABlockTransferSrcVectorDim,
ck::index_t ABlockTransferSrcScalarPerVector,
ck::index_t ABlockTransferDstScalarPerVector_K1,
bool ABlockLdsAddExtraM,
typename BBlockTransferThreadClusterLengths_K0_N_K1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
ck::index_t BBlockTransferSrcVectorDim,
ck::index_t BBlockTransferSrcScalarPerVector,
ck::index_t BBlockTransferDstScalarPerVector_K1,
bool BBlockLdsAddExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
{
const auto a_grid_desc_mraw_kraw = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(StrideA, I1));
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(I1, StrideA));
}
}();
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
const auto MPad = M - MRaw;
const auto KPad = K - KRaw;
if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad both M and K
assert(K % AK1 == 0);
const auto AK0 = K / AK1;
const auto a_grid_desc_m_k =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_right_pad_transform(MRaw, MPad),
make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MNPadding)
{
// pad M, but not K
assert(KRaw % AK1 == 0);
const auto AK0 = KRaw / AK1;
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_right_pad_transform(MRaw, MPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
// pad K, but not M
assert(K % AK1 == 0);
const auto AK0 = K / AK1;
const auto a_grid_desc_m_k = transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(MRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else
{
// not pad M or K
assert(KRaw % AK1 == 0);
const auto AK0 = KRaw / AK1;
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(MRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
}
static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
{
const auto b_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(I1, StrideB));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(StrideB, I1));
}
}();
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
const auto NPad = N - NRaw;
const auto KPad = K - KRaw;
if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad both N and K
assert(K % BK1 == 0);
const auto BK0 = K / BK1;
const auto b_grid_desc_n_k =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_right_pad_transform(NRaw, NPad),
make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::MNPadding)
{
// pad N, but not K
assert(KRaw % BK1 == 0);
const auto BK0 = KRaw / BK1;
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
// pad K, but not N
assert(K % BK1 == 0);
const auto BK0 = K / BK1;
const auto b_grid_desc_n_k = transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(NRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else
{
// not pad N or K
assert(KRaw % BK1 == 0);
const auto BK0 = KRaw / BK1;
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(NRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
}
static auto
MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t stride_M, index_t stride_N)
{
const auto c_grid_desc_mraw_nraw = [&]() {
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(stride_M, stride_N));
}();
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
const auto MPad = M - MRaw;
const auto NPad = N - NRaw;
if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad M and N
return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(MRaw, MPad),
make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
// pad M, but not N
return transform_tensor_descriptor(
c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
// pad N, but not M
return transform_tensor_descriptor(
c_grid_desc_mraw_nraw,
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
// not pad M or N
return c_grid_desc_mraw_nraw;
}
}
static auto MakeEGridDescriptor_G0_G1_M_N(index_t G0,
index_t G1,
index_t MRaw,
index_t NRaw,
index_t stride_G0,
index_t stride_G1,
index_t stride_M,
index_t stride_N)
{
const auto e_grid_desc_g0_g1_mraw_nraw = [&]() {
return make_naive_tensor_descriptor(
make_tuple(G0, G1, MRaw, NRaw),
make_tuple(stride_G0, stride_G1, stride_M, stride_N));
}();
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
const auto MPad = M - MRaw;
const auto NPad = N - NRaw;
if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad M and N
return transform_tensor_descriptor(
e_grid_desc_g0_g1_mraw_nraw,
make_tuple(make_pass_through_transform(G0),
make_pass_through_transform(G1),
make_right_pad_transform(MRaw, MPad),
make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
// pad M, but not N
return transform_tensor_descriptor(
e_grid_desc_g0_g1_mraw_nraw,
make_tuple(make_pass_through_transform(G0),
make_pass_through_transform(G1),
make_right_pad_transform(MRaw, MPad),
make_pass_through_transform(NRaw)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
// pad N, but not M
return transform_tensor_descriptor(
e_grid_desc_g0_g1_mraw_nraw,
make_tuple(make_pass_through_transform(G0),
make_pass_through_transform(G1),
make_pass_through_transform(MRaw),
make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
}
else
{
// not pad M or N
return e_grid_desc_g0_g1_mraw_nraw;
}
}
using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1));
using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1, 1));
using EGridDesc_G0_G1_M_N = decltype(MakeEGridDescriptor_G0_G1_M_N(1, 1, 1, 1, 1, 1, 1, 1));
struct ComputePtrOffsetOfStridedBatch
{
ComputePtrOffsetOfStridedBatch(index_t Batchstride_A,
index_t Batchstride_B,
EGridDesc_G0_G1_M_N e_grid_desc_g0_g1_m_n)
: Batchstride_A_(Batchstride_A),
Batchstride_B_(Batchstride_B),
e_grid_desc_g0_g1_m_n_(e_grid_desc_g0_g1_m_n)
{
}
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(Batchstride_A_);
}
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(Batchstride_B_);
}
__host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
{
const index_t G1 = e_grid_desc_g0_g1_m_n_.GetLength(I1);
index_t b0 = g_idx / G1;
index_t b1 = g_idx - b0 * G1; // g_idx % G1
return e_grid_desc_g0_g1_m_n_.CalculateOffset(make_multi_index(b0, b1, 0, 0));
}
private:
index_t Batchstride_A_;
index_t Batchstride_B_;
EGridDesc_G0_G1_M_N e_grid_desc_g0_g1_m_n_;
};
using GridwiseGemm = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype
AccDataType,
CDataType, // CShuffleDataType,
ck::Tuple<>, // DsDataType,
CDataType, // EDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
NumPrefetch,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
false, // AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsAddExtraM,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
false, // BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched>;
using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = decltype(
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}));
using Block2CTileMap = typename GridwiseGemm::DefaultBlock2ETileMap;
// Argument
struct Argument : public BaseArgument
{
Argument(const ADataType* p_a_grid,
const BDataType* p_b_grid,
CDataType* p_c_grid,
index_t M,
index_t N,
index_t K,
index_t stride_A,
index_t stride_B,
BatchedGemmCPermuteDesc batched_gemm_c_permute_desc,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
index_t BatchCount)
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid},
BatchCount_(BatchCount),
a_grid_desc_k0_m_k1_{
DeviceBatchedGemmCPermuteXdl::MakeAGridDescriptor_AK0_M_AK1(M, K, stride_A)},
b_grid_desc_k0_n_k1_{
DeviceBatchedGemmCPermuteXdl::MakeBGridDescriptor_BK0_N_BK1(K, N, stride_B)},
c_grid_desc_m_n_{DeviceBatchedGemmCPermuteXdl::MakeCGridDescriptor_M_N(
batched_gemm_c_permute_desc.M_,
batched_gemm_c_permute_desc.N_,
batched_gemm_c_permute_desc.stride_M_,
batched_gemm_c_permute_desc.stride_N_)},
e_grid_desc_g0_g1_m_n_{DeviceBatchedGemmCPermuteXdl::MakeEGridDescriptor_G0_G1_M_N(
batched_gemm_c_permute_desc.G0_,
batched_gemm_c_permute_desc.G1_,
batched_gemm_c_permute_desc.M_,
batched_gemm_c_permute_desc.N_,
batched_gemm_c_permute_desc.stride_G0_,
batched_gemm_c_permute_desc.stride_G1_,
batched_gemm_c_permute_desc.stride_M_,
batched_gemm_c_permute_desc.stride_N_)},
c_grid_desc_mblock_mperblock_nblock_nperblock{},
compute_ptr_offset_of_batch_{
type_convert<index_t>(a_grid_desc_k0_m_k1_.GetElementSpaceSize()),
type_convert<index_t>(b_grid_desc_k0_n_k1_.GetElementSpaceSize()),
e_grid_desc_g0_g1_m_n_},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(c_grid_desc_m_n_)},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op}
{
if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_,
b_grid_desc_k0_n_k1_,
c_grid_desc_m_n_,
block_2_ctile_map_))
{
c_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n_);
}
}
// private:
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
CDataType* p_c_grid_;
index_t BatchCount_;
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_;
EGridDesc_G0_G1_M_N e_grid_desc_g0_g1_m_n_;
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock;
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_;
Block2CTileMap block_2_ctile_map_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
};
// Invoker
struct Invoker : public BaseInvoker
{
using Argument = DeviceBatchedGemmCPermuteXdl::Argument;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
{
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0)
<< ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
<< arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.c_grid_desc_m_n_{" << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_))
{
throw std::runtime_error(
"wrong! GridwiseBatchedGemmCPermute_km_kn_m0m1n0n1_xdlops_v2r3 has invalid "
"setting");
}
const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.BatchCount_;
const auto K =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_batched_gemm_c_permute_xdl<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceBatchedGemmCPermuteXdl::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceBatchedGemmCPermuteXdl::BGridDesc_K0_N_K1>,
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
ComputePtrOffsetOfStridedBatch,
remove_reference_t<Block2CTileMap>,
has_main_k_block_loop_>;
return launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.BatchCount_,
arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.compute_ptr_offset_of_batch_,
arg.block_2_ctile_map_);
};
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
ave_time = launch_kernel(integral_constant<bool, true>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, false>{});
}
return ave_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
static bool IsSupportedArgument(const Argument& arg)
{
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_);
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b,
CDataType* p_c,
index_t M,
index_t N,
index_t K,
index_t stride_A,
index_t stride_B,
BatchedGemmCPermuteDesc batched_gemm_c_permute_desc,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
index_t BatchCount)
{
return Argument{p_a,
p_b,
p_c,
M,
N,
K,
stride_A,
stride_B,
batched_gemm_c_permute_desc,
a_element_op,
b_element_op,
c_element_op,
BatchCount};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
index_t M,
index_t N,
index_t K,
index_t stride_A,
index_t stride_B,
BatchedGemmCPermuteDesc batched_gemm_c_permute_desc,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
index_t BatchCount) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c),
M,
N,
K,
stride_A,
stride_B,
batched_gemm_c_permute_desc,
a_element_op,
b_element_op,
c_element_op,
BatchCount);
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceBatchedGemmCPermuteXdl"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< KPerBlock
<< ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -113,7 +113,7 @@ __global__ void
ignore = c_element_op;
ignore = compute_ptr_offset_of_batch;
ignore = block_2_ctile_map;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
#endif
}
template <typename ADataType,
......@@ -151,8 +151,15 @@ template <typename ADataType,
bool BBlockLdsAddExtraN,
ck::index_t CThreadTransferSrcDstVectorDim,
ck::index_t CThreadTransferDstScalarPerVector>
struct DeviceBatchedGemmXdl
: public DeviceBatchedGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>
struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
......@@ -334,12 +341,15 @@ struct DeviceBatchedGemmXdl
index_t StrideA,
index_t StrideB,
index_t StrideC,
index_t BatchStrideA,
index_t BatchStrideB,
index_t BatchStrideC,
index_t Batch,
index_t M01,
index_t N01,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
index_t Batch)
CElementwiseOperation c_element_op)
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid},
......@@ -350,10 +360,7 @@ struct DeviceBatchedGemmXdl
DeviceBatchedGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB)},
c_grid_desc_m_n_{DeviceBatchedGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC)},
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{},
compute_ptr_offset_of_batch_{
type_convert<index_t>(a_grid_desc_k0_m_k1_.GetElementSpaceSize()),
type_convert<index_t>(b_grid_desc_k0_n_k1_.GetElementSpaceSize()),
type_convert<index_t>(c_grid_desc_m_n_.GetElementSpaceSize())},
compute_ptr_offset_of_batch_{BatchStrideA, BatchStrideB, BatchStrideC},
block_2_ctile_map_{
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01)},
M01_{M01},
......@@ -536,10 +543,13 @@ struct DeviceBatchedGemmXdl
index_t StrideA,
index_t StrideB,
index_t StrideC,
index_t BatchStrideA,
index_t BatchStrideB,
index_t BatchStrideC,
index_t Batch,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
index_t Batch)
CElementwiseOperation c_element_op)
{
return Argument{p_a,
p_b,
......@@ -550,12 +560,15 @@ struct DeviceBatchedGemmXdl
StrideA,
StrideB,
StrideC,
BatchStrideA,
BatchStrideB,
BatchStrideC,
Batch,
1,
1,
a_element_op,
b_element_op,
c_element_op,
Batch};
c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
......@@ -570,10 +583,13 @@ struct DeviceBatchedGemmXdl
index_t StrideA,
index_t StrideB,
index_t StrideC,
index_t BatchStrideA,
index_t BatchStrideB,
index_t BatchStrideC,
index_t Batch,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
index_t Batch) override
CElementwiseOperation c_element_op) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
......@@ -584,12 +600,15 @@ struct DeviceBatchedGemmXdl
StrideA,
StrideB,
StrideC,
BatchStrideA,
BatchStrideB,
BatchStrideC,
Batch,
1,
1,
a_element_op,
b_element_op,
c_element_op,
Batch);
c_element_op);
}
// polymorphic
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment