Commit 3aa0214f authored by Bartlomiej Kocot's avatar Bartlomiej Kocot
Browse files

Merge branch 'develop' of github.com:ROCmSoftwarePlatform/composable_kernel...

Merge branch 'develop' of github.com:ROCmSoftwarePlatform/composable_kernel into barkocot/conv-elementwise
parents ae950178 1cc36ba5
......@@ -14,7 +14,7 @@ None
### Additions
- Added an image to a column kernel (#867)
- Added a column to an image kernel (#930)
- Support for 3D grouped convolution forward on RDNA 3 GPUs (#935)
- Support for 3D grouped convolution on RDNA 3 GPUs (#935, #950, #985)
- Grouped convolution support for small K and C (#822 #879 #897)
- Support for NHWGC (2D and 3D) grouped convolution backward weight (#769 #804)
- Support for bf16/f32/f16 and NHWGC (2D and 3d) grouped convolution backward data (#757 #799)
......
......@@ -49,6 +49,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
vim \
nano \
zlib1g-dev \
zip \
openssh-server \
clang-format-12 \
kmod && \
......
......@@ -526,6 +526,26 @@ def Build_CK(Map conf=[:]){
stash "ckprofiler_0.2.0_amd64.deb"
}
}
if (params.hipTensor_test && navi_node == 0 ){
//build and test hipTensor
sh """#!/bin/bash
rm -rf "${params.hipTensor_branch}".zip
rm -rf hipTensor-"${params.hipTensor_branch}"
wget https://github.com/ROCmSoftwarePlatform/hipTensor/archive/refs/heads/"${params.hipTensor_branch}".zip
unzip -o "${params.hipTensor_branch}".zip
"""
dir("hipTensor-${params.hipTensor_branch}"){
sh """#!/bin/bash
mkdir -p build
ls -ltr
CC=hipcc CXX=hipcc cmake -Bbuild . -D CMAKE_PREFIX_PATH="/opt/rocm;${env.WORKSPACE}/install"
cmake --build build -- -j
"""
}
dir("hipTensor-${params.hipTensor_branch}/build"){
sh 'ctest'
}
}
}
}
}
......@@ -654,6 +674,15 @@ pipeline {
name: "DL_KERNELS",
defaultValue: false,
description: "Select whether to build DL kernels (default: OFF)")
booleanParam(
name: "hipTensor_test",
defaultValue: true,
description: "Use the CK build to verify hipTensor build and tests (default: ON)")
string(
name: 'hipTensor_branch',
defaultValue: 'mainline',
description: 'Specify which branch of hipTensor to use (default: mainline)')
}
environment{
dbuser = "${dbuser}"
......
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
list(APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942)
list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
if(gpu IN_LIST gpu_list_xdl AND target EQUAL 0)
add_custom_target(example_grouped_conv_bwd_weight)
add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16 grouped_conv_bwd_weight_xdl_fp16.cpp)
if(result EQUAL 0)
......@@ -18,6 +19,14 @@ foreach(gpu IN LISTS GPU_TARGETS)
endif()
endif()
set(target 1)
endif()
if(gpu IN_LIST gpu_list_wmma AND target EQUAL 0)
add_custom_target(example_grouped_conv_bwd_weight)
add_example_executable(example_grouped_conv_bwd_weight_wmma_fp16 grouped_conv_bwd_weight_wmma_fp16.cpp)
if(result EQUAL 0)
add_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_wmma_fp16)
endif()
set(target 1)
endif()
endforeach()
......
......@@ -46,25 +46,21 @@ struct CommonLayoutSetting
using OutputLayout = OutputLay;
};
template <ck::index_t NDimSpatial>
struct CommonLayoutSettingSelector;
namespace ctl = ck::tensor_layout::convolution;
template <>
struct CommonLayoutSettingSelector<1> final : CommonLayoutSetting<ctl::GNWC, ctl::GKXC, ctl::GNWK>
{
};
template <>
struct CommonLayoutSettingSelector<2> final
: CommonLayoutSetting<ctl::GNHWC, ctl::GKYXC, ctl::GNHWK>
{
};
template <>
struct CommonLayoutSettingSelector<3> final
: CommonLayoutSetting<ctl::GNDHWC, ctl::GKZYXC, ctl::GNDHWK>
template <ck::index_t NDimSpatial>
struct CommonLayoutSettingSelector
: CommonLayoutSetting<ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::GNWC,
ck::tensor_layout::convolution::GNHWC,
ck::tensor_layout::convolution::GNDHWC>>,
ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::GKXC,
ck::tensor_layout::convolution::GKYXC,
ck::tensor_layout::convolution::GKZYXC>>,
ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::GNWK,
ck::tensor_layout::convolution::GNHWK,
ck::tensor_layout::convolution::GNDHWK>>>
{
};
......@@ -84,10 +80,10 @@ struct ExecutionConfig final
bool time_kernel = false;
};
#define DefaultConvParam \
ck::utils::conv::ConvParam \
{ \
2, 4, 1, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, { 1, 1 } \
#define DefaultConvParam \
ck::utils::conv::ConvParam \
{ \
3, 4, 1, 128, 256, {3, 3, 3}, {14, 14, 14}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, { 1, 1, 1 } \
}
inline void print_help_msg()
......
......@@ -76,4 +76,23 @@ using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWe
#include "run_grouped_conv_bwd_weight_example.inc"
int main(int argc, char* argv[]) { return !run_grouped_conv_bwd_weight_example(argc, argv); }
int main(int argc, char* argv[])
{
ExecutionConfig config;
ck::utils::conv::ConvParam conv_param = DefaultConvParam;
if(!parse_cmd_args(argc, argv, config, conv_param))
{
return 1;
}
switch(conv_param.num_dim_spatial_)
{
case 1: return !run_grouped_conv_bwd_weight<1>(config, conv_param);
case 2: return !run_grouped_conv_bwd_weight<2>(config, conv_param);
case 3: return !run_grouped_conv_bwd_weight<3>(config, conv_param);
default: break;
}
return 1;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp"
using InDataType = F16;
using WeiDataType = F16;
using OutDataType = F16;
using AccDataType = F32;
using InElementOp = PassThrough;
using WeiElementOp = PassThrough;
using OutElementOp = PassThrough;
template <ck::index_t NDimSpatial>
using DeviceConvBwdWeightInstance =
ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Wmma_CShuffle<
NDimSpatial,
ck::tensor_layout::convolution::GNDHWC,
ck::tensor_layout::convolution::GKZYXC,
ck::tensor_layout::convolution::GNDHWK,
InDataType, // InDataType
WeiDataType, // WeiDataType
OutDataType, // OutDataType
AccDataType, // AccDataType
InElementOp, // InElementwiseOperation
WeiElementOp, // WeiElementwiseOperation
OutElementOp, // OutElementwiseOperation
ConvBwdWeightDefault, // ConvolutionBackwardWeightSpecialization
256, // BlockSize
128, // MPerBlock
128, // NPerBlock
4, // K0PerBlock
8, // K1
16, // MPerWMMA
16, // NPerWMMA
4, // MRepeat
2, // NRepeat
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<0, 2, 1>, // ABlockTransferThreadClusterArrangeOrder
S<0, 2, 1>, // ABlockTransferSrcAccessOrder
1, // ABlockTransferSrcVectorDim
1, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
true, // ABlockLdsExtraM
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<0, 2, 1>, // BBlockTransferThreadClusterArrangeOrder
S<0, 2, 1>, // BBlockTransferSrcAccessOrder
1, // BBlockTransferSrcVectorDim
1, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
true, // BBlockLdsExtraN
4,
2,
S<1, 32, 1, 8>,
1>;
template <ck::index_t NDimSpatial>
using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight<NDimSpatial,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp>;
#include "run_grouped_conv_bwd_weight_example.inc"
int main(int argc, char* argv[])
{
ExecutionConfig config;
ck::utils::conv::ConvParam conv_param = DefaultConvParam;
if(!parse_cmd_args(argc, argv, config, conv_param))
{
return 1;
}
switch(conv_param.num_dim_spatial_)
{
case 3: return !run_grouped_conv_bwd_weight<3>(config, conv_param);
default: break;
}
return 1;
}
......@@ -78,4 +78,23 @@ using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWe
#include "run_grouped_conv_bwd_weight_example.inc"
int main(int argc, char* argv[]) { return !run_grouped_conv_bwd_weight_example(argc, argv); }
int main(int argc, char* argv[])
{
ExecutionConfig config;
ck::utils::conv::ConvParam conv_param = DefaultConvParam;
if(!parse_cmd_args(argc, argv, config, conv_param))
{
return 1;
}
switch(conv_param.num_dim_spatial_)
{
case 1: return !run_grouped_conv_bwd_weight<1>(config, conv_param);
case 2: return !run_grouped_conv_bwd_weight<2>(config, conv_param);
case 3: return !run_grouped_conv_bwd_weight<3>(config, conv_param);
default: break;
}
return 1;
}
......@@ -77,4 +77,23 @@ using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWe
#include "run_grouped_conv_bwd_weight_example.inc"
int main(int argc, char* argv[]) { return !run_grouped_conv_bwd_weight_example(argc, argv); }
int main(int argc, char* argv[])
{
ExecutionConfig config;
ck::utils::conv::ConvParam conv_param = DefaultConvParam;
if(!parse_cmd_args(argc, argv, config, conv_param))
{
return 1;
}
switch(conv_param.num_dim_spatial_)
{
case 1: return !run_grouped_conv_bwd_weight<1>(config, conv_param);
case 2: return !run_grouped_conv_bwd_weight<2>(config, conv_param);
case 3: return !run_grouped_conv_bwd_weight<3>(config, conv_param);
default: break;
}
return 1;
}
......@@ -83,4 +83,23 @@ using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWe
#include "run_grouped_conv_bwd_weight_example.inc"
int main(int argc, char* argv[]) { return !run_grouped_conv_bwd_weight_example(argc, argv); }
int main(int argc, char* argv[])
{
ExecutionConfig config;
ck::utils::conv::ConvParam conv_param = DefaultConvParam;
if(!parse_cmd_args(argc, argv, config, conv_param))
{
return 1;
}
switch(conv_param.num_dim_spatial_)
{
case 1: return !run_grouped_conv_bwd_weight<1>(config, conv_param);
case 2: return !run_grouped_conv_bwd_weight<2>(config, conv_param);
case 3: return !run_grouped_conv_bwd_weight<3>(config, conv_param);
default: break;
}
return 1;
}
......@@ -5,7 +5,7 @@ template <ck::index_t NDimSpatial>
bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
const ck::utils::conv::ConvParam& conv_param)
{
// Dl op doesn't support split_k > 1
// Dl and WMMA ops don't support split_k > 1
constexpr ck::index_t split_k = 1;
const auto in_g_n_c_wis_desc =
......@@ -143,23 +143,3 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
return true;
}
bool run_grouped_conv_bwd_weight_example(int argc, char* argv[])
{
ExecutionConfig config;
ck::utils::conv::ConvParam conv_param = DefaultConvParam;
if(!parse_cmd_args(argc, argv, config, conv_param))
{
return false;
}
switch(conv_param.num_dim_spatial_)
{
case 1: return run_grouped_conv_bwd_weight<1>(config, conv_param);
case 2: return run_grouped_conv_bwd_weight<2>(config, conv_param);
case 3: return run_grouped_conv_bwd_weight<3>(config, conv_param);
}
return false;
}
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list2 AND target EQUAL 0)
add_example_executable(example_gemm_multi_ABD_xdl_fp16 gemm_multi_ABD_xdl_fp16.cpp)
set(target 1)
endif()
endforeach()
endif()
list(APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list2 AND target EQUAL 0)
add_example_executable(example_gemm_multi_ABD_xdl_fp16 gemm_multi_ABD_xdl_fp16.cpp)
set(target 1)
endif()
endforeach()
list(APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list2 AND target EQUAL 0)
add_example_executable(example_contraction_multi_ABD_xdl_fp16 contraction_multi_ABD_xdl_fp16.cpp)
set(target 1)
endif()
endforeach()
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/numeric.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using A0DataType = F16;
using A1DataType = F32;
using BDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F32;
using DDataType = F16;
using EDataType = F16;
static constexpr ck::index_t NumDimM = 2;
static constexpr ck::index_t NumDimN = 2;
static constexpr ck::index_t NumDimK = 2;
struct AlphaBetaAdd
{
AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta){};
template <typename E, typename C, typename D>
__host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const;
template <>
__host__ __device__ constexpr void operator()<ck::half_t, float, ck::half_t>(
ck::half_t& e, const float& c, const ck::half_t& d) const
{
e = ck::type_convert<ck::half_t>(alpha_ * c + beta_ * ck::type_convert<float>(d));
};
float alpha_;
float beta_;
};
struct Multiply
{
__host__ __device__ constexpr void
operator()(ck::half_t& a, const ck::half_t& a0, const float& a1) const
{
a = ck::type_convert<ck::half_t>(ck::type_convert<float>(a0) * a1);
}
};
using AElementOp = Multiply;
using BElementOp = PassThrough;
using CDEElementOp = AlphaBetaAdd;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
using DeviceOpInstance = ck::tensor_operation::device::DeviceContractionMultipleABD_Xdl_CShuffle<
NumDimM,
NumDimN,
NumDimK,
ck::Tuple<A0DataType, A1DataType>,
ck::Tuple<BDataType>,
AccDataType,
CShuffleDataType,
ck::Tuple<DDataType>,
EDataType,
AElementOp,
BElementOp,
CDEElementOp,
GemmSpec,
1,
256,
256,
128,
32,
8,
8,
32,
32,
4,
2,
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
1,
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
1,
1,
1,
S<1, 32, 1, 8>,
8>;
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
float alpha = 1.0f;
float beta = 1.0f;
// A0[M0, M1, K0, K1]
std::vector<ck::index_t> a0_ms_ks_lengths{30, 128, 32, 64};
std::vector<ck::index_t> a0_ms_ks_strides{128 * 32 * 64, 32 * 64, 64, 1};
// A1[M1, K1] -> A1[M0, M1, K0, K1]
std::vector<ck::index_t> a1_ms_ks_lengths{30, 128, 32, 64};
std::vector<ck::index_t> a1_ms_ks_strides{0, 64, 0, 1};
// B[N0, N1, K0, K1]
std::vector<ck::index_t> b_ns_ks_lengths{32, 64, 32, 64};
std::vector<ck::index_t> b_ns_ks_strides{64 * 32 * 64, 32 * 64, 64, 1};
// D[M0, M1, N0, N1]
std::vector<ck::index_t> d_ms_ns_lengths{30, 128, 32, 64};
std::vector<ck::index_t> d_ms_ns_strides{128 * 32 * 64, 32 * 64, 64, 1};
// E[M0, M1, N0, N1]
std::vector<ck::index_t> e_ms_ns_lengths{30, 128, 32, 64};
std::vector<ck::index_t> e_ms_ns_strides{128 * 32 * 64, 32 * 64, 64, 1};
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
exit(0);
}
Tensor<A0DataType> a0_ms_ks(a0_ms_ks_lengths, a0_ms_ks_strides);
Tensor<A1DataType> a1_ms_ks(a1_ms_ks_lengths, a1_ms_ks_strides);
Tensor<BDataType> b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides);
Tensor<EDataType> d_ms_ns(d_ms_ns_lengths, d_ms_ns_strides);
Tensor<EDataType> e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides);
Tensor<EDataType> e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides);
std::cout << "a0_ms_ks: " << a0_ms_ks.mDesc << std::endl;
std::cout << "a1_ms_ks: " << a1_ms_ks.mDesc << std::endl;
std::cout << "b_ns_ks: " << b_ns_ks.mDesc << std::endl;
std::cout << "d_ms_ns: " << d_ms_ns.mDesc << std::endl;
std::cout << "e_ms_ns: " << e_ms_ns_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a0_ms_ks.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-5, 5});
a1_ms_ks.GenerateTensorValue(GeneratorTensor_2<A1DataType>{-5, 5});
b_ns_ks.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
d_ms_ns.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
break;
default:
a0_ms_ks.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
a1_ms_ks.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0.0, 1.0});
b_ns_ks.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
d_ms_ns.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break;
}
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_ms_ks.mDesc.GetElementSpaceSize());
DeviceMem a1_device_buf(sizeof(A1DataType) * a1_ms_ks.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_ns_ks.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf(sizeof(DDataType) * d_ms_ns.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_ms_ns_device_result.mDesc.GetElementSpaceSize());
a0_device_buf.ToDevice(a0_ms_ks.mData.data());
a1_device_buf.ToDevice(a1_ms_ks.mData.data());
b_device_buf.ToDevice(b_ns_ks.mData.data());
d_device_buf.ToDevice(d_ms_ns.mData.data());
// set zero
e_device_buf.SetZero();
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{alpha, beta};
// do GEMM
auto device_op = DeviceOpInstance{};
auto invoker = device_op.MakeInvoker();
auto argument = device_op.MakeArgument(
std::array<const void*, 2>{a0_device_buf.GetDeviceBuffer(),
a1_device_buf.GetDeviceBuffer()},
std::array<const void*, 1>{b_device_buf.GetDeviceBuffer()},
std::array<const void*, 1>{d_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(),
std::array<std::vector<ck::index_t>, 2>{a0_ms_ks_lengths, a1_ms_ks_lengths},
std::array<std::vector<ck::index_t>, 2>{a0_ms_ks_strides, a1_ms_ks_strides},
std::array<std::vector<ck::index_t>, 1>{b_ns_ks_lengths},
std::array<std::vector<ck::index_t>, 1>{b_ns_ks_strides},
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths},
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides},
e_ms_ns_lengths,
e_ms_ns_strides,
a_element_op,
b_element_op,
cde_element_op);
if(!device_op.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_contraction with the specified compilation parameters does "
"not support this problem");
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
if(time_kernel)
{
ck::index_t M =
ck::accumulate_n<ck::index_t>(e_ms_ns_lengths.begin(), NumDimM, 1, std::multiplies<>{});
ck::index_t N = ck::accumulate_n<ck::index_t>(
e_ms_ns_lengths.begin() + NumDimM, NumDimN, 1, std::multiplies<>{});
ck::index_t K = ck::accumulate_n<ck::index_t>(
a0_ms_ks_lengths.begin() + NumDimM, NumDimK, 1, std::multiplies<>{});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(A0DataType) * M * K + sizeof(BDataType) * K * N + +sizeof(EDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s" << std::endl;
}
if(do_verification)
{
Tensor<CShuffleDataType> c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides);
Tensor<A0DataType> a_ms_ks(a0_ms_ks_lengths, a0_ms_ks_strides);
for(size_t m0 = 0; m0 < a_ms_ks.mDesc.GetLengths()[0]; ++m0)
{
for(size_t m1 = 0; m1 < a_ms_ks.mDesc.GetLengths()[1]; ++m1)
{
for(size_t k0 = 0; k0 < a_ms_ks.mDesc.GetLengths()[2]; ++k0)
{
for(size_t k1 = 0; k1 < a_ms_ks.mDesc.GetLengths()[3]; ++k1)
{
a_element_op(a_ms_ks(m0, m1, k0, k1),
a0_ms_ks(m0, m1, k0, k1),
a1_ms_ks(m0, m1, k0, k1));
}
}
}
}
using ReferenceOpInstance =
ck::tensor_operation::host::ReferenceContraction_M2_N2_K2<NumDimM,
NumDimN,
NumDimK,
A0DataType,
BDataType,
CShuffleDataType,
AccDataType,
PassThrough,
BElementOp>;
auto ref_op = ReferenceOpInstance{};
auto ref_invoker = ref_op.MakeInvoker();
Tensor<float> empty_tensor(std::vector<ck::index_t>{}, std::vector<ck::index_t>{});
auto ref_argument =
ref_op.MakeArgument(a_ms_ks, b_ns_ks, c_ms_ns_host_result, PassThrough{}, b_element_op);
ref_invoker.Run(ref_argument);
for(size_t m0 = 0; m0 < e_ms_ns_host_result.mDesc.GetLengths()[0]; ++m0)
{
for(size_t m1 = 0; m1 < e_ms_ns_host_result.mDesc.GetLengths()[1]; ++m1)
{
for(size_t n0 = 0; n0 < e_ms_ns_host_result.mDesc.GetLengths()[2]; ++n0)
{
for(size_t n1 = 0; n1 < e_ms_ns_host_result.mDesc.GetLengths()[3]; ++n1)
{
cde_element_op(e_ms_ns_host_result(m0, m1, n0, n1),
c_ms_ns_host_result(m0, m1, n0, n1),
d_ms_ns(m0, m1, n0, n1));
}
}
}
}
e_device_buf.FromDevice(e_ms_ns_device_result.mData.data());
return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1;
}
return 0;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
// GEMM:
// input : A0[M0, M1, ... K0, K1, ...], ...
// input : B0[N0, N1, ... K0, K1, ...], ...
// input : D0[M0, M1, ... N0, N1, ...], D1[M0, M1, ... N0, N1, ...], ...
// output : E[M0, M1, ... N0, N1, ...]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// D0, D1, ... and E have the same layout
template <index_t NumDimM,
index_t NumDimN,
index_t NumDimK,
typename AsDataType,
typename BsDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation>
struct DeviceContractionMultipleABD : public BaseOperator
{
static constexpr index_t NumATensor = AsDataType::Size();
static constexpr index_t NumBTensor = BsDataType::Size();
static constexpr index_t NumDTensor = DsDataType::Size();
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(std::array<const void*, NumATensor> p_as,
std::array<const void*, NumBTensor> p_bs,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
const std::array<std::vector<index_t>, NumATensor>& a_ms_ks_lengths,
const std::array<std::vector<index_t>, NumATensor>& a_ms_ks_strides,
const std::array<std::vector<index_t>, NumBTensor>& b_ns_ks_lengths,
const std::array<std::vector<index_t>, NumBTensor>& b_ns_ks_strides,
const std::array<std::vector<index_t>, NumDTensor>& d_ms_ns_lengths,
const std::array<std::vector<index_t>, NumDTensor>& d_ms_ns_strides,
const std::vector<index_t>& e_ms_ns_length,
const std::vector<index_t>& e_ms_ns_stride,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_contraction_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
template <typename GridwiseGemm,
typename AsPointer,
typename BsPointer,
typename DsPointer,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
typename AsGridDesc_AK0_M_AK1,
typename BsGridDesc_BK0_N_BK1,
typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename Block2ETileMap,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_contraction_multiple_abd_xdl_cshuffle(
AsPointer p_as_grid,
BsPointer p_bs_grid,
DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op,
const AsGridDesc_AK0_M_AK1 as_grid_desc_ak0_m_ak1,
const BsGridDesc_BK0_N_BK1 bs_grid_desc_bk0_n_bk1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2ETileMap block_2_etile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_as_grid,
p_bs_grid,
p_ds_grid,
p_e_grid,
p_shared,
a_element_op,
b_element_op,
cde_element_op,
as_grid_desc_ak0_m_ak1,
bs_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_etile_map);
#else
ignore = p_as_grid;
ignore = p_bs_grid;
ignore = p_ds_grid;
ignore = p_e_grid;
ignore = a_element_op;
ignore = b_element_op;
ignore = cde_element_op;
ignore = as_grid_desc_ak0_m_ak1;
ignore = bs_grid_desc_bk0_n_bk1;
ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = e_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = block_2_etile_map;
#endif
}
} // namespace ck
namespace ck {
namespace tensor_operation {
namespace device {
// GEMM:
// input : A[M, K]
// input : B[N, K]
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// D0, D1, ... and E have the same layout
template <index_t NumDimM,
index_t NumDimN,
index_t NumDimK,
typename AsDataType,
typename BsDataType,
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t AK1,
index_t BK1,
index_t MPerXDL,
index_t NPerXDL,
index_t MXdlPerWave,
index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
index_t BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler(),
PipelineVersion PipelineVer = PipelineVersion::v1>
struct DeviceContractionMultipleABD_Xdl_CShuffle
: public DeviceContractionMultipleABD<NumDimM,
NumDimN,
NumDimK,
AsDataType,
BsDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
{
using DeviceOp = DeviceContractionMultipleABD_Xdl_CShuffle;
static constexpr index_t NumATensor = AsDataType::Size();
static constexpr index_t NumBTensor = BsDataType::Size();
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
using ComputeDataType = EDataType;
// GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleABD_xdl_cshuffle<
AsDataType,
BsDataType,
ComputeDataType,
AccDataType,
CShuffleDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
InMemoryDataOperationEnum::Set,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched,
PipelineVer>;
static constexpr auto matrix_padder =
ck::tensor_operation::device::MatrixPadder<GemmSpec, index_t, index_t, index_t>{
MPerBlock, NPerBlock, KPerBlock};
static auto MakeAGridDescriptor_M_K(const std::vector<index_t>& a_ms_ks_lengths_,
const std::vector<index_t>& a_ms_ks_strides_)
{
assert(a_ms_ks_lengths_.size() == NumDimM + NumDimK &&
a_ms_ks_strides_.size() == NumDimM + NumDimK);
const auto to_tuple = [&](auto& vec, auto num) {
return generate_tuple([&](auto i) { return vec[i]; }, num);
};
const auto a_ms_ks_lengths = to_tuple(a_ms_ks_lengths_, Number<NumDimM + NumDimK>{});
const auto a_ms_ks_strides = to_tuple(a_ms_ks_strides_, Number<NumDimM + NumDimK>{});
// dimension Ids for M0, M1, ...
constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{};
// dimension Ids for K0, K1, ...
constexpr auto kDimIds =
typename arithmetic_sequence_gen<NumDimM, NumDimM + NumDimK, 1>::type{};
// lengths for M0, M1, ...
const auto mLengths = get_container_subset(a_ms_ks_lengths, mDimIds);
// lengths for K0, K1, ...
const auto kLengths = get_container_subset(a_ms_ks_lengths, kDimIds);
// naive tensor A[M0, M1, M2, ..., K0, K1, K2...]
const auto a_grid_desc_ms_ks =
make_naive_tensor_descriptor(a_ms_ks_lengths, a_ms_ks_strides);
// transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...]
const auto a_grid_desc_mraw_kraw = transform_tensor_descriptor(
a_grid_desc_ms_ks,
make_tuple(make_merge_transform(mLengths), make_merge_transform(kLengths)),
make_tuple(mDimIds, kDimIds),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
}
__host__ __device__ static auto
MakeAsGridDescriptor_M_K(const std::array<std::vector<index_t>, NumATensor>& as_ms_ks_lengths,
const std::array<std::vector<index_t>, NumATensor>& as_ms_ks_strides)
{
return generate_tuple(
[&](auto i) {
return MakeAGridDescriptor_M_K(as_ms_ks_lengths[i], as_ms_ks_strides[i]);
},
Number<NumATensor>{});
}
// Assume: B[N0, N1, N2, ..., K0, K1, K2, ...]
static auto MakeBGridDescriptor_N_K(const std::vector<index_t>& b_ns_ks_lengths_,
const std::vector<index_t>& b_ns_ks_strides_)
{
assert(b_ns_ks_lengths_.size() == NumDimN + NumDimK &&
b_ns_ks_strides_.size() == NumDimN + NumDimK);
const auto to_tuple = [&](auto& vec, auto num) {
return generate_tuple([&](auto i) { return vec[i]; }, num);
};
const auto b_ns_ks_lengths = to_tuple(b_ns_ks_lengths_, Number<NumDimN + NumDimK>{});
const auto b_ns_ks_strides = to_tuple(b_ns_ks_strides_, Number<NumDimN + NumDimK>{});
// dimension Ids for N0, N1, ...
constexpr auto nDimIds = typename arithmetic_sequence_gen<0, NumDimN, 1>::type{};
// dimension Ids for K0, K1, ...
constexpr auto kDimIds =
typename arithmetic_sequence_gen<NumDimN, NumDimN + NumDimK, 1>::type{};
// lengths for K0, K1, ...
const auto kLengths = get_container_subset(b_ns_ks_lengths, kDimIds);
// lengths for N0, N1, ...
const auto nLengths = get_container_subset(b_ns_ks_lengths, nDimIds);
// naive tensor B[N0, N1, N2, ..., K0, K1, K2, ...]
const auto b_grid_desc_ns_ks =
make_naive_tensor_descriptor(b_ns_ks_lengths, b_ns_ks_strides);
// transformed tensor B[NRaw = N0 * N1 * N2 * ..., KRaw = K0 * K1 * K2 * ...]
const auto b_grid_desc_nraw_kraw = transform_tensor_descriptor(
b_grid_desc_ns_ks,
make_tuple(make_merge_transform(nLengths), make_merge_transform(kLengths)),
make_tuple(nDimIds, kDimIds),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
}
__host__ __device__ static auto
MakeBsGridDescriptor_N_K(const std::array<std::vector<index_t>, NumBTensor>& bs_ns_ks_lengths,
const std::array<std::vector<index_t>, NumBTensor>& bs_ns_ks_strides)
{
return generate_tuple(
[&](auto i) {
return MakeBGridDescriptor_N_K(bs_ns_ks_lengths[i], bs_ns_ks_strides[i]);
},
Number<NumBTensor>{});
}
// assume E[M0, M1, M2, ..., N0, N1, N2...]
static auto MakeEGridDescriptor_M_N(const std::vector<index_t>& e_ms_ns_lengths_,
const std::vector<index_t>& e_ms_ns_strides_)
{
assert(e_ms_ns_lengths_.size() == NumDimM + NumDimN &&
e_ms_ns_strides_.size() == NumDimM + NumDimN);
const auto to_tuple = [&](auto& vec, auto num) {
return generate_tuple([&](auto i) { return vec[i]; }, num);
};
const auto e_ms_ns_lengths = to_tuple(e_ms_ns_lengths_, Number<NumDimM + NumDimN>{});
const auto e_ms_ns_strides = to_tuple(e_ms_ns_strides_, Number<NumDimM + NumDimN>{});
// dimension Ids for M0, M1, ...
constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{};
// dimension Ids for N0, N1, ...
constexpr auto nDimIds =
typename arithmetic_sequence_gen<NumDimM, NumDimM + NumDimN, 1>::type{};
// lengths for M0, M1, ...
const auto mLengths = get_container_subset(e_ms_ns_lengths, mDimIds);
// lengths for K0, K1, ...
const auto nLengths = get_container_subset(e_ms_ns_lengths, nDimIds);
// naive tensor E[M0, M1, M2, ..., N0, N1, N2...]
const auto e_grid_desc_ms_ns =
make_naive_tensor_descriptor(e_ms_ns_lengths, e_ms_ns_strides);
// transformed tensor E[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * N2 * ...]
const auto e_grid_desc_mraw_nraw = transform_tensor_descriptor(
e_grid_desc_ms_ns,
make_tuple(make_merge_transform(mLengths), make_merge_transform(nLengths)),
make_tuple(mDimIds, nDimIds),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
}
static auto
MakeDsGridDescriptor_M_N(const std::array<std::vector<index_t>, NumDTensor>& ds_ms_ns_lengths,
const std::array<std::vector<index_t>, NumDTensor>& ds_ms_ns_strides)
{
return generate_tuple(
[&](auto i) {
return MakeEGridDescriptor_M_N(ds_ms_ns_lengths[i], ds_ms_ns_strides[i]);
},
Number<NumDTensor>{});
}
// desc for problem definition
using AsGridDesc_M_K = remove_cvref_t<decltype(MakeAsGridDescriptor_M_K({}, {}))>;
using BsGridDesc_N_K = remove_cvref_t<decltype(MakeBsGridDescriptor_N_K({}, {}))>;
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}))>;
using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N({}, {}))>;
// desc for blockwise copy
using AsGridDesc_AK0_M_AK1 =
remove_cvref_t<decltype(GridwiseGemm::MakeAsGridDescriptor_AK0_M_AK1(AsGridDesc_M_K{}))>;
using BsGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(GridwiseGemm::MakeBsGridDescriptor_BK0_N_BK1(BsGridDesc_N_K{}))>;
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
DsGridDesc_M_N{}))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
EGridDesc_M_N{}))>;
// block-to-e-tile map
using Block2ETileMap =
remove_cvref_t<decltype(GridwiseGemm::MakeBlock2ETileMap(EGridDesc_M_N{}))>;
// Argument
struct Argument : public BaseArgument
{
Argument(std::array<const void*, NumATensor> p_as_grid,
std::array<const void*, NumBTensor> p_bs_grid,
std::array<const void*, NumDTensor> p_ds_grid,
void* p_e_grid,
const std::array<std::vector<index_t>, NumATensor>& a_ms_ks_lengths,
const std::array<std::vector<index_t>, NumATensor>& a_ms_ks_strides,
const std::array<std::vector<index_t>, NumBTensor>& b_ns_ks_lengths,
const std::array<std::vector<index_t>, NumBTensor>& b_ns_ks_strides,
const std::array<std::vector<index_t>, NumDTensor>& d_ms_ns_lengths,
const std::array<std::vector<index_t>, NumDTensor>& d_ms_ns_strides,
const std::vector<index_t>& e_ms_ns_length,
const std::vector<index_t>& e_ms_ns_stride,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
: p_as_grid_{},
p_bs_grid_{},
p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e_grid)},
as_grid_desc_m_k_{},
bs_grid_desc_n_k_{},
ds_grid_desc_m_n_{},
e_grid_desc_m_n_{MakeEGridDescriptor_M_N(e_ms_ns_length, e_ms_ns_stride)},
as_grid_desc_ak0_m_ak1_{},
bs_grid_desc_bk0_n_bk1_{},
ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
e_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_etile_map_{GridwiseGemm::MakeBlock2ETileMap(e_grid_desc_m_n_)},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op}
{
// populate pointer, desc for As
static_for<0, NumATensor, 1>{}([&](auto i) {
// using ALayout = remove_cvref_t<tuple_element_t<i.value, AsLayout>>;
using ADataType = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
// A pointer
p_as_grid_(i) = static_cast<const ADataType*>(p_as_grid[i]);
// A desc
as_grid_desc_m_k_(i) =
MakeAGridDescriptor_M_K(a_ms_ks_lengths[i], a_ms_ks_strides[i]);
});
// populate pointer, desc for Bs
static_for<0, NumBTensor, 1>{}([&](auto i) {
// using BLayout = remove_cvref_t<tuple_element_t<i.value, BsLayout>>;
using BDataType = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
// B pointer
p_bs_grid_(i) = static_cast<const BDataType*>(p_bs_grid[i]);
// B desc
bs_grid_desc_n_k_(i) =
MakeBGridDescriptor_N_K(b_ns_ks_lengths[i], b_ns_ks_strides[i]);
});
// populate pointer, desc for Ds
static_for<0, NumDTensor, 1>{}([&](auto i) {
// using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
// D pointer
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]);
// D desc
ds_grid_desc_m_n_(i) =
MakeEGridDescriptor_M_N(d_ms_ns_lengths[i], d_ms_ns_strides[i]);
});
// populate desc for Ds/E
if(GridwiseGemm::CheckValidity(as_grid_desc_m_k_,
bs_grid_desc_n_k_,
ds_grid_desc_m_n_,
e_grid_desc_m_n_,
block_2_etile_map_))
{
as_grid_desc_ak0_m_ak1_ =
GridwiseGemm::MakeAsGridDescriptor_AK0_M_AK1(as_grid_desc_m_k_);
bs_grid_desc_bk0_n_bk1_ =
GridwiseGemm::MakeBsGridDescriptor_BK0_N_BK1(bs_grid_desc_n_k_);
ds_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n_);
e_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n_);
}
// for sanity check of vector memory access
for(index_t i = 0; i < NumATensor; ++i)
{
a_mz_stride_[i] = a_ms_ks_strides[i][NumDimM - 1];
a_kz_stride_[i] = a_ms_ks_strides[i][NumDimM + NumDimK - 1];
}
for(index_t i = 0; i < NumBTensor; ++i)
{
b_nz_stride_[i] = b_ns_ks_strides[i][NumDimN - 1];
b_kz_stride_[i] = b_ns_ks_strides[i][NumDimN + NumDimK - 1];
}
for(index_t i = 0; i < NumDTensor; ++i)
{
ds_nz_stride_[i] = d_ms_ns_strides[i][NumDimM + NumDimN - 1];
}
e_nz_stride_ = e_ms_ns_stride[NumDimM + NumDimN - 1];
}
// pointers
typename GridwiseGemm::AsGridPointer p_as_grid_;
typename GridwiseGemm::BsGridPointer p_bs_grid_;
typename GridwiseGemm::DsGridPointer p_ds_grid_;
EDataType* p_e_grid_;
// tensor descriptors for problem definiton
AsGridDesc_M_K as_grid_desc_m_k_;
BsGridDesc_N_K bs_grid_desc_n_k_;
DsGridDesc_M_N ds_grid_desc_m_n_;
EGridDesc_M_N e_grid_desc_m_n_;
// tensor descriptors for block/thread-wise copy
AsGridDesc_AK0_M_AK1 as_grid_desc_ak0_m_ak1_;
BsGridDesc_BK0_N_BK1 bs_grid_desc_bk0_n_bk1_;
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock_;
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_;
// block-to-e-tile map
Block2ETileMap block_2_etile_map_;
// element-wise op
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
// Strides for the last M/N/K dimensions of A/B/Ds/E
// for sanity check of vector load/store
std::array<index_t, NumATensor> a_mz_stride_;
std::array<index_t, NumATensor> a_kz_stride_;
std::array<index_t, NumBTensor> b_nz_stride_;
std::array<index_t, NumBTensor> b_kz_stride_;
std::array<index_t, NumDTensor> ds_nz_stride_;
index_t e_nz_stride_;
};
// Invoker
struct Invoker : public BaseInvoker
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(!GridwiseGemm::CheckValidity(arg.as_grid_desc_m_k_,
arg.bs_grid_desc_n_k_,
arg.ds_grid_desc_m_n_,
arg.e_grid_desc_m_n_,
arg.block_2_etile_map_))
{
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
}
const index_t grid_size =
arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_);
auto launch_kernel = [&](auto has_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value;
const auto kernel = kernel_contraction_multiple_abd_xdl_cshuffle<
GridwiseGemm,
typename GridwiseGemm::AsGridPointer,
typename GridwiseGemm::BsGridPointer,
typename GridwiseGemm::DsGridPointer,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
DeviceOp::AsGridDesc_AK0_M_AK1,
DeviceOp::BsGridDesc_BK0_N_BK1,
DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
DeviceOp::Block2ETileMap,
has_main_loop>;
return launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_as_grid_,
arg.p_bs_grid_,
arg.p_ds_grid_,
arg.p_e_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
arg.as_grid_desc_ak0_m_ak1_,
arg.bs_grid_desc_bk0_n_bk1_,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_etile_map_);
};
const auto K = arg.as_grid_desc_m_k_[I0].GetLength(I1);
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
return launch_kernel(integral_constant<bool, true>{});
}
else
{
return launch_kernel(integral_constant<bool, false>{});
}
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static bool IsSupportedArgument(const Argument& arg)
{
if(!ck::is_xdl_supported())
{
return false;
}
// check vector load/store
{
bool all_valid = true;
static_for<0, NumATensor, 1>{}([&](auto i) {
// vector memory access of A: could be on M or AK1 dimension
if constexpr(ABlockTransferSrcVectorDim == 1)
{
if(!(arg.a_mz_stride_[i] == 1 && arg.as_grid_desc_ak0_m_ak1_[i].GetLength(I1) %
ABlockTransferSrcScalarPerVector ==
0))
{
all_valid = false;
}
}
else
{
if(!(arg.a_kz_stride_[i] == 1 && arg.as_grid_desc_ak0_m_ak1_[i].GetLength(I2) %
ABlockTransferSrcScalarPerVector ==
0))
{
all_valid = false;
}
}
});
// vector memory access of B: could be on N or BK1 dimension
static_for<0, NumBTensor, 1>{}([&](auto i) {
if constexpr(BBlockTransferSrcVectorDim == 1)
{
if(!(arg.b_nz_stride_[i] == 1 && arg.bs_grid_desc_bk0_n_bk1_[i].GetLength(I1) %
BBlockTransferSrcScalarPerVector ==
0))
{
all_valid = false;
}
}
else
{
if(!(arg.b_kz_stride_[i] == 1 && arg.bs_grid_desc_bk0_n_bk1_[i].GetLength(I2) %
BBlockTransferSrcScalarPerVector ==
0))
{
all_valid = false;
}
}
});
// check vector load of Ds
static_for<0, NumDTensor, 1>{}([&](auto i) {
if(!(arg.ds_nz_stride_[i] == 1 &&
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_[i].GetLength(I3) %
CDEBlockTransferScalarPerVector_NPerBlock ==
0))
{
all_valid = false;
}
});
// vector memory access of E: always on NPerBlock dimension
if(!(arg.e_nz_stride_ == 1 &&
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_.GetLength(I3) %
CDEBlockTransferScalarPerVector_NPerBlock ==
0))
{
all_valid = false;
}
if(!all_valid)
{
return false;
}
}
return GridwiseGemm::CheckValidity(arg.as_grid_desc_m_k_,
arg.bs_grid_desc_n_k_,
arg.ds_grid_desc_m_n_,
arg.e_grid_desc_m_n_,
arg.block_2_etile_map_);
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(std::array<const void*, NumATensor> p_as,
std::array<const void*, NumBTensor> p_bs,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
const std::array<std::vector<index_t>, NumATensor>& a_ms_ks_lengths,
const std::array<std::vector<index_t>, NumATensor>& a_ms_ks_strides,
const std::array<std::vector<index_t>, NumBTensor>& b_ns_ks_lengths,
const std::array<std::vector<index_t>, NumBTensor>& b_ns_ks_strides,
const std::array<std::vector<index_t>, NumDTensor>& d_ms_ns_lengths,
const std::array<std::vector<index_t>, NumDTensor>& d_ms_ns_strides,
const std::vector<index_t>& e_ms_ns_length,
const std::vector<index_t>& e_ms_ns_stride,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
{
return Argument{p_as,
p_bs,
p_ds,
p_e,
a_ms_ks_lengths,
a_ms_ks_strides,
b_ns_ks_lengths,
b_ns_ks_strides,
d_ms_ns_lengths,
d_ms_ns_strides,
e_ms_ns_length,
e_ms_ns_stride,
a_element_op,
b_element_op,
cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument>
MakeArgumentPointer(std::array<const void*, NumATensor> p_as,
std::array<const void*, NumBTensor> p_bs,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
const std::array<std::vector<index_t>, NumATensor>& as_ms_ks_lengths,
const std::array<std::vector<index_t>, NumATensor>& as_ms_ks_strides,
const std::array<std::vector<index_t>, NumBTensor>& bs_ns_ks_lengths,
const std::array<std::vector<index_t>, NumBTensor>& bs_ns_ks_strides,
const std::array<std::vector<index_t>, NumDTensor>& ds_ms_ns_lengths,
const std::array<std::vector<index_t>, NumDTensor>& ds_ms_ns_strides,
const std::vector<index_t>& e_ms_ns_length,
const std::vector<index_t>& e_ms_ns_stride,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) override
{
return std::make_unique<Argument>(p_as,
p_bs,
p_ds,
p_e,
as_ms_ks_lengths,
as_ms_ks_strides,
bs_ns_ks_lengths,
bs_ns_ks_strides,
ds_ms_ns_lengths,
ds_ms_ns_strides,
e_ms_ns_length,
e_ms_ns_stride,
a_element_op,
b_element_op,
cde_element_op);
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
std::map<LoopScheduler, std::string> LoopSchedToString{
{LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}};
std::map<PipelineVersion, std::string> PipelineVersionToString{{PipelineVersion::v1, "v1"},
{PipelineVersion::v2, "v2"}};
// clang-format off
str << "DeviceContractionMultipleABD_Xdl_CShuffle"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< KPerBlock << ", "
<< AK1 << ", "
<< BK1 << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle << ", "
<< getGemmSpecializationString(GemmSpec)
<< ">"
<< " LoopScheduler: "
<< LoopSchedToString[LoopSched] << ", "
<< "PipelineVersion: "
<< PipelineVersionToString[PipelineVer];
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -127,7 +127,50 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
PipelineVer,
ComputeType>;
using Argument = typename GridwiseGemm::Argument;
struct Argument : public GridwiseGemm::Argument
{
Argument(const ADataType* p_a_grid_,
const BDataType* p_b_grid_,
CDataType* p_c_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
index_t StrideC_,
index_t MPadded_,
index_t NPadded_,
index_t KPadded_,
index_t K0_,
index_t k_batch_,
AElementwiseOperation a_element_op_,
BElementwiseOperation b_element_op_,
CElementwiseOperation c_element_op_)
: GridwiseGemm::Argument(p_a_grid_,
p_b_grid_,
p_c_grid_,
M_,
N_,
K_,
StrideA_,
StrideB_,
StrideC_,
MPadded_,
NPadded_,
KPadded_,
K0_,
k_batch_),
a_element_op(a_element_op_),
b_element_op(b_element_op_),
c_element_op(c_element_op_)
{
}
AElementwiseOperation a_element_op;
BElementwiseOperation b_element_op;
CElementwiseOperation c_element_op;
};
using DefaultBlock2CTileMap = typename GridwiseGemm::DefaultBlock2CTileMap;
// Invoker
......@@ -168,8 +211,17 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
karg.M * karg.N * sizeof(CDataType),
stream_config.stream_id_));
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg, b2c_map);
ave_time =
launch_and_time_kernel(stream_config,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
static_cast<typename GridwiseGemm::Argument>(karg),
b2c_map,
karg.a_element_op,
karg.b_element_op,
karg.c_element_op);
};
if(has_main_k0_block_loop)
......@@ -180,7 +232,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
kernel_gemm_xdlops_v2r4r2_simplified<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
DefaultBlock2CTileMap>;
DefaultBlock2CTileMap,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>;
Run(kernel);
}
......@@ -190,7 +245,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
kernel_gemm_xdlops_v2r4r2_simplified<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
DefaultBlock2CTileMap>;
DefaultBlock2CTileMap,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>;
Run(kernel);
}
......@@ -203,7 +261,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
kernel_gemm_xdlops_v2r4r2_simplified<GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
DefaultBlock2CTileMap>;
DefaultBlock2CTileMap,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>;
Run(kernel);
}
......@@ -213,7 +274,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
kernel_gemm_xdlops_v2r4r2_simplified<GridwiseGemm,
false,
InMemoryDataOperationEnum::AtomicAdd,
DefaultBlock2CTileMap>;
DefaultBlock2CTileMap,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>;
Run(kernel);
}
......@@ -261,12 +325,12 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
index_t StrideA,
index_t StrideB,
index_t StrideC,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
index_t KBatch)
{
return Argument{p_a,
return Argument(p_a,
p_b,
p_c,
M,
......@@ -279,7 +343,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
GridwiseGemm::CalculateNPadded(N),
GridwiseGemm::CalculateKPadded(K, KBatch),
GridwiseGemm::CalculateK0(K, KBatch),
KBatch};
KBatch,
a_element_op,
b_element_op,
c_element_op);
}
static auto MakeInvoker() { return Invoker{}; }
......@@ -294,9 +361,9 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
index_t StrideA,
index_t StrideB,
index_t StrideC,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
ck::index_t KBatch = 1) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
......@@ -312,7 +379,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
GridwiseGemm::CalculateNPadded(N),
GridwiseGemm::CalculateKPadded(K, KBatch),
GridwiseGemm::CalculateK0(K, KBatch),
KBatch);
KBatch,
a_element_op,
b_element_op,
c_element_op);
}
// polymorphic
......
......@@ -565,7 +565,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
auto launch_kernel = [&](auto has_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value;
const auto kernel = kernel_grouped_conv_fwd_multiple_d_wmma_cshuffle<
const auto kernel = kernel_grouped_conv_multiple_d_wmma_cshuffle<
GridwiseGemm,
ADataType,
BDataType,
......
......@@ -12,6 +12,7 @@
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
......@@ -22,32 +23,6 @@ namespace ck {
namespace tensor_operation {
namespace device {
namespace {
struct ComputePtrOffsetOfStridedBatch
{
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideA_);
}
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideB_);
}
__host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideC_);
}
index_t BatchStrideA_;
index_t BatchStrideB_;
index_t BatchStrideC_;
};
} // namespace
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
......@@ -952,7 +927,7 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
Block2CTileMap block_2_ctile_map_;
// for computing batch offset
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_;
ComputePtrOffsetOfStridedBatch<I0> compute_ptr_offset_of_batch_;
// element-wise op
OutElementwiseOperation a_element_op_;
......@@ -1024,7 +999,7 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
remove_reference_t<DeviceOp::BGridDesc_B_K0_N0_N1_K1>,
remove_reference_t<DeviceOp::CGridDesc_M0_M10_M11_N0_N10_N11>,
remove_reference_t<DeviceOp::Block2CTileMap>,
ComputePtrOffsetOfStridedBatch,
ComputePtrOffsetOfStridedBatch<I0>,
has_main_loop,
has_double_loop>;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <numeric>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <index_t NDimSpatial,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename InDataType,
typename WeiDataType,
typename OutDataType,
typename AccDataType,
typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation,
ConvolutionBackwardWeightSpecialization ConvBackwardWeightSpecialization,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t K0PerBlock,
index_t K1,
index_t MPerWMMA,
index_t NPerWMMA,
index_t MRepeat,
index_t NRepeat,
typename ABlockTransferThreadClusterLengths_K0_M_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_K1,
bool ABlockLdsAddExtraM,
typename BBlockTransferThreadClusterLengths_K0_N_K1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_K1,
bool BBlockLdsAddExtraN,
index_t CShuffleMRepeatPerShuffle,
index_t CShuffleNRepeatPerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
index_t NumGemmKPrefetchStage = 1,
LoopScheduler LoopSched = make_default_loop_scheduler(),
ck::PipelineVersion PipelineVer = ck::PipelineVersion::v1,
typename ck::enable_if<NDimSpatial == 3, bool>::type = false>
struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
: public DeviceGroupedConvBwdWeight<NDimSpatial,
InLayout,
WeiLayout,
OutLayout,
InDataType,
WeiDataType,
OutDataType,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
{
using DeviceOp = DeviceGroupedConvBwdWeight_Wmma_CShuffle;
using ADataType = OutDataType;
using BDataType = InDataType;
using CDataType = WeiDataType;
using AElementwiseOperation = OutElementwiseOperation;
using BElementwiseOperation = InElementwiseOperation;
using CElementwiseOperation = WeiElementwiseOperation;
// TODO make A/B datatype different
using ABDataType = InDataType;
// 3d
static constexpr bool is_NDHWGK_GKZYXC_NDHWGC =
is_same_v<InLayout, tensor_layout::convolution::NDHWGC> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKZYXC> &&
is_same_v<OutLayout, tensor_layout::convolution::NDHWGK>;
static constexpr bool is_GNDHWK_GKZYXC_GNDHWC =
is_same_v<InLayout, tensor_layout::convolution::GNDHWC> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKZYXC> &&
is_same_v<OutLayout, tensor_layout::convolution::GNDHWK>;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto GemmK1Number = Number<K1>{};
static constexpr index_t KPerBlock = K0PerBlock * GemmK1Number;
template <index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
constexpr static auto
make_out_grid_desc(const index_t N,
const index_t Do,
const index_t Ho,
const index_t Wo,
const index_t K,
const std::array<index_t, NDimSpatial + 3>& output_strides)
{
const index_t WoStride = output_strides[5];
const auto KStride = Number<1>{};
return make_naive_tensor_descriptor(make_tuple(N * Do * Ho * Wo, K),
make_tuple(WoStride, KStride));
}
template <index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
constexpr static auto
make_in_grid_desc(const index_t N,
const index_t Di,
const index_t Hi,
const index_t Wi,
const index_t C,
const std::array<index_t, NDimSpatial + 3>& input_strides)
{
const index_t NStride = input_strides[1];
const index_t DiStride = input_strides[3];
const index_t HiStride = input_strides[4];
const index_t WiStride = input_strides[5];
const auto CStride = input_strides[2];
if constexpr(ConvBackwardWeightSpecialization ==
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
{
return make_naive_tensor_descriptor(make_tuple(N * Di * Hi * Wi, C),
make_tuple(WiStride, CStride));
}
else
{
return make_naive_tensor_descriptor(
make_tuple(N, Di, Hi, Wi, C),
make_tuple(NStride, DiStride, HiStride, WiStride, CStride));
}
}
template <index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
constexpr static auto
make_wei_grid_desc(const index_t K,
const index_t Z,
const index_t Y,
const index_t X,
const index_t C,
const std::array<index_t, NDimSpatial + 3>& weights_strides)
{
const auto CStride = Number<1>{};
const auto KStride = weights_strides[1];
return make_naive_tensor_descriptor(make_tuple(K, Z * Y * X * C),
make_tuple(KStride, CStride));
}
template <index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
const index_t N,
const index_t K,
const index_t C,
const std::array<index_t, NDimSpatial>& input_spatial_lengths,
const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<index_t, NDimSpatial>& output_spatial_lengths,
const std::array<index_t, NDimSpatial + 3>& input_strides,
const std::array<index_t, NDimSpatial + 3>& weights_strides,
const std::array<index_t, NDimSpatial + 3>& output_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads)
{
using namespace ck;
const index_t Di = input_spatial_lengths[0];
const index_t Hi = input_spatial_lengths[1];
const index_t Wi = input_spatial_lengths[2];
const index_t Do = output_spatial_lengths[0];
const index_t Ho = output_spatial_lengths[1];
const index_t Wo = output_spatial_lengths[2];
const index_t Z = filter_spatial_lengths[0];
const index_t Y = filter_spatial_lengths[1];
const index_t X = filter_spatial_lengths[2];
const index_t ConvStrideD = conv_filter_strides[0];
const index_t ConvStrideH = conv_filter_strides[1];
const index_t ConvStrideW = conv_filter_strides[2];
const index_t ConvDilationD = conv_filter_dilations[0];
const index_t ConvDilationH = conv_filter_dilations[1];
const index_t ConvDilationW = conv_filter_dilations[2];
const index_t InLeftPadD = input_left_pads[0];
const index_t InLeftPadH = input_left_pads[1];
const index_t InLeftPadW = input_left_pads[2];
const index_t InRightPadD = input_right_pads[0];
const index_t InRightPadH = input_right_pads[1];
const index_t InRightPadW = input_right_pads[2];
const index_t GemmKTotal = N * Do * Ho * Wo;
const index_t GemmM = K;
const index_t GemmN = C * Z * X * Y;
const auto PadGemmM = (MPerBlock - GemmM % MPerBlock) % MPerBlock;
const auto PadGemmN = (NPerBlock - GemmN % NPerBlock) % NPerBlock;
const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock) * K0PerBlock;
const index_t GemmKPad = GemmK0 * GemmK1Number;
const auto out_grid_desc = make_out_grid_desc<NDim>(N, Do, Ho, Wo, K, output_strides);
const auto in_grid_desc = make_in_grid_desc<NDim>(N, Di, Hi, Wi, C, input_strides);
const auto wei_grid_desc = make_wei_grid_desc<NDim>(K, Z, Y, X, C, weights_strides);
if constexpr(ConvBackwardWeightSpecialization ==
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
{
// A: output tensor
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
out_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// B: input tensor
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
in_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_grid_desc);
}
else
{
// A: output tensor
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
out_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// B: input tensor
const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Di, InLeftPadD, InRightPadD),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_dip_hip_wip_c_grid_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{},
Sequence<1, 2>{},
Sequence<3, 4>{},
Sequence<5, 6>{},
Sequence<7>{}));
const auto in_gemmktotal_gemmn_grid_desc = transform_tensor_descriptor(
in_n_z_do_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(Z, Y, X, C)),
make_merge_transform(make_tuple(N, Do, Ho, Wo))),
make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
in_gemmktotal_gemmn_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// Pad
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc =
transform_tensor_descriptor(
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
make_tuple(make_pass_through_transform(GemmK0),
make_right_pad_transform(GemmM, PadGemmM),
make_pass_through_transform(GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc =
transform_tensor_descriptor(
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
make_tuple(make_pass_through_transform(GemmK0),
make_right_pad_transform(GemmN, PadGemmN),
make_pass_through_transform(GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto wei_gemmm_gemmn_pad_grid_desc =
transform_tensor_descriptor(wei_grid_desc,
make_tuple(make_right_pad_transform(GemmM, PadGemmM),
make_right_pad_transform(GemmN, PadGemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc,
wei_gemmm_gemmn_pad_grid_desc);
}
}
template <index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
static auto GetABCGridDesc()
{
const index_t dim = 1;
const std::array<index_t, NDimSpatial> lengths{1, 1, 1};
const std::array<index_t, NDimSpatial + 3> strides{1, 1, 1, 1, 1, 1};
const std::array<index_t, NDimSpatial> params{1, 1, 1};
return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>(dim,
dim,
dim,
lengths,
lengths,
lengths,
strides,
strides,
strides,
params,
params,
params,
params);
}
using ABCGridDescs = decltype(GetABCGridDesc<NDimSpatial>());
using AGridDesc_K0_M_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I0])>;
using BGridDesc_K0_N_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I1])>;
using CGridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
using GridwiseGemm = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle<
// DataType Family
ADataType,
BDataType,
AccDataType,
CDataType,
Tuple<>,
CDataType,
// InMemory Data Descriptor
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
Tuple<>,
CGridDesc_M_N,
// ElementwiseOp Family
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
InMemoryDataOperationEnum::Set,
// Tiling Family
MPerBlock,
NPerBlock,
K0PerBlock,
MPerWMMA,
NPerWMMA,
K1,
MRepeat,
NRepeat,
// ThreadCluster Family
BlockSize,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
false,
ABlockLdsAddExtraM,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
false,
BBlockLdsAddExtraN,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
NumGemmKPrefetchStage,
LoopSched,
PipelineVer>;
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(Tuple<>{}));
using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
CGridDesc_M_N{}));
using Block2CTileMap = decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(
CGridDesc_M_N{}, I1 /* M01 */, I1 /* N01 */));
struct Argument : public BaseArgument
{
Argument(const InDataType* p_in_grid,
WeiDataType* p_wei_grid,
const OutDataType* p_out_grid,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op,
index_t split_k)
: p_a_grid_{p_out_grid},
p_b_grid_{p_in_grid},
p_c_grid_{p_wei_grid},
a_grid_desc_kbatch_k0_m_k1_{},
b_grid_desc_kbatch_k0_n_k1_{},
c_grid_desc_m_n_{},
c_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_ctile_map_{},
compute_ptr_offset_of_batch_{},
a_element_op_{out_element_op},
b_element_op_{in_element_op},
c_element_op_{wei_element_op},
Conv_G_{a_g_n_c_wis_lengths[0]},
Conv_N_{a_g_n_c_wis_lengths[1]},
Conv_K_{b_g_k_c_xs_lengths[1]},
Conv_C_{a_g_n_c_wis_lengths[2]},
input_spatial_lengths_{},
filter_spatial_lengths_{},
output_spatial_lengths_{},
conv_filter_strides_{conv_filter_strides},
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads},
k_batch_{split_k}
{
constexpr index_t spatial_offset = 3;
std::copy(begin(a_g_n_c_wis_lengths) + spatial_offset,
end(a_g_n_c_wis_lengths),
begin(input_spatial_lengths_));
std::copy(begin(b_g_k_c_xs_lengths) + spatial_offset,
end(b_g_k_c_xs_lengths),
begin(filter_spatial_lengths_));
std::copy(begin(e_g_n_k_wos_lengths) + spatial_offset,
end(e_g_n_k_wos_lengths),
begin(output_spatial_lengths_));
const auto descs =
DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
Conv_N_,
Conv_K_,
Conv_C_,
input_spatial_lengths_,
filter_spatial_lengths_,
output_spatial_lengths_,
a_g_n_c_wis_strides,
b_g_k_c_xs_strides,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads);
a_grid_desc_kbatch_k0_m_k1_ = descs[I0];
b_grid_desc_kbatch_k0_n_k1_ = descs[I1];
c_grid_desc_m_n_ = descs[I2];
block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(
c_grid_desc_m_n_, I1 /* M01 */, I1 /* N01 */);
// A/B/C Batch Stride
compute_ptr_offset_of_batch_.BatchStrideA_ = e_g_n_k_wos_strides[0];
compute_ptr_offset_of_batch_.BatchStrideB_ = a_g_n_c_wis_strides[0];
compute_ptr_offset_of_batch_.BatchStrideE_ =
Conv_K_ * Conv_C_ *
std::accumulate(begin(filter_spatial_lengths_),
end(filter_spatial_lengths_),
index_t{1},
std::multiplies<>{});
if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_,
b_grid_desc_kbatch_k0_n_k1_,
c_grid_desc_m_n_,
block_2_ctile_map_))
{
c_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n_);
}
}
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
CDataType* p_c_grid_;
AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_;
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_;
Block2CTileMap block_2_ctile_map_;
// for computing batch offset
ComputePtrOffsetOfStridedBatch<I0> compute_ptr_offset_of_batch_;
OutElementwiseOperation a_element_op_;
InElementwiseOperation b_element_op_;
WeiElementwiseOperation c_element_op_;
// for checking IsSupportedArgument()
const index_t Conv_G_;
const index_t Conv_N_;
const index_t Conv_K_;
const index_t Conv_C_;
std::array<index_t, NDimSpatial> input_spatial_lengths_;
std::array<index_t, NDimSpatial> filter_spatial_lengths_;
std::array<index_t, NDimSpatial> output_spatial_lengths_;
const std::array<index_t, NDimSpatial>& conv_filter_strides_;
const std::array<index_t, NDimSpatial>& input_left_pads_;
const std::array<index_t, NDimSpatial>& input_right_pads_;
const index_t k_batch_;
};
// Invoker
struct Invoker : public BaseInvoker
{
using Argument = DeviceOp::Argument;
void Print(const Argument& arg)
{
std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{"
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", "
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", "
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{"
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", "
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", "
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.c_grid_desc_m_n_{" << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(stream_config.log_level_ > 0)
{
Print(arg);
}
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_))
{
throw std::runtime_error(
"wrong! GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle has invalid "
"setting");
}
const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.Conv_G_;
const auto K0 = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1);
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0);
auto launch_kernel = [&](auto has_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value;
const auto kernel = kernel_grouped_conv_multiple_d_wmma_cshuffle<
GridwiseGemm,
ADataType,
BDataType,
typename GridwiseGemm::DsGridPointer,
CDataType,
OutElementwiseOperation,
InElementwiseOperation,
WeiElementwiseOperation,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
ComputePtrOffsetOfStridedBatch<I0>,
has_main_loop>;
using EmptyTuple = Tuple<>;
return launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
EmptyTuple{}, // Ds
arg.p_c_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.Conv_G_,
arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock{},
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_,
arg.compute_ptr_offset_of_batch_);
};
if(has_main_k0_block_loop)
{
return launch_kernel(integral_constant<bool, true>{});
}
else
{
return launch_kernel(integral_constant<bool, false>{});
}
}
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
static bool IsSupportedArgument(const Argument& arg)
{
// check device
if(get_device_name() == "gfx1100" || get_device_name() == "gfx1101" ||
get_device_name() == "gfx1102")
{
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{
return false;
}
}
else
{
return false;
}
// TODO: Add support for split_k > 1
if(arg.k_batch_ != 1)
{
return false;
}
if constexpr(!(is_NDHWGK_GKZYXC_NDHWGC || is_GNDHWK_GKZYXC_GNDHWC))
{
return false;
}
if constexpr(ConvBackwardWeightSpecialization ==
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
{
// check if it's a 1x1 convolution with stride=1 and no padding
for(int i = 0; i < NDimSpatial; i++)
{
if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 &&
arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0))
{
return false;
}
}
}
// vector load A/B matrix from global memory
if(!(ABlockTransferSrcVectorDim == 1 && BBlockTransferSrcVectorDim == 1 &&
arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0 &&
arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0))
{
return false;
}
// vector store C matrix into global memory
if(!(arg.Conv_C_ % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
{
return false;
}
// Gridwise GEMM size
return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_);
}
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto
MakeArgument(const InDataType* p_in_grid,
WeiDataType* p_wei_grid,
const OutDataType* p_out_grid,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op,
const index_t split_k)
{
return Argument{p_in_grid,
p_wei_grid,
p_out_grid,
a_g_n_c_wis_lengths, // input
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths, // weight
b_g_k_c_xs_strides,
e_g_n_k_wos_lengths, // output
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
in_element_op,
wei_element_op,
out_element_op,
split_k};
}
static auto MakeInvoker() { return Invoker{}; }
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_in_grid,
void* p_wei_grid,
const void* p_out_grid,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op,
const index_t split_k) override
{
return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_grid),
static_cast<WeiDataType*>(p_wei_grid),
static_cast<const OutDataType*>(p_out_grid),
a_g_n_c_wis_lengths, // input
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths, // weight
b_g_k_c_xs_strides,
e_g_n_k_wos_lengths, // output
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
in_element_op,
wei_element_op,
out_element_op,
split_k);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceGroupedConvBwdWeight_Wmma_CShuffle"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< K0PerBlock << ", "
<< getConvBackwardWeightSpecializationString(ConvBackwardWeightSpecialization) << ", "
<< K1 << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< ABlockTransferDstScalarPerVector_K1 << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< BBlockTransferDstScalarPerVector_K1 << ", "
<< CShuffleMRepeatPerShuffle << ", "
<< CShuffleNRepeatPerShuffle << ", "
<< CShuffleBlockTransferScalarPerVector_NPerBlock
<< ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment