Commit dd6a8de4 authored by Jehandad Khan's avatar Jehandad Khan
Browse files

Merge branch 'develop' into jd/dev_pkg

parents 0aa899aa abf4bdb9
add_example_executable(example_grouped_gemm_xdl_fp16 grouped_gemm_xdl_fp16.cpp)
# Instructions for ```example_grouped_gemm_xdl```
## Run ```example_grouped_gemm_xdl```
```bash
#arg1: verification (0=no, 1=yes)
#arg2: initialization (0=no init, 1=integer value, 2=decimal value)
#arg3: run kernel # of times (>1)
./bin/example_grouped_gemm_xdl_fp16 0 1 5
```
Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16)
```
gemm[0] a_m_k: dim 2, lengths {256, 64}, strides {64, 1} b_k_n: dim 2, lengths {64, 128}, strides {1, 64} c_m_n: dim 2, lengths {256, 128}, strides {128, 1}
gemm[1] a_m_k: dim 2, lengths {512, 128}, strides {128, 1} b_k_n: dim 2, lengths {128, 256}, strides {1, 128} c_m_n: dim 2, lengths {512, 256}, strides {256, 1}
gemm[2] a_m_k: dim 2, lengths {768, 192}, strides {192, 1} b_k_n: dim 2, lengths {192, 384}, strides {1, 192} c_m_n: dim 2, lengths {768, 384}, strides {384, 1}
gemm[3] a_m_k: dim 2, lengths {1024, 256}, strides {256, 1} b_k_n: dim 2, lengths {256, 512}, strides {1, 256} c_m_n: dim 2, lengths {1024, 512}, strides {512, 1}
group: 0 arg.a_grid_desc_k0_m_k1_{8, 256, 8}, arg.b_grid_desc_k0_n_k1_{8, 128, 8}, arg.c_grid_desc_m_n_{ 256, 128}
group: 1 arg.a_grid_desc_k0_m_k1_{16, 512, 8}, arg.b_grid_desc_k0_n_k1_{16, 256, 8}, arg.c_grid_desc_m_n_{ 512, 256}
group: 2 arg.a_grid_desc_k0_m_k1_{24, 768, 8}, arg.b_grid_desc_k0_n_k1_{24, 384, 8}, arg.c_grid_desc_m_n_{ 768, 384}
group: 3 arg.a_grid_desc_k0_m_k1_{32, 1024, 8}, arg.b_grid_desc_k0_n_k1_{32, 512, 8}, arg.c_grid_desc_m_n_{ 1024, 512}
launch_and_time_kernel: grid_dim {30, 1, 1}, block_dim {256, 1, 1}
Warm up
Start running 5 times...
Perf: 0.037887 ms, 11.0706 TFlops, 90.8132 GB/s, DeviceGroupedGemmXdl<256, 256, 128, 4, 8, 32, 32, 4, 2>
```
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include <half.hpp>
#include "check_err.hpp"
#include "config.hpp"
#include "print.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "host_gemm.hpp"
#include "device_tensor.hpp"
#include "device_grouped_gemm_xdl.hpp"
#include "element_wise_operation.hpp"
#include "reference_gemm.hpp"
#include "gemm_specialization.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = ck::half_t;
using BDataType = ck::half_t;
using CDataType = ck::half_t;
using AccDataType = float;
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// static constexpr auto GemmMNPadding =
// ck::tensor_operation::device::GemmSpecialization::MNPadding;
// clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemmXdl
//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| Num|
//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| Prefetch|
//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| |
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AElementOp, BElementOp, CElementOp>;
int main(int argc, char* argv[])
{
bool do_verification = 0;
int init_method = 0;
int nrepeat = 5;
if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: run kernel # of times (>1)\n");
exit(0);
}
int group_count = 4;
// GEMM shape
std::vector<ck::tensor_operation::device::GemmShape> gemm_shapes;
std::vector<const void*> p_a, p_b;
std::vector<void*> p_c;
gemm_shapes.reserve(group_count);
for(int i = 0; i < group_count; i++)
{
int M = 256 + 256 * i;
int N = 128 + 128 * i;
int K = 64 + 64 * i;
gemm_shapes.push_back({M, N, K, K, K, N});
}
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({stride, 1}));
}
else
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({1, stride}));
}
};
std::vector<Tensor<ADataType>> a_tensors;
;
std::vector<Tensor<BDataType>> b_tensors;
std::vector<Tensor<CDataType>> c_host_tensors;
std::vector<Tensor<CDataType>> c_device_tensors;
a_tensors.reserve(group_count);
b_tensors.reserve(group_count);
c_host_tensors.reserve(group_count);
c_device_tensors.reserve(group_count);
using DeviceMemPtr = std::unique_ptr<DeviceMem>;
std::vector<DeviceMemPtr> a_tensors_device, b_tensors_device, c_tensors_device;
a_tensors_device.reserve(group_count);
b_tensors_device.reserve(group_count);
c_tensors_device.reserve(group_count);
std::size_t flop = 0, num_btype = 0;
for(int i = 0; i < gemm_shapes.size(); i++)
{
a_tensors.push_back(Tensor<ADataType>(f_host_tensor_descriptor(
gemm_shapes[i].M, gemm_shapes[i].K, gemm_shapes[i].StrideA, ALayout{})));
b_tensors.push_back(Tensor<BDataType>(f_host_tensor_descriptor(
gemm_shapes[i].K, gemm_shapes[i].N, gemm_shapes[i].StrideB, BLayout{})));
c_host_tensors.push_back(Tensor<CDataType>(f_host_tensor_descriptor(
gemm_shapes[i].M, gemm_shapes[i].N, gemm_shapes[i].StrideC, CLayout{})));
c_device_tensors.push_back(Tensor<CDataType>(f_host_tensor_descriptor(
gemm_shapes[i].M, gemm_shapes[i].N, gemm_shapes[i].StrideC, CLayout{})));
std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc
<< " b_k_n: " << b_tensors[i].mDesc << " c_m_n: " << c_device_tensors[i].mDesc
<< std::endl;
flop += std::size_t(2) * gemm_shapes[i].M * gemm_shapes[i].K * gemm_shapes[i].N;
num_btype += sizeof(ADataType) * a_tensors[i].mDesc.GetElementSize() +
sizeof(BDataType) * b_tensors[i].mDesc.GetElementSize() +
sizeof(CDataType) * c_device_tensors[i].mDesc.GetElementSize();
switch(init_method)
{
case 0: break;
case 1:
a_tensors[i].GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_tensors[i].GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
break;
case 2:
a_tensors[i].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break;
default:
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{});
}
}
for(int i = 0; i < gemm_shapes.size(); i++)
{
a_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(ADataType) * a_tensors[i].mDesc.GetElementSpace()));
b_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(BDataType) * b_tensors[i].mDesc.GetElementSpace()));
c_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(CDataType) * c_device_tensors[i].mDesc.GetElementSpace()));
a_tensors_device[i]->ToDevice(a_tensors[i].mData.data());
b_tensors_device[i]->ToDevice(b_tensors[i].mData.data());
p_a.push_back(a_tensors_device[i]->GetDeviceBuffer());
p_b.push_back(b_tensors_device[i]->GetDeviceBuffer());
p_c.push_back(c_tensors_device[i]->GetDeviceBuffer());
}
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
// do GEMM
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
auto argument =
gemm.MakeArgument(p_a, p_b, p_c, gemm_shapes, a_element_op, b_element_op, c_element_op);
if(!gemm.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
float ave_time = invoker.Run(argument, nrepeat);
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl;
if(do_verification)
{
for(int i = 0; i < gemm_shapes.size(); i++)
{
c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data());
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(a_tensors[i],
b_tensors[i],
c_host_tensors[i],
a_element_op,
b_element_op,
c_element_op);
ref_invoker.Run(ref_argument);
ck::utils::check_err(c_device_tensors[i].mData, c_host_tensors[i].mData);
}
}
return 0;
}
add_example_executable(example_gemm_reduce_xdl_fp16 gemm_reduce_xdl_fp16.cpp)
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include <half.hpp>
#include "config.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "device_tensor.hpp"
#include "device_gemm_reduce_xdl_cshuffle.hpp"
#include "element_wise_operation.hpp"
#include "reference_gemm.hpp"
#include "gemm_specialization.hpp"
#include "element_wise_reduce_operation.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using ADataType = F16;
using BDataType = F16;
using CDataType = F16;
using DDataType = F32;
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
using D0ReduceOp = ck::tensor_operation::element_wise::ReduceSum;
using D1ReduceOp = ck::tensor_operation::element_wise::ReduceSquareSum;
static constexpr auto GemmSpecialization =
ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off
using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_Xdl_CShuffle
//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| D0| D1| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//######| | | | Type| Type| Type| DataType| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Reduce| Reduce| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, AElementOp, BElementOp, CElementOp, D0ReduceOp, D1ReduceOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AElementOp, BElementOp, CElementOp>;
int main(int argc, char* argv[])
{
bool do_verification = 1;
int init_method = 1;
int nrepeat = 5;
// GEMM shape
ck::index_t M = 3840;
ck::index_t N = 4096;
ck::index_t K = 4096;
ck::index_t StrideA = 4096;
ck::index_t StrideB = 4096;
ck::index_t StrideC = 4096;
if(argc == 1)
{
// do nothing
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]);
}
else if(argc == 10)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
StrideA = std::stoi(argv[7]);
StrideB = std::stoi(argv[8]);
StrideC = std::stoi(argv[9]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: run kernel # of times (>1)\n");
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n");
exit(0);
}
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({stride, 1}));
}
else
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({1, stride}));
}
};
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<DDataType> d0_m_host_result(
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)})));
Tensor<DDataType> d1_m_host_result(
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)})));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<DDataType> d0_m_device_result(
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)})));
Tensor<DDataType> d1_m_device_result(
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)})));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
std::cout << "d0_m: " << d0_m_host_result.mDesc << std::endl;
std::cout << "d1_m: " << d1_m_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
break;
default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break;
}
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace());
DeviceMem d0_device_buf(sizeof(DDataType) * d0_m_device_result.mDesc.GetElementSpace());
DeviceMem d1_device_buf(sizeof(DDataType) * d1_m_device_result.mDesc.GetElementSpace());
a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
auto d0_reduce_op = D0ReduceOp{};
auto d1_reduce_op = D1ReduceOp{};
// do GEMM
auto gemm = DeviceGemmReduceInstance{};
auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
static_cast<DDataType*>(d0_device_buf.GetDeviceBuffer()),
static_cast<DDataType*>(d1_device_buf.GetDeviceBuffer()),
M,
N,
K,
StrideA,
StrideB,
StrideC,
a_element_op,
b_element_op,
c_element_op,
d0_reduce_op,
d1_reduce_op);
if(!gemm.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
// warm up
invoker.Run(argument);
// timing
float total_time = 0;
for(int i = 0; i < nrepeat; ++i)
{
// init DO, D1 to 0
d0_device_buf.SetZero();
d1_device_buf.SetZero();
KernelTimer timer;
timer.Start();
invoker.Run(argument);
timer.End();
total_time += timer.GetElapsedTime();
}
float ave_time = total_time / nrepeat;
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl;
if(do_verification)
{
c_device_buf.FromDevice(c_m_n_device_result.mData.data());
d0_device_buf.FromDevice(d0_m_device_result.mData.data());
d1_device_buf.FromDevice(d1_m_device_result.mData.data());
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
ref_invoker.Run(ref_argument);
for(int m = 0; m < M; ++m)
{
float d0_acc = d0_reduce_op.GetReduceZeroValue();
float d1_acc = d1_reduce_op.GetReduceZeroValue();
for(int n = 0; n < N; ++n)
{
d0_reduce_op.Reduce(d0_acc, c_m_n_host_result(m, n));
d1_reduce_op.Reduce(d1_acc, c_m_n_host_result(m, n));
}
d0_m_host_result(m) = d0_acc;
d1_m_host_result(m) = d1_acc;
}
check_error(c_m_n_host_result, c_m_n_device_result);
check_error(d0_m_host_result, d0_m_device_result);
check_error(d1_m_host_result, d1_m_device_result);
}
return 0;
}
add_example_executable(example_convnd_bwd_data_xdl convnd_bwd_data_xdl.cpp)
# Instructions for ```example_convnd_bwd_data_xdl```
## Run ```example_example_convnd_bwd_data_xdl```
```bash
#arg1: verification (0=no, 1=yes)
#arg2: initialization (0=no init, 1=integer value, 2=decimal value)
#arg3: run kernel # of times (>1)
#arg4: num_dim_spatial(1|2|3)
#arg5 to ...: N, K, C, [Z,] [Y,] X, [Di,] [Hi,] Wi, S[z,] [Sy,] Sx, [Dz,] [Dy,] Dx, [LeftPz,] [LeftPy,] LeftPx, [RightPy,] [RightPy,] RightPx
./bin/example_convnd_bwd_data_xdl 0 1 5
```
Result
```
in_n_c_hi_wi: dim 4, lengths {128, 128, 71, 71}, strides {645248, 1, 9088, 128}
wei_k_c_y_x: dim 4, lengths {256, 128, 3, 3}, strides {1152, 1, 384, 128}
out_n_k_ho_wo: dim 4, lengths {128, 256, 36, 36}, strides {331776, 1, 9216, 256}
arg.a_grid_desc_k0_m_k1_container_{128, 175232, 8}
arg.b_grid_desc_k0_n_k1_container_{128, 128, 8}
arg.c_grid_desc_m_n_container_{ 175232, 128}
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_( 2738, 2, 2, 2, 4, 2 )
launch_and_time_kernel: grid_dim {1369, 1, 1}, block_dim {256, 1, 1}
Warm up
Start running 1 times...
arg.a_grid_desc_k0_m_k1_container_{64, 175232, 8}
arg.b_grid_desc_k0_n_k1_container_{64, 128, 8}
arg.c_grid_desc_m_n_container_{ 175232, 128}
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_( 2738, 2, 2, 2, 4, 2 )
launch_and_time_kernel: grid_dim {1369, 1, 1}, block_dim {256, 1, 1}
Warm up
Start running 1 times...
arg.a_grid_desc_k0_m_k1_container_{64, 175232, 8}
arg.b_grid_desc_k0_n_k1_container_{64, 128, 8}
arg.c_grid_desc_m_n_container_{ 175232, 128}
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_( 2738, 2, 2, 2, 4, 2 )
launch_and_time_kernel: grid_dim {1369, 1, 1}, block_dim {256, 1, 1}
Warm up
Start running 1 times...
arg.a_grid_desc_k0_m_k1_container_{32, 175232, 8}
arg.b_grid_desc_k0_n_k1_container_{32, 128, 8}
arg.c_grid_desc_m_n_container_{ 175232, 128}
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_( 2738, 2, 2, 2, 4, 2 )
launch_and_time_kernel: grid_dim {1369, 1, 1}, block_dim {256, 1, 1}
Warm up
Start running 1 times...
Perf: 1.40031 ms, 69.8734 TFlops, 179.037 GB/s
```
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include <half.hpp>
#include "config.hpp"
#include "conv_fwd_util.hpp"
#include "print.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "device_tensor.hpp"
#include "tensor_layout.hpp"
#include "element_wise_operation.hpp"
#include "device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp"
#include "reference_conv_bwd_data.hpp"
using InDataType = ck::half_t;
using WeiDataType = ck::half_t;
using OutDataType = ck::half_t;
using AccDataType = float;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using InElementOp = ck::tensor_operation::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
using OutElementOp = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto ConvBwdDefault =
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default;
using DeviceConvBwdDataBasePtr =
ck::tensor_operation::device::DeviceConvBwdDataPtr<InElementOp, WeiElementOp, OutElementOp>;
template <ck::index_t NumDimSpatial>
using DeviceConvNDBwdDataInstance = ck::tensor_operation::device::
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K<
InDataType, // InDataType
WeiDataType, // WeiDataType
OutDataType, // OutDataType
AccDataType, // AccDataType
InElementOp, // InElementwiseOperation
WeiElementOp, // WeiElementwiseOperation
OutElementOp, // OutElementwiseOperation
ConvBwdDefault, // ConvolutionBackwardDataSpecialization
NumDimSpatial, // NumDimSpatial
256, // BlockSize
128, // MPerBlock
128, // NPerBlock
4, // K0PerBlock
8, // K1
32, // MPerXdl
32, // NPerXdl
2, // MXdlPerWave
2, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_K1
true, // ABlockLdsAddExtraM
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1
S<2, 0, 1>, // BBlockTransferThreadClusterArrangeOrder
S<0, 2, 1>, // BBlockTransferSrcAccessOrder
1, // BBlockTransferSrcVectorDim
2, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_K1
true, // BBlockLdsAddExtraN
7,
1>; // GemmCThreadTransferDstScalarPerVector
template <ck::index_t NumDimSpatial>
using ReferenceConvBwdDataInstance =
ck::tensor_operation::host::ReferenceConvBwdData<InDataType,
WeiDataType,
OutDataType,
AccDataType,
InElementOp,
WeiElementOp,
OutElementOp,
NumDimSpatial>;
void print_use_msg()
{
std::cout << "arg1: verification (0=no, 1=yes)\n"
<< "arg2: initialization (0=no init, 1=random value, 2= init to 1 )\n"
<< "arg3: run kernel # of times (>1)\n"
<< "arg4: N spatial dimensions (default 2)\n"
<< "Following arguments (depending on number of spatial dims):\n"
<< " N, K, C, \n"
<< " <filter spatial dimensions>, (ie Y, X for 2D)\n"
<< " <input image spatial dimensions>, (ie Hi, Wi for 2D)\n"
<< " <strides>, (ie Sy, Sx for 2D)\n"
<< " <dilations>, (ie Dy, Dx for 2D)\n"
<< " <left padding>, (ie LeftPy, LeftPx for 2D)\n"
<< " <right padding>, (ie RightPy, RightPx for 2D)\n"
<< std::endl;
}
ck::utils::conv::ConvParams parse_conv_params(int num_dim_spatial, char* argv[])
{
// (N, K, C) + num_dim_spatial * 6 (filter, input, strides, dilations, pad left, pad right)
ck::utils::conv::ConvParams params;
int arg_idx = 5;
params.num_dim_spatial = num_dim_spatial;
params.N = std::stoi(argv[arg_idx++]);
params.K = std::stoi(argv[arg_idx++]);
params.C = std::stoi(argv[arg_idx++]);
params.filter_spatial_lengths.resize(num_dim_spatial);
for(int i = 0; i < num_dim_spatial; ++i)
{
params.filter_spatial_lengths[i] = std::stoi(argv[arg_idx++]);
}
params.input_spatial_lengths.resize(num_dim_spatial);
for(int i = 0; i < num_dim_spatial; ++i)
{
params.input_spatial_lengths[i] = std::stoi(argv[arg_idx++]);
}
params.conv_filter_strides.resize(num_dim_spatial);
for(int i = 0; i < num_dim_spatial; ++i)
{
params.conv_filter_strides[i] = std::stoi(argv[arg_idx++]);
}
params.conv_filter_dilations.resize(num_dim_spatial);
for(int i = 0; i < num_dim_spatial; ++i)
{
params.conv_filter_dilations[i] = std::stoi(argv[arg_idx++]);
}
params.input_left_pads.resize(num_dim_spatial);
for(int i = 0; i < num_dim_spatial; ++i)
{
params.input_left_pads[i] = std::stoi(argv[arg_idx++]);
}
params.input_right_pads.resize(num_dim_spatial);
for(int i = 0; i < num_dim_spatial; ++i)
{
params.input_right_pads[i] = std::stoi(argv[arg_idx++]);
}
return params;
}
DeviceConvBwdDataBasePtr get_conv_instance(int num_dim_spatial)
{
switch(num_dim_spatial)
{
case 3: {
return std::make_unique<DeviceConvNDBwdDataInstance<3>>();
}
case 2: {
return std::make_unique<DeviceConvNDBwdDataInstance<2>>();
}
case 1: {
return std::make_unique<DeviceConvNDBwdDataInstance<1>>();
}
default: {
throw std::runtime_error("Unsupported number of spatial dimensions provided!");
}
}
}
int main(int argc, char* argv[])
{
bool do_verification = 0;
int init_method = 0;
int nrepeat = 5;
int num_dim_spatial = 2;
ck::utils::conv::ConvParams params;
params.C = 128;
if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]);
}
else if(argc > 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]);
num_dim_spatial = std::stoi(argv[4]);
// check args number
int conv_args = 3 + num_dim_spatial * 6;
int cmdline_nargs = conv_args + 5;
if(cmdline_nargs != argc)
{
print_use_msg();
exit(1);
}
params = parse_conv_params(num_dim_spatial, argv);
}
else if(argc != 1)
{
print_use_msg();
exit(1);
}
std::vector<std::size_t> input_dims{static_cast<std::size_t>(params.N),
static_cast<std::size_t>(params.C)};
input_dims.insert(std::end(input_dims),
std::begin(params.input_spatial_lengths),
std::end(params.input_spatial_lengths));
std::vector<std::size_t> filter_dims{static_cast<std::size_t>(params.K),
static_cast<std::size_t>(params.C)};
filter_dims.insert(std::end(filter_dims),
std::begin(params.filter_spatial_lengths),
std::end(params.filter_spatial_lengths));
const std::vector<ck::index_t>& output_spatial_lengths = params.GetOutputSpatialLengths();
std::vector<std::size_t> output_dims{static_cast<std::size_t>(params.N),
static_cast<std::size_t>(params.K)};
output_dims.insert(std::end(output_dims),
std::begin(output_spatial_lengths),
std::end(output_spatial_lengths));
Tensor<InDataType> in_n_c_hi_wi_host_result(
ck::utils::conv::get_input_host_tensor_descriptor(input_dims, num_dim_spatial));
Tensor<InDataType> in_n_c_hi_wi_device_result(
ck::utils::conv::get_input_host_tensor_descriptor(input_dims, num_dim_spatial));
Tensor<WeiDataType> wei_k_c_y_x(
ck::utils::conv::get_filters_host_tensor_descriptor(filter_dims, num_dim_spatial));
Tensor<OutDataType> out_n_k_ho_wo(
ck::utils::conv::get_output_host_tensor_descriptor(output_dims, num_dim_spatial));
std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi_host_result.mDesc << std::endl;
std::cout << "wei_k_c_y_x: " << wei_k_c_y_x.mDesc << std::endl;
std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_3<OutDataType>{-0.2, 0.2});
wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.2, 0.2});
break;
default:
out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_1<OutDataType>{1});
wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{1});
}
DeviceMem in_device_buf(sizeof(InDataType) *
in_n_c_hi_wi_device_result.mDesc.GetElementSpace());
DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_k_c_y_x.mDesc.GetElementSpace());
DeviceMem out_device_buf(sizeof(OutDataType) * out_n_k_ho_wo.mDesc.GetElementSpace());
out_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
wei_device_buf.ToDevice(wei_k_c_y_x.mData.data());
// reset input to zero
in_device_buf.SetZero();
// do GEMM
auto conv = get_conv_instance(num_dim_spatial);
auto invoker = conv->MakeInvokerPointer();
auto argument =
conv->MakeArgumentPointer(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
params.N,
params.K,
params.C,
params.input_spatial_lengths,
params.filter_spatial_lengths,
output_spatial_lengths,
params.conv_filter_strides,
params.conv_filter_dilations,
params.input_left_pads,
params.input_right_pads,
InElementOp{},
WeiElementOp{},
OutElementOp{});
if(!conv->IsSupportedArgument(argument.get()))
{
throw std::runtime_error(
"wrong! device_conv with the specified compilation parameters does "
"not support this Conv problem");
}
float ave_time = invoker->Run(argument.get(), nrepeat);
std::size_t flop = ck::utils::conv::get_flops(
params.N, params.C, params.K, params.filter_spatial_lengths, output_spatial_lengths);
std::size_t num_btype = ck::utils::conv::get_btype<InDataType, WeiDataType, OutDataType>(
params.N,
params.C,
params.K,
params.input_spatial_lengths,
params.filter_spatial_lengths,
output_spatial_lengths);
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl;
if(do_verification)
{
auto verify_f = [&](const auto& ref_conv) {
auto ref_invoker = ref_conv.MakeInvoker();
auto ref_argument = ref_conv.MakeArgument(in_n_c_hi_wi_host_result,
wei_k_c_y_x,
out_n_k_ho_wo,
params.conv_filter_strides,
params.conv_filter_dilations,
params.input_left_pads,
params.input_right_pads,
InElementOp{},
WeiElementOp{},
OutElementOp{});
ref_invoker.Run(ref_argument);
in_device_buf.FromDevice(in_n_c_hi_wi_device_result.mData.data());
check_error(in_n_c_hi_wi_host_result, in_n_c_hi_wi_device_result);
};
switch(num_dim_spatial)
{
case 3: {
auto ref_conv = ReferenceConvBwdDataInstance<3>();
verify_f(ref_conv);
break;
}
case 2: {
auto ref_conv = ReferenceConvBwdDataInstance<2>();
verify_f(ref_conv);
break;
}
case 1: {
auto ref_conv = ReferenceConvBwdDataInstance<1>();
verify_f(ref_conv);
break;
}
default: {
throw std::runtime_error("Unsupported number of spatial dimensions provided!");
}
}
}
}
add_example_executable(example_batched_gemm_reduce_xdl_fp16 batched_gemm_reduce_xdl_fp16.cpp)
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include <half.hpp>
#include "config.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "device_tensor.hpp"
#include "device_batched_gemm_reduce_xdl_cshuffle.hpp"
#include "element_wise_operation.hpp"
#include "reference_batched_gemm.hpp"
#include "gemm_specialization.hpp"
#include "element_wise_reduce_operation.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using ADataType = F16;
using BDataType = F16;
using CDataType = F16;
using DDataType = F32;
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
using D0ReduceOp = ck::tensor_operation::element_wise::ReduceSum;
using D1ReduceOp = ck::tensor_operation::element_wise::ReduceSquareSum;
static constexpr auto GemmSpecialization =
ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off
using DeviceBatchedGemmReduceInstance = ck::tensor_operation::device::DeviceBatchedGemmReduce_Xdl_CShuffle
//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| D0| D1| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//######| | | | Type| Type| Type| DataType| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Reduce| Reduce| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, AElementOp, BElementOp, CElementOp, D0ReduceOp, D1ReduceOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
// clang-format on
using ReferenceBatchedGemmInstance = ck::tensor_operation::host::
ReferenceBatchedGemm<ADataType, BDataType, CDataType, AElementOp, BElementOp, CElementOp>;
int main(int argc, char* argv[])
{
bool do_verification = 1;
int init_method = 1;
int nrepeat = 5;
// GEMM shape
ck::index_t M = 3840;
ck::index_t N = 4096;
ck::index_t K = 4096;
ck::index_t StrideA = 4096;
ck::index_t StrideB = 4096;
ck::index_t StrideC = 4096;
ck::index_t BatchCount = 4;
if(argc == 1)
{
// do nothing
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]);
}
else if(argc == 11)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
StrideA = std::stoi(argv[7]);
StrideB = std::stoi(argv[8]);
StrideC = std::stoi(argv[9]);
BatchCount = std::stoi(argv[9]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: run kernel # of times (>1)\n");
printf("arg4 to 10: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC, BatchCount\n");
exit(0);
}
auto f_host_tensor_descriptor = [](std::size_t batch_count,
std::size_t row,
std::size_t col,
std::size_t stride,
auto layout) {
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
std::vector<std::size_t>({row * stride, stride, 1}));
}
else
{
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
std::vector<std::size_t>({col * stride, 1, stride}));
}
};
Tensor<ADataType> a_g_m_k(f_host_tensor_descriptor(BatchCount, M, K, StrideA, ALayout{}));
Tensor<BDataType> b_g_k_n(f_host_tensor_descriptor(BatchCount, K, N, StrideB, BLayout{}));
Tensor<CDataType> c_g_m_n_host_result(
f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{}));
Tensor<DDataType> d0_g_m_host_result(HostTensorDescriptor(std::vector<std::size_t>(
{static_cast<std::size_t>(BatchCount), static_cast<std::size_t>(M)})));
Tensor<DDataType> d1_g_m_host_result(HostTensorDescriptor(std::vector<std::size_t>(
{static_cast<std::size_t>(BatchCount), static_cast<std::size_t>(M)})));
Tensor<CDataType> c_g_m_n_device_result(
f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{}));
Tensor<DDataType> d0_g_m_device_result(HostTensorDescriptor(std::vector<std::size_t>(
{static_cast<std::size_t>(BatchCount), static_cast<std::size_t>(M)})));
Tensor<DDataType> d1_g_m_device_result(HostTensorDescriptor(std::vector<std::size_t>(
{static_cast<std::size_t>(BatchCount), static_cast<std::size_t>(M)})));
std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl;
std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl;
std::cout << "c_g_m_n: " << c_g_m_n_host_result.mDesc << std::endl;
std::cout << "d0_g_m: " << d0_g_m_host_result.mDesc << std::endl;
std::cout << "d1_g_m: " << d1_g_m_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_g_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
break;
default:
a_g_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_g_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break;
}
DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpace());
DeviceMem c_device_buf(sizeof(CDataType) * c_g_m_n_device_result.mDesc.GetElementSpace());
DeviceMem d0_device_buf(sizeof(DDataType) * d0_g_m_device_result.mDesc.GetElementSpace());
DeviceMem d1_device_buf(sizeof(DDataType) * d1_g_m_device_result.mDesc.GetElementSpace());
a_device_buf.ToDevice(a_g_m_k.mData.data());
b_device_buf.ToDevice(b_g_k_n.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
auto d0_reduce_op = D0ReduceOp{};
auto d1_reduce_op = D1ReduceOp{};
// do GEMM
auto batched_gemm = DeviceBatchedGemmReduceInstance{};
auto invoker = batched_gemm.MakeInvoker();
auto argument =
batched_gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
static_cast<DDataType*>(d0_device_buf.GetDeviceBuffer()),
static_cast<DDataType*>(d1_device_buf.GetDeviceBuffer()),
M,
N,
K,
StrideA,
StrideB,
StrideC,
a_element_op,
b_element_op,
c_element_op,
d0_reduce_op,
d1_reduce_op,
BatchCount);
if(!batched_gemm.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
// warm up
invoker.Run(argument);
// timing
float total_time = 0;
for(int i = 0; i < nrepeat; ++i)
{
// init DO, D1 to 0
d0_device_buf.SetZero();
d1_device_buf.SetZero();
KernelTimer timer;
timer.Start();
invoker.Run(argument);
timer.End();
total_time += timer.GetElapsedTime();
}
float ave_time = total_time / nrepeat;
std::size_t flop = std::size_t(2) * BatchCount * M * N * K;
std::size_t num_btype = sizeof(ADataType) * BatchCount * M * K +
sizeof(BDataType) * BatchCount * K * N +
sizeof(CDataType) * BatchCount * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< batched_gemm.GetTypeString() << std::endl;
if(do_verification)
{
c_device_buf.FromDevice(c_g_m_n_device_result.mData.data());
d0_device_buf.FromDevice(d0_g_m_device_result.mData.data());
d1_device_buf.FromDevice(d1_g_m_device_result.mData.data());
auto ref_batched_gemm = ReferenceBatchedGemmInstance{};
auto ref_invoker = ref_batched_gemm.MakeInvoker();
auto ref_argument = ref_batched_gemm.MakeArgument(
a_g_m_k, b_g_k_n, c_g_m_n_host_result, a_element_op, b_element_op, c_element_op);
ref_invoker.Run(ref_argument);
for(int batch = 0; batch < BatchCount; ++batch)
{
for(int m = 0; m < M; ++m)
{
float d0_acc = d0_reduce_op.GetReduceZeroValue();
float d1_acc = d1_reduce_op.GetReduceZeroValue();
for(int n = 0; n < N; ++n)
{
d0_reduce_op.Reduce(d0_acc, c_g_m_n_host_result(batch, m, n));
d1_reduce_op.Reduce(d1_acc, c_g_m_n_host_result(batch, m, n));
}
d0_g_m_host_result(batch, m) = d0_acc;
d1_g_m_host_result(batch, m) = d1_acc;
}
}
check_error(c_g_m_n_host_result, c_g_m_n_device_result);
check_error(d0_g_m_host_result, d0_g_m_device_result);
check_error(d1_g_m_host_result, d1_g_m_device_result);
}
return 0;
}
...@@ -13,6 +13,7 @@ include_directories(BEFORE ...@@ -13,6 +13,7 @@ include_directories(BEFORE
${PROJECT_SOURCE_DIR}/library/include/ck/library/host_tensor ${PROJECT_SOURCE_DIR}/library/include/ck/library/host_tensor
${PROJECT_SOURCE_DIR}/library/include/ck/library/reference_tensor_operation/cpu ${PROJECT_SOURCE_DIR}/library/include/ck/library/reference_tensor_operation/cpu
${PROJECT_SOURCE_DIR}/library/include/ck/library/reference_tensor_operation/gpu ${PROJECT_SOURCE_DIR}/library/include/ck/library/reference_tensor_operation/gpu
${PROJECT_SOURCE_DIR}/library/include/ck/library/utility
${PROJECT_SOURCE_DIR}/external/include/half ${PROJECT_SOURCE_DIR}/external/include/half
) )
...@@ -30,12 +31,15 @@ add_subdirectory(01_gemm) ...@@ -30,12 +31,15 @@ add_subdirectory(01_gemm)
add_subdirectory(02_gemm_alpha_beta) add_subdirectory(02_gemm_alpha_beta)
add_subdirectory(03_gemm_bias_relu) add_subdirectory(03_gemm_bias_relu)
add_subdirectory(04_gemm_bias_relu_add) add_subdirectory(04_gemm_bias_relu_add)
add_subdirectory(05_conv2d_fwd)
add_subdirectory(06_conv2d_fwd_bias_relu) add_subdirectory(06_conv2d_fwd_bias_relu)
add_subdirectory(07_conv2d_fwd_bias_relu_add) add_subdirectory(07_conv2d_fwd_bias_relu_add)
add_subdirectory(08_conv3d_fwd)
add_subdirectory(09_convnd_fwd) add_subdirectory(09_convnd_fwd)
add_subdirectory(10_conv2d_bwd_data) add_subdirectory(10_conv2d_bwd_data)
add_subdirectory(11_conv2d_bwd_wgt) add_subdirectory(11_conv2d_bwd_weight)
add_subdirectory(12_reduce) add_subdirectory(12_reduce)
add_subdirectory(13_pool2d_fwd) add_subdirectory(13_pool2d_fwd)
add_subdirectory(14_gemm_xdl_requant_relu_requant)
add_subdirectory(17_convnd_bwd_data_xdl)
add_subdirectory(15_grouped_gemm)
add_subdirectory(16_gemm_reduce)
add_subdirectory(18_batched_gemm_reduce)
...@@ -6,15 +6,9 @@ ...@@ -6,15 +6,9 @@
#include "hip/hip_fp16.h" #include "hip/hip_fp16.h"
#endif #endif
// "Constant" address space for kernel parameter // constant address space for kernel parameter
#define CONSTANT __attribute__((address_space(4))) // https://llvm.org/docs/AMDGPUUsage.html#address-spaces
#define CK_CONSTANT_ADDRESS_SPACE __attribute__((address_space(4)))
// GPU target
// should enable one and only one GPU target
#if !(defined(CK_AMD_GPU_GFX803) || defined(CK_AMD_GPU_GFX900) || defined(CK_AMD_GPU_GFX906) || \
defined(CK_AMD_GPU_GFX908) || defined(CK_AMD_GPU_GFX90A) || defined(CK_AMD_GPU_GFX1030))
#error Need to define (only) one GPU target
#endif
// launch bounds // launch bounds
#define CK_USE_LAUNCH_BOUNDS 1 #define CK_USE_LAUNCH_BOUNDS 1
...@@ -24,149 +18,134 @@ ...@@ -24,149 +18,134 @@
#define CK_MIN_BLOCK_PER_CU 2 #define CK_MIN_BLOCK_PER_CU 2
#endif #endif
// GPU-specific parameters // check GPU target
#if defined(CK_AMD_GPU_GFX803) || defined(CK_AMD_GPU_GFX900) || defined(CK_AMD_GPU_GFX906) || \ #ifdef __HIP_DEVICE_COMPILE__
defined(CK_AMD_GPU_GFX908) || defined(CK_AMD_GPU_GFX90A) #if !(defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \
// buffer resourse defined(__gfx90a__) || defined(__gfx1030__))
#error Not supported target
#endif
#endif
// buffer resourse, wave size
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_BUFFER_RESOURCE_3RD_DWORD -1
#define CK_GPU_WAVE_SIZE -1
#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) // for GPU code
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000 #define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000
// wave size
#define CK_GPU_WAVE_SIZE 64 #define CK_GPU_WAVE_SIZE 64
#elif defined(CK_AMD_GPU_GFX1030) #elif defined(__gfx1030__) // for GPU code
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000 #define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000
#define CK_GPU_WAVE_SIZE 32 #define CK_GPU_WAVE_SIZE 32
#endif #endif
// FMA instruction // FMA instruction
#if defined(CK_AMD_GPU_GFX803) || defined(CK_AMD_GPU_GFX900) #ifndef __HIP_DEVICE_COMPILE__ // for host code, define nothing
#elif defined(__gfx803__) || defined(__gfx900__) // for GPU code
#define CK_USE_AMD_V_MAC_F32 #define CK_USE_AMD_V_MAC_F32
#elif defined(CK_AMD_GPU_GFX906) || defined(CK_AMD_GPU_GFX908) || defined(CK_AMD_GPU_GFX90a) || \ #elif defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(CK_AMD_GPU_GFX1030) defined(__gfx1030__) // for GPU code
#define CK_USE_AMD_V_FMAC_F32 #define CK_USE_AMD_V_FMAC_F32
#define CK_USE_AMD_V_DOT2_F32_F16 #define CK_USE_AMD_V_DOT2_F32_F16
#define CK_USE_AMD_V_DOT4_I32_I8 #define CK_USE_AMD_V_DOT4_I32_I8
#endif #endif
// multi index // MFMA instruction
#define CK_USE_DYNAMICALLY_INDEXED_MULTI_INDEX 0 #ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_USE_AMD_MFMA
// AMD inline asm #elif defined(__gfx908__) || defined(__gfx90a__) // for GPU code
#ifndef CK_USE_AMD_INLINE_ASM #define CK_USE_AMD_MFMA
#define CK_USE_AMD_INLINE_ASM 1
#endif #endif
// AMD inner product (DLOP) #if defined(__gfx90a__)
#ifndef CK_USE_AMD_INNER_PRODUCT_INLINE_ASM #define CK_USE_AMD_MFMA_BF16_1K_OP
#define CK_USE_AMD_INNER_PRODUCT_INLINE_ASM 1
#endif #endif
// AMD buffer_load // buffer load
#ifndef CK_USE_AMD_BUFFER_LOAD
#define CK_USE_AMD_BUFFER_LOAD 1 #define CK_USE_AMD_BUFFER_LOAD 1
#endif
// AMD buffer_store // buffer store
#ifndef CK_USE_AMD_BUFFER_STORE
#define CK_USE_AMD_BUFFER_STORE 1 #define CK_USE_AMD_BUFFER_STORE 1
#endif
// AMD buffer_atomic_add // buffer atomic add: integer
#ifndef CK_USE_AMD_BUFFER_ATOMIC_ADD #define CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER 1
#define CK_USE_AMD_BUFFER_ATOMIC_ADD 1
#endif
// AMD XDLOPS // buffer atomic add: floating point
#ifndef CK_USE_AMD_XDLOPS #ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_USE_AMD_XDLOPS 0 #define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
#elif defined(__gfx908__) || defined(__gfx90a__) // for GPU code
#define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
#else // for GPU code
#define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 0
#endif #endif
// inline asm
#define CK_USE_AMD_INLINE_ASM 1
// inner product (DLOP)
#define CK_USE_AMD_INNER_PRODUCT_INLINE_ASM 1
// block synchronization only s_wait lgkmcnt(0), not vmcnt(0) // block synchronization only s_wait lgkmcnt(0), not vmcnt(0)
#ifndef CK_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM #define CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1
#define CK_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1
#endif
// experimental implementation for buffer load/store/atomic // experimental feature: multi index implemented as array
#ifndef CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK #define CK_EXPERIMENTAL_USE_DYNAMICALLY_INDEXED_MULTI_INDEX 0
#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 0
#endif
#ifndef CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK // experimental feature: static tensor descriptor
#define CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK 1 #define CK_EXPERIMENTAL_STATIC_TENSOR_DESCRIPTOR 0
#endif
#ifndef CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK // experimental feature: buffer load/store/atomic-add OOB trick
#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 0
#define CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK 1
#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK 1 #define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK 1
#endif
// experimental implementation for in-regsiter sub-dword transpose // experimental feature: in-regsiter sub-dword transpose
#ifndef CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE
#define CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE 1 #define CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE 1
#endif
#define CK_EXPERIMENTAL_STATIC_TENSOR_DESCRIPTOR 0
// merge transformation use magic number division // experimental feature: merge transformation use magic number division
#ifndef CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION
#define CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION 1 #define CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION 1
#endif
// use __builtin_memcpy instead of pointer cast to access a vector from pointer of scalar // experimental feature: use __builtin_memcpy instead of pointer cast to access a vector from
#ifndef CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS // pointer of scalar
#define CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS 0 #define CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS 0
#endif
// use __builtin_memcpy instead of union to do bit_cast // experimental feature: use __builtin_memcpy instead of union to do bit_cast
#ifndef CK_EXPERIMENTAL_USE_MEMCPY_FOR_BIT_CAST
#define CK_EXPERIMENTAL_USE_MEMCPY_FOR_BIT_CAST 1 #define CK_EXPERIMENTAL_USE_MEMCPY_FOR_BIT_CAST 1
#endif
// hack: have underlying assumption that need to be satsified, otherwise it's a bug // hack: have underlying assumption that need to be satsified, otherwise it's a bug
// hack for forcing register to keep idx_diff_low_const in SGPR. idx_diff_low_const must be // hack for forcing register to keep idx_diff_low_const in SGPR. idx_diff_low_const must be
// thread-invariant, otherwise it's a bug // thread-invariant, otherwise it's a bug
// TODO: separate index calculation into "compile-time", "global", "block", "wave", "thread" // TODO: separate index calculation into "compile-time", "global", "block", "wave", "thread"
#ifndef CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE
#define CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE 0 #define CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE 0
#endif
// workaround for compiler crash when compiling recursive lambda // workaround: compiler crash when compiling recursive lambda
#ifndef CK_WORKAROUND_SWDEV_275126
#define CK_WORKAROUND_SWDEV_275126 1 #define CK_WORKAROUND_SWDEV_275126 1
#endif
// workaround for compiler crash when using buffer load/store for i8 // workaround: compiler crash when using buffer load/store for i8
#ifndef CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
#define CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE 1 #define CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE 1
#endif
// workaround for compiler gnerating inefficient ds_write instructions // workaround: compiler gnerating inefficient ds_write instructions
#ifndef CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE
#define CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE 1 #define CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE 1
#endif
// workaround for register spill due to compiler issue, when casting type between fp32 and fp16
#ifndef CK_WORKAROUND_SWDEV_XXXXXX_THREAD_WISE_COPY_V1R4_TYPE_CONVERT_ISSUE
#define CK_WORKAROUND_SWDEV_XXXXXX_THREAD_WISE_COPY_V1R4_TYPE_CONVERT_ISSUE 1
#endif
#ifndef CK_WORKAROUND_SWDEV_XXXXXX_THREAD_WISE_COPY_V1R5_TYPE_CONVERT_ISSUE // workaround: verifaction failure, due to compiler regression, for conv bwd-data fp16 using some
#define CK_WORKAROUND_SWDEV_XXXXXX_THREAD_WISE_COPY_V1R5_TYPE_CONVERT_ISSUE 1
#endif
// workaround for verifaction failure, due to compiler regression, for conv bwd-data fp16 using some
// tuning parameter // tuning parameter
#ifndef CK_WORKAROUND_SWDEV_325164
#define CK_WORKAROUND_SWDEV_325164 1 #define CK_WORKAROUND_SWDEV_325164 1
#endif
// workaround for verification failure ConvNd forward
// https://github.com/ROCmSoftwarePlatform/composable_kernel/issues/135
#define CK_WORKAROUND_GITHUB_135 1
namespace ck { namespace ck {
enum InMemoryDataOperationEnum_t enum struct InMemoryDataOperationEnum
{ {
Set, Set,
AtomicAdd, AtomicAdd,
Add Add
}; };
enum ActivTypeEnum_t // TODO: no longer needed, remove this
enum struct ActivTypeEnum
{ {
None, None,
LeakyRelu, LeakyRelu,
......
...@@ -7,9 +7,9 @@ ...@@ -7,9 +7,9 @@
namespace ck { namespace ck {
// Number of GEMMs = YTilda * XTilda // Number of GEMMs = YTilde * XTilde
// GemmM = C // GemmM = C
// GemmN = N * HTildaSlice * WTildaSlice // GemmN = N * HTildeSlice * WTildeSlice
// GemmK = K * YDotSlice * XDotSlice // GemmK = K * YDotSlice * XDotSlice
template <typename... Wei, template <typename... Wei,
typename... In, typename... In,
...@@ -18,8 +18,8 @@ template <typename... Wei, ...@@ -18,8 +18,8 @@ template <typename... Wei,
typename ConvDilations, typename ConvDilations,
typename InLeftPads, typename InLeftPads,
typename InRightPads, typename InRightPads,
index_t IYTildaValue, index_t IYTildeValue,
index_t IXTildaValue, index_t IXTildeValue,
index_t GemmK1Value> index_t GemmK1Value>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto
transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk( transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
...@@ -30,8 +30,8 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk( ...@@ -30,8 +30,8 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
const ConvDilations& conv_dilations, const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads, const InLeftPads& in_left_pads,
const InRightPads& in_right_pads, const InRightPads& in_right_pads,
Number<IYTildaValue>, Number<IYTildeValue>,
Number<IXTildaValue>, Number<IXTildeValue>,
Number<GemmK1Value>) Number<GemmK1Value>)
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
...@@ -40,8 +40,8 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk( ...@@ -40,8 +40,8 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
constexpr auto GemmK1 = Number<GemmK1Value>{}; constexpr auto GemmK1 = Number<GemmK1Value>{};
constexpr auto IYTilda = Number<IYTildaValue>{}; constexpr auto IYTilde = Number<IYTildeValue>{};
constexpr auto IXTilda = Number<IXTildaValue>{}; constexpr auto IXTilde = Number<IXTildeValue>{};
const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0); const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0);
const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3); const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3);
...@@ -71,55 +71,55 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk( ...@@ -71,55 +71,55 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
const auto YTilda = ConvStrideH / GcdStrideDilationH; const auto YTilde = ConvStrideH / GcdStrideDilationH;
const auto XTilda = ConvStrideW / GcdStrideDilationW; const auto XTilde = ConvStrideW / GcdStrideDilationW;
const auto YDot = math::integer_divide_ceil(Y, YTilda); const auto YDot = math::integer_divide_ceil(Y, YTilde);
const auto XDot = math::integer_divide_ceil(X, XTilda); const auto XDot = math::integer_divide_ceil(X, XTilde);
const auto HTilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH); const auto HTilde = Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
const auto WTilda = Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW); const auto WTilde = Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
// only work on HTilda and WTilda that contribute to non-padding area of input tensor // only work on HTilde and WTilde that contribute to non-padding area of input tensor
const auto IHTildaSliceBegin = math::integer_divide_floor( const auto IHTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadH - ConvDilationH * (YTilda - I1)), ConvStrideH); math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH);
const auto IWTildaSliceBegin = math::integer_divide_floor( const auto IWTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadW - ConvDilationW * (XTilda - I1)), ConvStrideW); math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW);
const auto IHTildaSliceEnd = const auto IHTildeSliceEnd =
math::min(HTilda, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1); math::min(HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1);
const auto IWTildaSliceEnd = const auto IWTildeSliceEnd =
math::min(WTilda, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1); math::min(WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
const auto HTildaSlice = IHTildaSliceEnd - IHTildaSliceBegin; const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
const auto WTildaSlice = IWTildaSliceEnd - IWTildaSliceBegin; const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
// GemmK is different for each GEMM // GemmK is different for each GEMM
const auto YDotSlice = math::integer_divide_ceil(Y - IYTilda, YTilda); const auto YDotSlice = math::integer_divide_ceil(Y - IYTilde, YTilde);
const auto XDotSlice = math::integer_divide_ceil(X - IXTilda, XTilda); const auto XDotSlice = math::integer_divide_ceil(X - IXTilde, XTilde);
const auto K1 = GemmK1; const auto K1 = GemmK1;
const auto K0 = K / K1; const auto K0 = K / K1;
// weight tensor // weight tensor
const auto wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc = transform_tensor_descriptor( const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor(
wei_k_y_x_c_grid_desc, wei_k_y_x_c_grid_desc,
make_tuple(make_pass_through_transform(K), make_tuple(make_pass_through_transform(K),
make_embed_transform(make_tuple(YDot, YTilda), make_embed_transform(make_tuple(YDot, YTilde),
make_tuple(ConvStrideH / GcdStrideDilationH, I1)), make_tuple(ConvStrideH / GcdStrideDilationH, I1)),
make_embed_transform(make_tuple(XDot, XTilda), make_embed_transform(make_tuple(XDot, XTilde),
make_tuple(ConvStrideW / GcdStrideDilationW, I1)), make_tuple(ConvStrideW / GcdStrideDilationW, I1)),
make_pass_through_transform(C)), make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc = const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc =
transform_tensor_descriptor(wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc, transform_tensor_descriptor(wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(K0, K1)), make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
make_slice_transform(YDot, I0, YDotSlice), make_slice_transform(YDot, I0, YDotSlice),
make_slice_transform(XDot, I0, XDotSlice), make_slice_transform(XDot, I0, XDotSlice),
make_freeze_transform(IYTilda), make_freeze_transform(IYTilde),
make_freeze_transform(IXTilda), make_freeze_transform(IXTilde),
make_pass_through_transform(C)), make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
Sequence<1>{}, Sequence<1>{},
...@@ -163,25 +163,25 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk( ...@@ -163,25 +163,25 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto out_n_ydot_htilda_xdot_wtilda_k_grid_desc = transform_tensor_descriptor( const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor(
out_n_hop_wop_k_grid_desc, out_n_hop_wop_k_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(YDot, HTilda), make_embed_transform(make_tuple(YDot, HTilde),
make_tuple(-ConvDilationH / GcdStrideDilationH, I1)), make_tuple(-ConvDilationH / GcdStrideDilationH, I1)),
make_embed_transform(make_tuple(XDot, WTilda), make_embed_transform(make_tuple(XDot, WTilde),
make_tuple(-ConvDilationW / GcdStrideDilationW, I1)), make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
make_pass_through_transform(K)), make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc = const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc =
transform_tensor_descriptor( transform_tensor_descriptor(
out_n_ydot_htilda_xdot_wtilda_k_grid_desc, out_n_ydot_htilde_xdot_wtilde_k_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N),
make_slice_transform(YDot, I0, YDotSlice), make_slice_transform(YDot, I0, YDotSlice),
make_slice_transform(HTilda, IHTildaSliceBegin, HTildaSlice), make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
make_slice_transform(XDot, I0, XDotSlice), make_slice_transform(XDot, I0, XDotSlice),
make_slice_transform(WTilda, IWTildaSliceBegin, WTildaSlice), make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
make_unmerge_transform(make_tuple(K0, K1))), make_unmerge_transform(make_tuple(K0, K1))),
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
Sequence<1>{}, Sequence<1>{},
...@@ -198,17 +198,17 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk( ...@@ -198,17 +198,17 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
#if 1 #if 1
const auto out_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( const auto out_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc, out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc,
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)), make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)),
make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)), make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)),
make_pass_through_transform(K1)), make_pass_through_transform(K1)),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}), make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
#else #else
const auto out_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( const auto out_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc, out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc,
make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)), make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)),
make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)), make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)),
make_pass_through_transform(K1)), make_pass_through_transform(K1)),
make_tuple(Sequence<5, 1, 3>{}, Sequence<0, 2, 4>{}, Sequence<6>{}), make_tuple(Sequence<5, 1, 3>{}, Sequence<0, 2, 4>{}, Sequence<6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
...@@ -224,24 +224,24 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk( ...@@ -224,24 +224,24 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc = transform_tensor_descriptor( const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc, in_n_hip_wip_c_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(YTilda, HTilda), make_embed_transform(make_tuple(YTilde, HTilde),
make_tuple(ConvDilationH, ConvStrideH)), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(XTilda, WTilda), make_embed_transform(make_tuple(XTilde, WTilde),
make_tuple(ConvDilationW, ConvStrideW)), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)), make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_n_htildaslice_wtildaslice_c_grid_desc = transform_tensor_descriptor( const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor(
in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc, in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N),
make_freeze_transform(IYTilda), make_freeze_transform(IYTilde),
make_slice_transform(HTilda, IHTildaSliceBegin, HTildaSlice), make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
make_freeze_transform(IXTilda), make_freeze_transform(IXTilde),
make_slice_transform(WTilda, IWTildaSliceBegin, WTildaSlice), make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
make_pass_through_transform(C)), make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
Sequence<1>{}, Sequence<1>{},
...@@ -257,9 +257,9 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk( ...@@ -257,9 +257,9 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
Sequence<3>{})); Sequence<3>{}));
const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor( const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
in_n_htildaslice_wtildaslice_c_grid_desc, in_n_htildeslice_wtildeslice_c_grid_desc,
make_tuple(make_pass_through_transform(C), make_tuple(make_pass_through_transform(C),
make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice))), make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice))),
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}), make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
......
...@@ -10,8 +10,8 @@ namespace ck { ...@@ -10,8 +10,8 @@ namespace ck {
// A: out // A: out
// B: wei // B: wei
// C: in // C: in
// Number of GEMMs = YTilda * XTilda // Number of GEMMs = YTilde * XTilde
// GemmM = N * HTildaSlice * WTildaSlice // GemmM = N * HTildeSlice * WTildeSlice
// GemmN = C // GemmN = C
// GemmK = K * YDotSlice * XDotSlice // GemmK = K * YDotSlice * XDotSlice
template <typename... Wei, template <typename... Wei,
...@@ -21,8 +21,8 @@ template <typename... Wei, ...@@ -21,8 +21,8 @@ template <typename... Wei,
typename ConvDilations, typename ConvDilations,
typename InLeftPads, typename InLeftPads,
typename InRightPads, typename InRightPads,
typename IYTilda, typename IYTilde,
typename IXTilda, typename IXTilde,
index_t GemmK1Value> index_t GemmK1Value>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto
transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
...@@ -33,8 +33,8 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( ...@@ -33,8 +33,8 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
const ConvDilations& conv_dilations, const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads, const InLeftPads& in_left_pads,
const InRightPads& in_right_pads, const InRightPads& in_right_pads,
IYTilda i_ytilda, IYTilde i_ytilde,
IXTilda i_xtilda, IXTilde i_xtilde,
Number<GemmK1Value>) Number<GemmK1Value>)
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
...@@ -72,32 +72,32 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( ...@@ -72,32 +72,32 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
const auto YTilda = ConvStrideH / GcdStrideDilationH; const auto YTilde = ConvStrideH / GcdStrideDilationH;
const auto XTilda = ConvStrideW / GcdStrideDilationW; const auto XTilde = ConvStrideW / GcdStrideDilationW;
const auto YDot = math::integer_divide_ceil(Y, YTilda); const auto YDot = math::integer_divide_ceil(Y, YTilde);
const auto XDot = math::integer_divide_ceil(X, XTilda); const auto XDot = math::integer_divide_ceil(X, XTilde);
const auto HTilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH); const auto HTilde = Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
const auto WTilda = Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW); const auto WTilde = Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
// only work on HTilda and WTilda that contribute to non-padding area of input tensor // only work on HTilde and WTilde that contribute to non-padding area of input tensor
const auto IHTildaSliceBegin = math::integer_divide_floor( const auto IHTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadH - ConvDilationH * (YTilda - I1)), ConvStrideH); math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH);
const auto IWTildaSliceBegin = math::integer_divide_floor( const auto IWTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadW - ConvDilationW * (XTilda - I1)), ConvStrideW); math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW);
const auto IHTildaSliceEnd = const auto IHTildeSliceEnd =
math::min(HTilda, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1); math::min(HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1);
const auto IWTildaSliceEnd = const auto IWTildeSliceEnd =
math::min(WTilda, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1); math::min(WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
const auto HTildaSlice = IHTildaSliceEnd - IHTildaSliceBegin; const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
const auto WTildaSlice = IWTildaSliceEnd - IWTildaSliceBegin; const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
// GemmK is different for each GEMM // GemmK is different for each GEMM
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilda, YTilda); const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilda, XTilda); const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
const auto K1 = GemmK1; const auto K1 = GemmK1;
const auto K0 = K / K1; const auto K0 = K / K1;
...@@ -113,25 +113,25 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( ...@@ -113,25 +113,25 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto out_n_ydot_htilda_xdot_wtilda_k_grid_desc = transform_tensor_descriptor( const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor(
out_n_hop_wop_k_grid_desc, out_n_hop_wop_k_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(YDot, HTilda), make_embed_transform(make_tuple(YDot, HTilde),
make_tuple(-ConvDilationH / GcdStrideDilationH, I1)), make_tuple(-ConvDilationH / GcdStrideDilationH, I1)),
make_embed_transform(make_tuple(XDot, WTilda), make_embed_transform(make_tuple(XDot, WTilde),
make_tuple(-ConvDilationW / GcdStrideDilationW, I1)), make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
make_pass_through_transform(K)), make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc = const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc =
transform_tensor_descriptor( transform_tensor_descriptor(
out_n_ydot_htilda_xdot_wtilda_k_grid_desc, out_n_ydot_htilde_xdot_wtilde_k_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N),
make_slice_transform(YDot, I0, YDotSlice), make_slice_transform(YDot, I0, YDotSlice),
make_slice_transform(HTilda, IHTildaSliceBegin, HTildaSlice), make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
make_slice_transform(XDot, I0, XDotSlice), make_slice_transform(XDot, I0, XDotSlice),
make_slice_transform(WTilda, IWTildaSliceBegin, WTildaSlice), make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
make_unmerge_transform(make_tuple(K0, K1))), make_unmerge_transform(make_tuple(K0, K1))),
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
Sequence<1>{}, Sequence<1>{},
...@@ -148,41 +148,41 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( ...@@ -148,41 +148,41 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
#if 1 #if 1
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc, out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc,
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)), make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)),
make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)), make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)),
make_pass_through_transform(K1)), make_pass_through_transform(K1)),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}), make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
#else #else
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc, out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc,
make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)), make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)),
make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)), make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)),
make_pass_through_transform(K1)), make_pass_through_transform(K1)),
make_tuple(Sequence<5, 1, 3>{}, Sequence<0, 2, 4>{}, Sequence<6>{}), make_tuple(Sequence<5, 1, 3>{}, Sequence<0, 2, 4>{}, Sequence<6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
#endif #endif
// B: weight tensor // B: weight tensor
const auto wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc = transform_tensor_descriptor( const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor(
wei_k_y_x_c_grid_desc, wei_k_y_x_c_grid_desc,
make_tuple(make_pass_through_transform(K), make_tuple(make_pass_through_transform(K),
make_embed_transform(make_tuple(YDot, YTilda), make_embed_transform(make_tuple(YDot, YTilde),
make_tuple(ConvStrideH / GcdStrideDilationH, I1)), make_tuple(ConvStrideH / GcdStrideDilationH, I1)),
make_embed_transform(make_tuple(XDot, XTilda), make_embed_transform(make_tuple(XDot, XTilde),
make_tuple(ConvStrideW / GcdStrideDilationW, I1)), make_tuple(ConvStrideW / GcdStrideDilationW, I1)),
make_pass_through_transform(C)), make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc = const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc =
transform_tensor_descriptor(wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc, transform_tensor_descriptor(wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(K0, K1)), make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
make_slice_transform(YDot, I0, YDotSlice), make_slice_transform(YDot, I0, YDotSlice),
make_slice_transform(XDot, I0, XDotSlice), make_slice_transform(XDot, I0, XDotSlice),
make_freeze_transform(i_ytilda), make_freeze_transform(i_ytilde),
make_freeze_transform(i_xtilda), make_freeze_transform(i_xtilde),
make_pass_through_transform(C)), make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
Sequence<1>{}, Sequence<1>{},
...@@ -225,24 +225,24 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( ...@@ -225,24 +225,24 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc = transform_tensor_descriptor( const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc, in_n_hip_wip_c_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(YTilda, HTilda), make_embed_transform(make_tuple(YTilde, HTilde),
make_tuple(ConvDilationH, ConvStrideH)), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(XTilda, WTilda), make_embed_transform(make_tuple(XTilde, WTilde),
make_tuple(ConvDilationW, ConvStrideW)), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)), make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_n_htildaslice_wtildaslice_c_grid_desc = transform_tensor_descriptor( const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor(
in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc, in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N),
make_freeze_transform(i_ytilda), make_freeze_transform(i_ytilde),
make_slice_transform(HTilda, IHTildaSliceBegin, HTildaSlice), make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
make_freeze_transform(i_xtilda), make_freeze_transform(i_xtilde),
make_slice_transform(WTilda, IWTildaSliceBegin, WTildaSlice), make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
make_pass_through_transform(C)), make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
Sequence<1>{}, Sequence<1>{},
...@@ -258,8 +258,8 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( ...@@ -258,8 +258,8 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
Sequence<3>{})); Sequence<3>{}));
const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor( const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
in_n_htildaslice_wtildaslice_c_grid_desc, in_n_htildeslice_wtildeslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)), make_tuple(make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)),
make_pass_through_transform(C)), make_pass_through_transform(C)),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
namespace ck { namespace ck {
// StaticTensor for Scalar // StaticTensor for Scalar
template <AddressSpaceEnum_t AddressSpace, template <AddressSpaceEnum AddressSpace,
typename T, typename T,
typename TensorDesc, typename TensorDesc,
bool InvalidElementUseNumericalZeroValue, bool InvalidElementUseNumericalZeroValue,
...@@ -80,7 +80,7 @@ struct StaticTensor ...@@ -80,7 +80,7 @@ struct StaticTensor
}; };
// StaticTensor for vector // StaticTensor for vector
template <AddressSpaceEnum_t AddressSpace, template <AddressSpaceEnum AddressSpace,
typename S, typename S,
index_t ScalarPerVector, index_t ScalarPerVector,
typename TensorDesc, typename TensorDesc,
...@@ -245,7 +245,7 @@ struct StaticTensorTupleOfVectorBuffer ...@@ -245,7 +245,7 @@ struct StaticTensorTupleOfVectorBuffer
S ignored_element_scalar_; S ignored_element_scalar_;
}; };
template <AddressSpaceEnum_t AddressSpace, template <AddressSpaceEnum AddressSpace,
typename T, typename T,
typename TensorDesc, typename TensorDesc,
typename enable_if<TensorDesc::IsKnownAtCompileTime(), bool>::type = false> typename enable_if<TensorDesc::IsKnownAtCompileTime(), bool>::type = false>
...@@ -255,7 +255,7 @@ __host__ __device__ constexpr auto make_static_tensor(TensorDesc) ...@@ -255,7 +255,7 @@ __host__ __device__ constexpr auto make_static_tensor(TensorDesc)
} }
template < template <
AddressSpaceEnum_t AddressSpace, AddressSpaceEnum AddressSpace,
typename T, typename T,
typename TensorDesc, typename TensorDesc,
typename X, typename X,
......
...@@ -207,9 +207,9 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2 ...@@ -207,9 +207,9 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2
CM0M1N0N1ThreadDesc{}.GetLength(I2) == N0, CM0M1N0N1ThreadDesc{}.GetLength(I2) == N0,
"wrong"); "wrong");
auto a_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatA>( auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatA>(
a_k_m0_m1_thread_desc_.GetElementSpaceSize()); a_k_m0_m1_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatB>( auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatB>(
b_k_n0_n1_thread_desc_.GetElementSpaceSize()); b_k_n0_n1_thread_desc_.GetElementSpaceSize());
constexpr auto threadwise_gemm = constexpr auto threadwise_gemm =
......
...@@ -220,9 +220,9 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B ...@@ -220,9 +220,9 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B
CThreadDesc_BM0_BM11_BN0_BN11{}.GetLength(I2) == BN0, CThreadDesc_BM0_BM11_BN0_BN11{}.GetLength(I2) == BN0,
"wrong"); "wrong");
auto a_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatA>( auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatA>(
a_thread_desc_bk0_bm0_bm1_bk1_.GetElementSpaceSize()); a_thread_desc_bk0_bm0_bm1_bk1_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatB>( auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatB>(
b_thread_desc_bk0_bn0_bn1_bk1_.GetElementSpaceSize()); b_thread_desc_bk0_bn0_bn1_bk1_.GetElementSpaceSize());
constexpr auto threadwise_contraction = constexpr auto threadwise_contraction =
......
...@@ -119,7 +119,7 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3 ...@@ -119,7 +119,7 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
constexpr auto a_block_mtx = ABlockDesc_E1_K1_E2{}; constexpr auto a_block_mtx = ABlockDesc_E1_K1_E2{};
// thread A buffer for GEMM // thread A buffer for GEMM
StaticBuffer<AddressSpaceEnum_t::Vgpr, FloatA, a_thread_mtx_.GetElementSpaceSize(), true> StaticBuffer<AddressSpaceEnum::Vgpr, FloatA, a_thread_mtx_.GetElementSpaceSize(), true>
a_thread_buf; a_thread_buf;
constexpr auto threadwise_gemm = ThreadwiseGemmDlops_km_kn_mn_v3<FloatA, constexpr auto threadwise_gemm = ThreadwiseGemmDlops_km_kn_mn_v3<FloatA,
......
...@@ -42,7 +42,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -42,7 +42,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
StaticBufferTupleOfVector<AddressSpaceEnum_t::Vgpr, StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
FloatAcc, FloatAcc,
MRepeat * NRepeat, MRepeat * NRepeat,
xdlops_gemm.GetRegSizePerXdlops(), xdlops_gemm.GetRegSizePerXdlops(),
...@@ -250,9 +250,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -250,9 +250,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
const BBlockBuffer& b_block_buf, const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const CThreadBuffer& c_thread_buf) const
{ {
auto a_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAB>( auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
a_thread_desc_.GetElementSpaceSize()); a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAB>( auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
b_thread_desc_.GetElementSpaceSize()); b_thread_desc_.GetElementSpaceSize());
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
......
...@@ -16,7 +16,7 @@ namespace ck { ...@@ -16,7 +16,7 @@ namespace ck {
template <index_t BlockSize, template <index_t BlockSize,
typename SrcElementwiseOperation, typename SrcElementwiseOperation,
typename DstElementwiseOperation, typename DstElementwiseOperation,
InMemoryDataOperationEnum_t DstInMemOp, InMemoryDataOperationEnum DstInMemOp,
typename BlockSliceLengths, typename BlockSliceLengths,
typename ThreadClusterLengths, typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder, typename ThreadClusterArrangeOrder,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment