Commit aea62819 authored by Chaitanya Inumella's avatar Chaitanya Inumella
Browse files

Rebase branch 'develop' of...

Rebase branch 'develop' of https://github.com/ROCmSoftwarePlatform/composable_kernel into contraction_hipTENSOR
parents 75af5450 75ab874e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Add = ck::tensor_operation::element_wise::Add;
using ADataType = F16;
using BDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F16;
using DDataType = F16;
using DsDataType = ck::Tuple<DDataType>;
using EDataType = F16;
using ALayout = Row;
using BLayout = Col;
using DLayout = Row;
using DsLayout = ck::Tuple<DLayout>;
using ELayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = Add;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl
// clang-format off
//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// clang-format on
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=n0, 1=yes)\n");
exit(0);
}
int group_count = rand() % 16 + 1;
// GEMM shape
std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs;
std::vector<const void*> p_a, p_b;
std::vector<std::array<const void*, 1>> p_ds;
std::vector<void*> p_c;
gemm_descs.reserve(group_count);
for(int i = 0; i < group_count; i++)
{
int M = 256 + 256 * i;
int N = 128 + 128 * i;
int K = 64 + 64 * i;
int stride_A = K;
int stride_B = K;
int stride_C = N;
std::vector<ck::index_t> stride_Ds = {0};
gemm_descs.push_back({M, N, K, stride_A, stride_B, stride_C, stride_Ds});
}
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({stride, 1}));
}
else
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({1, stride}));
}
};
std::vector<Tensor<ADataType>> a_tensors;
std::vector<Tensor<BDataType>> b_tensors;
std::vector<Tensor<DDataType>> d_tensors;
std::vector<Tensor<EDataType>> e_host_tensors;
std::vector<Tensor<EDataType>> e_device_tensors;
a_tensors.reserve(group_count);
b_tensors.reserve(group_count);
d_tensors.reserve(group_count);
e_host_tensors.reserve(group_count);
e_device_tensors.reserve(group_count);
using DeviceMemPtr = std::unique_ptr<DeviceMem>;
std::vector<DeviceMemPtr> a_tensors_device, b_tensors_device, d_tensors_device,
e_tensors_device;
a_tensors_device.reserve(group_count);
b_tensors_device.reserve(group_count);
d_tensors_device.reserve(group_count);
e_tensors_device.reserve(group_count);
std::size_t flop = 0, num_btype = 0;
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
a_tensors.push_back(Tensor<ADataType>(f_host_tensor_descriptor(
gemm_descs[i].M_, gemm_descs[i].K_, gemm_descs[i].stride_A_, ALayout{})));
b_tensors.push_back(Tensor<BDataType>(f_host_tensor_descriptor(
gemm_descs[i].K_, gemm_descs[i].N_, gemm_descs[i].stride_B_, BLayout{})));
d_tensors.push_back(Tensor<DDataType>(f_host_tensor_descriptor(
gemm_descs[i].M_, gemm_descs[i].N_, gemm_descs[i].stride_Ds_[0], ELayout{})));
e_host_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor(
gemm_descs[i].M_, gemm_descs[i].N_, gemm_descs[i].stride_C_, ELayout{})));
e_device_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor(
gemm_descs[i].M_, gemm_descs[i].N_, gemm_descs[i].stride_C_, ELayout{})));
std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc
<< " b_k_n: " << b_tensors[i].mDesc << " c_m_n: " << e_device_tensors[i].mDesc
<< std::endl;
flop += std::size_t(2) * gemm_descs[i].M_ * gemm_descs[i].K_ * gemm_descs[i].N_;
num_btype += sizeof(ADataType) * a_tensors[i].mDesc.GetElementSize() +
sizeof(BDataType) * b_tensors[i].mDesc.GetElementSize() +
sizeof(EDataType) * e_device_tensors[i].mDesc.GetElementSize();
switch(init_method)
{
case 0: break;
case 1:
a_tensors[i].GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_tensors[i].GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
d_tensors[i].GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
break;
case 2:
a_tensors[i].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
d_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break;
default:
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{});
d_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
}
}
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
a_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(ADataType) * a_tensors[i].mDesc.GetElementSpaceSize()));
b_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(BDataType) * b_tensors[i].mDesc.GetElementSpaceSize()));
d_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(DDataType) * d_tensors[i].mDesc.GetElementSpaceSize()));
e_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(EDataType) * e_device_tensors[i].mDesc.GetElementSpaceSize()));
a_tensors_device[i]->ToDevice(a_tensors[i].mData.data());
b_tensors_device[i]->ToDevice(b_tensors[i].mData.data());
d_tensors_device[i]->ToDevice(d_tensors[i].mData.data());
p_a.push_back(a_tensors_device[i]->GetDeviceBuffer());
p_b.push_back(b_tensors_device[i]->GetDeviceBuffer());
p_ds.push_back({d_tensors_device[i]->GetDeviceBuffer()});
p_c.push_back(e_tensors_device[i]->GetDeviceBuffer());
}
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
// do GEMM
auto argument = gemm.MakeArgument(
p_a, p_b, p_ds, p_c, gemm_descs, a_element_op, b_element_op, cde_element_op);
DeviceMem gemm_desc_workspace(gemm.GetWorkSpaceSize(&argument));
gemm.SetWorkSpacePointer(&argument, gemm_desc_workspace.GetDeviceBuffer());
if(!gemm.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl;
bool pass = true;
if(do_verification)
{
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
EDataType,
AccDataType,
AElementOp,
BElementOp,
PassThrough>;
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
e_tensors_device[i]->FromDevice(e_device_tensors[i].mData.data());
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(a_tensors[i],
b_tensors[i],
e_host_tensors[i],
a_element_op,
b_element_op,
PassThrough{});
ref_invoker.Run(ref_argument);
for(int m = 0; m < gemm_descs[i].M_; ++m)
{
for(int n = 0; n < gemm_descs[i].N_; ++n)
{
cde_element_op(
e_host_tensors[i](m, n), e_host_tensors[i](m, n), d_tensors[i](m, n));
}
}
pass &= ck::utils::check_err(e_device_tensors[i].mData, e_host_tensors[i].mData);
}
}
return pass ? 0 : 1;
}
add_example_executable(example_batched_gemm_xdl_fp16 batched_gemm_xdl_fp16.cpp)
add_example_executable(example_batched_gemm_bias_xdl_fp16 batched_gemm_bias_xdl_fp16.cpp)
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d_xdl.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Add = ck::tensor_operation::element_wise::Add;
using ADataType = F16;
using BDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F16;
using DDataType = F16;
using DsDataType = ck::Tuple<DDataType>;
using EDataType = F16;
using ALayout = Row;
using BLayout = Col;
using DLayout = Row;
using DsLayout = ck::Tuple<DLayout>;
using ELayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = Add;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// static constexpr auto MNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
// static constexpr auto MNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD_Xdl
//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// clang-format on
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
const int M = 256 * (rand() % 16 + 1);
const int N = 128 * (rand() % 16 + 1);
const int K = 64 * (rand() % 16 + 1);
const int stride_A = K;
const int stride_B = K;
const int stride_D = 0;
const int stride_E = N;
const int batch_stride_A = M * K;
const int batch_stride_B = K * N;
const int batch_stride_D = N;
const int batch_stride_E = M * N;
const int batch_count = 16;
if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=n0, 1=yes)\n");
exit(0);
}
// GEMM shape
auto f_host_tensor_descriptor = [](std::size_t batch_count_,
std::size_t row,
std::size_t col,
std::size_t stride,
std::size_t batch_stride,
auto layout) {
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor(std::vector<std::size_t>({batch_count_, row, col}),
std::vector<std::size_t>({batch_stride, stride, 1}));
}
else
{
return HostTensorDescriptor(std::vector<std::size_t>({batch_count_, row, col}),
std::vector<std::size_t>({batch_stride, 1, stride}));
}
};
Tensor<ADataType> a_g_m_k(
f_host_tensor_descriptor(batch_count, M, K, stride_A, batch_stride_A, ALayout{}));
Tensor<BDataType> b_g_k_n(
f_host_tensor_descriptor(batch_count, K, N, stride_B, batch_stride_B, BLayout{}));
Tensor<DDataType> d_g_m_n(
f_host_tensor_descriptor(batch_count, M, N, stride_D, batch_stride_D, DLayout{}));
Tensor<EDataType> e_g_m_n_device_result(
f_host_tensor_descriptor(batch_count, M, N, stride_E, batch_stride_E, ELayout{}));
std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl;
std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl;
std::cout << "d_g_m_n: " << d_g_m_n.mDesc << std::endl;
std::cout << "e_g_m_n: " << e_g_m_n_device_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_g_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
d_g_m_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
break;
default:
a_g_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_g_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
d_g_m_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break;
}
DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf(sizeof(DDataType) * d_g_m_n.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(EDataType) * e_g_m_n_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_g_m_k.mData.data());
b_device_buf.ToDevice(b_g_k_n.mData.data());
d_device_buf.ToDevice(d_g_m_n.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
// do GEMM
auto argument = gemm.MakeArgument(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
{d_device_buf.GetDeviceBuffer()},
c_device_buf.GetDeviceBuffer(),
M,
N,
K,
batch_count,
stride_A,
stride_B,
{stride_D},
stride_E,
batch_stride_A,
batch_stride_B,
{batch_stride_D},
batch_stride_E,
a_element_op,
b_element_op,
cde_element_op);
if(!gemm.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * batch_count * M * N * K;
std::size_t num_btype = sizeof(ADataType) * batch_count * M * K +
sizeof(BDataType) * batch_count * K * N +
sizeof(EDataType) * batch_count * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl;
bool pass = true;
if(do_verification)
{
c_device_buf.FromDevice(e_g_m_n_device_result.mData.data());
using ReferenceBatchedGemmInstance =
ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
BDataType,
EDataType,
AElementOp,
BElementOp,
PassThrough>;
auto ref_batched_gemm = ReferenceBatchedGemmInstance{};
auto ref_invoker = ref_batched_gemm.MakeInvoker();
Tensor<EDataType> e_g_m_n_host_result(
f_host_tensor_descriptor(batch_count, M, N, stride_E, batch_stride_E, ELayout{}));
auto ref_argument = ref_batched_gemm.MakeArgument(
a_g_m_k, b_g_k_n, e_g_m_n_host_result, a_element_op, b_element_op, PassThrough{});
ref_invoker.Run(ref_argument);
for(int g = 0; g < batch_count; g++)
{
for(int m = 0; m < M; ++m)
{
for(int n = 0; n < N; ++n)
{
cde_element_op(e_g_m_n_host_result(g, m, n),
e_g_m_n_host_result(g, m, n),
d_g_m_n(g, m, n));
}
}
}
pass = ck::utils::check_err(
e_g_m_n_host_result.mData, e_g_m_n_device_result.mData, "Error: Incorrect results c");
}
return pass ? 0 : 1;
}
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d_xdl.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = F16;
using BDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F16;
using DsDataType = ck::Tuple<>;
using EDataType = F16;
using ALayout = Row;
using BLayout = Col;
using DsLayout = ck::Tuple<>;
using ELayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// static constexpr auto MNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
// static constexpr auto MNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD_Xdl
//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// clang-format on
using ReferenceBatchedGemmInstance = ck::tensor_operation::host::
ReferenceBatchedGemm<ADataType, BDataType, EDataType, AElementOp, BElementOp, CDEElementOp>;
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
const int M = 256 * (rand() % 16 + 1);
const int N = 128 * (rand() % 16 + 1);
const int K = 64 * (rand() % 16 + 1);
const int stride_A = K;
const int stride_B = K;
const int stride_C = N;
const int batch_stride_A = M * K;
const int batch_stride_B = K * N;
const int batch_stride_C = M * N;
const int batch_count = 16;
if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=n0, 1=yes)\n");
exit(0);
}
// GEMM shape
auto f_host_tensor_descriptor = [](std::size_t batch_count_,
std::size_t row,
std::size_t col,
std::size_t stride,
std::size_t batch_stride,
auto layout) {
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor(std::vector<std::size_t>({batch_count_, row, col}),
std::vector<std::size_t>({batch_stride, stride, 1}));
}
else
{
return HostTensorDescriptor(std::vector<std::size_t>({batch_count_, row, col}),
std::vector<std::size_t>({batch_stride, 1, stride}));
}
};
Tensor<ADataType> a_g_m_k(
f_host_tensor_descriptor(batch_count, M, K, stride_A, batch_stride_A, ALayout{}));
Tensor<BDataType> b_g_k_n(
f_host_tensor_descriptor(batch_count, K, N, stride_B, batch_stride_B, BLayout{}));
Tensor<EDataType> e_g_m_n_device_result(
f_host_tensor_descriptor(batch_count, M, N, stride_C, batch_stride_C, ELayout{}));
std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl;
std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl;
std::cout << "e_g_m_n: " << e_g_m_n_device_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_g_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
break;
default:
a_g_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_g_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break;
}
DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(EDataType) * e_g_m_n_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_g_m_k.mData.data());
b_device_buf.ToDevice(b_g_k_n.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
// do GEMM
auto argument = gemm.MakeArgument(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
{},
c_device_buf.GetDeviceBuffer(),
M,
N,
K,
batch_count,
stride_A,
stride_B,
{},
stride_C,
batch_stride_A,
batch_stride_B,
{},
batch_stride_C,
a_element_op,
b_element_op,
cde_element_op);
if(!gemm.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * batch_count * M * N * K;
std::size_t num_btype = sizeof(ADataType) * batch_count * M * K +
sizeof(BDataType) * batch_count * K * N +
sizeof(EDataType) * batch_count * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl;
bool pass = true;
if(do_verification)
{
c_device_buf.FromDevice(e_g_m_n_device_result.mData.data());
auto ref_batched_gemm = ReferenceBatchedGemmInstance{};
auto ref_invoker = ref_batched_gemm.MakeInvoker();
Tensor<EDataType> e_g_m_n_host_result(
f_host_tensor_descriptor(batch_count, M, N, stride_C, batch_stride_C, ELayout{}));
auto ref_argument = ref_batched_gemm.MakeArgument(
a_g_m_k, b_g_k_n, e_g_m_n_host_result, a_element_op, b_element_op, cde_element_op);
ref_invoker.Run(ref_argument);
pass = ck::utils::check_err(
e_g_m_n_host_result.mData, e_g_m_n_device_result.mData, "Error: Incorrect results c");
}
return pass ? 0 : 1;
}
add_example_executable(example_grouped_convnd_fwd_bias_relu_xdl_fp16 grouped_convnd_fwd_bias_relu_xdl_fp16.cpp)
target_link_libraries(example_grouped_convnd_fwd_bias_relu_xdl_fp16 PRIVATE utility)
```bash
#arg1: verification (0=no, 1=yes)
#arg2: initialization (0=no init, 1=integer value, 2=decimal value)
#arg3: time kernel (0=no, 1=yes)
#Following arguments (depending on number of spatial dims):
# N spatial dimensions
# G, N, K, C,
# <filter spatial dimensions>, (ie Y, X for 2D)
# <input image spatial dimensions>, (ie Hi, Wi for 2D)
# <strides>, (ie Sy, Sx for 2D)
# <dilations>, (ie Dy, Dx for 2D)
# <left padding>, (ie LeftPy, LeftPx for 2D)
# <right padding>, (ie RightPy, RightPx for 2D)
bin/example_grouped_convnd_fwd_bias_relu_xdl_fp16 1 1 1
```
Result (MI100)
```
in: dim 5, lengths {1, 128, 192, 71, 71}, strides {6912, 967872, 1, 13632, 192}
wei: dim 5, lengths {1, 256, 192, 3, 3}, strides {192, 1728, 1, 576, 192}
bias: dim 5, lengths {1, 128, 256, 36, 36}, strides {256, 0, 1, 0, 0}
out: dim 5, lengths {1, 128, 256, 36, 36}, strides {256, 331776, 1, 9216, 256}
launch_and_time_kernel: grid_dim {1296, 1, 1}, block_dim {256, 1, 1}
Warm up 1 time
Start running 10 times...
Perf: 1.19215 ms, 123.112 TFlops, 279.827 GB/s, DeviceGroupedConvFwdMultipleD_Xdl_CShuffle<256, 128, 256, 32, Default>
```
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <iostream>
#include <numeric>
#include <type_traits>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
void print_helper_msg()
{
std::cout << "arg1: verification (0=no, 1=yes)\n"
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"
<< "arg3: time kernel (0=no, 1=yes)\n"
<< ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl;
}
template <ck::index_t NDimSpatial,
typename InDataType,
typename WeiDataType,
typename OutDataType,
typename InElementOp,
typename WeiElementOp,
typename OutElementOp,
typename DeviceConvNDFwdInstance>
int run_grouped_conv_fwd_bias(bool do_verification,
int init_method,
bool time_kernel,
const ck::utils::conv::ConvParam& conv_param,
const HostTensorDescriptor& in_g_n_c_wis_desc,
const HostTensorDescriptor& wei_g_k_c_xs_desc,
const HostTensorDescriptor& bias_g_n_k_wos_desc,
const HostTensorDescriptor& out_g_n_k_wos_desc,
const InElementOp& in_element_op,
const WeiElementOp& wei_element_op,
const OutElementOp& out_element_op)
{
Tensor<InDataType> in(in_g_n_c_wis_desc);
Tensor<WeiDataType> wei(wei_g_k_c_xs_desc);
Tensor<OutDataType> bias(bias_g_n_k_wos_desc);
Tensor<OutDataType> out_host(out_g_n_k_wos_desc);
Tensor<OutDataType> out_device(out_g_n_k_wos_desc);
std::cout << "in: " << in.mDesc << std::endl;
std::cout << "wei: " << wei.mDesc << std::endl;
std::cout << "bias: " << bias.mDesc << std::endl;
std::cout << "out: " << out_host.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
in.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5});
wei.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5});
bias.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5});
break;
default:
in.GenerateTensorValue(GeneratorTensor_3<InDataType>{0.0, 1.0});
wei.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.5, 0.5});
bias.GenerateTensorValue(GeneratorTensor_3<OutDataType>{-0.5, 0.5});
}
DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize());
DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize());
DeviceMem bias_device_buf(sizeof(OutDataType) * bias.mDesc.GetElementSpaceSize());
DeviceMem out_device_buf(sizeof(OutDataType) * out_device.mDesc.GetElementSpaceSize());
in_device_buf.ToDevice(in.mData.data());
wei_device_buf.ToDevice(wei.mData.data());
bias_device_buf.ToDevice(bias.mData.data());
std::array<ck::index_t, NDimSpatial + 3> a_g_n_c_wis_lengths{};
std::array<ck::index_t, NDimSpatial + 3> a_g_n_c_wis_strides{};
std::array<ck::index_t, NDimSpatial + 3> b_g_k_c_xs_lengths{};
std::array<ck::index_t, NDimSpatial + 3> b_g_k_c_xs_strides{};
std::array<ck::index_t, NDimSpatial + 3> d_g_n_k_wos_lengths{};
std::array<ck::index_t, NDimSpatial + 3> d_g_n_k_wos_strides{};
std::array<ck::index_t, NDimSpatial + 3> e_g_n_k_wos_lengths{};
std::array<ck::index_t, NDimSpatial + 3> e_g_n_k_wos_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_dilations{};
std::array<ck::index_t, NDimSpatial> input_left_pads{};
std::array<ck::index_t, NDimSpatial> input_right_pads{};
auto copy = [](auto& x, auto& y) { std::copy(x.begin(), x.end(), y.begin()); };
copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths);
copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides);
copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths);
copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides);
copy(bias_g_n_k_wos_desc.GetLengths(), d_g_n_k_wos_lengths);
copy(bias_g_n_k_wos_desc.GetStrides(), d_g_n_k_wos_strides);
copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths);
copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides);
copy(conv_param.conv_filter_strides_, conv_filter_strides);
copy(conv_param.conv_filter_dilations_, conv_filter_dilations);
copy(conv_param.input_left_pads_, input_left_pads);
copy(conv_param.input_right_pads_, input_right_pads);
// do Conv
auto conv = DeviceConvNDFwdInstance{};
auto invoker = conv.MakeInvoker();
auto argument = conv.MakeArgument(
in_device_buf.GetDeviceBuffer(),
wei_device_buf.GetDeviceBuffer(),
std::array<const void*, 1>{bias_device_buf.GetDeviceBuffer()},
out_device_buf.GetDeviceBuffer(),
a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
std::array<std::array<ck::index_t, NDimSpatial + 3>, 1>{{d_g_n_k_wos_lengths}},
std::array<std::array<ck::index_t, NDimSpatial + 3>, 1>{{d_g_n_k_wos_strides}},
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
in_element_op,
wei_element_op,
out_element_op);
if(!conv.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_conv with the specified compilation parameters does "
"not support this Conv problem");
}
float avg_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = conv_param.GetFlops();
std::size_t num_btype = conv_param.GetByte<InDataType, WeiDataType, OutDataType>();
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
float gb_per_sec = num_btype / 1.E6 / avg_time;
std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< conv.GetTypeString() << std::endl;
if(do_verification)
{
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
Tensor<OutDataType> c_host(out_g_n_k_wos_desc);
auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd<NDimSpatial,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
PassThrough>();
auto ref_invoker = ref_conv.MakeInvoker();
auto ref_argument = ref_conv.MakeArgument(in,
wei,
c_host,
conv_param.conv_filter_strides_,
conv_param.conv_filter_dilations_,
conv_param.input_left_pads_,
conv_param.input_right_pads_,
in_element_op,
wei_element_op,
PassThrough{});
ref_invoker.Run(ref_argument);
// TODO: implement elementwise operation for host
out_host.ForEach(
[&](auto&, auto idx) { out_element_op(out_host(idx), c_host(idx), bias(idx)); });
out_device_buf.FromDevice(out_device.mData.data());
return ck::utils::check_err(
out_device.mData, out_host.mData, "Error: incorrect results!", 1e-5f, 1e-4f)
? 0
: 1;
}
return 0;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "grouped_convnd_fwd_bias_common.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
using InDataType = ck::half_t;
using WeiDataType = ck::half_t;
using AccDataType = float;
using CShuffleDataType = ck::half_t;
using BiasDataType = ck::half_t;
using OutDataType = ck::half_t;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using InElementOp = ck::tensor_operation::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
using OutElementOp = ck::tensor_operation::element_wise::AddRelu;
static constexpr auto ConvSpec =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
template <ck::index_t NDimSpatial,
typename InLayout,
typename WeiLayout,
typename BiasLayout,
typename OutLayout>
using DeviceGroupedConvNDFwdInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle<
NDimSpatial,
InLayout,
WeiLayout,
ck::Tuple<BiasLayout>,
OutLayout,
InDataType,
WeiDataType,
AccDataType,
CShuffleDataType,
ck::Tuple<BiasDataType>,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
ConvSpec, // ConvForwardSpecialization
GemmSpec, // GemmSpecialization
1, //
256, // BlockSize
128, // MPerBlock
256, // NPerBlock
32, // KPerBlock
8, // AK1
8, // BK1
32, // MPerXdl
32, // NPerXdl
2, // MXdlPerWave
4, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
1,
1,
S<1, 32, 1, 8>,
8>;
int main(int argc, char* argv[])
{
namespace ctc = ck::tensor_layout::convolution;
print_helper_msg();
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
// conventional group conv definition
// G = 2
// [N, C, Hi, Wi] = [128, 384, 71, 71]
// [K, C, Y, X] = [512, 192, 3, 3]
// [N, K, Ho, Wo] = [128, 512, 36, 36]
// CK group conv definition
// [G, N, C, Hi, Wi] = [2, 128, 192, 71, 71]
// [G, K, C, Y, X] = [2, 256, 192, 3, 3]
// [G, N, K, Ho, Wo] = [2, 128, 256, 36, 36]
ck::utils::conv::ConvParam conv_param{
2, 2, 128, 256, 192, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, {1, 1}};
if(argc == 1)
{
// use default
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
const ck::index_t num_dim_spatial = std::stoi(argv[4]);
conv_param = ck::utils::conv::parse_conv_param(num_dim_spatial, 5, argv);
}
const auto in_element_op = InElementOp{};
const auto wei_element_op = WeiElementOp{};
const auto out_element_op = OutElementOp{};
if(conv_param.num_dim_spatial_ == 1)
{
using InLayout = ctc::G_NW_C;
using WeiLayout = ctc::G_K_X_C;
using BiasLayout = ctc::G_NW_K;
using OutLayout = ctc::G_NW_K;
const auto in_g_n_c_wis_desc = HostTensorDescriptor(
{conv_param.G_, conv_param.N_, conv_param.C_, conv_param.input_spatial_lengths_[0]},
{
conv_param.C_, // g
conv_param.input_spatial_lengths_[0] * conv_param.G_ * conv_param.C_, // n
1, // c
conv_param.G_ * conv_param.C_ // wi
});
const auto wei_g_k_c_xs_desc = HostTensorDescriptor(
{conv_param.G_, conv_param.K_, conv_param.C_, conv_param.filter_spatial_lengths_[0]},
{
conv_param.K_ * conv_param.filter_spatial_lengths_[0] * conv_param.C_, // g
conv_param.filter_spatial_lengths_[0] * conv_param.C_, // k
1, // c
conv_param.C_ // x
});
const auto bias_g_n_k_wos_desc = HostTensorDescriptor(
{conv_param.G_, conv_param.N_, conv_param.K_, conv_param.output_spatial_lengths_[0]},
{
conv_param.K_, // g
0, // k
1, // c
0 // x
});
const auto out_g_n_k_wos_desc = HostTensorDescriptor(
{conv_param.G_, conv_param.N_, conv_param.K_, conv_param.output_spatial_lengths_[0]},
{
conv_param.K_, // g
conv_param.output_spatial_lengths_[0] * conv_param.G_ * conv_param.K_, // n
1, // k
conv_param.G_ * conv_param.K_ // wo
});
return run_grouped_conv_fwd_bias<
1,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
DeviceGroupedConvNDFwdInstance<1, InLayout, WeiLayout, BiasLayout, OutLayout>>(
do_verification,
init_method,
time_kernel,
conv_param,
in_g_n_c_wis_desc,
wei_g_k_c_xs_desc,
bias_g_n_k_wos_desc,
out_g_n_k_wos_desc,
in_element_op,
wei_element_op,
out_element_op);
}
else if(conv_param.num_dim_spatial_ == 2)
{
using InLayout = ctc::G_NHW_C;
using WeiLayout = ctc::G_K_YX_C;
using BiasLayout = ctc::G_NHW_K;
using OutLayout = ctc::G_NHW_K;
const auto in_g_n_c_wis_desc = HostTensorDescriptor(
{conv_param.G_,
conv_param.N_,
conv_param.C_,
conv_param.input_spatial_lengths_[0],
conv_param.input_spatial_lengths_[1]},
{
conv_param.C_, // g
conv_param.input_spatial_lengths_[0] * conv_param.input_spatial_lengths_[1] *
conv_param.G_ * conv_param.C_, // n
1, // c
conv_param.input_spatial_lengths_[1] * conv_param.G_ * conv_param.C_, // hi
conv_param.G_ * conv_param.C_ // wi
});
const auto wei_g_k_c_xs_desc =
HostTensorDescriptor({conv_param.G_,
conv_param.K_,
conv_param.C_,
conv_param.filter_spatial_lengths_[0],
conv_param.filter_spatial_lengths_[1]},
{
conv_param.K_ * conv_param.filter_spatial_lengths_[0] *
conv_param.filter_spatial_lengths_[1] * conv_param.C_, // g
conv_param.filter_spatial_lengths_[0] *
conv_param.filter_spatial_lengths_[1] * conv_param.C_, // k
1, // c
conv_param.filter_spatial_lengths_[1] * conv_param.C_, // y
conv_param.C_ // x
});
const auto bias_g_n_k_wos_desc =
HostTensorDescriptor({conv_param.G_,
conv_param.N_,
conv_param.K_,
conv_param.output_spatial_lengths_[0],
conv_param.output_spatial_lengths_[1]},
{
conv_param.K_, // g
0, // n
1, // k
0, // ho
0 // wo
});
const auto out_g_n_k_wos_desc = HostTensorDescriptor(
{conv_param.G_,
conv_param.N_,
conv_param.K_,
conv_param.output_spatial_lengths_[0],
conv_param.output_spatial_lengths_[1]},
{
conv_param.K_, // g
conv_param.output_spatial_lengths_[0] * conv_param.output_spatial_lengths_[1] *
conv_param.G_ * conv_param.K_, // n
1, // k
conv_param.output_spatial_lengths_[1] * conv_param.G_ * conv_param.K_, // ho
conv_param.G_ * conv_param.K_ // wo
});
return run_grouped_conv_fwd_bias<
2,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
DeviceGroupedConvNDFwdInstance<2, InLayout, WeiLayout, BiasLayout, OutLayout>>(
do_verification,
init_method,
time_kernel,
conv_param,
in_g_n_c_wis_desc,
wei_g_k_c_xs_desc,
bias_g_n_k_wos_desc,
out_g_n_k_wos_desc,
in_element_op,
wei_element_op,
out_element_op);
}
else if(conv_param.num_dim_spatial_ == 3)
{
using InLayout = ctc::G_NDHW_C;
using WeiLayout = ctc::G_K_ZYX_C;
using BiasLayout = ctc::G_NDHW_K;
using OutLayout = ctc::G_NDHW_K;
const auto in_g_n_c_wis_desc = HostTensorDescriptor(
{conv_param.G_,
conv_param.N_,
conv_param.C_,
conv_param.input_spatial_lengths_[0],
conv_param.input_spatial_lengths_[1],
conv_param.input_spatial_lengths_[2]},
{
conv_param.C_, // g
conv_param.input_spatial_lengths_[0] * conv_param.input_spatial_lengths_[1] *
conv_param.input_spatial_lengths_[2] * conv_param.G_ * conv_param.C_, // n
1, // c
conv_param.input_spatial_lengths_[1] * conv_param.input_spatial_lengths_[2] *
conv_param.G_ * conv_param.C_, // di
conv_param.input_spatial_lengths_[2] * conv_param.G_ * conv_param.C_, // hi
conv_param.G_ * conv_param.C_ // wi
});
const auto wei_g_k_c_xs_desc = HostTensorDescriptor(
{conv_param.G_,
conv_param.K_,
conv_param.C_,
conv_param.filter_spatial_lengths_[0],
conv_param.filter_spatial_lengths_[1],
conv_param.filter_spatial_lengths_[2]},
{
conv_param.K_ * conv_param.filter_spatial_lengths_[0] *
conv_param.filter_spatial_lengths_[1] * conv_param.filter_spatial_lengths_[2] *
conv_param.C_, // g
conv_param.filter_spatial_lengths_[0] * conv_param.filter_spatial_lengths_[1] *
conv_param.filter_spatial_lengths_[2] * conv_param.C_, // k
1, // c
conv_param.filter_spatial_lengths_[1] * conv_param.filter_spatial_lengths_[2] *
conv_param.C_, // z
conv_param.filter_spatial_lengths_[2] * conv_param.C_, // y
conv_param.C_ // x
});
const auto bias_g_n_k_wos_desc =
HostTensorDescriptor({conv_param.G_,
conv_param.N_,
conv_param.K_,
conv_param.output_spatial_lengths_[0],
conv_param.output_spatial_lengths_[1],
conv_param.output_spatial_lengths_[2]},
{
conv_param.K_, // g
0, // n
1, // k
0, // z
0, // y
0 // x
});
const auto out_g_n_k_wos_desc = HostTensorDescriptor(
{conv_param.G_,
conv_param.N_,
conv_param.K_,
conv_param.output_spatial_lengths_[0],
conv_param.output_spatial_lengths_[1],
conv_param.output_spatial_lengths_[2]},
{
conv_param.K_, // g
conv_param.output_spatial_lengths_[0] * conv_param.output_spatial_lengths_[1] *
conv_param.output_spatial_lengths_[2] * conv_param.G_ * conv_param.K_, // n
1, // k
conv_param.output_spatial_lengths_[1] * conv_param.output_spatial_lengths_[2] *
conv_param.G_ * conv_param.K_, // do
conv_param.output_spatial_lengths_[2] * conv_param.G_ * conv_param.K_, // ho
conv_param.G_ * conv_param.K_ // wo
});
return run_grouped_conv_fwd_bias<
3,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
DeviceGroupedConvNDFwdInstance<3, InLayout, WeiLayout, BiasLayout, OutLayout>>(
do_verification,
init_method,
time_kernel,
conv_param,
in_g_n_c_wis_desc,
wei_g_k_c_xs_desc,
bias_g_n_k_wos_desc,
out_g_n_k_wos_desc,
in_element_op,
wei_element_op,
out_element_op);
}
return 0;
}
...@@ -8,7 +8,7 @@ add_custom_target(examples) ...@@ -8,7 +8,7 @@ add_custom_target(examples)
function(add_example_executable EXAMPLE_NAME FILE_NAME) function(add_example_executable EXAMPLE_NAME FILE_NAME)
message("adding example ${EXAMPLE_NAME}") message("adding example ${EXAMPLE_NAME}")
add_executable(${EXAMPLE_NAME} ${FILE_NAME}) add_executable(${EXAMPLE_NAME} ${FILE_NAME})
target_link_libraries(${EXAMPLE_NAME} PRIVATE host_tensor) target_link_libraries(${EXAMPLE_NAME} PRIVATE utility)
add_test(NAME ${EXAMPLE_NAME} COMMAND $<TARGET_FILE:${EXAMPLE_NAME}> ${ARGN}) add_test(NAME ${EXAMPLE_NAME} COMMAND $<TARGET_FILE:${EXAMPLE_NAME}> ${ARGN})
add_dependencies(examples ${EXAMPLE_NAME}) add_dependencies(examples ${EXAMPLE_NAME})
add_dependencies(check ${EXAMPLE_NAME}) add_dependencies(check ${EXAMPLE_NAME})
...@@ -17,7 +17,7 @@ endfunction(add_example_executable EXAMPLE_NAME) ...@@ -17,7 +17,7 @@ endfunction(add_example_executable EXAMPLE_NAME)
function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME) function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
message("adding example ${EXAMPLE_NAME}") message("adding example ${EXAMPLE_NAME}")
add_executable(${EXAMPLE_NAME} ${FILE_NAME}) add_executable(${EXAMPLE_NAME} ${FILE_NAME})
target_link_libraries(${EXAMPLE_NAME} PRIVATE host_tensor) target_link_libraries(${EXAMPLE_NAME} PRIVATE utility)
add_dependencies(examples ${EXAMPLE_NAME}) add_dependencies(examples ${EXAMPLE_NAME})
endfunction(add_example_executable_no_testing EXAMPLE_NAME) endfunction(add_example_executable_no_testing EXAMPLE_NAME)
...@@ -25,23 +25,23 @@ add_subdirectory(01_gemm) ...@@ -25,23 +25,23 @@ add_subdirectory(01_gemm)
add_subdirectory(02_gemm_bilinear) add_subdirectory(02_gemm_bilinear)
add_subdirectory(03_gemm_bias_relu) add_subdirectory(03_gemm_bias_relu)
add_subdirectory(04_gemm_add_add_fastgelu) add_subdirectory(04_gemm_add_add_fastgelu)
add_subdirectory(06_conv2d_fwd_bias_relu)
add_subdirectory(07_conv2d_fwd_bias_relu_add)
add_subdirectory(09_convnd_fwd) add_subdirectory(09_convnd_fwd)
add_subdirectory(10_conv2d_bwd_data)
add_subdirectory(11_conv2d_bwd_weight)
add_subdirectory(12_reduce) add_subdirectory(12_reduce)
add_subdirectory(13_pool2d_fwd) add_subdirectory(13_pool2d_fwd)
add_subdirectory(14_gemm_xdl_requant_relu_requant) add_subdirectory(14_gemm_xdl_requant_relu_requant)
add_subdirectory(15_grouped_gemm) add_subdirectory(15_grouped_gemm)
add_subdirectory(16_gemm_reduce) add_subdirectory(16_gemm_reduce)
add_subdirectory(17_convnd_bwd_data_xdl) add_subdirectory(17_convnd_bwd_data)
add_subdirectory(18_batched_gemm_reduce) add_subdirectory(18_batched_gemm_reduce)
add_subdirectory(19_binary_elementwise) add_subdirectory(19_binary_elementwise)
add_subdirectory(20_convnd_bwd_weight_xdl) add_subdirectory(20_convnd_bwd_weight)
add_subdirectory(21_gemm_layernorm) add_subdirectory(21_gemm_layernorm)
add_subdirectory(22_cgemm) add_subdirectory(22_cgemm)
add_subdirectory(23_softmax) add_subdirectory(23_softmax)
add_subdirectory(24_batched_gemm_c_permute) add_subdirectory(24_batched_gemm_e_permute)
add_subdirectory(25_gemm_bias_c_permute) add_subdirectory(25_gemm_bias_e_permute)
add_subdirectory(26_contraction) add_subdirectory(26_contraction)
add_subdirectory(27_layernorm)
add_subdirectory(28_grouped_gemm_bias)
add_subdirectory(29_batched_gemm_multi_d)
add_subdirectory(30_grouped_convnd_fwd_bias_relu)
...@@ -146,7 +146,7 @@ ...@@ -146,7 +146,7 @@
// workaround: verifaction failure, due to compiler regression, for conv bwd-data fp16 using some // workaround: verifaction failure, due to compiler regression, for conv bwd-data fp16 using some
// tuning parameter // tuning parameter
#define CK_WORKAROUND_SWDEV_325164 1 #define CK_WORKAROUND_SWDEV_325164 0
namespace ck { namespace ck {
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <iostream>
#include <vector>
#include <iterator>
#include "ck/tensor_description/tensor_descriptor.hpp"
template <typename T>
std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
{
std::copy(std::begin(v), std::end(v), std::ostream_iterator<T>(os, " "));
return os;
}
template <typename T, std::size_t N>
std::ostream& operator<<(std::ostream& os, const std::array<T, N>& v)
{
std::copy(std::begin(v), std::end(v), std::ostream_iterator<T>(os, " "));
return os;
}
template <typename... Ts>
std::ostream& operator<<(std::ostream& os, const ck::TensorDescriptor<Ts...>& desc)
{
constexpr ck::index_t nDim = ck::remove_cvref_t<decltype(desc)>::GetNumOfDimension();
os << "{";
ck::static_for<0, nDim - 1, 1>{}([&](auto i) { os << desc.GetLength(i) << ", "; });
os << desc.GetLength(ck::Number<nDim - 1>{});
os << "}";
return os;
}
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/stream_config.hpp" #include "ck/stream_config.hpp"
#include "ck/device_utility/hip_check_error.hpp" #include "ck/host_utility/hip_check_error.hpp"
template <typename... Args, typename F> template <typename... Args, typename F>
float launch_and_time_kernel(const StreamConfig& stream_config, float launch_and_time_kernel(const StreamConfig& stream_config,
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CONVOLUTION_BACKWARD_DATA_SPECIALIZATION #pragma once
#define CONVOLUTION_BACKWARD_DATA_SPECIALIZATION
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -14,7 +13,18 @@ enum struct ConvolutionBackwardDataSpecialization ...@@ -14,7 +13,18 @@ enum struct ConvolutionBackwardDataSpecialization
Filter1x1Stride1Pad0, Filter1x1Stride1Pad0,
}; };
inline std::string
getConvBackwardDataSpecializationString(const ConvolutionBackwardDataSpecialization& s)
{
switch(s)
{
case ConvolutionBackwardDataSpecialization::Default: return "Default";
case ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0:
return "FFilter1x1Stride1Pad0";
default: return "Unrecognized specialization!";
}
}
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
...@@ -15,6 +15,19 @@ enum struct ConvolutionBackwardWeightSpecialization ...@@ -15,6 +15,19 @@ enum struct ConvolutionBackwardWeightSpecialization
OddC, OddC,
}; };
inline std::string
getConvBackwardWeightSpecializationString(const ConvolutionBackwardWeightSpecialization& s)
{
switch(s)
{
case ConvolutionBackwardWeightSpecialization::Default: return "Default";
case ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0:
return "Filter1x1Stride1Pad0";
case ConvolutionBackwardWeightSpecialization::Filter1x1Pad0: return "Filter1x1Pad0";
case ConvolutionBackwardWeightSpecialization::OddC: return "OddC";
default: return "Unrecognized specialization!";
}
}
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CONVOLUTION_FORWARD_SPECIALIZATION #pragma once
#define CONVOLUTION_FORWARD_SPECIALIZATION
#include <string> #include <string>
...@@ -33,4 +32,3 @@ inline std::string getConvForwardSpecializationString(const ConvolutionForwardSp ...@@ -33,4 +32,3 @@ inline std::string getConvForwardSpecializationString(const ConvolutionForwardSp
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
...@@ -12,8 +12,8 @@ ...@@ -12,8 +12,8 @@
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp" #include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_5ary_Elementwise_1d.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_5ary_Elementwise_1d.hpp"
#include "ck/device_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/device_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
......
...@@ -8,41 +8,43 @@ namespace ck { ...@@ -8,41 +8,43 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
struct BatchedGemmCPermuteDesc struct BatchedGemmEPermuteDesc
{ {
ck::index_t G0_, G1_, M_, N_; ck::index_t G0_, G1_, M_, N_;
ck::index_t stride_G0_, stride_G1_, stride_M_, stride_N_; ck::index_t stride_G0_, stride_G1_, stride_M_, stride_N_;
}; };
template <typename AElementwiseOperation, template <typename ALayout,
typename BLayout,
typename DELayout,
typename ADataType,
typename BDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation> typename CDEElementwiseOperation>
struct DeviceBatchedGemmCPermute : public BaseOperator struct DeviceBatchedGemmEPermute : public BaseOperator
{ {
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a, MakeArgumentPointer(const void* p_a,
const void* p_b, const void* p_b,
void* p_c, void* p_e,
index_t M, index_t M,
index_t N, index_t N,
index_t K, index_t K,
index_t stride_A, index_t stride_A,
index_t stride_B, index_t stride_B,
BatchedGemmCPermuteDesc batched_gemm_c_permute_desc, index_t batch_stride_A,
index_t batch_stride_B,
BatchedGemmEPermuteDesc batched_gemm_e_permute_desc,
index_t BatchCount,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CDEElementwiseOperation cde_element_op) = 0;
ck::index_t BatchCount) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
using DeviceBatchedGemmCPermutePtr = std::unique_ptr<
DeviceBatchedGemmCPermute<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>>;
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -7,11 +7,12 @@ ...@@ -7,11 +7,12 @@
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_c_permute.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_e_permute.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/device_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/device_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -24,12 +25,12 @@ namespace device { ...@@ -24,12 +25,12 @@ namespace device {
* given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly * given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly
* strided batched, but we can easily extend to other layouts. The returned offset can be either \p * strided batched, but we can easily extend to other layouts. The returned offset can be either \p
* index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB * index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
* limitations. * limitations.
* *
* \tparam Block2CTileMap Block2CTileMap::CalculateBottomIndex() takes in id of a workgroup and * \tparam Block2ETileMap Block2ETileMap::CalculateBottomIndex() takes in id of a workgroup and
* returns the 2D index of the tile that it computes. \see * returns the 2D index of the tile that it computes. \see
* GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run(). * GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run().
*
* \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2 * \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2
* tiles from different matrices. Keep in mind that these 2 matrices can share the same grid * tiles from different matrices. Keep in mind that these 2 matrices can share the same grid
* descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link * descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link
...@@ -37,40 +38,40 @@ namespace device { ...@@ -37,40 +38,40 @@ namespace device {
* DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the computing of * DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the computing of
* pointer offset into \p ComputePtrOffsetOfStridedBatch. * pointer offset into \p ComputePtrOffsetOfStridedBatch.
* *
* \note \p Block2CTileMap allows customized mapping between a workgroup and the C-tile it computes. * \note \p Block2ETileMap allows customized mapping between a workgroup and the C-tile it computes.
* Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to * Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to
* realize BatchedGemmCPermute and GroupedGemm (and the corresponding GEMM fusion). * realize BatchedGemmCPermute and GroupedGemm (and the corresponding GEMM fusion).
* *
*/ */
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename FloatAB, typename ABDataType,
typename FloatC, typename EDataType,
typename AGridDesc_K0_M_K1, typename AGridDesc_AK0_M_AK1,
typename BGridDesc_K0_N_K1, typename BGridDesc_BK0_N_BK1,
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CDEElementwiseOperation,
typename ComputePtrOffsetOfBatch, typename ComputePtrOffsetOfBatch,
typename Block2CTileMap, typename Block2ETileMap,
bool HasMainKBlockLoop> bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_batched_gemm_c_permute_xdl(const FloatAB* __restrict__ p_a_grid, kernel_batched_gemm_e_permute_xdl(const ABDataType* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const ABDataType* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, EDataType* __restrict__ p_e_grid,
const index_t batch_count, const index_t batch_count,
const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock, e_grid_desc_mblock_mperblock_nblock_nperblock,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op, const CDEElementwiseOperation cde_element_op,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const Block2CTileMap block_2_ctile_map) const Block2ETileMap block_2_etile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
...@@ -81,92 +82,102 @@ __global__ void ...@@ -81,92 +82,102 @@ __global__ void
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx))); static_cast<long_index_t>(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)));
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>( GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
p_a_grid + a_batch_offset, p_b_grid + b_batch_offset,
p_b_grid + b_batch_offset, ck::Tuple<>{},
ck::Tuple<>{}, p_e_grid + e_batch_offset,
p_c_grid + c_batch_offset, p_shared,
p_shared, a_element_op,
a_element_op, b_element_op,
b_element_op, cde_element_op,
c_element_op, a_grid_desc_ak0_m_ak1,
a_grid_desc_k0_m_k1, b_grid_desc_bk0_n_bk1,
b_grid_desc_k0_n_k1, ck::Tuple<>{},
ck::StaticallyIndexedArray< e_grid_desc_mblock_mperblock_nblock_nperblock,
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, block_2_etile_map);
0>{},
c_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_ctile_map);
#else #else
ignore = p_a_grid; ignore = p_a_grid;
ignore = p_b_grid; ignore = p_b_grid;
ignore = p_c_grid; ignore = p_e_grid;
ignore = batch_count; ignore = batch_count;
ignore = a_grid_desc_k0_m_k1; ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_k0_n_k1; ignore = b_grid_desc_bk0_n_bk1;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; ignore = e_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = a_element_op; ignore = a_element_op;
ignore = b_element_op; ignore = b_element_op;
ignore = c_element_op; ignore = cde_element_op;
ignore = compute_ptr_offset_of_batch; ignore = compute_ptr_offset_of_batch;
ignore = block_2_ctile_map; ignore = block_2_etile_map;
#endif #endif
} }
template <typename ALayout, template <typename ALayout,
typename BLayout, typename BLayout,
typename ELayout,
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
typename CDataType,
typename AccDataType, typename AccDataType,
typename CShuffleDataType,
typename EDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CDEElementwiseOperation,
GemmSpecialization GemmSpec, GemmSpecialization GemmSpec,
ck::index_t NumPrefetch, index_t NumPrefetch,
ck::index_t BlockSize, index_t BlockSize,
ck::index_t MPerBlock, index_t MPerBlock,
ck::index_t NPerBlock, index_t NPerBlock,
ck::index_t KPerBlock, index_t KPerBlock,
ck::index_t AK1, index_t AK1,
ck::index_t BK1, index_t BK1,
ck::index_t MPerXDL, index_t MPerXDL,
ck::index_t NPerXDL, index_t NPerXDL,
ck::index_t MXdlPerWave, index_t MXdlPerWave,
ck::index_t NXdlPerWave, index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_K0_M_K1, typename ABlockTransferThreadClusterLengths_K0_M_K1,
typename ABlockTransferThreadClusterArrangeOrder, typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder, typename ABlockTransferSrcAccessOrder,
ck::index_t ABlockTransferSrcVectorDim, index_t ABlockTransferSrcVectorDim,
ck::index_t ABlockTransferSrcScalarPerVector, index_t ABlockTransferSrcScalarPerVector,
ck::index_t ABlockTransferDstScalarPerVector_K1, index_t ABlockTransferDstScalarPerVector_K1,
bool ABlockLdsAddExtraM, index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_K0_N_K1, typename BBlockTransferThreadClusterLengths_K0_N_K1,
typename BBlockTransferThreadClusterArrangeOrder, typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder, typename BBlockTransferSrcAccessOrder,
ck::index_t BBlockTransferSrcVectorDim, index_t BBlockTransferSrcVectorDim,
ck::index_t BBlockTransferSrcScalarPerVector, index_t BBlockTransferSrcScalarPerVector,
ck::index_t BBlockTransferDstScalarPerVector_K1, index_t BBlockTransferDstScalarPerVector_K1,
bool BBlockLdsAddExtraN, index_t BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle, index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock, index_t CDEBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()> LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementwiseOperation, struct DeviceBatchedGemmEPermuteXdl : public DeviceBatchedGemmEPermute<ALayout,
BLayout,
ELayout,
ADataType,
BDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation> CDEElementwiseOperation>
{ {
using DeviceOp = DeviceBatchedGemmEPermuteXdl;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA) static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA)
{ {
const auto a_grid_desc_mraw_kraw = [&]() { const auto a_grid_desc_mraw_kraw = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>) if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
...@@ -181,95 +192,10 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw ...@@ -181,95 +192,10 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
} }
}(); }();
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
const auto MPad = M - MRaw;
const auto KPad = K - KRaw;
if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad both M and K
assert(K % AK1 == 0);
const auto AK0 = K / AK1;
const auto a_grid_desc_m_k =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_right_pad_transform(MRaw, MPad),
make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MNPadding)
{
// pad M, but not K
assert(KRaw % AK1 == 0);
const auto AK0 = KRaw / AK1;
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_right_pad_transform(MRaw, MPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
// pad K, but not M
assert(K % AK1 == 0);
const auto AK0 = K / AK1;
const auto a_grid_desc_m_k = transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(MRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else
{
// not pad M or K
assert(KRaw % AK1 == 0);
const auto AK0 = KRaw / AK1;
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(MRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
} }
static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB) static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
{ {
const auto b_grid_desc_nraw_kraw = [&]() { const auto b_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
...@@ -284,142 +210,16 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw ...@@ -284,142 +210,16 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
} }
}(); }();
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
const auto NPad = N - NRaw;
const auto KPad = K - KRaw;
if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad both N and K
assert(K % BK1 == 0);
const auto BK0 = K / BK1;
const auto b_grid_desc_n_k =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_right_pad_transform(NRaw, NPad),
make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::MNPadding)
{
// pad N, but not K
assert(KRaw % BK1 == 0);
const auto BK0 = KRaw / BK1;
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
// pad K, but not N
assert(K % BK1 == 0);
const auto BK0 = K / BK1;
const auto b_grid_desc_n_k = transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(NRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else
{
// not pad N or K
assert(KRaw % BK1 == 0);
const auto BK0 = KRaw / BK1;
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(NRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
} }
static auto static auto
MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t stride_M, index_t stride_N) MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t stride_M, index_t stride_N)
{ {
const auto c_grid_desc_mraw_nraw = [&]() { const auto e_grid_desc_mraw_nraw =
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), make_tuple(stride_M, stride_N));
make_tuple(stride_M, stride_N));
}();
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
const auto MPad = M - MRaw;
const auto NPad = N - NRaw;
if constexpr(GemmSpec == GemmSpecialization::MNPadding || return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad M and N
return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(MRaw, MPad),
make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
// pad M, but not N
return transform_tensor_descriptor(
c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
// pad N, but not M
return transform_tensor_descriptor(
c_grid_desc_mraw_nraw,
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
// not pad M or N
return c_grid_desc_mraw_nraw;
}
} }
static auto MakeEGridDescriptor_G0_G1_M_N(index_t G0, static auto MakeEGridDescriptor_G0_G1_M_N(index_t G0,
...@@ -489,9 +289,9 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw ...@@ -489,9 +289,9 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
} }
} }
using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1));
using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1, 1)); using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(1, 1, 1, 1));
using EGridDesc_G0_G1_M_N = decltype(MakeEGridDescriptor_G0_G1_M_N(1, 1, 1, 1, 1, 1, 1, 1)); using EGridDesc_G0_G1_M_N = decltype(MakeEGridDescriptor_G0_G1_M_N(1, 1, 1, 1, 1, 1, 1, 1));
struct ComputePtrOffsetOfStridedBatch struct ComputePtrOffsetOfStridedBatch
...@@ -529,19 +329,20 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw ...@@ -529,19 +329,20 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
EGridDesc_G0_G1_M_N e_grid_desc_g0_g1_m_n_; EGridDesc_G0_G1_M_N e_grid_desc_g0_g1_m_n_;
}; };
using GridwiseGemm = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle< using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
AccDataType, AccDataType,
CDataType, // CShuffleDataType, CShuffleDataType,
ck::Tuple<>, // DsDataType, ck::Tuple<>, // DsDataType,
CDataType, // EDataType, EDataType, // EDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CDEElementwiseOperation,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1, AGridDesc_M_K,
BGridDesc_K0_N_K1, BGridDesc_N_K,
CGridDesc_M_N, Tuple<>,
EGridDesc_M_N,
NumPrefetch, NumPrefetch,
BlockSize, BlockSize,
MPerBlock, MPerBlock,
...@@ -560,7 +361,7 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw ...@@ -560,7 +361,7 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1, ABlockTransferDstScalarPerVector_K1,
false, // AThreadTransferSrcResetCoordinateAfterRun, false, // AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsAddExtraM, ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
...@@ -568,118 +369,134 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw ...@@ -568,118 +369,134 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1, BBlockTransferDstScalarPerVector_K1,
false, // BThreadTransferSrcResetCoordinateAfterRun, false, // BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN, BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle, CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEBlockTransferScalarPerVector_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched>; LoopSched>;
using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = decltype( using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{})); GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
using Block2CTileMap = typename GridwiseGemm::DefaultBlock2ETileMap; using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = decltype(
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}));
using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap;
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
Argument(const ADataType* p_a_grid, Argument(const ADataType* p_a_grid,
const BDataType* p_b_grid, const BDataType* p_b_grid,
CDataType* p_c_grid, EDataType* p_e_grid,
index_t M, index_t M,
index_t N, index_t N,
index_t K, index_t K,
index_t stride_A, index_t stride_A,
index_t stride_B, index_t stride_B,
BatchedGemmCPermuteDesc batched_gemm_c_permute_desc, index_t batch_stride_A,
index_t batch_stride_B,
BatchedGemmEPermuteDesc batched_gemm_e_permute_desc,
index_t BatchCount,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CDEElementwiseOperation cde_element_op)
index_t BatchCount)
: p_a_grid_{p_a_grid}, : p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid}, p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid}, p_e_grid_{p_e_grid},
BatchCount_(BatchCount), BatchCount_(BatchCount),
a_grid_desc_k0_m_k1_{ a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(M, K, stride_A)},
DeviceBatchedGemmCPermuteXdl::MakeAGridDescriptor_AK0_M_AK1(M, K, stride_A)}, b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(K, N, stride_B)},
b_grid_desc_k0_n_k1_{ e_grid_desc_m_n_{
DeviceBatchedGemmCPermuteXdl::MakeBGridDescriptor_BK0_N_BK1(K, N, stride_B)}, DeviceOp::MakeEGridDescriptor_M_N(batched_gemm_e_permute_desc.M_,
c_grid_desc_m_n_{DeviceBatchedGemmCPermuteXdl::MakeCGridDescriptor_M_N( batched_gemm_e_permute_desc.N_,
batched_gemm_c_permute_desc.M_, batched_gemm_e_permute_desc.stride_M_,
batched_gemm_c_permute_desc.N_, batched_gemm_e_permute_desc.stride_N_)},
batched_gemm_c_permute_desc.stride_M_, a_grid_desc_ak0_m_ak1_{
batched_gemm_c_permute_desc.stride_N_)}, GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)},
e_grid_desc_g0_g1_m_n_{DeviceBatchedGemmCPermuteXdl::MakeEGridDescriptor_G0_G1_M_N( b_grid_desc_bk0_n_bk1_{
batched_gemm_c_permute_desc.G0_, GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)},
batched_gemm_c_permute_desc.G1_, e_grid_desc_mblock_mperblock_nblock_nperblock{},
batched_gemm_c_permute_desc.M_, e_grid_desc_g0_g1_m_n_{
batched_gemm_c_permute_desc.N_, DeviceOp::MakeEGridDescriptor_G0_G1_M_N(batched_gemm_e_permute_desc.G0_,
batched_gemm_c_permute_desc.stride_G0_, batched_gemm_e_permute_desc.G1_,
batched_gemm_c_permute_desc.stride_G1_, batched_gemm_e_permute_desc.M_,
batched_gemm_c_permute_desc.stride_M_, batched_gemm_e_permute_desc.N_,
batched_gemm_c_permute_desc.stride_N_)}, batched_gemm_e_permute_desc.stride_G0_,
c_grid_desc_mblock_mperblock_nblock_nperblock{}, batched_gemm_e_permute_desc.stride_G1_,
compute_ptr_offset_of_batch_{ batched_gemm_e_permute_desc.stride_M_,
type_convert<index_t>(a_grid_desc_k0_m_k1_.GetElementSpaceSize()), batched_gemm_e_permute_desc.stride_N_)},
type_convert<index_t>(b_grid_desc_k0_n_k1_.GetElementSpaceSize()), compute_ptr_offset_of_batch_{batch_stride_A, batch_stride_B, e_grid_desc_g0_g1_m_n_},
e_grid_desc_g0_g1_m_n_}, block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(c_grid_desc_m_n_)},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
c_element_op_{c_element_op} cde_element_op_{cde_element_op}
{ {
if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_,
if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, b_grid_desc_n_k_,
b_grid_desc_k0_n_k1_, ck::Tuple<>{},
c_grid_desc_m_n_, e_grid_desc_m_n_,
block_2_ctile_map_)) block_2_etile_map_))
{ {
c_grid_desc_mblock_mperblock_nblock_nperblock = e_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n_); e_grid_desc_m_n_);
} }
} }
void Print() const
{
std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl;
std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl;
std::cout << "C[M, N]: " << e_grid_desc_m_n_ << std::endl;
}
// private: // private:
// pointers
const ADataType* p_a_grid_; const ADataType* p_a_grid_;
const BDataType* p_b_grid_; const BDataType* p_b_grid_;
CDataType* p_c_grid_; EDataType* p_e_grid_;
// batch count
index_t BatchCount_; index_t BatchCount_;
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; // tensor descriptors for problem definiton
CGridDesc_M_N c_grid_desc_m_n_; AGridDesc_M_K a_grid_desc_m_k_;
BGridDesc_N_K b_grid_desc_n_k_;
EGridDesc_M_N e_grid_desc_m_n_;
// tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock;
EGridDesc_G0_G1_M_N e_grid_desc_g0_g1_m_n_; EGridDesc_G0_G1_M_N e_grid_desc_g0_g1_m_n_;
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock;
// for calculating Batch offset
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_;
Block2CTileMap block_2_ctile_map_;
// block-to-e-tile map
Block2ETileMap block_2_etile_map_;
// element-wise op
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_; CDEElementwiseOperation cde_element_op_;
}; };
// Invoker // Invoker
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
{ {
using Argument = DeviceBatchedGemmCPermuteXdl::Argument; using Argument = DeviceOp::Argument;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
{ if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) arg.b_grid_desc_n_k_,
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " ck::Tuple<>{},
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; arg.e_grid_desc_m_n_,
arg.block_2_etile_map_))
std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0)
<< ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
<< arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.c_grid_desc_m_n_{" << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_))
{ {
throw std::runtime_error( throw std::runtime_error(
"wrong! GridwiseBatchedGemmCPermute_km_kn_m0m1n0n1_xdlops_v2r3 has invalid " "wrong! GridwiseBatchedGemmCPermute_km_kn_m0m1n0n1_xdlops_v2r3 has invalid "
...@@ -687,26 +504,24 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw ...@@ -687,26 +504,24 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
} }
const index_t grid_size = const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.BatchCount_; arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * arg.BatchCount_;
const auto K = const auto K =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) { auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_batched_gemm_c_permute_xdl< const auto kernel = kernel_batched_gemm_e_permute_xdl<
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, EDataType,
remove_reference_t<DeviceBatchedGemmCPermuteXdl::AGridDesc_K0_M_K1>, remove_reference_t<DeviceOp::AGridDesc_AK0_M_AK1>,
remove_reference_t<DeviceBatchedGemmCPermuteXdl::BGridDesc_K0_N_K1>, remove_reference_t<DeviceOp::BGridDesc_BK0_N_BK1>,
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CDEElementwiseOperation,
ComputePtrOffsetOfStridedBatch, ComputePtrOffsetOfStridedBatch,
remove_reference_t<Block2CTileMap>, remove_reference_t<Block2ETileMap>,
has_main_k_block_loop_>; has_main_k_block_loop_>;
return launch_and_time_kernel(stream_config, return launch_and_time_kernel(stream_config,
...@@ -716,28 +531,26 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw ...@@ -716,28 +531,26 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
0, 0,
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_c_grid_, arg.p_e_grid_,
arg.BatchCount_, arg.BatchCount_,
arg.a_grid_desc_k0_m_k1_, arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock, arg.e_grid_desc_mblock_mperblock_nblock_nperblock,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_, arg.cde_element_op_,
arg.compute_ptr_offset_of_batch_, arg.compute_ptr_offset_of_batch_,
arg.block_2_ctile_map_); arg.block_2_etile_map_);
}; };
if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{ {
ave_time = launch_kernel(integral_constant<bool, true>{}); return launch_kernel(integral_constant<bool, true>{});
} }
else else
{ {
ave_time = launch_kernel(integral_constant<bool, false>{}); return launch_kernel(integral_constant<bool, false>{});
} }
return ave_time;
} }
// polymorphic // polymorphic
...@@ -756,10 +569,11 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw ...@@ -756,10 +569,11 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_n_k_,
arg.c_grid_desc_m_n_, ck::Tuple<>{},
arg.block_2_ctile_map_); arg.e_grid_desc_m_n_,
arg.block_2_etile_map_);
} }
// polymorphic // polymorphic
...@@ -770,31 +584,35 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw ...@@ -770,31 +584,35 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
static auto MakeArgument(const ADataType* p_a, static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b, const BDataType* p_b,
CDataType* p_c, EDataType* p_e,
index_t M, index_t M,
index_t N, index_t N,
index_t K, index_t K,
index_t stride_A, index_t stride_A,
index_t stride_B, index_t stride_B,
BatchedGemmCPermuteDesc batched_gemm_c_permute_desc, index_t batch_stride_A,
index_t batch_stride_B,
BatchedGemmEPermuteDesc batched_gemm_e_permute_desc,
index_t BatchCount,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CDEElementwiseOperation cde_element_op)
index_t BatchCount)
{ {
return Argument{p_a, return Argument{p_a,
p_b, p_b,
p_c, p_e,
M, M,
N, N,
K, K,
stride_A, stride_A,
stride_B, stride_B,
batched_gemm_c_permute_desc, batch_stride_A,
batch_stride_B,
batched_gemm_e_permute_desc,
BatchCount,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, cde_element_op};
BatchCount};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
...@@ -803,31 +621,35 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw ...@@ -803,31 +621,35 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a, MakeArgumentPointer(const void* p_a,
const void* p_b, const void* p_b,
void* p_c, void* p_e,
index_t M, index_t M,
index_t N, index_t N,
index_t K, index_t K,
index_t stride_A, index_t stride_A,
index_t stride_B, index_t stride_B,
BatchedGemmCPermuteDesc batched_gemm_c_permute_desc, index_t batch_stride_A,
index_t batch_stride_B,
BatchedGemmEPermuteDesc batched_gemm_e_permute_desc,
index_t BatchCount,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CDEElementwiseOperation cde_element_op) override
index_t BatchCount) override
{ {
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c), static_cast<EDataType*>(p_e),
M, M,
N, N,
K, K,
stride_A, stride_A,
stride_B, stride_B,
batched_gemm_c_permute_desc, batch_stride_A,
batch_stride_B,
batched_gemm_e_permute_desc,
BatchCount,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, cde_element_op);
BatchCount);
} }
// polymorphic // polymorphic
...@@ -842,7 +664,7 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw ...@@ -842,7 +664,7 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DeviceBatchedGemmCPermuteXdl" str << "DeviceBatchedGemmEPermuteXdl"
<< "<" << "<"
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment