Unverified Commit 370efa6c authored by ltqin's avatar ltqin Committed by GitHub
Browse files

batched_gemm + multiple_d + gemm + multiple_d (#394)



* refactor

* start

* add device gemm file

* add BatchStrideD0

* add stridd0

* add gridwise file

* add d0 parameters to gridwise gemm

* add c layout transformer

* add d0 threadwise copy

* init kernel

* init kernel

* regular code

* nm desc put to out

* kernel parameter can not use reference

* host add bias+gelu

* run right for bias+gelu

* change AddFastGelu into another file

* interface add d1 bias parameters

* add d1 parameter to argument

* add d1 parameter to gridwise

* first all code,not verify

* gelu change to relu and GetElementSpaceSize bug

* add instance

* start add to ckprofiler

* ckprofiler finish code

* change input parameter for ckProfiler

* fix host bias+gelu bug

* show help for ckProfiler

* fix bug for lunch kernel ignore parametes

* add pad and fix about bug

* mutiple d0

* add dynamic d0_element_op

* change profiler and  instance to mutiple d0

* example have 2 d0

* remove some comments not using

* change 2 d0 have self  parameters

* change d element_op name

* change class name(multiple_d)

* fix bug

* fix bug that don't find file

* update profiler

* refactor

* update profiler

* clean

* revert example change

* add gon layout

* optimize parameter for gno

* add gon to gemm+gemm

* change helping input parameters

* change to GemmPadder_v2

* using ForEach

* fix gb_per_sec
Co-authored-by: default avatarChao Liu <lc.roy86@gmail.com>
Co-authored-by: default avatarltqin <letaoqin@amd.com>
parent b22ebd44
add_example_executable(example_batched_gemm_add_add_relu_gemm_add_xdl_fp16 batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
/*
Computes C_m_o = Relu(A0[m, k] * B0[n, k] + D00[m, n] + D01[mn]) * B1[n, o] + D1[m, o]
*/
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using A0DataType = F16;
using B0DataType = F16;
using Acc0DataType = F32;
using D00DataType = F16;
using D01DataType = F16;
using B1DataType = F16;
using Acc1DataType = F32;
using C1ShuffleDataType = F32;
using D1DataType = F16;
using E1DataType = F16;
using A0Layout = Row;
using B0Layout = Col;
using D00Layout = Row;
using D01Layout = Row;
using B1Layout = Row;
using D1Layout = Row;
using E1Layout = Row;
// E = Relu(C + D0 + D1)
struct AddAddRelu
{
__host__ __device__ void
operator()(ck::half_t& e, const ck::half_t& c, const ck::half_t& d0, const ck::half_t& d1) const
{
const ck::half_t x = c + d0 + d1;
ck::tensor_operation::element_wise::Relu{}.template operator()<ck::half_t>(e, x);
}
__host__ __device__ void
operator()(float& e, const float& c, const ck::half_t& d0, const ck::half_t& d1) const
{
const float x = c + (d0 + d1);
ck::tensor_operation::element_wise::Relu{}.template operator()<float>(e, x);
}
};
// E = Gelu(C + D0 + D1)
struct AddAddGelu
{
__host__ __device__ void
operator()(ck::half_t& e, const ck::half_t& c, const ck::half_t& d0, const ck::half_t& d1) const
{
const ck::half_t x = c + d0 + d1;
ck::tensor_operation::element_wise::Gelu{}.template operator()<ck::half_t, ck::half_t>(e,
x);
}
__host__ __device__ void
operator()(float& e, const float& c, const ck::half_t& d0, const ck::half_t& d1) const
{
const float x = c + (d0 + d1);
ck::tensor_operation::element_wise::Gelu{}.template operator()<float, float>(e, x);
}
};
// E = FastGelu(C + D0 + D1)
struct AddAddFastGelu
{
__host__ __device__ void
operator()(float& e, const float& c, const ck::half_t& d0, const ck::half_t& d1) const
{
const float x = c + (d0 + d1);
ck::tensor_operation::element_wise::FastGelu{}.template operator()<float, float>(e, x);
}
};
using A0ElementOp = PassThrough;
using B0ElementOp = PassThrough;
using CDE0ElementOp = AddAddRelu;
using A1ElementOp = PassThrough;
using B1ElementOp = PassThrough;
using CDE1ElementOp = ck::tensor_operation::element_wise::Add;
static constexpr bool PadGemm0M = false;
static constexpr bool PadGemm0N = false;
static constexpr bool PadGemm0K = false;
static constexpr bool PadGemm1N = false;
static constexpr bool PadGemm1K = false;
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle<
A0Layout,
B0Layout,
ck::Tuple<D00Layout, D01Layout>,
B1Layout,
ck::Tuple<D1Layout>,
E1Layout,
A0DataType,
B0DataType,
Acc0DataType,
ck::Tuple<D00DataType, D01DataType>,
B1DataType,
Acc1DataType,
C1ShuffleDataType,
ck::Tuple<D1DataType>,
E1DataType,
A0ElementOp,
B0ElementOp,
CDE0ElementOp,
B1ElementOp,
CDE1ElementOp,
PadGemm0M,
PadGemm0N,
PadGemm0K,
PadGemm1N,
PadGemm1K,
1,
256,
128, // MPerBlock
128, // NPerBlock
32, // KPerBlock
128, // Gemm1NPerBlock
32, // Gemm1KPerBlock
8, // AK1
8, // BK1
2, // B1K1
32, // MPerXDL
32, // NPerXDL
1, // MXdlPerWave
4, // NXdlPerWave
4, // Gemm1NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>, // BBlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<8, 32, 1>, // B1BlockTransfer
S<0, 2, 1>,
S<0, 2, 1>,
1,
4,
2,
false,
1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8>; // CShuffleBlockTransferScalarPerVector_NPerBlock
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
// GEMM shape
ck::index_t M = 1024;
ck::index_t N = 1024;
ck::index_t K = 64;
ck::index_t O = 128;
ck::index_t BatchCount = 4;
ck::index_t StrideA0 = -1;
ck::index_t StrideB0 = -1;
ck::index_t StrideD00 = -1;
ck::index_t StrideD01 = -1;
ck::index_t StrideB1 = -1;
ck::index_t StrideD1 = -1;
ck::index_t StrideE1 = -1;
ck::index_t BatchStrideA0 = -1;
ck::index_t BatchStrideB0 = -1;
ck::index_t BatchStrideD00 = -1;
ck::index_t BatchStrideD01 = -1;
ck::index_t BatchStrideB1 = -1;
ck::index_t BatchStrideD1 = -1;
ck::index_t BatchStrideE1 = -1;
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 9)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
O = std::stoi(argv[7]);
BatchCount = std::stoi(argv[8]);
}
else if(argc == 23)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
O = std::stoi(argv[7]);
BatchCount = std::stoi(argv[8]);
StrideA0 = std::stoi(argv[9]);
StrideB0 = std::stoi(argv[10]);
StrideD00 = std::stoi(argv[11]);
StrideD01 = std::stoi(argv[12]);
StrideB1 = std::stoi(argv[13]);
StrideD1 = std::stoi(argv[14]);
StrideE1 = std::stoi(argv[15]);
BatchStrideA0 = std::stoi(argv[16]);
BatchStrideB0 = std::stoi(argv[17]);
BatchStrideD00 = std::stoi(argv[18]);
BatchStrideD01 = std::stoi(argv[19]);
BatchStrideB1 = std::stoi(argv[20]);
BatchStrideD1 = std::stoi(argv[21]);
BatchStrideE1 = std::stoi(argv[22]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 8: M, N, K, O, Batch\n");
printf(
"arg9 to 15: StrideA0, StrideB0, StrideD00, StrideD01, StrideB1, StrideD1, StrideE1\n");
printf("arg16 to 22: BatchStrideA0, BatchStrideB0, BatchStrideD00, BatchStrideD01, "
"BatchStrideB1, BatchStrideD1, BatchStrideE1 \n");
exit(0);
}
const int DefaultStrideA0 = ck::is_same_v<A0Layout, Row> ? K : M;
const int DefaultStrideB0 = ck::is_same_v<B0Layout, Row> ? N : K;
const int DefaultStrideD00 = ck::is_same_v<D00Layout, Row> ? N : M;
const int DefaultStrideD01 = ck::is_same_v<D01Layout, Row> ? N : M;
const int DefaultStrideB1 = ck::is_same_v<B1Layout, Row> ? O : N;
const int DefaultStrideD1 = ck::is_same_v<D1Layout, Row> ? O : M;
const int DefaultStrideE1 = ck::is_same_v<E1Layout, Row> ? O : M;
StrideA0 = (StrideA0 < 0) ? DefaultStrideA0 : StrideA0;
StrideB0 = (StrideB0 < 0) ? DefaultStrideB0 : StrideB0;
StrideD00 = (StrideD00 < 0) ? DefaultStrideD00 : StrideD00;
StrideD01 = (StrideD01 < 0) ? DefaultStrideD01 : StrideD01;
StrideB1 = (StrideB1 < 0) ? DefaultStrideB1 : StrideB1;
StrideD1 = (StrideD1 < 0) ? DefaultStrideD1 : StrideD1;
StrideE1 = (StrideE1 < 0) ? DefaultStrideE1 : StrideE1;
const int DefaultBatchStrideA0 = (ck::is_same_v<A0Layout, Col> ? K : M) * StrideA0;
const int DefaultBatchStrideB0 = (ck::is_same_v<B0Layout, Col> ? N : K) * StrideB0;
const int DefaultBatchStrideD00 = (ck::is_same_v<D00Layout, Col> ? N : M) * StrideD00;
const int DefaultBatchStrideD01 = (ck::is_same_v<D01Layout, Col> ? N : M) * StrideD01;
const int DefaultBatchStrideB1 = (ck::is_same_v<B1Layout, Col> ? O : N) * StrideB1;
const int DefaultBatchStrideD1 = (ck::is_same_v<D1Layout, Col> ? O : M) * StrideD1;
const int DefaultBatchStrideE1 = (ck::is_same_v<E1Layout, Col> ? O : M) * StrideE1;
BatchStrideA0 = BatchStrideA0 < 0 ? DefaultBatchStrideA0 : BatchStrideA0;
BatchStrideB0 = BatchStrideB0 < 0 ? DefaultBatchStrideB0 : BatchStrideB0;
BatchStrideD00 = BatchStrideD00 < 0 ? DefaultBatchStrideD00 : BatchStrideD00;
BatchStrideD01 = BatchStrideD01 < 0 ? DefaultBatchStrideD01 : BatchStrideD01;
BatchStrideB1 = BatchStrideB1 < 0 ? DefaultBatchStrideB1 : BatchStrideB1;
BatchStrideD1 = BatchStrideD1 < 0 ? DefaultBatchStrideD1 : BatchStrideD1;
BatchStrideE1 = BatchStrideE1 < 0 ? DefaultBatchStrideE1 : BatchStrideE1;
auto f_host_tensor_descriptor = [](std::size_t batch_count,
std::size_t row,
std::size_t col,
std::size_t stride,
std::size_t batch_stride,
auto layout) {
if(std::is_same<decltype(layout), Row>::value)
{
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
std::vector<std::size_t>({batch_stride, stride, 1}));
}
else
{
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
std::vector<std::size_t>({batch_stride, 1, stride}));
}
};
// E_m_o = A_m_k * B0_k_n * B1_n_o
Tensor<A0DataType> a0_g_m_k(
f_host_tensor_descriptor(BatchCount, M, K, StrideA0, BatchStrideA0, A0Layout{}));
Tensor<B0DataType> b0_g_k_n(
f_host_tensor_descriptor(BatchCount, K, N, StrideB0, BatchStrideB0, B0Layout{}));
Tensor<D00DataType> d00_g_m_n(
f_host_tensor_descriptor(BatchCount, M, N, StrideD00, BatchStrideD00, D00Layout{}));
Tensor<D01DataType> d01_g_m_n(
f_host_tensor_descriptor(BatchCount, M, N, StrideD01, BatchStrideD01, D01Layout{}));
Tensor<B1DataType> b1_g_n_o(
f_host_tensor_descriptor(BatchCount, N, O, StrideB1, BatchStrideB1, B1Layout{}));
Tensor<D1DataType> d1_g_m_o(
f_host_tensor_descriptor(BatchCount, M, O, StrideD1, BatchStrideD1, D1Layout{}));
Tensor<E1DataType> e1_g_m_o_host_result(
f_host_tensor_descriptor(BatchCount, M, O, StrideE1, BatchStrideE1, E1Layout{}));
Tensor<E1DataType> e1_g_m_o_device_result(
f_host_tensor_descriptor(BatchCount, M, O, StrideE1, BatchStrideE1, E1Layout{}));
std::cout << "a0_g_m_k: " << a0_g_m_k.mDesc << std::endl;
std::cout << "b0_g_k_n: " << b0_g_k_n.mDesc << std::endl;
std::cout << "d00_g_m_n: " << d00_g_m_n.mDesc
<< " size: " << d00_g_m_n.mDesc.GetElementSpaceSize() << std::endl;
std::cout << "d01_g_m_n: " << d01_g_m_n.mDesc
<< " size: " << d01_g_m_n.mDesc.GetElementSpaceSize() << std::endl;
std::cout << "b1_g_n_o: " << b1_g_n_o.mDesc << std::endl;
std::cout << "e1_g_m_o: " << e1_g_m_o_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a0_g_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 3});
b0_g_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 3});
d00_g_m_n.GenerateTensorValue(GeneratorTensor_2<D00DataType>{-2, 3});
d01_g_m_n.GenerateTensorValue(GeneratorTensor_2<D01DataType>{-2, 3});
b1_g_n_o.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 3});
d1_g_m_o.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-2, 3});
break;
case 2:
a0_g_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
b0_g_k_n.GenerateTensorValue(GeneratorTensor_3<B0DataType>{0.0, 1.0});
d00_g_m_n.GenerateTensorValue(GeneratorTensor_3<D00DataType>{0.0, 1.0});
d01_g_m_n.GenerateTensorValue(GeneratorTensor_3<D01DataType>{0.0, 1.0});
b1_g_n_o.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5});
d1_g_m_o.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
break;
default:
a0_g_m_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{1});
b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
d00_g_m_n.GenerateTensorValue(GeneratorTensor_1<D00DataType>{1});
d01_g_m_n.GenerateTensorValue(GeneratorTensor_1<D01DataType>{1});
b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
d1_g_m_o.GenerateTensorValue(GeneratorTensor_1<D1DataType>{1});
}
DeviceMem a0_g_m_k_device_buf(sizeof(A0DataType) * a0_g_m_k.mDesc.GetElementSize());
DeviceMem b0_g_k_n_device_buf(sizeof(B0DataType) * b0_g_k_n.mDesc.GetElementSize());
DeviceMem d00_g_m_n_device_buf(sizeof(D00DataType) * d00_g_m_n.mDesc.GetElementSpaceSize());
DeviceMem d01_g_m_n_device_buf(sizeof(D01DataType) * d01_g_m_n.mDesc.GetElementSpaceSize());
DeviceMem b1_g_n_o_device_buf(sizeof(B1DataType) * b1_g_n_o.mDesc.GetElementSize());
DeviceMem e1_g_m_o_device_buf(sizeof(E1DataType) *
e1_g_m_o_device_result.mDesc.GetElementSize());
DeviceMem d1_g_m_o_device_buf(sizeof(D1DataType) * d1_g_m_o.mDesc.GetElementSpaceSize());
a0_g_m_k_device_buf.ToDevice(a0_g_m_k.mData.data());
b0_g_k_n_device_buf.ToDevice(b0_g_k_n.mData.data());
d00_g_m_n_device_buf.ToDevice(d00_g_m_n.mData.data());
d01_g_m_n_device_buf.ToDevice(d01_g_m_n.mData.data());
b1_g_n_o_device_buf.ToDevice(b1_g_n_o.mData.data());
d1_g_m_o_device_buf.ToDevice(d1_g_m_o.mData.data());
auto a0_element_op = A0ElementOp{};
auto b0_element_op = B0ElementOp{};
auto cde0_element_op = CDE0ElementOp{};
auto b1_element_op = B1ElementOp{};
auto cde1_element_op = CDE1ElementOp{};
// do GEMM
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
auto argument =
gemm.MakeArgument(static_cast<A0DataType*>(a0_g_m_k_device_buf.GetDeviceBuffer()),
static_cast<B0DataType*>(b0_g_k_n_device_buf.GetDeviceBuffer()),
std::array<const void*, 2>{d00_g_m_n_device_buf.GetDeviceBuffer(),
d01_g_m_n_device_buf.GetDeviceBuffer()},
static_cast<B1DataType*>(b1_g_n_o_device_buf.GetDeviceBuffer()),
std::array<const void*, 1>{d1_g_m_o_device_buf.GetDeviceBuffer()},
static_cast<E1DataType*>(e1_g_m_o_device_buf.GetDeviceBuffer()),
M,
N,
K,
O,
BatchCount,
StrideA0,
StrideB0,
std::array<ck::index_t, 2>{StrideD00, StrideD01},
StrideB1,
std::array<ck::index_t, 1>{StrideD1},
StrideE1,
BatchStrideA0,
BatchStrideB0,
std::array<ck::index_t, 2>{BatchStrideD00, BatchStrideD01},
BatchStrideB1,
std::array<ck::index_t, 1>{BatchStrideD1},
BatchStrideE1,
a0_element_op,
b0_element_op,
cde0_element_op,
b1_element_op,
cde1_element_op);
if(!gemm.IsSupportedArgument(argument))
{
std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
return 0;
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount;
std::size_t num_btype =
(sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(D00DataType) * N +
sizeof(D01DataType) * N + sizeof(B1DataType) * N * O + sizeof(E1DataType) * M * O +
sizeof(D1DataType) * O) *
BatchCount;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl;
e1_g_m_o_device_buf.FromDevice(e1_g_m_o_device_result.mData.data());
if(do_verification)
{
using ReferenceGemm0Instance =
ck::tensor_operation::host::ReferenceBatchedGemm<A0DataType,
B0DataType,
Acc0DataType,
Acc0DataType,
A0ElementOp,
B0ElementOp,
PassThrough>;
using ReferenceGemm1Instance =
ck::tensor_operation::host::ReferenceBatchedGemm<Acc0DataType,
B1DataType,
Acc1DataType,
Acc1DataType,
PassThrough,
B1ElementOp,
PassThrough>;
// Output of Gemm0 is input A of Gemm1
Tensor<Acc0DataType> c0_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{}));
Tensor<Acc0DataType> e0_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{}));
Tensor<Acc1DataType> c1_g_m_o(f_host_tensor_descriptor(BatchCount, M, O, O, M * O, Row{}));
auto ref_gemm0 = ReferenceGemm0Instance{};
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
auto ref_gemm0_argument = ref_gemm0.MakeArgument(
a0_g_m_k, b0_g_k_n, c0_g_m_n, a0_element_op, b0_element_op, PassThrough{});
ref_gemm0_invoker.Run(ref_gemm0_argument);
// bias+bias+relu
e0_g_m_n.ForEach([&](auto&, auto idx) {
cde0_element_op(e0_g_m_n(idx), c0_g_m_n(idx), d00_g_m_n(idx), d01_g_m_n(idx));
});
auto ref_gemm1 = ReferenceGemm1Instance{};
auto ref_gemm1_invoker = ref_gemm1.MakeInvoker();
auto ref_gemm1_argument = ref_gemm1.MakeArgument(
e0_g_m_n, b1_g_n_o, c1_g_m_o, PassThrough{}, b1_element_op, PassThrough{});
ref_gemm1_invoker.Run(ref_gemm1_argument);
// bias
e1_g_m_o_host_result.ForEach([&](auto&, auto idx) {
cde1_element_op(e1_g_m_o_host_result(idx), c1_g_m_o(idx), d1_g_m_o(idx));
});
return ck::utils::check_err(e1_g_m_o_device_result.mData, e1_g_m_o_host_result.mData) ? 0
: 1;
}
return 0;
}
...@@ -52,4 +52,5 @@ add_subdirectory(33_multiple_reduce) ...@@ -52,4 +52,5 @@ add_subdirectory(33_multiple_reduce)
add_subdirectory(34_batchnorm) add_subdirectory(34_batchnorm)
add_subdirectory(35_splitK_gemm) add_subdirectory(35_splitK_gemm)
add_subdirectory(36_sparse_embedding) add_subdirectory(36_sparse_embedding)
add_subdirectory(37_batched_gemm_add_add_relu_gemm_add)
add_subdirectory(41_grouped_conv_conv_fwd) add_subdirectory(41_grouped_conv_conv_fwd)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename A0Layout,
typename B0Layout,
typename D0sLayout,
typename B1Layout,
typename D1sLayout,
typename E1Layout,
typename A0DataType,
typename B0DataType,
typename D0sDataType,
typename B1DataType,
typename D1sDataType,
typename E1DataType,
typename A0ElementwiseOperation,
typename B0ElementwiseOperation,
typename CDE0ElementwiseOperation,
typename B1ElementwiseOperation,
typename CDE1ElementwiseOperation>
struct DeviceBatchedGemmMultipleDGemmMultipleD : public BaseOperator
{
static constexpr index_t NumD0Tensor = D0sDataType::Size();
static constexpr index_t NumD1Tensor = D1sDataType::Size();
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a0,
const void* p_b0,
std::array<const void*, NumD0Tensor> p_d0s,
const void* p_b1,
std::array<const void*, NumD1Tensor> p_d1s,
void* p_e1,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t O,
ck::index_t Batch,
ck::index_t StrideA0,
ck::index_t StrideB0,
std::array<ck::index_t, NumD0Tensor> StrideD0s,
ck::index_t StrideB1,
std::array<ck::index_t, NumD1Tensor> StrideD1s,
ck::index_t StrideE1,
ck::index_t BatchStrideA0,
ck::index_t BatchStrideB0,
std::array<ck::index_t, NumD0Tensor> BatchStrideD0s,
ck::index_t BatchStrideB1,
std::array<ck::index_t, NumD1Tensor> BatchStrideD1s,
ck::index_t BatchStrideE1,
A0ElementwiseOperation a0_element_op,
B0ElementwiseOperation b0_element_op,
CDE0ElementwiseOperation cde0_element_op,
B1ElementwiseOperation b1_element_op,
CDE1ElementwiseOperation cde1_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename GridwiseGemm,
typename A0B0B1DataType,
typename D0sPointer,
typename D1sPointer,
typename E1DataType,
typename A0ElementwiseOperation,
typename B0ElementwiseOperation,
typename CDE0ElementwiseOperation,
typename B1ElementwiseOperation,
typename CDE1ElementwiseOperation,
typename A0GridDesc_AK0_M_AK1,
typename B0GridDesc_BK0_N_BK1,
typename D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5,
typename B1GridDesc_BK0_N_BK1,
typename D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename Block2E1TileMap,
typename ComputeBasePtrOfStridedBatch,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_batched_gemm_gemm_xdl_cshuffle_v1(
const A0B0B1DataType* __restrict__ p_a0_grid,
const A0B0B1DataType* __restrict__ p_b0_grid,
D0sPointer p_d0s_grid,
const A0B0B1DataType* __restrict__ p_b1_grid,
D1sPointer p_d1s_grid,
E1DataType* __restrict__ p_e1_grid,
const A0ElementwiseOperation a0_element_op,
const B0ElementwiseOperation b0_element_op,
const CDE0ElementwiseOperation cde0_element_op,
const B1ElementwiseOperation b1_element_op,
const CDE1ElementwiseOperation cde1_element_op,
const A0GridDesc_AK0_M_AK1 a0_grid_desc_ak0_m_ak1,
const B0GridDesc_BK0_N_BK1 b0_grid_desc_bk0_n_bk1,
const D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
const D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
d1s_grid_desc_mblock_mperblock_nblock_nperblock,
const E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e1_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2E1TileMap block_2_e1tile_map,
const index_t batch_count,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetABasePtr(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetBBasePtr(g_idx)));
const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetB1BasePtr(g_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetCBasePtr(g_idx)));
static_for<0, p_d0s_grid.Size(), 1>{}([&](auto In) {
const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetD0BasePtr(g_idx, In)));
p_d0s_grid(In) = p_d0s_grid(In) + d0_batch_offset;
});
static_for<0, p_d1s_grid.Size(), 1>{}([&](auto In) {
const long_index_t d1_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetD1BasePtr(g_idx, In)));
p_d1s_grid(In) = p_d1s_grid(In) + d1_batch_offset;
});
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a0_grid + a_batch_offset,
p_b0_grid + b_batch_offset,
p_d0s_grid,
p_b1_grid + b1_batch_offset,
p_d1s_grid,
p_e1_grid + c_batch_offset,
p_shared,
a0_element_op,
b0_element_op,
cde0_element_op,
b1_element_op,
cde1_element_op,
a0_grid_desc_ak0_m_ak1,
b0_grid_desc_bk0_n_bk1,
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
b1_grid_desc_bk0_n_bk1,
d1s_grid_desc_mblock_mperblock_nblock_nperblock,
e1_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_e1tile_map);
#else
ignore = p_a0_grid;
ignore = p_b0_grid;
ignore = p_d0s_grid;
ignore = p_b1_grid;
ignore = p_d1s_grid;
ignore = p_e1_grid;
ignore = a0_element_op;
ignore = b0_element_op;
ignore = cde0_element_op;
ignore = b1_element_op;
ignore = cde1_element_op;
ignore = a0_grid_desc_ak0_m_ak1;
ignore = b0_grid_desc_bk0_n_bk1;
ignore = d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5;
ignore = b1_grid_desc_bk0_n_bk1;
ignore = d1s_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = e1_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = block_2_e1tile_map;
ignore = batch_count;
ignore = compute_base_ptr_of_batch;
#endif
}
// Computes C = A * B0 * B1
// ^^^^^^ (Acc0)
// ^^^^^^^^^^^ (Acc1)
template <typename A0Layout,
typename B0Layout, // B0Layout
typename D0sLayout,
typename B1Layout,
typename D1sLayout,
typename E1Layout,
typename A0DataType,
typename B0DataType,
typename Acc0DataType,
typename D0sDataType,
typename B1DataType,
typename Acc1DataType,
typename C1ShuffleDataType,
typename D1sDataType,
typename E1DataType,
typename A0ElementwiseOperation,
typename B0ElementwiseOperation,
typename CDE0ElementwiseOperation,
typename B1ElementwiseOperation,
typename CDE1ElementwiseOperation,
bool PadGemm0M,
bool PadGemm0N,
bool PadGemm0K,
bool PadGemm1N,
bool PadGemm1K,
index_t NumGemm0KPrefetchStage,
index_t BlockSize,
index_t Gemm0MPerBlock,
index_t Gemm0NPerBlock,
index_t Gemm0KPerBlock,
index_t Gemm1NPerBlock,
index_t Gemm1KPerBlock,
index_t A0K1,
index_t B0K1,
index_t B1K1,
index_t Gemm0MPerXdl,
index_t Gemm0NPerXdl,
index_t Gemm0MXdlPerWave,
index_t Gemm0NXdlPerWave,
index_t Gemm1NXdlPerWave,
typename A0BlockTransferThreadClusterLengths_AK0_M_AK1,
typename A0BlockTransferThreadClusterArrangeOrder,
typename A0BlockTransferSrcAccessOrder,
index_t A0BlockTransferSrcVectorDim,
index_t A0BlockTransferSrcScalarPerVector,
index_t A0BlockTransferDstScalarPerVector_AK1,
bool A0BlockLdsExtraM,
typename B0BlockTransferThreadClusterLengths_BK0_N_BK1,
typename B0BlockTransferThreadClusterArrangeOrder,
typename B0BlockTransferSrcAccessOrder,
index_t B0BlockTransferSrcVectorDim,
index_t B0BlockTransferSrcScalarPerVector,
index_t B0BlockTransferDstScalarPerVector_BK1,
bool B0BlockLdsExtraN,
typename B1BlockTransferThreadClusterLengths_BK0_N_BK1,
typename B1BlockTransferThreadClusterArrangeOrder,
typename B1BlockTransferSrcAccessOrder,
index_t B1BlockTransferSrcVectorDim,
index_t B1BlockTransferSrcScalarPerVector,
index_t B1BlockTransferDstScalarPerVector_BK1,
bool B1BlockLdsExtraN,
index_t C1ShuffleMXdlPerWavePerShuffle,
index_t C1ShuffleGemm0NXdlPerWavePerShuffle,
typename CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDE1ShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
: public DeviceBatchedGemmMultipleDGemmMultipleD<A0Layout,
B0Layout,
D0sLayout,
B1Layout,
D1sLayout,
E1Layout,
A0DataType,
B0DataType,
D0sDataType,
B1DataType,
D1sDataType,
E1DataType,
A0ElementwiseOperation,
B0ElementwiseOperation,
CDE0ElementwiseOperation,
B1ElementwiseOperation,
CDE1ElementwiseOperation>
{
using DeviceOp = DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle;
static constexpr index_t NumD0Tensor = D0sDataType::Size();
static constexpr index_t NumD1Tensor = D1sDataType::Size();
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
static constexpr auto I8 = Number<8>{};
static constexpr auto I9 = Number<9>{};
static constexpr auto gemm0_padder =
GemmPadder_v2<PadGemm0M, PadGemm0N, PadGemm0K, index_t, index_t, index_t>{
Gemm0MPerBlock, Gemm0NPerBlock, Gemm0KPerBlock};
static constexpr auto gemm1_padder =
GemmPadder_v2<PadGemm0M, PadGemm1N, PadGemm1K, index_t, index_t, index_t>{
Gemm0MPerBlock, Gemm1NPerBlock, Gemm1KPerBlock};
// for Gemm0
static auto MakeA0GridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA0)
{
const auto a0_grid_desc_mraw_kraw = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, A0Layout>)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(StrideA0, I1));
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, A0Layout>)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(I1, StrideA0));
}
}();
return gemm0_padder.PadADescriptor_M_K(a0_grid_desc_mraw_kraw);
}
// for Gemm0
static auto MakeB0GridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
{
const auto b0_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, B0Layout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(I1, StrideB));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, B0Layout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(StrideB, I1));
}
}();
return gemm0_padder.PadBDescriptor_N_K(b0_grid_desc_nraw_kraw);
}
// for Gemm0
template <typename DLay>
static auto MakeD0GridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideD0)
{
const auto d0_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, DLay>::value)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(StrideD0, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, DLay>::value)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(I1, StrideD0));
}
}();
return gemm0_padder.PadCDescriptor_M_N(d0_grid_desc_mraw_nraw);
}
// for Gemm1
static auto MakeB1GridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
{
const auto b1_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, B1Layout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(I1, StrideB));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, B1Layout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(StrideB, I1));
}
}();
return gemm1_padder.PadBDescriptor_N_K(b1_grid_desc_nraw_kraw);
}
// for Gemm1
template <typename ELay>
static auto MakeE1GridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE1)
{
const auto e1_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ELay>::value)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(StrideE1, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ELay>::value)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(I1, StrideE1));
}
}();
return gemm1_padder.PadCDescriptor_M_N(e1_grid_desc_mraw_nraw);
}
static auto MakeD0sGridDescriptor_M_N(const std::array<index_t, NumD1Tensor>& MRaws,
const std::array<index_t, NumD1Tensor>& NRaws,
const std::array<index_t, NumD1Tensor>& DsStride)
{
return generate_tuple(
[&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, D0sLayout>>;
return DeviceOp::MakeD0GridDescriptor_M_N<DLayout>(MRaws[i], NRaws[i], DsStride[i]);
},
Number<NumD0Tensor>{});
}
static auto MakeD1sGridDescriptor_M_N(const std::array<index_t, NumD1Tensor>& MRaws,
const std::array<index_t, NumD1Tensor>& NRaws,
const std::array<index_t, NumD1Tensor>& DsStride)
{
return generate_tuple(
[&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, D1sLayout>>;
return DeviceOp::MakeE1GridDescriptor_M_N<DLayout>(MRaws[i], NRaws[i], DsStride[i]);
},
Number<NumD1Tensor>{});
}
struct ComputeBasePtrOfStridedBatch
{
ComputeBasePtrOfStridedBatch(index_t BatchStrideA0,
index_t BatchStrideB0,
std::array<index_t, NumD0Tensor> BatchStrideD0s,
index_t BatchStrideB1,
std::array<index_t, NumD1Tensor> BatchStrideD1s,
index_t BatchStrideE1)
: BatchStrideA0_(BatchStrideA0),
BatchStrideB0_(BatchStrideB0),
BatchStrideD0s_(BatchStrideD0s),
BatchStrideB1_(BatchStrideB1),
BatchStrideD1s_(BatchStrideD1s),
BatchStrideE1_(BatchStrideE1)
{
}
__host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideA0_);
}
__host__ __device__ constexpr long_index_t GetBBasePtr(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideB0_);
}
template <index_t I>
__host__ __device__ constexpr long_index_t GetD0BasePtr(index_t g_idx,
Number<I> d1_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideD0s_[d1_idx]);
}
__host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideB1_);
}
__host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideE1_);
}
template <index_t I>
__host__ __device__ constexpr auto GetD1BasePtr(index_t g_idx, Number<I> d1_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideD1s_[d1_idx]);
}
private:
index_t BatchStrideA0_;
index_t BatchStrideB0_;
std::array<index_t, NumD0Tensor> BatchStrideD0s_;
index_t BatchStrideB1_;
std::array<index_t, NumD1Tensor> BatchStrideD1s_;
index_t BatchStrideE1_;
};
using A0GridDesc_M_K = decltype(MakeA0GridDescriptor_M_K(1, 1, 1));
using B0GridDesc_N_K = decltype(MakeB0GridDescriptor_N_K(1, 1, 1));
using D0sGridDesc_M_N = remove_cvref_t<decltype(MakeD0sGridDescriptor_M_N({}, {}, {}))>;
using B1GridDesc_N_K = decltype(MakeB1GridDescriptor_N_K(1, 1, 1));
using D1sGridDesc_M_N = remove_cvref_t<decltype(MakeD1sGridDescriptor_M_N({}, {}, {}))>;
using E1GridDesc_M_N = decltype(MakeE1GridDescriptor_M_N<E1Layout>(1, 1, 1));
// GridwiseGemm
using GridwiseGemm = GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle<
A0DataType, // TODO: distinguish A/B datatype
Acc0DataType,
D0sDataType,
Acc1DataType,
C1ShuffleDataType,
D1sDataType,
E1DataType,
A0ElementwiseOperation,
B0ElementwiseOperation,
CDE0ElementwiseOperation,
B1ElementwiseOperation,
CDE1ElementwiseOperation,
InMemoryDataOperationEnum::Set,
A0GridDesc_M_K,
B0GridDesc_N_K,
D0sGridDesc_M_N,
B1GridDesc_N_K,
D1sGridDesc_M_N,
E1GridDesc_M_N,
NumGemm0KPrefetchStage,
BlockSize,
Gemm0MPerBlock,
Gemm0NPerBlock,
Gemm0KPerBlock,
Gemm1NPerBlock,
Gemm1KPerBlock,
A0K1,
B0K1,
B1K1,
Gemm0MPerXdl,
Gemm0NPerXdl,
Gemm0MXdlPerWave,
Gemm0NXdlPerWave,
Gemm1NXdlPerWave,
A0BlockTransferThreadClusterLengths_AK0_M_AK1,
A0BlockTransferThreadClusterArrangeOrder,
A0BlockTransferSrcAccessOrder,
A0BlockTransferSrcVectorDim,
A0BlockTransferSrcScalarPerVector,
A0BlockTransferDstScalarPerVector_AK1,
true,
A0BlockLdsExtraM,
B0BlockTransferThreadClusterLengths_BK0_N_BK1,
B0BlockTransferThreadClusterArrangeOrder,
B0BlockTransferSrcAccessOrder,
B0BlockTransferSrcVectorDim,
B0BlockTransferSrcScalarPerVector,
B0BlockTransferDstScalarPerVector_BK1,
true,
B0BlockLdsExtraN,
B1BlockTransferThreadClusterLengths_BK0_N_BK1,
B1BlockTransferThreadClusterArrangeOrder,
B1BlockTransferSrcAccessOrder,
B1BlockTransferSrcVectorDim,
B1BlockTransferSrcScalarPerVector,
B1BlockTransferDstScalarPerVector_BK1,
false,
B1BlockLdsExtraN,
C1ShuffleMXdlPerWavePerShuffle,
C1ShuffleGemm0NXdlPerWavePerShuffle,
CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDE1ShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched>;
using A0GridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeDefaultA0GridDescriptor_AK0_M_AK1(A0GridDesc_M_K{}))>;
using B0GridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeDefaultB0GridDescriptor_BK0_N_BK1(B0GridDesc_N_K{}))>;
using B1GridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeDefaultB1GridDescriptor_BK0_N_BK1(B1GridDesc_N_K{}))>;
// Argument
struct Argument : public BaseArgument
{
Argument(const A0DataType* p_a0_grid,
const B0DataType* p_b0_grid,
std::array<const void*, NumD0Tensor> p_d0s_grid,
const B1DataType* p_b1_grid,
std::array<const void*, NumD1Tensor> p_d1s_grid,
E1DataType* p_e1_grid,
index_t MRaw,
index_t NRaw,
index_t KRaw,
index_t Gemm1NRaw, // = ORaw
index_t Batch,
index_t StrideA0,
index_t StrideB0,
std::array<index_t, NumD0Tensor> StrideD0s,
index_t StrideB1,
std::array<index_t, NumD1Tensor> StrideD1s,
index_t StrideE1,
index_t BatchStrideA0,
index_t BatchStrideB0,
std::array<index_t, NumD0Tensor> BatchStrideD0s,
index_t BatchStrideB1,
std::array<index_t, NumD1Tensor> BatchStrideD1s,
index_t BatchStrideE1,
A0ElementwiseOperation a0_element_op,
B0ElementwiseOperation b0_element_op,
CDE0ElementwiseOperation cde0_element_op,
B1ElementwiseOperation b1_element_op,
CDE1ElementwiseOperation cde1_element_op)
: p_a0_grid_{p_a0_grid},
p_b0_grid_{p_b0_grid},
p_d0s_grid_{},
p_b1_grid_{p_b1_grid},
p_d1s_grid_{},
p_e1_grid_{p_e1_grid},
a0_grid_desc_m_k_{DeviceOp::MakeA0GridDescriptor_M_K(MRaw, KRaw, StrideA0)},
b0_grid_desc_n_k_{DeviceOp::MakeB0GridDescriptor_N_K(KRaw, NRaw, StrideB0)},
d0s_grid_desc_m_n_{},
b1_grid_desc_n_k_{DeviceOp::MakeB1GridDescriptor_N_K(NRaw, Gemm1NRaw, StrideB1)},
d1s_grid_desc_m_n_{},
e1_grid_desc_m_n_{
DeviceOp::MakeE1GridDescriptor_M_N<E1Layout>(MRaw, Gemm1NRaw, StrideE1)},
a0_grid_desc_ak0_m_ak1_{
GridwiseGemm::MakeDefaultA0GridDescriptor_AK0_M_AK1(a0_grid_desc_m_k_)},
b0_grid_desc_bk0_n_bk1_{
GridwiseGemm::MakeDefaultB0GridDescriptor_BK0_N_BK1(b0_grid_desc_n_k_)},
d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_{},
b1_grid_desc_bk0_n_bk1_{
GridwiseGemm::MakeDefaultB1GridDescriptor_BK0_N_BK1(b1_grid_desc_n_k_)},
d1s_grid_desc_mblock_mperblock_nblock_nperblock_{},
e1_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_e1tile_map_{GridwiseGemm::MakeDefaultBlock2E1TileMap(e1_grid_desc_m_n_)},
a0_element_op_{a0_element_op},
b0_element_op_{b0_element_op},
cde0_element_op_{cde0_element_op},
b1_element_op_{b1_element_op},
cde1_element_op_{cde1_element_op},
batch_count_(Batch),
compute_base_ptr_of_batch_{BatchStrideA0,
BatchStrideB0,
BatchStrideD0s,
BatchStrideB1,
BatchStrideD1s,
BatchStrideE1}
{
std::cout << "a0_grid_desc_m_k_{" << a0_grid_desc_m_k_.GetLength(I0) << ", "
<< a0_grid_desc_m_k_.GetLength(I1) << "}" << std::endl;
std::cout << "b0_grid_desc_n_k_{" << b0_grid_desc_n_k_.GetLength(I0) << ", "
<< b0_grid_desc_n_k_.GetLength(I1) << "}" << std::endl;
std::cout << "d0s_grid_desc_m_n_[I0]{" << d0s_grid_desc_m_n_[I0].GetLength(I0) << ", "
<< d0s_grid_desc_m_n_[I0].GetLength(I1) << "}" << std::endl;
std::cout << "b1_grid_desc_n_k_{" << b1_grid_desc_n_k_.GetLength(I0) << ", "
<< b1_grid_desc_n_k_.GetLength(I1) << "}" << std::endl;
std::cout << "d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_{"
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I0) << ", "
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I1) << ", "
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I2) << ", "
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I3) << ", "
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I4) << ", "
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I5) << ", "
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I6) << ", "
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I7) << ", "
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I8) << ", "
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I9) << "}"
<< std::endl;
std::cout << "e1_grid_desc_m_n_{" << e1_grid_desc_m_n_.GetLength(I0) << ", "
<< e1_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
using D0Layout = remove_cvref_t<tuple_element_t<i.value, D0sLayout>>;
using D0DataType = remove_cvref_t<tuple_element_t<i.value, D0sDataType>>;
// D0 pointer
p_d0s_grid_(i) = static_cast<const D0DataType*>(p_d0s_grid[i]);
// D0 desc
d0s_grid_desc_m_n_(i) =
DeviceOp::MakeD0GridDescriptor_M_N<D0Layout>(MRaw, NRaw, StrideD0s[i]);
});
static_for<0, NumD1Tensor, 1>{}([&](auto i) {
using D1Layout = remove_cvref_t<tuple_element_t<i.value, D1sLayout>>;
using D1DataType = remove_cvref_t<tuple_element_t<i.value, D1sDataType>>;
// D1 pointer
p_d1s_grid_(i) = static_cast<const D1DataType*>(p_d1s_grid[i]);
// D1 desc
d1s_grid_desc_m_n_(i) =
DeviceOp::MakeE1GridDescriptor_M_N<D1Layout>(MRaw, Gemm1NRaw, StrideD1s[i]);
});
if(GridwiseGemm::CheckValidity(a0_grid_desc_m_k_,
b0_grid_desc_n_k_,
b1_grid_desc_n_k_,
e1_grid_desc_m_n_,
block_2_e1tile_map_))
{
e1_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e1_grid_desc_m_n_);
d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_ =
GridwiseGemm::MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(
d0s_grid_desc_m_n_);
d1s_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeD1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
d1s_grid_desc_m_n_);
}
}
// private:
// pointers
const A0DataType* p_a0_grid_;
const B0DataType* p_b0_grid_;
typename GridwiseGemm::D0sGridPointer p_d0s_grid_;
const B1DataType* p_b1_grid_;
typename GridwiseGemm::D1sGridPointer p_d1s_grid_;
E1DataType* p_e1_grid_;
// tensor descriptors for problem definiton
A0GridDesc_M_K a0_grid_desc_m_k_;
B0GridDesc_N_K b0_grid_desc_n_k_;
D0sGridDesc_M_N d0s_grid_desc_m_n_;
B1GridDesc_N_K b1_grid_desc_n_k_;
D1sGridDesc_M_N d1s_grid_desc_m_n_;
E1GridDesc_M_N e1_grid_desc_m_n_;
// tensor descriptors for block/thread-wise copy
A0GridDesc_AK0_M_AK1 a0_grid_desc_ak0_m_ak1_;
B0GridDesc_BK0_N_BK1 b0_grid_desc_bk0_n_bk1_;
typename GridwiseGemm::D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_;
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_;
typename GridwiseGemm::D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
d1s_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e1_grid_desc_mblock_mperblock_nblock_nperblock_;
// block-to-e1-tile map
typename GridwiseGemm::DefaultBlock2E1TileMap block_2_e1tile_map_;
// element-wise op
A0ElementwiseOperation a0_element_op_;
B0ElementwiseOperation b0_element_op_;
CDE0ElementwiseOperation cde0_element_op_;
B1ElementwiseOperation b1_element_op_;
CDE1ElementwiseOperation cde1_element_op_;
// batch
index_t batch_count_;
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
};
// Invoker
struct Invoker : public BaseInvoker
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(!GridwiseGemm::CheckValidity(arg.a0_grid_desc_m_k_,
arg.b0_grid_desc_n_k_,
arg.b1_grid_desc_n_k_,
arg.e1_grid_desc_m_n_,
arg.block_2_e1tile_map_))
{
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
}
const index_t grid_size =
arg.block_2_e1tile_map_.CalculateGridSize(arg.e1_grid_desc_m_n_) * arg.batch_count_;
// Gemm0_K
const auto K = arg.a0_grid_desc_m_k_.GetLength(I1);
auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_batched_gemm_gemm_xdl_cshuffle_v1<
GridwiseGemm,
A0DataType, // TODO: distiguish A/B datatype
typename GridwiseGemm::D0sGridPointer,
typename GridwiseGemm::D1sGridPointer,
E1DataType,
A0ElementwiseOperation,
B0ElementwiseOperation,
CDE0ElementwiseOperation,
B1ElementwiseOperation,
CDE1ElementwiseOperation,
DeviceOp::A0GridDesc_AK0_M_AK1,
DeviceOp::B0GridDesc_BK0_N_BK1,
typename GridwiseGemm::D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5,
DeviceOp::B1GridDesc_BK0_N_BK1,
typename GridwiseGemm::D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DefaultBlock2E1TileMap,
ComputeBasePtrOfStridedBatch,
has_main_k_block_loop_>;
return launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a0_grid_,
arg.p_b0_grid_,
arg.p_d0s_grid_,
arg.p_b1_grid_,
arg.p_d1s_grid_,
arg.p_e1_grid_,
arg.a0_element_op_,
arg.b0_element_op_,
arg.cde0_element_op_,
arg.b1_element_op_,
arg.cde1_element_op_,
arg.a0_grid_desc_ak0_m_ak1_,
arg.b0_grid_desc_bk0_n_bk1_,
arg.d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg.b1_grid_desc_bk0_n_bk1_,
arg.d1s_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.e1_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_e1tile_map_,
arg.batch_count_,
arg.compute_base_ptr_of_batch_);
};
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
// to concern Gemm0's loop
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
return launch_kernel(integral_constant<bool, true>{});
}
else
{
return launch_kernel(integral_constant<bool, false>{});
}
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
static bool IsSupportedArgument(const Argument& arg)
{
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a"))
{
return false;
}
return GridwiseGemm::CheckValidity(arg.a0_grid_desc_m_k_,
arg.b0_grid_desc_n_k_,
arg.b1_grid_desc_n_k_,
arg.e1_grid_desc_m_n_,
arg.block_2_e1tile_map_);
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const A0DataType* p_a0,
const B0DataType* p_b0,
std::array<const void*, NumD0Tensor> p_d0s,
const B1DataType* p_b1,
std::array<const void*, NumD1Tensor> p_d1s,
E1DataType* p_e1,
index_t MRaw,
index_t NRaw,
index_t KRaw,
index_t Gemm1NRaw,
index_t Batch,
index_t StrideA0,
index_t StrideB0,
std::array<index_t, NumD0Tensor> StrideD0s,
index_t StrideB1,
std::array<index_t, NumD1Tensor> StrideD1s,
index_t StrideE1,
index_t BatchStrideA0,
index_t BatchStrideB0,
std::array<index_t, NumD0Tensor> BatchStrideD0s,
index_t BatchStrideB1,
std::array<index_t, NumD1Tensor> BatchStrideD1s,
index_t BatchStrideE1,
A0ElementwiseOperation a0_element_op,
B0ElementwiseOperation b0_element_op,
CDE0ElementwiseOperation cde0_element_op,
B1ElementwiseOperation b1_element_op,
CDE1ElementwiseOperation cde1_element_op)
{
return Argument{p_a0, p_b0,
p_d0s, p_b1,
p_d1s, p_e1,
MRaw, NRaw,
KRaw, Gemm1NRaw,
Batch, StrideA0,
StrideB0, StrideD0s,
StrideB1, StrideD1s,
StrideE1, BatchStrideA0,
BatchStrideB0, BatchStrideD0s,
BatchStrideB1, BatchStrideD1s,
BatchStrideE1, a0_element_op,
b0_element_op, cde0_element_op,
b1_element_op, cde1_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a0,
const void* p_b0,
std::array<const void*, NumD0Tensor> p_d0s,
const void* p_b1,
std::array<const void*, NumD1Tensor> p_d1s,
void* p_e1,
index_t MRaw,
index_t NRaw,
index_t KRaw,
index_t Gemm1NRaw,
index_t Batch,
index_t StrideA0,
index_t StrideB0,
std::array<ck::index_t, NumD0Tensor> StrideD0s,
index_t StrideB1,
std::array<ck::index_t, NumD1Tensor> StrideD1s,
index_t StrideE1,
index_t BatchStrideA0,
index_t BatchStrideB0,
std::array<ck::index_t, NumD0Tensor> BatchStrideD0s,
index_t BatchStrideB1,
std::array<ck::index_t, NumD1Tensor> BatchStrideD1s,
index_t BatchStrideE1,
A0ElementwiseOperation a0_element_op,
B0ElementwiseOperation b0_element_op,
CDE0ElementwiseOperation cde0_element_op,
B1ElementwiseOperation b1_element_op,
CDE1ElementwiseOperation cde1_element_op) override
{
return std::make_unique<Argument>(static_cast<const A0DataType*>(p_a0),
static_cast<const B0DataType*>(p_b0),
p_d0s,
static_cast<const B1DataType*>(p_b1),
p_d1s,
static_cast<E1DataType*>(p_e1),
MRaw,
NRaw,
KRaw,
Gemm1NRaw,
Batch,
StrideA0,
StrideB0,
StrideD0s,
StrideB1,
StrideD1s,
StrideE1,
BatchStrideA0,
BatchStrideB0,
BatchStrideD0s,
BatchStrideB1,
BatchStrideD1s,
BatchStrideE1,
a0_element_op,
b0_element_op,
cde0_element_op,
b1_element_op,
cde1_element_op);
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle"
<< "<"
<< BlockSize << ", "
<< Gemm0MPerBlock << ", "
<< Gemm0NPerBlock << ", "
<< Gemm0KPerBlock << ", "
<< A0K1 << ", "
<< B0K1 << ", "
<< B1K1 << ", "
<< Gemm0MPerXdl << ", "
<< Gemm0NPerXdl << ", "
<< Gemm0MXdlPerWave << ", "
<< Gemm0NXdlPerWave << ", "
<< Gemm1NXdlPerWave << "> ";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -218,6 +218,165 @@ struct GemmPadder_v2 ...@@ -218,6 +218,165 @@ struct GemmPadder_v2
KPerTileType KPerTile_; KPerTileType KPerTile_;
}; };
// M/N/KPerTileType could be index_t or Number<>
template <bool PadM,
bool PadN,
bool PadK,
typename MPerTileType,
typename NPerTileType,
typename KPerTileType>
struct MatrixPadder_v2
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
template <typename ADesc_MRaw_KRaw>
__host__ __device__ constexpr auto
PadADescriptor_M_K(const ADesc_MRaw_KRaw& a_desc_mraw_kraw) const
{
const auto MRaw = a_desc_mraw_kraw.GetLength(I0);
const auto KRaw = a_desc_mraw_kraw.GetLength(I1);
const auto M = math::integer_divide_ceil(MRaw, MPerTile_) * MPerTile_;
const auto K = math::integer_divide_ceil(KRaw, KPerTile_) * KPerTile_;
const auto MPad = M - MRaw;
const auto KPad = K - KRaw;
if constexpr(PadM && PadK)
{
// pad both M and K
return transform_tensor_descriptor(a_desc_mraw_kraw,
make_tuple(make_right_pad_transform(MRaw, MPad),
make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(PadM && (!PadK))
{
// pad M, but not K
return transform_tensor_descriptor(
a_desc_mraw_kraw,
make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(KRaw)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr((!PadM) && PadK)
{
// pad K, but not M
return transform_tensor_descriptor(
a_desc_mraw_kraw,
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
// not pad M or K
return a_desc_mraw_kraw;
}
}
template <typename BDesc_NRaw_KRaw>
__host__ __device__ constexpr auto
PadBDescriptor_N_K(const BDesc_NRaw_KRaw& b_desc_nraw_kraw) const
{
const auto NRaw = b_desc_nraw_kraw.GetLength(I0);
const auto KRaw = b_desc_nraw_kraw.GetLength(I1);
const auto N = math::integer_divide_ceil(NRaw, NPerTile_) * NPerTile_;
const auto K = math::integer_divide_ceil(KRaw, KPerTile_) * KPerTile_;
const auto NPad = N - NRaw;
const auto KPad = K - KRaw;
if constexpr(PadN && PadK)
{
// pad both N and K
return transform_tensor_descriptor(b_desc_nraw_kraw,
make_tuple(make_right_pad_transform(NRaw, NPad),
make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(PadN && (!PadK))
{
// pad N, but not K
return transform_tensor_descriptor(
b_desc_nraw_kraw,
make_tuple(make_right_pad_transform(NRaw, NPad), make_pass_through_transform(KRaw)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr((!PadN) && PadK)
{
// pad K, but not N
return transform_tensor_descriptor(
b_desc_nraw_kraw,
make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
// not pad N or K
return b_desc_nraw_kraw;
}
}
template <typename CDesc_MRaw_NRaw>
__host__ __device__ constexpr auto
PadCDescriptor_M_N(const CDesc_MRaw_NRaw& c_desc_mraw_nraw) const
{
const auto MRaw = c_desc_mraw_nraw.GetLength(I0);
const auto NRaw = c_desc_mraw_nraw.GetLength(I1);
const auto M = math::integer_divide_ceil(MRaw, MPerTile_) * MPerTile_;
const auto N = math::integer_divide_ceil(NRaw, NPerTile_) * NPerTile_;
const auto MPad = M - MRaw;
const auto NPad = N - NRaw;
if constexpr(PadM && PadN)
{
// pad M and N
return transform_tensor_descriptor(c_desc_mraw_nraw,
make_tuple(make_right_pad_transform(MRaw, MPad),
make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(PadM && (!PadN))
{
// pad M, but not N
return transform_tensor_descriptor(
c_desc_mraw_nraw,
make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr((!PadM) && PadN)
{
// pad N, but not M
return transform_tensor_descriptor(
c_desc_mraw_nraw,
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
// not pad M or N
return c_desc_mraw_nraw;
}
}
MPerTileType MPerTile_;
NPerTileType NPerTile_;
KPerTileType KPerTile_;
};
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -28,6 +28,13 @@ struct Add ...@@ -28,6 +28,13 @@ struct Add
y = x0 + x1; y = x0 + x1;
}; };
template <>
__host__ __device__ constexpr void
operator()<float>(float& y, const float& x0, const half_t& x1) const
{
y = x0 + type_convert<half_t>(x1);
};
template <> template <>
__host__ __device__ constexpr void __host__ __device__ constexpr void
operator()<half_t>(half_t& y, const float& x0, const half_t& x1) const operator()<half_t>(half_t& y, const float& x0, const half_t& x1) const
...@@ -172,6 +179,14 @@ struct AddRelu ...@@ -172,6 +179,14 @@ struct AddRelu
const float a = x0 + x1; const float a = x0 + x1;
y = a > type_convert<half_t>(0.0f) ? a : type_convert<half_t>(0.0f); y = a > type_convert<half_t>(0.0f) ? a : type_convert<half_t>(0.0f);
}; };
template <>
__host__ __device__ constexpr void
operator()<float, float, half_t>(float& y, const float& x0, const half_t& x1) const
{
const float a = x0 + type_convert<float>(x1);
y = a > 0.0f ? a : 0.0f;
};
}; };
struct AddHardswish struct AddHardswish
...@@ -210,6 +225,46 @@ struct AddHardswish ...@@ -210,6 +225,46 @@ struct AddHardswish
}; };
}; };
// C = A * B
// E = FastGelu(C + D)
struct AddFastGelu
{
// Fast GeLU
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
__host__ __device__ static constexpr float GetFastGeLU(float x)
{
const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float emu = exp(-u);
const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f);
return x * cdf;
}
template <typename T>
static inline constexpr bool is_valid_param_type_v =
std::is_same_v<T, float> || std::is_same_v<T, half_t> || std::is_same_v<T, bhalf_t> ||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>;
template <typename E, typename C, typename D>
__host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const
{
static_assert(is_valid_param_type_v<E> && is_valid_param_type_v<C> &&
is_valid_param_type_v<D>);
const float y = GetFastGeLU(type_convert<float>(c) + type_convert<float>(d));
e = type_convert<E>(y);
}
template <typename D>
__host__ __device__ constexpr void operator()(float& e, const float& c, const D& d) const
{
static_assert(is_valid_param_type_v<D>);
e = GetFastGeLU(c + type_convert<float>(d));
}
};
} // namespace element_wise } // namespace element_wise
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -211,6 +211,27 @@ struct FastGelu ...@@ -211,6 +211,27 @@ struct FastGelu
} }
}; };
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+erf(x/sqrt(2)))
struct Gelu
{
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const;
template <>
__host__ __device__ void operator()<float, float>(float& y, const float& x) const
{
y = 0.5f * x * (1.f + erf(float(0.70710678118f * x)));
}
template <>
__host__ __device__ void operator()<ck::half_t, ck::half_t>(ck::half_t& y,
const ck::half_t& x) const
{
y = ck::half_t(0.5) * x * (ck::half_t(1) + ck::half_t(erf(float(0.70710678118f * x))));
}
};
} // namespace element_wise } // namespace element_wise
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
template <typename A0B0B1DataType, // FIXME: don't assume A0/B0/B1 have same datatype
typename Acc0DataType,
typename D0sDataType,
typename Acc1DataType,
typename C1ShuffleDataType,
typename D1sDataType,
typename E1DataType,
typename A0ElementwiseOperation,
typename B0ElementwiseOperation,
typename CDE0ElementwiseOperation,
typename B1ElementwiseOperation,
typename CDE1ElementwiseOperation,
InMemoryDataOperationEnum E1GlobalMemoryDataOperation,
typename A0GridDesc_M_K,
typename B0GridDesc_N_K,
typename D0sGridDesc_M_N,
typename B1GridDesc_N_K,
typename D1sGridDesc_M_N,
typename E1GridDesc_M_N,
index_t NumGemm0KPrefetchStage,
index_t BlockSize,
index_t Gemm0MPerBlock,
index_t Gemm0NPerBlock,
index_t Gemm0KPerBlock,
index_t Gemm1NPerBlock,
index_t Gemm1KPerBlock,
index_t A0K1Value,
index_t B0K1Value,
index_t B1K1Value,
index_t Gemm0MPerXdl,
index_t Gemm0NPerXdl,
index_t Gemm0MXdlPerWave,
index_t Gemm0NXdlPerWave,
index_t Gemm1NXdlPerWave,
typename A0BlockTransferThreadClusterLengths_AK0_M_AK1,
typename A0BlockTransferThreadClusterArrangeOrder,
typename A0BlockTransferSrcAccessOrder,
index_t A0BlockTransferSrcVectorDim,
index_t A0BlockTransferSrcScalarPerVector,
index_t A0BlockTransferDstScalarPerVector_AK1,
bool A0ThreadTransferSrcResetCoordinateAfterRun, // ignored
index_t A0BlockLdsExtraM,
typename B0BlockTransferThreadClusterLengths_BK0_N_BK1,
typename B0BlockTransferThreadClusterArrangeOrder,
typename B0BlockTransferSrcAccessOrder,
index_t B0BlockTransferSrcVectorDim,
index_t B0BlockTransferSrcScalarPerVector,
index_t B0BlockTransferDstScalarPerVector_BK1,
bool B0ThreadTransferSrcResetCoordinateAfterRun, // ignored
index_t B0BlockLdsExtraN,
typename B1BlockTransferThreadClusterLengths_BK0_N_BK1,
typename B1BlockTransferThreadClusterArrangeOrder,
typename B1BlockTransferSrcAccessOrder,
index_t B1BlockTransferSrcVectorDim,
index_t B1BlockTransferSrcScalarPerVector,
index_t B1BlockTransferDstScalarPerVector_BK1,
bool B1ThreadTransferSrcResetCoordinateAfterRun,
index_t B1BlockLdsExtraN,
index_t C1ShuffleGemm0MXdlPerWavePerShuffle,
index_t C1ShuffleGemm0NXdlPerWavePerShuffle,
typename CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDE1ShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched>
struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
{
static_assert(LoopSched == LoopScheduler::Default,
"Non-default loop scheduler is currently not supported");
static constexpr index_t NumD0Tensor = D0sDataType::Size();
static constexpr index_t NumD1Tensor = D1sDataType::Size();
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
static constexpr auto WaveSize = 64;
// K1 should be Number<...>
// Gemm0
static constexpr auto A0K1 = Number<A0K1Value>{};
static constexpr auto B0K1 = Number<B0K1Value>{};
static constexpr auto A0K0PerBlock = Number<Gemm0KPerBlock / A0K1Value>{};
static constexpr auto B0K0PerBlock = Number<Gemm0KPerBlock / B0K1Value>{};
static constexpr auto Gemm0MWaves = Gemm0MPerBlock / (Gemm0MPerXdl * Gemm0MXdlPerWave);
static constexpr auto Gemm0NWaves = Gemm0NPerBlock / (Gemm0NPerXdl * Gemm0NXdlPerWave);
// Gemm1
static constexpr auto B1K1 = Number<B1K1Value>{};
static constexpr auto B1K0PerBlock = Number<Gemm1KPerBlock / B1K1Value>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = GridwiseGemmPipeline_v1<NumGemm0KPrefetchStage>;
// ck::Tuple<const D0DataType1*, const D0DataType2*, ...>
static constexpr auto MakeD0sGridPointer()
{
return generate_tuple(
[&](auto i) {
using D0DataType = remove_cvref_t<tuple_element_t<i.value, D0sDataType>>;
return static_cast<const D0DataType*>(nullptr);
},
Number<NumD0Tensor>{});
}
// ck::Tuple<const D1DataType1*, const D1DataType2*, ...>
static constexpr auto MakeD1sGridPointer()
{
return generate_tuple(
[&](auto i) {
using D1DataType = remove_cvref_t<tuple_element_t<i.value, D1sDataType>>;
return static_cast<const D1DataType*>(nullptr);
},
Number<NumD1Tensor>{});
}
__device__ static auto GetGemm0WaveIdx()
{
const index_t thread_id = get_thread_local_1d_id();
constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(Gemm0MWaves, Gemm0NWaves, WaveSize))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
}
__device__ static auto GetGemm0WaveMNIdx(const index_t thread_id)
{
constexpr auto wave_threadid_to_mn_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(WaveSize / Gemm0NPerXdl, Gemm0NPerXdl))),
make_tuple(Sequence<0, 1>{}),
make_tuple(Sequence<0>{}));
return wave_threadid_to_mn_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
}
template <typename A0BlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto
MakeGemm0AMmaTileDescriptor_M0_M1_M2_K(const A0BlockDesc_AK0_M_AK1&)
{
constexpr index_t MWaves = Gemm0MPerBlock / (Gemm0MXdlPerWave * Gemm0MPerXdl);
return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<Gemm0MXdlPerWave, MWaves, Gemm0MPerXdl>(
A0BlockDesc_AK0_M_AK1{});
}
template <typename BBlockDesc_BK0_N_BK1>
__host__ __device__ static constexpr auto
MakeGemm0BMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
{
constexpr index_t NWaves = Gemm0NPerBlock / (Gemm0NXdlPerWave * Gemm0NPerXdl);
return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<Gemm0NXdlPerWave, NWaves, Gemm0NPerXdl>(
BBlockDesc_BK0_N_BK1{});
}
template <typename A0BlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto
MakeGemm1AMmaTileDescriptor_M0_M1_M2_K(const A0BlockDesc_AK0_M_AK1&)
{
return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<Gemm0MXdlPerWave, 1, 1>(
A0BlockDesc_AK0_M_AK1{});
}
template <typename BBlockDesc_BK0_N_BK1>
__host__ __device__ static constexpr auto
MakeGemm1BMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
{
constexpr index_t Gemm1NWaves = Gemm1NPerBlock / (Gemm1NXdlPerWave * Gemm0NPerXdl);
return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<Gemm1NXdlPerWave, Gemm1NWaves, Gemm0NPerXdl>(
BBlockDesc_BK0_N_BK1{});
}
__host__ __device__ static constexpr auto GetA0BlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{
// A0 matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(A0K0PerBlock, Number<Gemm0MPerBlock>{}, A0K1),
make_tuple(Number<Gemm0MPerBlock + A0BlockLdsExtraM>{} * A0K1, A0K1, I1));
}
__host__ __device__ static constexpr auto GetB0BlockDescriptor_BK0PerBlock_NPerBlock_BK1()
{
// B0 matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(B0K0PerBlock, Number<Gemm0NPerBlock>{}, B0K1),
make_tuple(Number<Gemm0NPerBlock + B0BlockLdsExtraN>{} * B0K1, B0K1, I1));
}
__host__ __device__ static constexpr auto GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1()
{
// B1 matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(B1K0PerBlock, Number<Gemm1NPerBlock>{}, B1K1),
make_tuple(Number<Gemm1NPerBlock + B1BlockLdsExtraN>{} * B1K1, B1K1, I1));
}
__host__ __device__ static constexpr auto
GetC1ShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
{
constexpr index_t MWave = Gemm0MPerBlock / (Gemm0MXdlPerWave * Gemm0MPerXdl);
constexpr index_t NWave = Gemm1NPerBlock / (Gemm1NXdlPerWave * Gemm0NPerXdl);
constexpr auto c1_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
make_naive_tensor_descriptor_packed(
make_tuple(I1,
Number<C1ShuffleGemm0MXdlPerWavePerShuffle * MWave * Gemm0MPerXdl>{},
I1,
Number<C1ShuffleGemm0NXdlPerWavePerShuffle * NWave * Gemm0NPerXdl>{}));
return c1_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
}
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
const index_t gemm0_bytes_end = (SharedMemTrait::a0_block_space_size_aligned +
SharedMemTrait::b0_block_space_size_aligned) *
sizeof(A0B0B1DataType);
const index_t gemm1_bytes_end =
(SharedMemTrait::b1_block_space_offset + SharedMemTrait::b1_block_space_size_aligned) *
sizeof(A0B0B1DataType);
const index_t c1_block_bytes_end =
SharedMemTrait::c1_block_space_size * sizeof(C1ShuffleDataType);
return math::max(gemm0_bytes_end, gemm1_bytes_end, c1_block_bytes_end);
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Block2E1TileMap>
__host__ __device__ static constexpr bool
CheckValidity(const A0GridDesc_M_K& a0_grid_desc_m_k,
const B0GridDesc_N_K& b0_grid_desc_n_k,
const B1GridDesc_N_K& b1_grid_desc_n_k,
const E1GridDesc_M_N& e1_grid_desc_m_n,
const Block2E1TileMap& block_2_e1tile_map)
{
static_assert((Gemm0MPerBlock % (Gemm0MPerXdl * Gemm0MXdlPerWave) == 0) &&
(Gemm0NPerBlock % (Gemm0NXdlPerWave * Gemm0NPerXdl)) == 0,
"Invalid tuning param!");
const auto M = a0_grid_desc_m_k.GetLength(I0);
const auto N = b0_grid_desc_n_k.GetLength(I0);
const auto K = a0_grid_desc_m_k.GetLength(I1);
const auto Gemm1N = b1_grid_desc_n_k.GetLength(I0);
if(!(M == e1_grid_desc_m_n.GetLength(I0) && Gemm1N == e1_grid_desc_m_n.GetLength(I1)))
{
return false;
}
if(!(M % Gemm0MPerBlock == 0 && N % Gemm0NPerBlock == 0 && K % Gemm0KPerBlock == 0 &&
Gemm1N % Gemm1NPerBlock == 0))
{
return false;
}
// check gemm0 gridwise gemm pipeline
const auto num_gemm0_k_loop = K / Gemm0KPerBlock;
if(!GridwiseGemmPipe::IsSupported(num_gemm0_k_loop))
{
return false;
}
// check gemm1 gridwise gemm pipeline
if(!(Gemm0NPerBlock % Gemm1KPerBlock == 0))
{
return false;
}
const auto num_gemm1_k_inner_loop = Gemm0NPerBlock / Gemm1KPerBlock;
if(!GridwiseGemmPipe::IsSupported(num_gemm1_k_inner_loop))
{
return false;
}
if(!block_2_e1tile_map.CheckValidity(e1_grid_desc_m_n))
{
return false;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return true;
}
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
{
const index_t num_loop = K / Gemm0KPerBlock;
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
}
// A0 desc for source in blockwise copy
__host__ __device__ static constexpr auto
MakeDefaultA0GridDescriptor_AK0_M_AK1(const A0GridDesc_M_K& a0_grid_desc_m_k)
{
const auto M = a0_grid_desc_m_k.GetLength(I0);
const auto K = a0_grid_desc_m_k.GetLength(I1);
const auto A0K0 = K / A0K1;
return transform_tensor_descriptor(
a0_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(A0K0, A0K1)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
// B0 desc for source in blockwise copy
__host__ __device__ static constexpr auto
MakeDefaultB0GridDescriptor_BK0_N_BK1(const B0GridDesc_N_K& b0_grid_desc_n_k)
{
const auto N = b0_grid_desc_n_k.GetLength(I0);
const auto K = b0_grid_desc_n_k.GetLength(I1);
const auto B0K0 = K / B0K1;
return transform_tensor_descriptor(
b0_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(B0K0, B0K1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
// D0 desc for source in blockwise copy
template <typename D0GridDesc_M_N>
__host__ __device__ static constexpr auto
MakeGemm0D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(const D0GridDesc_M_N& d0_grid_desc_m_n)
{
const auto M = d0_grid_desc_m_n.GetLength(I0);
const auto N = d0_grid_desc_m_n.GetLength(I1);
constexpr auto mfma =
MfmaSelector<A0B0B1DataType, Gemm0MPerXdl, Gemm0NPerXdl>::selected_mfma;
constexpr auto N3 = mfma.num_groups_per_blk;
constexpr auto N5 = mfma.group_size;
return transform_tensor_descriptor(
d0_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(
M / Gemm0MPerBlock, Gemm0MXdlPerWave, Gemm0MWaves, Gemm0MPerXdl)),
make_unmerge_transform(make_tuple(N / Gemm0NPerBlock,
Gemm0NXdlPerWave,
Gemm0NWaves,
N3,
WaveSize / Gemm0NPerXdl,
N5))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7, 8, 9>{}));
}
// B1 desc for source in blockwise copy
__host__ __device__ static constexpr auto
MakeDefaultB1GridDescriptor_BK0_N_BK1(const B1GridDesc_N_K& b1_grid_desc_n_k)
{
const auto N = b1_grid_desc_n_k.GetLength(I0);
const auto K = b1_grid_desc_n_k.GetLength(I1);
const auto B1K0 = K / B1K1;
return transform_tensor_descriptor(
b1_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
// C1 desc for destination in blockwise copy
__host__ __device__ static constexpr auto
MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const E1GridDesc_M_N& e1_grid_desc_m_n)
{
const auto M = e1_grid_desc_m_n.GetLength(I0);
const auto N = e1_grid_desc_m_n.GetLength(I1);
const auto MBlock = M / Gemm0MPerBlock;
const auto NBlock = N / Gemm1NPerBlock;
const auto e1_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
e1_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<Gemm0MPerBlock>{})),
make_unmerge_transform(make_tuple(NBlock, Number<Gemm1NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
return e1_grid_desc_mblock_mperblock_nblock_nperblock;
}
// D0s desc for source in blockwise copy
__host__ __device__ static constexpr auto
MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(const D0sGridDesc_M_N& ds_grid_desc_m_n)
{
return generate_tuple(
[&](auto i) {
return MakeGemm0D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(ds_grid_desc_m_n[i]);
},
Number<NumD0Tensor>{});
}
// Ds desc for source in blockwise copy
template <typename DsGridDescriptor_M_N>
__host__ __device__ static constexpr auto
MakeD1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
const DsGridDescriptor_M_N& ds_grid_desc_m_n)
{
return generate_tuple(
[&](auto i) {
return MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[i]);
},
Number<NumD1Tensor>{});
}
// return block_id to C1 matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto
MakeDefaultBlock2E1TileMap(const E1GridDesc_M_N& e1_grid_desc_m_n)
{
return BlockToCTileMap_M00_N0_M01Adapt<Gemm0MPerBlock, Gemm1NPerBlock, E1GridDesc_M_N>(
e1_grid_desc_m_n);
}
using E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(E1GridDesc_M_N{}))>;
using D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 = remove_cvref_t<decltype(
MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(D0sGridDesc_M_N{}))>;
using D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeD1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(D1sGridDesc_M_N{}))>;
using DefaultBlock2E1TileMap =
remove_cvref_t<decltype(MakeDefaultBlock2E1TileMap(E1GridDesc_M_N{}))>;
struct SharedMemTrait
{
// LDS allocation for A0 and B0: be careful of alignment
static constexpr auto a0_block_desc_ak0_m_ak1 =
GetA0BlockDescriptor_AK0PerBlock_MPerBlock_AK1();
static constexpr auto b0_block_desc_bk0_n_bk1 =
GetB0BlockDescriptor_BK0PerBlock_NPerBlock_BK1();
static constexpr auto b1_block_desc_bk0_n_bk1 =
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1();
static constexpr auto max_lds_align = math::lcm(math::lcm(A0K1, B0K1), B1K1);
static constexpr auto a0_block_space_size_aligned = math::integer_least_multiple(
a0_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
static constexpr auto b0_block_space_size_aligned = math::integer_least_multiple(
b0_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
static constexpr auto b1_block_space_size_aligned = math::integer_least_multiple(
b1_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
static constexpr auto a0_block_space_offset = 0;
static constexpr auto b0_block_space_offset = a0_block_space_size_aligned.value;
static constexpr auto b1_block_space_offset = 0;
// LDS allocation for C1 shuffle in LDS
static constexpr auto c1_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetC1ShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
static constexpr auto c1_block_space_size =
c1_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
};
using D0sGridPointer = decltype(MakeD0sGridPointer());
using D1sGridPointer = decltype(MakeD1sGridPointer());
template <bool HasMainKBlockLoop,
typename A0GridDesc_AK0_M_AK1,
typename B0GridDesc_BK0_N_BK1,
typename B1GridDesc_BK0_N_BK1,
typename Block2E1TileMap>
__device__ static void Run(const A0B0B1DataType* __restrict__ p_a0_grid,
const A0B0B1DataType* __restrict__ p_b0_grid,
D0sGridPointer p_d0s_grid,
const A0B0B1DataType* __restrict__ p_b1_grid,
D1sGridPointer p_d1s_grid,
E1DataType* __restrict__ p_e1_grid,
void* __restrict__ p_shared,
const A0ElementwiseOperation& a0_element_op,
const B0ElementwiseOperation& b0_element_op,
const CDE0ElementwiseOperation& cde0_element_op,
const B1ElementwiseOperation& b1_element_op,
const CDE1ElementwiseOperation& cde1_element_op,
const A0GridDesc_AK0_M_AK1& a0_grid_desc_ak0_m_ak1,
const B0GridDesc_BK0_N_BK1& b0_grid_desc_bk0_n_bk1,
const D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5&
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
const B1GridDesc_BK0_N_BK1& b1_grid_desc_bk0_n_bk1,
const D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
d1s_grid_desc_mblock_mperblock_nblock_nperblock,
const E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
e1_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2E1TileMap& block_2_e1tile_map)
{
const auto a0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a0_grid, a0_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b0_grid, b0_grid_desc_bk0_n_bk1.GetElementSpaceSize());
const auto b1_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b1_grid, b1_grid_desc_bk0_n_bk1.GetElementSpaceSize());
auto e1_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_e1_grid, e1_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
const auto d0s_grid_buf = generate_tuple(
[&](auto i) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d0s_grid[i],
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i].GetElementSpaceSize());
},
Number<NumD0Tensor>{});
const auto d1s_grid_buf = generate_tuple(
[&](auto i) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d1s_grid[i],
d1s_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
},
Number<NumD1Tensor>{});
// divide block work by [M, N]
const auto block_work_idx =
block_2_e1tile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!block_2_e1tile_map.ValidCTileIndex(
block_work_idx,
make_tuple(e1_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
e1_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
{
return;
}
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * Gemm0MPerBlock);
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * Gemm1NPerBlock);
// A0 matrix in LDS memory, dst of blockwise copy
constexpr auto a0_block_desc_ak0_m_ak1 = GetA0BlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// B0 matrix in LDS memory, dst of blockwise copy
constexpr auto b0_block_desc_bk0_n_bk1 = GetB0BlockDescriptor_BK0PerBlock_NPerBlock_BK1();
//
// set up Gemm0
//
// A0 matrix blockwise copy
auto a0_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
A0ElementwiseOperation,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<A0K0PerBlock, Gemm0MPerBlock, A0K1>,
A0BlockTransferThreadClusterLengths_AK0_M_AK1,
A0BlockTransferThreadClusterArrangeOrder,
A0B0B1DataType,
A0B0B1DataType,
decltype(a0_grid_desc_ak0_m_ak1),
decltype(a0_block_desc_ak0_m_ak1),
A0BlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
A0BlockTransferSrcVectorDim,
2,
A0BlockTransferSrcScalarPerVector,
A0BlockTransferDstScalarPerVector_AK1,
1,
1,
true, // SrcResetCoord
true, // DstResetCoord
NumGemm0KPrefetchStage>(
a0_grid_desc_ak0_m_ak1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
a0_element_op,
a0_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
// B0 matrix blockwise copy
auto b0_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
B0ElementwiseOperation,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<B0K0PerBlock, Gemm0NPerBlock, B0K1>,
B0BlockTransferThreadClusterLengths_BK0_N_BK1,
B0BlockTransferThreadClusterArrangeOrder,
A0B0B1DataType,
A0B0B1DataType,
decltype(b0_grid_desc_bk0_n_bk1),
decltype(b0_block_desc_bk0_n_bk1),
B0BlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
B0BlockTransferSrcVectorDim,
2,
B0BlockTransferSrcScalarPerVector,
B0BlockTransferDstScalarPerVector_BK1,
1,
1,
true, // SrcResetCoord
true, // DstResetCoord
NumGemm0KPrefetchStage>(
b0_grid_desc_bk0_n_bk1,
make_multi_index(0, 0, 0), // will loop over GemmN dimension
b0_element_op,
b0_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
// Fused Gemm+Gemm pipeline
// for n in N0:
// for k in K0:
// acc[m][n] += A[m][k] * B0[k][n]
// acc1[m][o] += acc[m][n] * B1[n][o]
// sanity check
constexpr index_t KPack = math::max(
math::lcm(A0K1, B0K1),
MfmaSelector<A0B0B1DataType, Gemm0MPerXdl, Gemm0NPerXdl>::selected_mfma.k_per_blk);
auto blockwise_gemm0 = BlockwiseGemmXdlops_v2<
BlockSize,
A0B0B1DataType,
Acc0DataType,
decltype(a0_block_desc_ak0_m_ak1),
decltype(b0_block_desc_bk0_n_bk1),
decltype(MakeGemm0AMmaTileDescriptor_M0_M1_M2_K(a0_block_desc_ak0_m_ak1)),
decltype(MakeGemm0BMmaTileDescriptor_N0_N1_N2_K(b0_block_desc_bk0_n_bk1)),
Gemm0MPerBlock,
Gemm0NPerBlock,
Gemm0KPerBlock,
Gemm0MPerXdl,
Gemm0NPerXdl,
Gemm0MXdlPerWave,
Gemm0NXdlPerWave,
KPack,
true>{}; // TransposeC
auto acc0_thread_buf = blockwise_gemm0.GetCThreadBuffer();
// LDS allocation for A0 and B0: be careful of alignment
auto a0_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<A0B0B1DataType*>(p_shared) + SharedMemTrait::a0_block_space_offset,
a0_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b0_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<A0B0B1DataType*>(p_shared) + SharedMemTrait::b0_block_space_offset,
b0_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a0_block_slice_copy_step = make_multi_index(Gemm0KPerBlock / A0K1, 0, 0);
constexpr auto b0_block_slice_copy_step = make_multi_index(Gemm0KPerBlock / B0K1, 0, 0);
const auto a0_block_reset_copy_step =
make_multi_index(-a0_grid_desc_ak0_m_ak1.GetLength(I0), 0, 0);
const auto b0_block_reset_copy_step =
make_multi_index(-b0_grid_desc_bk0_n_bk1.GetLength(I0), Gemm0NPerBlock, 0);
// gridwise GEMM pipeline
// Only supports LoopScheduler::Default
const auto gridwise_gemm0_pipeline =
GridwiseGemmPipeline_v1_Selector<NumGemm0KPrefetchStage, LoopScheduler::Default>();
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a0_grid_desc_ak0_m_ak1.GetLength(I0) * a0_grid_desc_ak0_m_ak1.GetLength(I2)) /
Gemm0KPerBlock);
//
// set up Gemm1
//
// Acc0 matrix threadwise copy: AccVGPR to VGPR and downcast to XDL input data type
constexpr auto acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
blockwise_gemm0.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
constexpr auto m0 = acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0);
constexpr auto n0 = acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I1);
constexpr auto m1 = acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I2);
constexpr auto n1 = acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I3);
constexpr auto m2 = acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I4);
constexpr auto n2 = acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I5);
constexpr auto n3 = acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I6);
constexpr auto n4 = acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I7);
constexpr auto b1_block_slice_copy_step = make_multi_index(Gemm1KPerBlock / B1K1, 0, 0);
// d0 matrix threadwise copy
constexpr auto d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId
I1, // NBlockID
I1, // MRepeat
I1, // NRepeat
I1, // MWaveId
I1, // NWaveId
I1, // MPerXdl
I1, // NGroupNum
I1, // NInputNum
n4)); // registerNum
auto d0s_thread_buf = generate_tuple(
[&](auto) {
return StaticBuffer<
AddressSpaceEnum::Vgpr,
A0B0B1DataType,
d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize(),
true>{};
},
Number<NumD0Tensor>{});
const auto wave_id = GetGemm0WaveIdx();
const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63
constexpr auto acc0_thread_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<Gemm0MXdlPerWave>{}, Number<Gemm0NXdlPerWave>{}, n2, n4));
auto d0s_threadwise_copy = generate_tuple(
[&](auto i) {
return ThreadwiseTensorSliceTransfer_v2<
A0B0B1DataType,
A0B0B1DataType,
decltype(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i]),
decltype(d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
Sequence<I1, I1, I1, I1, I1, I1, I1, I1, I1, n4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
9,
n4,
1,
false>(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
make_multi_index(block_work_idx[I0], // MBlockId
0, // NBlockId
0, // mrepeat
0, // nrepeat
wave_id[I0], // MWaveId
wave_id[I1], // NWaveId
wave_m_n_id[I1], // MPerXdl
0, // group
wave_m_n_id[I0], // NInputIndex
0)); // register number
},
Number<NumD0Tensor>{});
// acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 to acc0_thread_desc_k0_m_k1
// n0_n1_n2_n3 -> k0
// m0_m1_m2 -> m
// n4 -> k1
// NOTE: had to use merge_v3 or will spit out compilation errors
constexpr auto acc0_thread_desc_k0_m_k1 = transform_tensor_descriptor(
acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_tuple(make_merge_transform_v3_division_mod(make_tuple(n0, n1, n2, n3)),
make_merge_transform_v3_division_mod(make_tuple(m0, m1, m2)),
make_pass_through_transform(n4)),
make_tuple(Sequence<1, 3, 5, 6>{}, Sequence<0, 2, 4>{}, Sequence<7>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// A1 matrix in AccVGPR
// N2 num_groups_per_blk, N3 num_input_blks, N4 group_size
constexpr auto Acc0N3 =
blockwise_gemm0.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLength(I6);
constexpr auto A1ThreadSlice_K0_M_K1 = make_tuple(
Number<Gemm1KPerBlock / n4 / Acc0N3>{}, Number<m0 * m1 * m2>{}, Number<n4>{});
constexpr auto A1ThreadSliceK0 = A1ThreadSlice_K0_M_K1[I0];
constexpr auto A1ThreadSliceM = A1ThreadSlice_K0_M_K1[I1];
constexpr auto A1ThreadSliceK1 = A1ThreadSlice_K0_M_K1[I2];
constexpr auto a1_thread_desc_k0_m_k1 = make_naive_tensor_descriptor(
A1ThreadSlice_K0_M_K1,
make_tuple(A1ThreadSliceM * A1ThreadSliceK1, A1ThreadSliceK1, I1));
// B1 matrix in LDS memory, dst of blockwise copy
constexpr auto b1_block_desc_bk0_n_bk1 = GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// A1 matrix blockwise copy
auto a1_blockwise_copy = ThreadwiseTensorSliceTransfer_StaticToStatic<
Acc0DataType,
A0B0B1DataType,
decltype(acc0_thread_desc_k0_m_k1),
decltype(a1_thread_desc_k0_m_k1),
tensor_operation::element_wise::PassThrough,
Sequence<A1ThreadSliceK0, A1ThreadSliceM, A1ThreadSliceK1>,
Sequence<1, 0, 2>,
2,
n4>{tensor_operation::element_wise::PassThrough{}};
// B1 matrix blockwise copy
auto b1_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
B0ElementwiseOperation,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<B1K0PerBlock, Gemm1NPerBlock, B1K1>,
B1BlockTransferThreadClusterLengths_BK0_N_BK1,
B1BlockTransferThreadClusterArrangeOrder,
A0B0B1DataType,
A0B0B1DataType,
decltype(b1_grid_desc_bk0_n_bk1),
decltype(b1_block_desc_bk0_n_bk1),
B1BlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
B1BlockTransferSrcVectorDim,
2,
B1BlockTransferSrcScalarPerVector,
B1BlockTransferDstScalarPerVector_BK1,
1,
1,
B1ThreadTransferSrcResetCoordinateAfterRun,
true, // DstResetCoord
1>(b1_grid_desc_bk0_n_bk1,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b1_element_op,
b1_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
auto a1_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, A0B0B1DataType>(
a1_thread_desc_k0_m_k1.GetElementSpaceSize());
// reuse LDS space for gemm0's b0_block_buf
auto b1_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<A0B0B1DataType*>(p_shared) + SharedMemTrait::b1_block_space_offset,
b1_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr index_t Gemm1KPack = math::max(
math::lcm(
MfmaSelector<A0B0B1DataType, Gemm0MPerXdl, Gemm0NPerXdl>::selected_mfma.group_size,
B1K1),
MfmaSelector<A0B0B1DataType, Gemm0MPerXdl, Gemm0NPerXdl>::selected_mfma.k_per_blk);
auto blockwise_gemm1 = BlockwiseGemmXdlops_v2<
BlockSize,
A0B0B1DataType,
Acc1DataType,
decltype(a1_thread_desc_k0_m_k1),
decltype(b1_block_desc_bk0_n_bk1),
decltype(MakeGemm1AMmaTileDescriptor_M0_M1_M2_K(a1_thread_desc_k0_m_k1)),
decltype(MakeGemm1BMmaTileDescriptor_N0_N1_N2_K(b1_block_desc_bk0_n_bk1)),
Gemm0MPerBlock,
Gemm1NPerBlock,
Gemm1KPerBlock,
Gemm0MPerXdl,
Gemm0NPerXdl,
Gemm0MXdlPerWave,
Gemm1NXdlPerWave,
Gemm1KPack,
false, // TransposeC
Gemm1KPack, // AMmaKStride
Gemm1KPack * XdlopsGemm<A0B0B1DataType, Gemm0MPerXdl, Gemm0NPerXdl, Gemm1KPack, false>{}
.K0PerXdlops>{ // BMmaKStride
make_tuple(0, 0, 0, 0)}; // A_origin
auto c1_thread_buf = blockwise_gemm1.GetCThreadBuffer();
const index_t num_gemm1_k_block_outer_loop =
b0_grid_desc_bk0_n_bk1.GetLength(I1) / Gemm0NPerBlock;
constexpr index_t num_gemm1_k_block_inner_loop = Gemm0NPerBlock / Gemm1KPerBlock;
// Initialize C1
c1_thread_buf.Clear();
// gemm1 K loop
index_t gemm1_k_block_outer_index = 0;
do
{
// gemm0
gridwise_gemm0_pipeline.template Run<HasMainKBlockLoop>(a0_grid_desc_ak0_m_ak1,
a0_block_desc_ak0_m_ak1,
a0_blockwise_copy,
a0_grid_buf,
a0_block_buf,
a0_block_slice_copy_step,
b0_grid_desc_bk0_n_bk1,
b0_block_desc_bk0_n_bk1,
b0_blockwise_copy,
b0_grid_buf,
b0_block_buf,
b0_block_slice_copy_step,
blockwise_gemm0,
acc0_thread_buf,
num_k_block_main_loop);
// bias+gelu
{
static_for<0, Gemm0MXdlPerWave, 1>{}([&](auto mr) {
static_for<0, Gemm0NXdlPerWave, 1>{}([&](auto nr) {
static_for<0, n2, 1>{}([&](auto groupid) {
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
d0s_threadwise_copy(i).Run(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
d0s_grid_buf[i],
d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
d0s_thread_buf(i));
});
static_for<0, n4, 1>{}([&](auto i) {
constexpr index_t c_offset = acc0_thread_desc.CalculateOffset(
make_tuple(mr, nr, groupid, i));
// get reference to src data
const auto src_data_refs = generate_tie(
// return type should be lvalue
[&](auto iSrc) -> const auto& {
return d0s_thread_buf[iSrc][i];
},
Number<NumD0Tensor>{});
// get reference to dst data
auto dst_data_refs = generate_tie(
// return type should be lvalue
[&](auto) -> auto& {
return acc0_thread_buf(Number<c_offset>{});
},
Number<2>{});
unpack2(cde0_element_op, dst_data_refs, src_data_refs);
});
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
d0s_threadwise_copy(i).MoveSrcSliceWindow(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
make_multi_index(0, 0, 0, 0, 0, 0, 0, 1, 0, 0));
});
});
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
d0s_threadwise_copy(i).MoveSrcSliceWindow(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
make_multi_index(0, 0, 0, 1, 0, 0, 0, -n2.value, 0, 0));
});
});
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
d0s_threadwise_copy(i).MoveSrcSliceWindow(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
make_multi_index(0, 0, 1, -Gemm0NXdlPerWave, 0, 0, 0, 0, 0, 0));
});
});
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
d0s_threadwise_copy(i).MoveSrcSliceWindow(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
make_multi_index(0, 1, -Gemm0MXdlPerWave, 0, 0, 0, 0, 0, 0, 0));
});
}
// gemm1
{
// TODO: explore using dynamic buffer for a1 thread buffer
// For a1_blockwise_copy, the goal is to satisfy pipeline requirements RunRead(),
// RunWrite(), and MoveSliceWindow(). But it is impossible to implement given that
// the A1 source buffer is static buffer holding the output of first GEMM and
// requires constexpr offset by design. Therefore, we pass tensor coordinate offset
// explicitly in Run() below.
// preload data into LDS
b1_blockwise_copy.RunRead(b1_grid_desc_bk0_n_bk1, b1_grid_buf);
b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc_bk0_n_bk1,
b1_block_slice_copy_step);
block_sync_lds(); // wait for gemm0 LDS read
b1_blockwise_copy.RunWrite(b1_block_desc_bk0_n_bk1, b1_block_buf);
// main body
if constexpr(num_gemm1_k_block_inner_loop > 1)
{
static_for<0, num_gemm1_k_block_inner_loop - 1, 1>{}([&](auto i) {
a1_blockwise_copy.Run(acc0_thread_desc_k0_m_k1,
make_tuple(Number<i * A1ThreadSliceK0>{}, I0, I0),
acc0_thread_buf,
a1_thread_desc_k0_m_k1,
make_tuple(I0, I0, I0),
a1_thread_buf);
b1_blockwise_copy.RunRead(b1_grid_desc_bk0_n_bk1, b1_grid_buf);
block_sync_lds();
blockwise_gemm1.Run(a1_thread_buf, b1_block_buf, c1_thread_buf);
block_sync_lds();
b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc_bk0_n_bk1,
b1_block_slice_copy_step);
b1_blockwise_copy.RunWrite(b1_block_desc_bk0_n_bk1, b1_block_buf);
});
}
// tail
{
a1_blockwise_copy.Run(
acc0_thread_desc_k0_m_k1,
make_tuple(
Number<(num_gemm1_k_block_inner_loop - 1) * A1ThreadSliceK0>{}, I0, I0),
acc0_thread_buf,
a1_thread_desc_k0_m_k1,
make_tuple(I0, I0, I0),
a1_thread_buf);
block_sync_lds();
blockwise_gemm1.Run(a1_thread_buf, b1_block_buf, c1_thread_buf);
}
} // end gemm1
a0_blockwise_copy.MoveSrcSliceWindow(a0_grid_desc_ak0_m_ak1,
a0_block_reset_copy_step); // rewind K
b0_blockwise_copy.MoveSrcSliceWindow(b0_grid_desc_bk0_n_bk1,
b0_block_reset_copy_step); // rewind K and step N
block_sync_lds(); // wait for gemm1 LDS read
} while(++gemm1_k_block_outer_index < num_gemm1_k_block_outer_loop); // end j loop
// shuffle C1 and write out
{
static_assert(Gemm0MXdlPerWave % C1ShuffleGemm0MXdlPerWavePerShuffle == 0 &&
Gemm1NXdlPerWave % C1ShuffleGemm0NXdlPerWavePerShuffle == 0,
"wrong!");
constexpr index_t MWave = Gemm0MPerBlock / (Gemm0MXdlPerWave * Gemm0MPerXdl);
constexpr index_t NWave = Gemm1NPerBlock / (Gemm1NXdlPerWave * Gemm0NPerXdl);
// TODO: hacky, fix it!
constexpr auto c1_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
blockwise_gemm1.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
// TODO: hacky, fix it!
// c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
constexpr auto c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
blockwise_gemm1.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto M0 = c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
constexpr auto N0 = c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
constexpr auto M1 = c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
constexpr auto N1 = c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
constexpr auto M2 = c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
constexpr auto M3 = c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
constexpr auto M4 = c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
constexpr auto N2 = c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
constexpr auto c1_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetC1ShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
auto c1_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<C1ShuffleDataType*>(p_shared),
c1_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
constexpr auto c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
c1_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_tuple(
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<C1ShuffleGemm0MXdlPerWavePerShuffle>{}, // M0 (Gemm0MXdlPerWave) per
// shuffle
M1, // M1 = MWave
M2, // M2 * M3 * M4 = Gemm0MPerXdl
M3,
M4)),
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<C1ShuffleGemm0NXdlPerWavePerShuffle>{}, // N0 (Gemm0NXdlPerWave) per
// shuffle
N1, // N1 = NWave
N2))), // N2 = Gemm0NPerXdl
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(
Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM C1 matrix starting index
const auto c1_thread_mtx_on_block =
blockwise_gemm1.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_block = c1_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c1_thread_mtx_on_block[I1];
const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_block_idx =
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_block));
const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_block_idx =
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block));
// shuffle: threadwise copy C from VGPR to LDS
auto c1_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<Acc1DataType,
C1ShuffleDataType,
decltype(c1_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
tensor_operation::element_wise::PassThrough,
Sequence<C1ShuffleGemm0MXdlPerWavePerShuffle,
C1ShuffleGemm0NXdlPerWavePerShuffle,
I1,
I1,
M2,
I1,
M4,
I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
1,
InMemoryDataOperationEnum::Set,
1,
true>{
c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(0,
0,
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I3],
m_thread_data_on_block_idx[I4],
n_thread_data_on_block_idx[I2]),
tensor_operation::element_wise::PassThrough{}};
// tuple of reference to C/Ds tensor descriptors
const auto c1_d1s_desc_refs = concat_tuple_of_reference(
tie(c1_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
generate_tie(
[&](auto i) -> const auto& // return type should be reference
{ return d1s_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
Number<NumD1Tensor>{}));
// tuple of reference to C/Ds tensor descriptors
const auto c1_d1s_buf_refs = concat_tuple_of_reference(
tie(c1_shuffle_block_buf),
generate_tie(
[&](auto i) -> const auto& // return type should be reference
{ return d1s_grid_buf[i]; },
Number<NumD1Tensor>{}));
// tuple of starting index of C/Ds blockwise copy
const auto idx_c1_d1s_block_begin = container_concat(
make_tuple(make_multi_index(0, 0, 0, 0)),
generate_tuple(
[&](auto) {
return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0);
},
Number<NumD1Tensor>{}));
// shuffle: blockwise copy C from LDS to global
auto cde1_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v7<
ThisThreadBlock,
decltype(container_concat(make_tuple(C1ShuffleDataType{}), D1sDataType{})),
Tuple<E1DataType>,
decltype(c1_d1s_desc_refs),
decltype(tie(e1_grid_desc_mblock_mperblock_nblock_nperblock)),
CDE1ElementwiseOperation,
Sequence<static_cast<index_t>(E1GlobalMemoryDataOperation)>, // FIXME: make Sequence
// support arbitray
// type
Sequence<1,
C1ShuffleGemm0MXdlPerWavePerShuffle * MWave * Gemm0MPerXdl,
1,
C1ShuffleGemm0NXdlPerWavePerShuffle * NWave *
Gemm0NPerXdl>, // BlockSliceLengths,
CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CDE1ShuffleBlockTransferScalarPerVector_NPerBlock,
sequence_merge_t<
Sequence<true>,
uniform_sequence_gen_t<NumD1Tensor,
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
Sequence<false>> // ThreadTransferDstResetCoordinateAfterRunFlags
{c1_d1s_desc_refs,
idx_c1_d1s_block_begin,
tie(e1_grid_desc_mblock_mperblock_nblock_nperblock),
make_tuple(make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)),
cde1_element_op};
// space filling curve for threadwise C in VGPR
constexpr auto sfc_c1_vgpr =
SpaceFillingCurve<Sequence<Gemm0MXdlPerWave, Gemm1NXdlPerWave, 1, 1, M2, 1, M4, 1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
Sequence<C1ShuffleGemm0MXdlPerWavePerShuffle,
C1ShuffleGemm0NXdlPerWavePerShuffle,
1,
1,
M2,
1,
M4,
1>>{};
// space filling curve for shuffled blockwise C in global mem
constexpr auto sfc_e1_global = SpaceFillingCurve<
Sequence<1, Gemm0MPerBlock, 1, Gemm1NPerBlock>,
Sequence<0, 2, 1, 3>,
Sequence<1,
C1ShuffleGemm0MXdlPerWavePerShuffle * MWave * Gemm0MPerXdl,
1,
C1ShuffleGemm0NXdlPerWavePerShuffle * NWave * Gemm0NPerXdl>>{};
constexpr index_t num_access = sfc_c1_vgpr.GetNumOfAccess();
static_assert(num_access == sfc_e1_global.GetNumOfAccess(), "wrong!");
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
block_sync_lds();
// each thread write its data from VGPR to LDS
c1_thread_copy_vgpr_to_lds.Run(c1_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
sfc_c1_vgpr.GetIndexTupleOfNumber(access_id),
c1_thread_buf,
c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c1_shuffle_block_buf);
// make sure it's safe to read from LDS
block_sync_lds();
// each block copy its data from LDS to global
cde1_shuffle_block_copy_lds_to_global.Run(
c1_d1s_desc_refs,
c1_d1s_buf_refs,
tie(e1_grid_desc_mblock_mperblock_nblock_nperblock),
tie(e1_grid_buf));
if constexpr(access_id < num_access - 1)
{
constexpr auto e1_global_step = sfc_e1_global.GetForwardStep(access_id);
// move on D1s
static_for<0, NumD1Tensor, 1>{}([&](auto i) {
cde1_shuffle_block_copy_lds_to_global.MoveSrcSliceWindow(
c1_d1s_desc_refs, i + I1, e1_global_step);
});
// move on C
cde1_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
tie(e1_grid_desc_mblock_mperblock_nblock_nperblock), I0, e1_global_step);
}
});
}
}
};
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
using CDE0ElementOp = ck::tensor_operation::element_wise::AddRelu;
using CDE1ElementOp = ck::tensor_operation::element_wise::Add;
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
std::vector<std::unique_ptr<DeviceBatchedGemmMultipleDGemmMultipleD<Row,
Col,
ck::Tuple<Row>,
Row,
ck::Tuple<Row>,
Row,
F16,
F16,
ck::Tuple<F16>,
F16,
ck::Tuple<F16>,
F16,
PassThrough,
PassThrough,
CDE0ElementOp,
PassThrough,
CDE1ElementOp>>>&
instances);
void add_device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance(
std::vector<std::unique_ptr<DeviceBatchedGemmMultipleDGemmMultipleD<Row,
Col,
ck::Tuple<Row>,
Col,
ck::Tuple<Row>,
Row,
F16,
F16,
ck::Tuple<F16>,
F16,
ck::Tuple<F16>,
F16,
PassThrough,
PassThrough,
CDE0ElementOp,
PassThrough,
CDE1ElementOp>>>&
instances);
template <typename A0Layout,
typename B0Layout,
typename D0sLayout,
typename B1Layout,
typename D1sLayout,
typename E1Layout,
typename A0DataType,
typename B0DataType,
typename D0sDataType,
typename B1DataType,
typename D1sDataType,
typename E1DataType>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceBatchedGemmMultipleDGemmMultipleD<A0Layout,
B0Layout,
D0sLayout,
B1Layout,
D1sLayout,
E1Layout,
A0DataType,
B0DataType,
D0sDataType,
B1DataType,
D1sDataType,
E1DataType,
PassThrough,
PassThrough,
CDE0ElementOp,
PassThrough,
CDE1ElementOp>>
{
using DeviceOp = DeviceBatchedGemmMultipleDGemmMultipleD<A0Layout,
B0Layout,
D0sLayout,
B1Layout,
D1sLayout,
E1Layout,
A0DataType,
B0DataType,
D0sDataType,
B1DataType,
D1sDataType,
E1DataType,
PassThrough,
PassThrough,
CDE0ElementOp,
PassThrough,
CDE1ElementOp>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<A0DataType, half_t> && is_same_v<B0DataType, half_t> &&
is_same_v<B1DataType, half_t> && is_same_v<E1DataType, half_t>)
{
if constexpr(is_same_v<A0Layout, Row> && is_same_v<B0Layout, Col> &&
is_same_v<B1Layout, Row> && is_same_v<E1Layout, Row>)
{
add_device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
op_ptrs);
}
else if constexpr(is_same_v<A0Layout, Row> && is_same_v<B0Layout, Col> &&
is_same_v<B1Layout, Col> && is_same_v<E1Layout, Row>)
{
add_device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance(
op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -16,6 +16,7 @@ add_subdirectory(batched_gemm) ...@@ -16,6 +16,7 @@ add_subdirectory(batched_gemm)
add_subdirectory(batched_gemm_reduce) add_subdirectory(batched_gemm_reduce)
add_subdirectory(batched_gemm_gemm) add_subdirectory(batched_gemm_gemm)
add_subdirectory(batched_gemm_softmax_gemm) add_subdirectory(batched_gemm_softmax_gemm)
add_subdirectory(batched_gemm_add_relu_gemm_add)
add_subdirectory(grouped_gemm) add_subdirectory(grouped_gemm)
add_subdirectory(contraction_scale) add_subdirectory(contraction_scale)
add_subdirectory(contraction_bilinear) add_subdirectory(contraction_bilinear)
...@@ -42,6 +43,7 @@ add_library(device_operations STATIC ...@@ -42,6 +43,7 @@ add_library(device_operations STATIC
$<TARGET_OBJECTS:device_gemm_add_add_fastgelu_instance> $<TARGET_OBJECTS:device_gemm_add_add_fastgelu_instance>
$<TARGET_OBJECTS:device_gemm_bias_add_reduce_instance> $<TARGET_OBJECTS:device_gemm_bias_add_reduce_instance>
$<TARGET_OBJECTS:device_batched_gemm_instance> $<TARGET_OBJECTS:device_batched_gemm_instance>
$<TARGET_OBJECTS:device_batched_gemm_add_relu_gemm_add_instance>
$<TARGET_OBJECTS:device_batched_gemm_reduce_instance> $<TARGET_OBJECTS:device_batched_gemm_reduce_instance>
$<TARGET_OBJECTS:device_grouped_gemm_instance> $<TARGET_OBJECTS:device_grouped_gemm_instance>
$<TARGET_OBJECTS:device_contraction_scale_instance> $<TARGET_OBJECTS:device_contraction_scale_instance>
......
add_instance_library(device_batched_gemm_add_relu_gemm_add_instance
device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp
)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using CDE0ElementOp = ck::tensor_operation::element_wise::AddRelu;
using CDE1ElementOp = ck::tensor_operation::element_wise::Add;
// c[g, m, n] = a[g, m, k] * b[g, n, k]
using device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances =
std::tuple<
// clang-format off
//##################################################| A0Layout| B0Layout| D0Layout| B1Layout| D1sLayout| E1Layout| A0Data| B0Data| Acc0DataType| D0DataType| B1Data| Acc1CData| CShuffle| D1sData| E1Data| A0| B0| CDE0| B1| CDE1| PadGemm0M| PadGemm0N| PadGemm0K| PadGemm1N| PadGemm1K|NumGemm0K| Block| Gemm0| Gemm0| Gemm0| Gemm1| Gemm1|A0K1|B0K1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| C1Shuffle| C1Shuffle| CDE1BlockTransferClusterLengths| CDE1BlockTransfer|
//##################################################| | | | | | | Type| Type| Type| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| | | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//##################################################| | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per|Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//##################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | |
// no padding
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 32, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 32, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 64, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 64, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8>,
// Padded fallback kernel
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, true, true, true, true, true, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, true, true, true, true, true, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>
// clang-format on
>;
void add_device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
std::vector<std::unique_ptr<DeviceBatchedGemmMultipleDGemmMultipleD<Row,
Col,
ck::Tuple<Row>,
Row,
ck::Tuple<Row>,
Row,
F16,
F16,
ck::Tuple<F16>,
F16,
ck::Tuple<F16>,
F16,
PassThrough,
PassThrough,
CDE0ElementOp,
PassThrough,
CDE1ElementOp>>>& instances)
{
add_device_operation_instances(
instances,
device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using CDE0ElementOp = ck::tensor_operation::element_wise::AddRelu;
using CDE1ElementOp = ck::tensor_operation::element_wise::Add;
// c[g, m, n] = a[g, m, k] * b[g, n, k]
using device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instances =
std::tuple<
// clang-format off
//##################################################| A0Layout| B0Layout| D0Layout| B1Layout| D1sLayout| E1Layout| A0Data| B0Data| Acc0DataType| D0DataType| B1Data| Acc1CData| CShuffle| D1sData| E1Data| A0| B0| CDE0| B1| CDE1| PadGemm0M| PadGemm0N| PadGemm0K| PadGemm1N| PadGemm1K| NumGemm0K| Block| Gemm0| Gemm0| Gemm0| Gemm1| Gemm1| A0K1| B0K1|B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| C1Shuffle| C1Shuffle| CDE1BlockTransferClusterLengths| CDE1BlockTransfer|
//##################################################| | | | | | | Type| Type| Type| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| | | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//##################################################| | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//##################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | |
// no padding
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 256, 128, 32, 128, 32, 8, 8, 4, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 64, 64, 32, 8, 8, 4, 32, 32, 1, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 32, 64, 32, 8, 8, 4, 32, 32, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 64, 128, 32, 8, 8, 4, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 32, 128, 32, 8, 8, 4, 32, 32, 1, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 32, 128, 32, 8, 8, 4, 16, 16, 1, 16, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 8, S<1, 16, 1,16>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 32, 64, 32, 8, 8, 4, 16, 16, 1, 16, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 4, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 64, 128, 32, 8, 8, 4, 16, 16, 1, 16, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 8, S<1, 16, 1,16>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 64, 64, 32, 8, 8, 4, 16, 16, 1, 16, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 4, S<1, 32, 1, 8>, 8>,
// Padded fallback kernel
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, true, true, true, true, true, 1, 256, 128, 128, 64, 128, 32, 8, 8, 4, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, true, true, true, true, true, 1, 256, 128, 64, 32, 128, 32, 8, 8, 4, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>
// clang-format on
>;
void add_device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance(
std::vector<std::unique_ptr<DeviceBatchedGemmMultipleDGemmMultipleD<Row,
Col,
ck::Tuple<Row>,
Col,
ck::Tuple<Row>,
Row,
F16,
F16,
ck::Tuple<F16>,
F16,
ck::Tuple<F16>,
F16,
PassThrough,
PassThrough,
CDE0ElementOp,
PassThrough,
CDE1ElementOp>>>& instances)
{
add_device_operation_instances(
instances,
device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -12,6 +12,8 @@ set(PROFILER_SOURCE ...@@ -12,6 +12,8 @@ set(PROFILER_SOURCE
src/profile_gemm_add_add_fastgelu.cpp src/profile_gemm_add_add_fastgelu.cpp
src/profile_gemm_reduce.cpp src/profile_gemm_reduce.cpp
src/profile_batched_gemm.cpp src/profile_batched_gemm.cpp
src/profile_batched_gemm_gemm.cpp
src/profile_batched_gemm_add_relu_gemm_add.cpp
src/profile_batched_gemm_reduce.cpp src/profile_batched_gemm_reduce.cpp
src/profile_grouped_gemm.cpp src/profile_grouped_gemm.cpp
src/profile_conv_fwd.cpp src/profile_conv_fwd.cpp
...@@ -35,6 +37,8 @@ target_link_libraries(ckProfiler PRIVATE device_gemm_add_add_fastgelu_instance) ...@@ -35,6 +37,8 @@ target_link_libraries(ckProfiler PRIVATE device_gemm_add_add_fastgelu_instance)
target_link_libraries(ckProfiler PRIVATE device_gemm_reduce_instance) target_link_libraries(ckProfiler PRIVATE device_gemm_reduce_instance)
target_link_libraries(ckProfiler PRIVATE device_gemm_bias_add_reduce_instance) target_link_libraries(ckProfiler PRIVATE device_gemm_bias_add_reduce_instance)
target_link_libraries(ckProfiler PRIVATE device_batched_gemm_instance) target_link_libraries(ckProfiler PRIVATE device_batched_gemm_instance)
target_link_libraries(ckProfiler PRIVATE device_batched_gemm_gemm_instance)
target_link_libraries(ckProfiler PRIVATE device_batched_gemm_add_relu_gemm_add_instance)
target_link_libraries(ckProfiler PRIVATE device_batched_gemm_reduce_instance) target_link_libraries(ckProfiler PRIVATE device_batched_gemm_reduce_instance)
target_link_libraries(ckProfiler PRIVATE device_grouped_gemm_instance) target_link_libraries(ckProfiler PRIVATE device_grouped_gemm_instance)
target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_instance) target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_instance)
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
namespace ck {
namespace profiler {
template <typename A0Layout,
typename B0Layout,
typename D0sLayout,
typename B1Layout,
typename D1sLayout,
typename E1Layout,
typename A0DataType,
typename B0DataType,
typename D0sDataType,
typename B1DataType,
typename D1sDataType,
typename E1DataType>
bool profile_batched_gemm_add_relu_gemm_add_impl(bool do_verification,
int init_method,
bool do_log,
bool time_kernel,
int M,
int N,
int K,
int O,
int BatchCount = 1,
int StrideA0 = -1,
int StrideB0 = -1,
int StrideD0 = -1,
int StrideB1 = -1,
int StrideD1 = -1,
int StrideE1 = -1,
int BatchStrideA0 = -1,
int BatchStrideB0 = -1,
int BatchStrideD0 = -1,
int BatchStrideB1 = -1,
int BatchStrideD1 = -1,
int BatchStrideE1 = -1)
{
using Row = tensor_layout::gemm::RowMajor;
using Col = tensor_layout::gemm::ColumnMajor;
using PassThrough = tensor_operation::element_wise::PassThrough;
using A0ElementOp = PassThrough;
using B0ElementOp = PassThrough;
using CDE0ElementOp = ck::tensor_operation::element_wise::AddRelu;
using B1ElementOp = PassThrough;
using CDE1ElementOp = ck::tensor_operation::element_wise::Add;
using D0DataType = remove_cvref_t<tuple_element_t<0, D0sDataType>>;
using D0Layout = remove_cvref_t<tuple_element_t<0, D0sLayout>>;
using D1DataType = remove_cvref_t<tuple_element_t<0, D1sDataType>>;
using D1Layout = remove_cvref_t<tuple_element_t<0, D1sLayout>>;
// for reference
using RefAcc0DataType = float;
using RefAcc1DataType = float;
bool pass = true;
const int DefaultStrideA0 = ck::is_same_v<A0Layout, Row> ? K : M;
const int DefaultStrideB0 = ck::is_same_v<B0Layout, Row> ? N : K;
const int DefaultStrideD0 = ck::is_same_v<D0Layout, Row> ? N : M;
const int DefaultStrideB1 = ck::is_same_v<B1Layout, Row> ? O : N;
const int DefaultStrideD1 = ck::is_same_v<D1Layout, Row> ? O : M;
const int DefaultStrideE1 = ck::is_same_v<E1Layout, Row> ? O : M;
StrideA0 = (StrideA0 < 0) ? DefaultStrideA0 : StrideA0;
StrideB0 = (StrideB0 < 0) ? DefaultStrideB0 : StrideB0;
StrideD0 = (StrideD0 < 0) ? DefaultStrideD0 : StrideD0;
StrideB1 = (StrideB1 < 0) ? DefaultStrideB1 : StrideB1;
StrideD1 = (StrideD1 < 0) ? DefaultStrideD1 : StrideD1;
StrideE1 = (StrideE1 < 0) ? DefaultStrideE1 : StrideE1;
const int DefaultBatchStrideA0 = (ck::is_same_v<A0Layout, Col> ? K : M) * StrideA0;
const int DefaultBatchStrideB0 = (ck::is_same_v<B0Layout, Col> ? N : K) * StrideB0;
const int DefaultBatchStrideD0 = (ck::is_same_v<D0Layout, Col> ? N : M) * StrideD0;
const int DefaultBatchStrideB1 = (ck::is_same_v<B1Layout, Col> ? O : N) * StrideB1;
const int DefaultBatchStrideD1 = (ck::is_same_v<D1Layout, Col> ? O : M) * StrideD1;
const int DefaultBatchStrideE1 = (ck::is_same_v<E1Layout, Col> ? O : M) * StrideE1;
BatchStrideA0 = BatchStrideA0 < 0 ? DefaultBatchStrideA0 : BatchStrideA0;
BatchStrideB0 = BatchStrideB0 < 0 ? DefaultBatchStrideB0 : BatchStrideB0;
BatchStrideD0 = BatchStrideD0 < 0 ? DefaultBatchStrideD0 : BatchStrideD0;
BatchStrideB1 = BatchStrideB1 < 0 ? DefaultBatchStrideB1 : BatchStrideB1;
BatchStrideD1 = BatchStrideD1 < 0 ? DefaultBatchStrideD1 : BatchStrideD1;
BatchStrideE1 = BatchStrideE1 < 0 ? DefaultBatchStrideE1 : BatchStrideE1;
auto f_host_tensor_descriptor = [](std::size_t batch_count,
std::size_t row,
std::size_t col,
std::size_t stride,
std::size_t batch_stride,
auto layout) {
if(std::is_same<decltype(layout), Row>::value)
{
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
std::vector<std::size_t>({batch_stride, stride, 1}));
}
else
{
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
std::vector<std::size_t>({batch_stride, 1, stride}));
}
};
// E_m_o = A_m_k * B0_k_n * B1_n_o
Tensor<A0DataType> a0_g_m_k(
f_host_tensor_descriptor(BatchCount, M, K, StrideA0, BatchStrideA0, A0Layout{}));
Tensor<B0DataType> b0_g_k_n(
f_host_tensor_descriptor(BatchCount, K, N, StrideB0, BatchStrideB0, B0Layout{}));
Tensor<D0DataType> d0_g_m_n(
f_host_tensor_descriptor(BatchCount, M, N, StrideD0, BatchStrideD0, D0Layout{}));
Tensor<B1DataType> b1_g_n_o(
f_host_tensor_descriptor(BatchCount, N, O, StrideB1, BatchStrideB1, B1Layout{}));
Tensor<D1DataType> d1_g_m_o(
f_host_tensor_descriptor(BatchCount, M, O, StrideD1, BatchStrideD1, D1Layout{}));
Tensor<E1DataType> e1_g_m_o_host_result(
f_host_tensor_descriptor(BatchCount, M, O, StrideE1, BatchStrideE1, E1Layout{}));
Tensor<E1DataType> e1_g_m_o_device_result(
f_host_tensor_descriptor(BatchCount, M, O, StrideE1, BatchStrideE1, E1Layout{}));
// Host verification: Output of Gemm0 is input A of Gemm1
Tensor<RefAcc0DataType> c0_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{}));
Tensor<RefAcc0DataType> e0_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{}));
Tensor<RefAcc1DataType> c1_g_m_o(f_host_tensor_descriptor(BatchCount, M, O, O, M * O, Row{}));
std::cout << "a0_g_m_k: " << a0_g_m_k.mDesc << std::endl;
std::cout << "b0_g_k_n: " << b0_g_k_n.mDesc << std::endl;
std::cout << "d0_g_m_n: " << d0_g_m_n.mDesc << std::endl;
std::cout << "b1_g_n_o: " << b1_g_n_o.mDesc << std::endl;
std::cout << "d1_g_m_o: " << d1_g_m_o.mDesc << std::endl;
std::cout << "e1_g_m_o: " << e1_g_m_o_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a0_g_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 3});
b0_g_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 3});
d0_g_m_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-2, 3});
b1_g_n_o.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 3});
d1_g_m_o.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-2, 3});
break;
default:
a0_g_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
b0_g_k_n.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
d0_g_m_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0});
b1_g_n_o.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5});
d1_g_m_o.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
}
DeviceMem a0_g_m_k_device_buf(sizeof(A0DataType) * a0_g_m_k.mDesc.GetElementSize());
DeviceMem b0_g_k_n_device_buf(sizeof(B0DataType) * b0_g_k_n.mDesc.GetElementSize());
DeviceMem d0_g_m_n_device_buf(sizeof(D0DataType) * d0_g_m_n.mDesc.GetElementSpaceSize());
DeviceMem b1_g_n_o_device_buf(sizeof(B1DataType) * b1_g_n_o.mDesc.GetElementSize());
DeviceMem d1_g_m_o_device_buf(sizeof(D1DataType) * d1_g_m_o.mDesc.GetElementSpaceSize());
DeviceMem e1_g_m_o_device_buf(sizeof(E1DataType) *
e1_g_m_o_device_result.mDesc.GetElementSize());
a0_g_m_k_device_buf.ToDevice(a0_g_m_k.mData.data());
b0_g_k_n_device_buf.ToDevice(b0_g_k_n.mData.data());
d0_g_m_n_device_buf.ToDevice(d0_g_m_n.mData.data());
b1_g_n_o_device_buf.ToDevice(b1_g_n_o.mData.data());
d1_g_m_o_device_buf.ToDevice(d1_g_m_o.mData.data());
auto a0_element_op = A0ElementOp{};
auto b0_element_op = B0ElementOp{};
auto cde0_element_op = CDE0ElementOp{};
auto b1_element_op = B1ElementOp{};
auto cde1_element_op = CDE1ElementOp{};
using DeviceOp =
tensor_operation::device::DeviceBatchedGemmMultipleDGemmMultipleD<A0Layout,
B0Layout,
D0sLayout,
B1Layout,
D1sLayout,
E1Layout,
A0DataType,
B0DataType,
D0sDataType,
B1DataType,
D1sDataType,
E1DataType,
A0ElementOp,
B0ElementOp,
CDE0ElementOp,
B1ElementOp,
CDE1ElementOp>;
// get device op instances
const auto op_ptrs = tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
if(do_verification)
{
// Ref Gemm0
using ReferenceGemm0Instance = tensor_operation::host::ReferenceBatchedGemm<A0DataType,
B0DataType,
RefAcc0DataType,
RefAcc0DataType,
A0ElementOp,
B0ElementOp,
PassThrough>;
// Ref Gemm1
using ReferenceGemm1Instance = tensor_operation::host::ReferenceBatchedGemm<RefAcc0DataType,
B1DataType,
RefAcc1DataType,
RefAcc1DataType,
PassThrough,
B1ElementOp,
PassThrough>;
auto ref_gemm0 = ReferenceGemm0Instance{};
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
auto ref_gemm0_argument = ref_gemm0.MakeArgument(
a0_g_m_k, b0_g_k_n, c0_g_m_n, a0_element_op, b0_element_op, PassThrough{});
ref_gemm0_invoker.Run(ref_gemm0_argument);
// cde0_elementwise
e0_g_m_n.ForEach(
[&](auto&, auto idx) { cde0_element_op(e0_g_m_n(idx), c0_g_m_n(idx), d0_g_m_n(idx)); });
auto ref_gemm1 = ReferenceGemm1Instance{};
auto ref_gemm1_invoker = ref_gemm1.MakeInvoker();
auto ref_gemm1_argument = ref_gemm1.MakeArgument(
e0_g_m_n, b1_g_n_o, c1_g_m_o, PassThrough{}, b1_element_op, PassThrough{});
ref_gemm1_invoker.Run(ref_gemm1_argument);
// cde1_elementwise
e1_g_m_o_host_result.ForEach([&](auto&, auto idx) {
cde1_element_op(e1_g_m_o_host_result(idx), c1_g_m_o(idx), d1_g_m_o(idx));
});
}
std::string best_op_name;
float best_ave_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
// profile device op instances
for(auto& op_ptr : op_ptrs)
{
auto argument_ptr = op_ptr->MakeArgumentPointer(
static_cast<A0DataType*>(a0_g_m_k_device_buf.GetDeviceBuffer()),
static_cast<B0DataType*>(b0_g_k_n_device_buf.GetDeviceBuffer()),
std::array<const void*, 1>{d0_g_m_n_device_buf.GetDeviceBuffer()},
static_cast<B1DataType*>(b1_g_n_o_device_buf.GetDeviceBuffer()),
std::array<const void*, 1>{d1_g_m_o_device_buf.GetDeviceBuffer()},
static_cast<E1DataType*>(e1_g_m_o_device_buf.GetDeviceBuffer()),
M,
N,
K,
O,
BatchCount,
StrideA0,
StrideB0,
std::array<ck::index_t, 1>{StrideD0},
StrideB1,
std::array<ck::index_t, 1>{StrideD1},
StrideE1,
BatchStrideA0,
BatchStrideB0,
std::array<ck::index_t, 1>{BatchStrideD0},
BatchStrideB1,
std::array<ck::index_t, 1>{BatchStrideD1},
BatchStrideE1,
a0_element_op,
b0_element_op,
cde0_element_op,
b1_element_op,
cde1_element_op);
auto invoker_ptr = op_ptr->MakeInvokerPointer();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
std::string op_name = op_ptr->GetTypeString();
float ave_time =
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount;
std::size_t num_btype =
(sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(D0DataType) * N +
sizeof(B1DataType) * N * O + sizeof(E1DataType) * M * O + sizeof(D1DataType) * O) *
BatchCount;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << op_name << std::endl;
if(tflops > best_tflops)
{
best_op_name = op_name;
best_tflops = tflops;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
}
if(do_verification)
{
e1_g_m_o_device_buf.FromDevice(e1_g_m_o_device_result.mData.data());
pass = pass & ck::utils::check_err(e1_g_m_o_device_result.mData,
e1_g_m_o_host_result.mData);
if(do_log)
{
LogRangeAsType<float>(
std::cout << "e1_g_m_o_host_result : ", e1_g_m_o_host_result.mData, ",")
<< std::endl;
LogRangeAsType<float>(
std::cout << "e1_g_m_o_device_result : ", e1_g_m_o_device_result.mData, ",")
<< std::endl;
}
}
}
else
{
std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl;
}
}
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
return pass;
}
} // namespace profiler
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "profiler/include/profile_batched_gemm_add_relu_gemm_add_impl.hpp"
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
int profile_batched_gemm_add_relu_gemm_add(int argc, char* argv[])
{
enum struct GemmMatrixLayout
{
MK_NK_MN_NO_MO_MO, // 0
MK_NK_MN_ON_MO_MO, // 1
};
enum struct GemmDataType
{
F32_F32_F32_F32_F32_F32, // 0
F16_F16_F16_F16_F16_F16, // 1
};
GemmDataType data_type = GemmDataType::F16_F16_F16_F16_F16_F16;
GemmMatrixLayout layout = GemmMatrixLayout::MK_NK_MN_NO_MO_MO;
bool do_verification = true;
int init_method = 1;
bool do_log = 0;
bool time_kernel = false;
// GEMM shape
ck::index_t M = 1024;
ck::index_t N = 1024;
ck::index_t K = 64;
ck::index_t O = 128;
ck::index_t BatchCount = 4;
ck::index_t StrideA0 = -1;
ck::index_t StrideB0 = -1;
ck::index_t StrideD0 = -1;
ck::index_t StrideB1 = -1;
ck::index_t StrideD1 = -1;
ck::index_t StrideE1 = -1;
ck::index_t BatchStrideA0 = -1;
ck::index_t BatchStrideB0 = -1;
ck::index_t BatchStrideD0 = -1;
ck::index_t BatchStrideB1 = -1;
ck::index_t BatchStrideD1 = -1;
ck::index_t BatchStrideE1 = -1;
if(argc == 8)
{
data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
do_verification = std::stoi(argv[4]);
init_method = std::stoi(argv[5]);
do_log = std::stoi(argv[6]);
time_kernel = std::stoi(argv[7]);
}
else if(argc == 13)
{
data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
do_verification = std::stoi(argv[4]);
init_method = std::stoi(argv[5]);
do_log = std::stoi(argv[6]);
time_kernel = std::stoi(argv[7]);
M = std::stoi(argv[8]);
N = std::stoi(argv[9]);
K = std::stoi(argv[10]);
O = std::stoi(argv[11]);
BatchCount = std::stoi(argv[12]);
}
else if(argc == 25)
{
data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
do_verification = std::stoi(argv[4]);
init_method = std::stoi(argv[5]);
do_log = std::stoi(argv[6]);
time_kernel = std::stoi(argv[7]);
M = std::stoi(argv[8]);
N = std::stoi(argv[9]);
K = std::stoi(argv[10]);
O = std::stoi(argv[11]);
BatchCount = std::stoi(argv[12]);
StrideA0 = std::stoi(argv[13]);
StrideB0 = std::stoi(argv[14]);
StrideD0 = std::stoi(argv[15]);
StrideB1 = std::stoi(argv[16]);
StrideD1 = std::stoi(argv[17]);
StrideE1 = std::stoi(argv[18]);
BatchStrideA0 = std::stoi(argv[19]);
BatchStrideB0 = std::stoi(argv[20]);
BatchStrideD0 = std::stoi(argv[21]);
BatchStrideB1 = std::stoi(argv[22]);
BatchStrideD1 = std::stoi(argv[23]);
BatchStrideE1 = std::stoi(argv[24]);
}
else
{
printf("arg1: tensor operation (batched_gemm_add_relu_gemm_add: "
"Batched_GEMM+Add+Relu+Gemm+Add)\n");
printf("arg2: data type (1: fp16)\n");
printf("arg3: matrix layout (0: Relu(A0[m, k] * B0[n, k] + D0[m, n]) * B1[n, o] + D1[m, o] "
"= E1[m, o]; 1: Relu(A0[m, k] * B0[n, k] + D0[m, n]) * B1[o, n] + D1[m, o] = "
"E1[m, o];)\n");
printf("arg4: verification (0: no; 1: yes)\n");
printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n");
printf("arg6: print tensor value (0: no; 1: yes)\n");
printf("arg7: time kernel (0=no, 1=yes)\n");
printf("arg8 to 12: M, N, K, O, Batch\n");
printf("arg13 to 18: StrideA0, StrideB0, StrideD0, StrideB1, StrideD1, StrideE1\n");
printf("arg19 to 24: BatchStrideA0, BatchStrideB0, BatchStrideD0, BatchStrideB1, "
"BatchStrideD1, BatchStrideE1 \n");
exit(1);
}
if(data_type == GemmDataType::F16_F16_F16_F16_F16_F16 &&
layout == GemmMatrixLayout::MK_NK_MN_NO_MO_MO)
{
ck::profiler::profile_batched_gemm_add_relu_gemm_add_impl<Row, // A0Layout,
Col, // B0Layout,
ck::Tuple<Row>, // D0sLayout,
Row, // B1Layout,
ck::Tuple<Row>, // D1sLayout,
Row, // E1Layout,
F16, // A0DataType,
F16, // B0DataType,
ck::Tuple<F16>, // D0DataType,
F16, // B1DataType,
ck::Tuple<F16>, // D1sDataType
F16> // E1DataType,
(do_verification,
init_method,
do_log,
time_kernel,
M,
N,
K,
O,
BatchCount,
StrideA0,
StrideB0,
StrideD0,
StrideB1,
StrideD1,
StrideE1,
BatchStrideA0,
BatchStrideB0,
BatchStrideD0,
BatchStrideB1,
BatchStrideD1,
BatchStrideE1);
}
else if(data_type == GemmDataType::F16_F16_F16_F16_F16_F16 &&
layout == GemmMatrixLayout::MK_NK_MN_ON_MO_MO)
{
ck::profiler::profile_batched_gemm_add_relu_gemm_add_impl<Row, // A0Layout,
Col, // B0Layout,
ck::Tuple<Row>, // D0sLayout,
Col, // B1Layout,
ck::Tuple<Row>, // D1sLayout,
Row, // E1Layout,
F16, // A0DataType,
F16, // B0DataType,
ck::Tuple<F16>, // D0DataType,
F16, // B1DataType,
ck::Tuple<F16>, // D1sDataType
F16> // E1DataType,
(do_verification,
init_method,
do_log,
time_kernel,
M,
N,
K,
O,
BatchCount,
StrideA0,
StrideB0,
StrideD0,
StrideB1,
StrideD1,
StrideE1,
BatchStrideA0,
BatchStrideB0,
BatchStrideD0,
BatchStrideB1,
BatchStrideD1,
BatchStrideE1);
}
else
{
throw std::runtime_error("wrong! this data_type & layout is not implemented");
}
return 0;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "profiler/include/profile_batched_gemm_gemm_impl.hpp"
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
int profile_batched_gemm_gemm(int argc, char* argv[])
{
enum struct GemmMatrixLayout
{
MK_NK_NO_MO, // 0
MK_NK_ON_MO, // 0
};
enum struct GemmDataType
{
F32_F32_F32_F32, // 0
F16_F16_F16_F16, // 1
};
GemmDataType data_type = GemmDataType::F16_F16_F16_F16;
GemmMatrixLayout layout = GemmMatrixLayout::MK_NK_NO_MO;
bool do_verification = true;
int init_method = 1;
bool do_log = 0;
bool time_kernel = false;
// GEMM shape
ck::index_t M = 1024;
ck::index_t N = 1024;
ck::index_t K = 64;
ck::index_t O = 128;
ck::index_t BatchCount = 4;
ck::index_t StrideA0 = -1;
ck::index_t StrideB0 = -1;
ck::index_t StrideB1 = -1;
ck::index_t StrideE1 = -1;
ck::index_t BatchStrideA0 = -1;
ck::index_t BatchStrideB0 = -1;
ck::index_t BatchStrideB1 = -1;
ck::index_t BatchStrideE1 = -1;
if(argc == 8)
{
data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
do_verification = std::stoi(argv[4]);
init_method = std::stoi(argv[5]);
do_log = std::stoi(argv[6]);
time_kernel = std::stoi(argv[7]);
}
else if(argc == 13)
{
data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
do_verification = std::stoi(argv[4]);
init_method = std::stoi(argv[5]);
do_log = std::stoi(argv[6]);
time_kernel = std::stoi(argv[7]);
M = std::stoi(argv[8]);
N = std::stoi(argv[9]);
K = std::stoi(argv[10]);
O = std::stoi(argv[11]);
BatchCount = std::stoi(argv[12]);
}
else if(argc == 21)
{
data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
do_verification = std::stoi(argv[4]);
init_method = std::stoi(argv[5]);
do_log = std::stoi(argv[6]);
time_kernel = std::stoi(argv[7]);
M = std::stoi(argv[8]);
N = std::stoi(argv[9]);
K = std::stoi(argv[10]);
O = std::stoi(argv[11]);
BatchCount = std::stoi(argv[12]);
StrideA0 = std::stoi(argv[13]);
StrideB0 = std::stoi(argv[14]);
StrideB1 = std::stoi(argv[15]);
StrideE1 = std::stoi(argv[16]);
BatchStrideA0 = std::stoi(argv[17]);
BatchStrideB0 = std::stoi(argv[18]);
BatchStrideB1 = std::stoi(argv[19]);
BatchStrideE1 = std::stoi(argv[20]);
}
else
{
printf("arg1: tensor operation (batched_gemm_gemm: Batched_GEMM+Gemm)\n");
printf("arg2: data type (1: fp16)\n");
printf("arg3: matrix layout (0: Relu(A0[m, k] * B0[n, k] + D0[m, n]) * B1[n, o] + D1[m, o] "
"= E1[m, o]; 1: Relu(A0[m, k] * B0[n, k] + D0[m, n]) * B1[o, n] + D1[m, o] = E1[m, "
"o];)\n");
printf("arg4: verification (0: no; 1: yes)\n");
printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n");
printf("arg6: print tensor value (0: no; 1: yes)\n");
printf("arg7: time kernel (0=no, 1=yes)\n");
printf("arg8 to 12: M, N, K, O, Batch\n");
printf("arg13 to 16: StrideA0, StrideB0, StrideB1, StrideE1\n");
printf("arg17 to 20: BatchStrideA0, BatchStrideB0, BatchStrideB1, BatchStrideE1 \n");
exit(1);
}
if(data_type == GemmDataType::F16_F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_NO_MO)
{
ck::profiler::profile_batched_gemm_gemm_impl<F16, // A0DataType,
F16, // B0DataType,
F16, // B1DataType,
F16, // E1DataType,
Row, // A0Layout,
Col, // B0Layout,
Row, // B1Layout,
Row> // E1Layout,
(do_verification,
init_method,
do_log,
time_kernel,
M,
N,
K,
O,
BatchCount,
StrideA0,
StrideB0,
StrideB1,
StrideE1,
BatchStrideA0,
BatchStrideB0,
BatchStrideB1,
BatchStrideE1);
}
else if(data_type == GemmDataType::F16_F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_ON_MO)
{
ck::profiler::profile_batched_gemm_gemm_impl<F16, // A0DataType,
F16, // B0DataType,
F16, // B1DataType,
F16, // E1DataType,
Row, // A0Layout,
Col, // B0Layout,
Col, // B1Layout,
Row> // E1Layout,
(do_verification,
init_method,
do_log,
time_kernel,
M,
N,
K,
O,
BatchCount,
StrideA0,
StrideB0,
StrideB1,
StrideE1,
BatchStrideA0,
BatchStrideB0,
BatchStrideB1,
BatchStrideE1);
}
else
{
throw std::runtime_error("wrong! this data_type & layout is not implemented");
}
return 0;
}
...@@ -10,6 +10,8 @@ int profile_gemm_add_add_fastgelu(int, char*[]); ...@@ -10,6 +10,8 @@ int profile_gemm_add_add_fastgelu(int, char*[]);
int profile_gemm_reduce(int, char*[]); int profile_gemm_reduce(int, char*[]);
int profile_gemm_bias_add_reduce(int, char*[]); int profile_gemm_bias_add_reduce(int, char*[]);
int profile_batched_gemm(int, char*[]); int profile_batched_gemm(int, char*[]);
int profile_batched_gemm_gemm(int, char*[]);
int profile_batched_gemm_add_relu_gemm_add(int, char*[]);
int profile_batched_gemm_reduce(int, char*[]); int profile_batched_gemm_reduce(int, char*[]);
int profile_grouped_gemm(int, char*[]); int profile_grouped_gemm(int, char*[]);
int profile_conv_fwd(int, char*[]); int profile_conv_fwd(int, char*[]);
...@@ -32,6 +34,8 @@ static void print_helper_message() ...@@ -32,6 +34,8 @@ static void print_helper_message()
" gemm_reduce: GEMM+Reduce\n" " gemm_reduce: GEMM+Reduce\n"
" gemm_bias_add_reduce: GEMM+Bias+Add+Reduce\n" " gemm_bias_add_reduce: GEMM+Bias+Add+Reduce\n"
" batched_gemm: Batched GEMM\n" " batched_gemm: Batched GEMM\n"
" batched_gemm_gemm: Batched+GEMM+GEMM\n"
" batched_gemm_add_relu_gemm_add: Batched+GEMM+bias+gelu+GEMM+bias\n"
" batched_gemm_reduce: Batched GEMM+Reduce\n" " batched_gemm_reduce: Batched GEMM+Reduce\n"
" grouped_gemm: Grouped GEMM\n" " grouped_gemm: Grouped GEMM\n"
" conv_fwd: Convolution Forward\n" " conv_fwd: Convolution Forward\n"
...@@ -80,6 +84,14 @@ int main(int argc, char* argv[]) ...@@ -80,6 +84,14 @@ int main(int argc, char* argv[])
{ {
return profile_batched_gemm(argc, argv); return profile_batched_gemm(argc, argv);
} }
else if(strcmp(argv[1], "batched_gemm_gemm") == 0)
{
return profile_batched_gemm_gemm(argc, argv);
}
else if(strcmp(argv[1], "batched_gemm_add_relu_gemm_add") == 0)
{
return profile_batched_gemm_add_relu_gemm_add(argc, argv);
}
else if(strcmp(argv[1], "batched_gemm_reduce") == 0) else if(strcmp(argv[1], "batched_gemm_reduce") == 0)
{ {
return profile_batched_gemm_reduce(argc, argv); return profile_batched_gemm_reduce(argc, argv);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment