Commit 11279540 authored by Astha Rai's avatar Astha Rai
Browse files

Merge branch 'transpose_5d' of github.com:ROCmSoftwarePlatform/composable_kernel into transpose_5d

parents 14daa201 33e78b9a
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_activ_common.hpp"
using OutElementOp = ck::tensor_operation::element_wise::Relu;
using DeviceGroupedConvNDFwdActivInstance = DeviceGroupedConvNDFwdInstance<OutElementOp>;
#include "run_convnd_fwd_activ_example.inc"
int main(int argc, char* argv[]) { return !run_convnd_fwd_example(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <iostream>
#include <numeric>
#include <type_traits>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
constexpr ck::index_t NDimSpatial = 3;
using InDataType = ck::half_t;
using WeiDataType = ck::half_t;
using AccDataType = float;
using CShuffleDataType = ck::half_t;
using OutDataType = ck::half_t;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using InLayout = ck::tensor_layout::convolution::GNDHWC;
using WeiLayout = ck::tensor_layout::convolution::GKZYXC;
using OutLayout = ck::tensor_layout::convolution::GNDHWK;
using InElementOp = ck::tensor_operation::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
using OutElementOp = ck::tensor_operation::element_wise::ScaleAddScaleAddRelu;
static constexpr auto ConvSpec =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
template <typename OutElementOp>
using DeviceGroupedConvNDFwdInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle<
NDimSpatial,
InLayout,
WeiLayout,
ck::Tuple<OutLayout, OutLayout>,
OutLayout,
InDataType,
WeiDataType,
AccDataType,
CShuffleDataType,
ck::Tuple<OutDataType, OutDataType>,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
ConvSpec, // ConvForwardSpecialization
GemmSpec, // GemmSpecialization
1, //
256, // BlockSize
128, // MPerBlock
256, // NPerBlock
32, // KPerBlock
8, // AK1
8, // BK1
32, // MPerXdl
32, // NPerXdl
2, // MXdlPerWave
4, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
1,
1,
S<1, 32, 1, 8>,
8>;
using DeviceGroupedConvNDFwdActivInstance = DeviceGroupedConvNDFwdInstance<OutElementOp>;
namespace {
// Use custom implementation to pass two more tensors for post op
template <ck::index_t NDimSpatial,
typename InDataType,
typename WeiDataType,
typename OutDataType,
typename InElementOp,
typename WeiElementOp,
typename OutElementOp,
typename DeviceConvNDFwdInstance>
bool run_grouped_conv_fwd(bool do_verification,
int init_method,
bool time_kernel,
const ck::utils::conv::ConvParam& conv_param,
const HostTensorDescriptor& in_g_n_c_wis_desc,
const HostTensorDescriptor& wei_g_k_c_xs_desc,
const HostTensorDescriptor& out_g_n_k_wos_desc,
const InElementOp& in_element_op,
const WeiElementOp& wei_element_op,
const OutElementOp& out_element_op)
{
constexpr ck::index_t NumDs = 2;
Tensor<InDataType> in(in_g_n_c_wis_desc);
Tensor<WeiDataType> wei(wei_g_k_c_xs_desc);
Tensor<OutDataType> out_host(out_g_n_k_wos_desc);
Tensor<OutDataType> out_device(out_g_n_k_wos_desc);
std::array<Tensor<OutDataType>, NumDs> d_tensors = {Tensor<OutDataType>(out_g_n_k_wos_desc),
Tensor<OutDataType>(out_g_n_k_wos_desc)};
std::cout << "in: " << in.mDesc << std::endl;
std::cout << "wei: " << wei.mDesc << std::endl;
std::cout << "out: " << out_host.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
in.GenerateTensorValue(GeneratorTensor_2<InDataType>{-2, 2});
wei.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-2, 2});
d_tensors[0].GenerateTensorValue(GeneratorTensor_2<OutDataType>{-2, 2});
d_tensors[1].GenerateTensorValue(GeneratorTensor_2<OutDataType>{-2, 2});
break;
default:
in.GenerateTensorValue(GeneratorTensor_3<InDataType>{-1.0, 1.0});
wei.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.05, 0.05});
d_tensors[0].GenerateTensorValue(GeneratorTensor_3<OutDataType>{-0.05, 0.05});
d_tensors[1].GenerateTensorValue(GeneratorTensor_3<OutDataType>{-0.05, 0.05});
}
DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize());
DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize());
DeviceMem d0_buf(sizeof(OutDataType) * d_tensors[0].mDesc.GetElementSpaceSize());
DeviceMem d1_buf(sizeof(OutDataType) * d_tensors[1].mDesc.GetElementSpaceSize());
DeviceMem out_device_buf(sizeof(OutDataType) * out_device.mDesc.GetElementSpaceSize());
in_device_buf.ToDevice(in.mData.data());
wei_device_buf.ToDevice(wei.mData.data());
d0_buf.ToDevice(d_tensors[0].mData.data());
d1_buf.ToDevice(d_tensors[1].mData.data());
std::array<ck::index_t, NDimSpatial + 3> a_g_n_c_wis_lengths{};
std::array<ck::index_t, NDimSpatial + 3> a_g_n_c_wis_strides{};
std::array<ck::index_t, NDimSpatial + 3> b_g_k_c_xs_lengths{};
std::array<ck::index_t, NDimSpatial + 3> b_g_k_c_xs_strides{};
std::array<ck::index_t, NDimSpatial + 3> e_g_n_k_wos_lengths{};
std::array<ck::index_t, NDimSpatial + 3> e_g_n_k_wos_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_dilations{};
std::array<ck::index_t, NDimSpatial> input_left_pads{};
std::array<ck::index_t, NDimSpatial> input_right_pads{};
auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); };
copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths);
copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides);
copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths);
copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides);
copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths);
copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides);
copy(conv_param.conv_filter_strides_, conv_filter_strides);
copy(conv_param.conv_filter_dilations_, conv_filter_dilations);
copy(conv_param.input_left_pads_, input_left_pads);
copy(conv_param.input_right_pads_, input_right_pads);
const std::array<const void*, NumDs> ds = {d0_buf.GetDeviceBuffer(), d1_buf.GetDeviceBuffer()};
auto conv = DeviceConvNDFwdInstance{};
auto invoker = conv.MakeInvoker();
auto argument = conv.MakeArgument(in_device_buf.GetDeviceBuffer(),
wei_device_buf.GetDeviceBuffer(),
ds,
out_device_buf.GetDeviceBuffer(),
a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
std::array<std::array<ck::index_t, NDimSpatial + 3>, NumDs>{
e_g_n_k_wos_lengths, e_g_n_k_wos_lengths},
std::array<std::array<ck::index_t, NDimSpatial + 3>, NumDs>{
e_g_n_k_wos_strides, e_g_n_k_wos_strides},
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
in_element_op,
wei_element_op,
out_element_op);
if(!conv.IsSupportedArgument(argument))
{
throw std::runtime_error("The device op with the specified compilation parameters does "
"not support this convolution problem.");
}
float avg_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop =
conv_param.GetFlops() + 2 * conv_param.GetOutputByte<OutDataType>() / sizeof(OutDataType);
std::size_t num_btype = conv_param.GetByte<InDataType, WeiDataType, OutDataType>() +
2 * conv_param.GetOutputByte<OutDataType>();
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
float gb_per_sec = num_btype / 1.E6 / avg_time;
std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< conv.GetTypeString() << std::endl;
if(do_verification)
{
auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd<NDimSpatial,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
NumDs>();
auto ref_invoker = ref_conv.MakeInvoker();
auto ref_argument = ref_conv.MakeArgument(in,
wei,
out_host,
conv_param.conv_filter_strides_,
conv_param.conv_filter_dilations_,
conv_param.input_left_pads_,
conv_param.input_right_pads_,
in_element_op,
wei_element_op,
out_element_op,
d_tensors);
ref_invoker.Run(ref_argument);
out_device_buf.FromDevice(out_device.mData.data());
return ck::utils::check_err(out_device, out_host, "Error: incorrect results!");
}
return true;
}
} // namespace
#include "run_convnd_fwd_activ_example.inc"
int main(int argc, char* argv[]) { return !run_convnd_fwd_example(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_activ_common.hpp"
using OutElementOp = ck::tensor_operation::element_wise::Sigmoid;
using DeviceGroupedConvNDFwdActivInstance = DeviceGroupedConvNDFwdInstance<OutElementOp>;
#include "run_convnd_fwd_activ_example.inc"
int main(int argc, char* argv[]) { return !run_convnd_fwd_example(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_activ_common.hpp"
using OutElementOp = ck::tensor_operation::element_wise::SoftRelu;
using DeviceGroupedConvNDFwdActivInstance = DeviceGroupedConvNDFwdInstance<OutElementOp>;
#include "run_convnd_fwd_activ_example.inc"
int main(int argc, char* argv[]) { return !run_convnd_fwd_example(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_activ_common.hpp"
using OutElementOp = ck::tensor_operation::element_wise::TanH;
using DeviceGroupedConvNDFwdActivInstance = DeviceGroupedConvNDFwdInstance<OutElementOp>;
#include "run_convnd_fwd_activ_example.inc"
int main(int argc, char* argv[]) { return !run_convnd_fwd_example(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
void print_helper_msg()
{
std::cout << "arg1: verification (0=no, 1=yes)\n"
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"
<< "arg3: time kernel (0=no, 1=yes)\n"
<< ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl;
}
bool run_convnd_fwd_example(int argc, char* argv[])
{
print_helper_msg();
bool do_verification = true;
// Use floats for SoftRelu by default to avoid overflow after e^x.
int init_method =
std::is_same_v<OutElementOp, ck::tensor_operation::element_wise::SoftRelu> ? 2 : 1;
bool time_kernel = false;
// Following shapes are selected to avoid overflow. Expect inf in case of
// size increase for some elementwise ops.
ck::utils::conv::ConvParam conv_param{
3, 1, 16, 128, 8, {3, 3, 3}, {17, 17, 17}, {2, 2, 2}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}};
if(argc == 1)
{
// use default
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
const ck::index_t num_dim_spatial = std::stoi(argv[4]);
conv_param = ck::utils::conv::parse_conv_param(num_dim_spatial, 5, argv);
}
const auto in_element_op = InElementOp{};
const auto wei_element_op = WeiElementOp{};
const auto out_element_op = OutElementOp{};
const auto run = [&]() {
const auto in_g_n_c_wis_desc =
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(
conv_param);
const auto wei_g_k_c_xs_desc =
ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(
conv_param);
const auto out_g_n_k_wos_desc =
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(
conv_param);
return run_grouped_conv_fwd<NDimSpatial,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
DeviceGroupedConvNDFwdActivInstance>(do_verification,
init_method,
time_kernel,
conv_param,
in_g_n_c_wis_desc,
wei_g_k_c_xs_desc,
out_g_n_k_wos_desc,
in_element_op,
wei_element_op,
out_element_op);
};
if(conv_param.num_dim_spatial_ == 3)
{
return run();
}
return false;
}
......@@ -11,31 +11,27 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
if(DEFINED DTYPES)
foreach(source IN LISTS FILE_NAME)
set(test 0)
foreach(type IN LISTS DTYPES)
if(type MATCHES "fp16")
set(type1 "_f16")
elseif(type MATCHES "fp32")
set(type1 "_f32")
elseif(type MATCHES "fp8")
set(type1 "_f8")
elseif(type MATCHES "bf16")
set(type1 "_b16")
elseif(type MATCHES "fp64")
set(type1 "_f64")
elseif(type MATCHES "int8")
set(type1 "_i8")
endif()
if("${source}" MATCHES "${type}" OR "${source}" MATCHES "${type1}")
#if filename matches any selected type, exit type loop and do no exclude the file from the list
set(test 0)
break()
elseif((source MATCHES "fp8" OR source MATCHES "fp32" OR source MATCHES "fp64" OR source MATCHES "bf16" OR source MATCHES "int8" OR source MATCHES "fp16" OR
source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND
NOT(source MATCHES type OR source MATCHES type1))
#if filename contains a type which doesn't match any selected type, mark it for removal
set(test 1)
endif()
endforeach()
if((source MATCHES "_fp16" OR source MATCHES "_f16") AND NOT "fp16" IN_LIST DTYPES)
set(test 1)
endif()
if((source MATCHES "_fp32" OR source MATCHES "_f32") AND NOT "fp32" IN_LIST DTYPES)
set(test 1)
endif()
if((source MATCHES "_fp64" OR source MATCHES "_f64") AND NOT "fp64" IN_LIST DTYPES)
set(test 1)
endif()
if((source MATCHES "_fp8" OR source MATCHES "_f8") AND NOT "fp8" IN_LIST DTYPES)
set(test 1)
endif()
if((source MATCHES "_bf8" OR source MATCHES "_bf8") AND NOT "bf8" IN_LIST DTYPES)
set(test 1)
endif()
if((source MATCHES "_bf16" OR source MATCHES "_b16") AND NOT "bf16" IN_LIST DTYPES)
set(test 1)
endif()
if((source MATCHES "_int8" OR source MATCHES "_i8") AND NOT "int8" IN_LIST DTYPES)
set(test 1)
endif()
if(test EQUAL 1)
message("removing example source file ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}")
......@@ -62,37 +58,39 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
set(result ${result} PARENT_SCOPE)
endfunction(add_example_executable EXAMPLE_NAME)
function(add_example_dependencies EXAMPLE_NAME FILE_NAME)
if(result EQUAL 0)
add_dependencies(${EXAMPLE_NAME} ${FILE_NAME})
endif()
endfunction(add_example_dependencies EXAMPLE_NAME)
function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
message("adding example ${EXAMPLE_NAME}")
set(result 1)
if(DEFINED DTYPES)
foreach(source IN LISTS FILE_NAME)
set(test 0)
foreach(type IN LISTS DTYPES)
if(type MATCHES "fp16")
set(type1 "_f16")
elseif(type MATCHES "fp32")
set(type1 "_f32")
elseif(type MATCHES "fp8")
set(type1 "_f8")
elseif(type MATCHES "bf16")
set(type1 "_b16")
elseif(type MATCHES "fp64")
set(type1 "_f64")
elseif(type MATCHES "int8")
set(type1 "_i8")
endif()
if("${source}" MATCHES "${type}" OR "${source}" MATCHES "${type1}")
#if filename matches any selected type, exit type loop and do no exclude the file from the list
set(test 0)
break()
elseif((source MATCHES "fp8" OR source MATCHES "fp32" OR source MATCHES "fp64" OR source MATCHES "bf16" OR source MATCHES "int8" OR source MATCHES "fp16" OR
source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND
NOT(source MATCHES type OR source MATCHES type1))
#if filename contains a type which doesn't match any selected type, mark it for removal
set(test 1)
endif()
endforeach()
if((source MATCHES "_fp16" OR source MATCHES "_f16") AND NOT "fp16" IN_LIST DTYPES)
set(test 1)
endif()
if((source MATCHES "_fp32" OR source MATCHES "_f32") AND NOT "fp32" IN_LIST DTYPES)
set(test 1)
endif()
if((source MATCHES "_fp64" OR source MATCHES "_f64") AND NOT "fp64" IN_LIST DTYPES)
set(test 1)
endif()
if((source MATCHES "_fp8" OR source MATCHES "_f8") AND NOT "fp8" IN_LIST DTYPES)
set(test 1)
endif()
if((source MATCHES "_bf8" OR source MATCHES "_bf8") AND NOT "bf8" IN_LIST DTYPES)
set(test 1)
endif()
if((source MATCHES "_bf16" OR source MATCHES "_b16") AND NOT "bf16" IN_LIST DTYPES)
set(test 1)
endif()
if((source MATCHES "_int8" OR source MATCHES "_i8") AND NOT "int8" IN_LIST DTYPES)
set(test 1)
endif()
if(test EQUAL 1)
message("removing example ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}")
......
......@@ -66,6 +66,10 @@
#define CK_USE_AMD_V_FMAC_F32
#define CK_USE_AMD_V_DOT2_F32_F16
#define CK_USE_AMD_V_DOT4_I32_I8
#elif defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
#define CK_USE_AMD_V_FMAC_F32
#define CK_USE_AMD_V_DOT2_F32_F16
#define CK_USE_AMD_V_DOT4_I32_I8_GFX11
#endif
// MFMA instruction
......
......@@ -3,8 +3,10 @@
#pragma once
#include <sstream>
#include <hip/hip_runtime.h>
// To be removed, which really does not tell the location of failed HIP functional call
inline void hip_check_error(hipError_t x)
{
if(x != hipSuccess)
......@@ -15,3 +17,16 @@ inline void hip_check_error(hipError_t x)
throw std::runtime_error(ss.str());
}
}
#define HIP_CHECK_ERROR(retval_or_funcall) \
do \
{ \
hipError_t _tmpVal = retval_or_funcall; \
if(_tmpVal != hipSuccess) \
{ \
std::ostringstream ostr; \
ostr << "HIP Function Failed (" << __FILE__ << "," << __LINE__ << ") " \
<< hipGetErrorString(_tmpVal); \
throw std::runtime_error(ostr.str()); \
} \
} while(0)
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
// GEMM:
// input : A0[M0, M1, ... K0, K1, ...], ...
// input : B0[N0, N1, ... K0, K1, ...], ...
// input : D0[M0, M1, ... N0, N1, ...], D1[M0, M1, ... N0, N1, ...], ...
// output : E[M0, M1, ... N0, N1, ...]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// D0, D1, ... and E have the same layout
template <index_t NumDimM,
index_t NumDimN,
index_t NumDimK,
typename AsDataType,
typename BsDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation>
struct DeviceContractionMultipleABD : public BaseOperator
{
static constexpr index_t NumATensor = AsDataType::Size();
static constexpr index_t NumBTensor = BsDataType::Size();
static constexpr index_t NumDTensor = DsDataType::Size();
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(std::array<const void*, NumATensor> p_as,
std::array<const void*, NumBTensor> p_bs,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
const std::array<std::vector<index_t>, NumATensor>& a_ms_ks_lengths,
const std::array<std::vector<index_t>, NumATensor>& a_ms_ks_strides,
const std::array<std::vector<index_t>, NumBTensor>& b_ns_ks_lengths,
const std::array<std::vector<index_t>, NumBTensor>& b_ns_ks_strides,
const std::array<std::vector<index_t>, NumDTensor>& d_ms_ns_lengths,
const std::array<std::vector<index_t>, NumDTensor>& d_ms_ns_strides,
const std::vector<index_t>& e_ms_ns_length,
const std::vector<index_t>& e_ms_ns_stride,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -33,7 +33,8 @@ template <index_t NumDimM,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation>
typename CDEElementwiseOperation,
typename ComputeDataType = ADataType>
struct DeviceContractionMultipleD : public BaseOperator
{
static constexpr index_t NumDTensor = DsDataType::Size();
......
......@@ -14,11 +14,12 @@ namespace device {
/**
* \brief Convolution Tensor Rearrange.
*
* This Device operator supports conversion image ([G, N, Di, Hi, Wi, C]) to
* the gemm problem([N * Do * Ho * Wo, Z * Y * X * C]) (Image to Column) and
* conversion gemm form to the image (Column to Image).
*
* Note that G must be equal to 1.
* This Device operator supports converting an image to
* the GEMM representation (Image to Column) and
* converting a GEMM form to the image (Column to Image).
* Supported layouts:
* [G, N, Di, Hi, Wi, C] <-> [G, N * Do * Ho * Wo, Z * Y * X * C]
* [N, Di, Hi, Wi, G, C] <-> [N * Do * Ho * Wo, G, Z * Y * X * C]
*
* \tparam NDimSpatial Number of spatial dimensions.
* \tparam ImageLayout Input Layout.
......@@ -39,13 +40,14 @@ struct DeviceConvTensorRearrange : public BaseOperator
*
* \param p_in A pointer to the device memory of the input image.
* \param p_out A pointer to the device memory of the output.
* \param G Convolution number of groups.
* \param N Convolution batch size.
* \param C Convolution number of channels.
* \param input_spatial_lengths Input spatial lengths.
* \param filter_spatial_lengths Filter spatial lengths.
* \param output_spatial_lengths Output spatial lengths.
* \param image_g_n_c_wis_strides Image strides in order [G, N, C, D, H, W].
* \param gemm_m_k_strides Gemm form strides.
* \param gemm_g_m_k_strides Gemm form strides.
* \param conv_filter_strides Convolution filter strides.
* \param conv_filter_dilations Convolution filter dilations.
* \param input_left_pads Convolution left pads.
......@@ -55,13 +57,14 @@ struct DeviceConvTensorRearrange : public BaseOperator
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_in,
void* p_out,
const ck::index_t G,
const ck::index_t N,
const ck::index_t C,
const std::array<index_t, NDimSpatial>& input_spatial_lengths,
const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<index_t, NDimSpatial>& output_spatial_lengths,
const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides,
const std::array<index_t, 2>& gemm_m_k_strides,
const std::array<index_t, 3>& gemm_g_m_k_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
......
......@@ -14,8 +14,8 @@ namespace device {
template <typename XDataType,
typename GammaDataType,
typename BetaDataType,
typename ComputeDataType,
typename YDataType,
typename SaveMeanInvStdDataType,
typename YElementwiseOperation,
index_t Rank,
index_t NumReduceDim>
......@@ -27,6 +27,8 @@ struct DeviceNormalization : public BaseOperator
const std::vector<index_t> gammaStrides,
const std::vector<index_t> betaStrides,
const std::vector<index_t> yStrides,
const std::vector<index_t> saveMeanStrides,
const std::vector<index_t> saveInvStdStrides,
const std::vector<index_t> reduceDims,
double epsilon,
const void* p_x,
......@@ -43,16 +45,16 @@ struct DeviceNormalization : public BaseOperator
template <typename XDataType,
typename GammaDataType,
typename BetaDataType,
typename ComputeDataType,
typename YDataType,
typename SaveMeanInvStdDataType,
typename YElementwiseOperation,
index_t Rank,
index_t NumReduceDim>
using DeviceNormalizationPtr = std::unique_ptr<DeviceNormalization<XDataType,
GammaDataType,
BetaDataType,
ComputeDataType,
YDataType,
SaveMeanInvStdDataType,
YElementwiseOperation,
Rank,
NumReduceDim>>;
......
......@@ -17,15 +17,18 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/conv_tensor_rearrange_op.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/host_utility/io.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
// Image to column for input layout NDHWC:
// input : image converted to the gemm problem [N * Do * Ho * Wo, Z * Y * X * C]
// output : image [N, Di, Hi, Wi, C]
// Column to Image:
// input : gemm form [G, N * Do * Ho * Wo, Z * Y * X * C]
// output : input image [G, N, Di, Hi, Wi, C]
// input : gemm form [N * Do * Ho * Wo, G, Z * Y * X * C]
// output : input image [N, Di, Hi, Wi, G, C]
template <index_t NDimSpatial,
typename ImageLayout,
typename InputDataType,
......@@ -43,6 +46,14 @@ struct DeviceColumnToImageImpl
OutputDataType,
conv_tensor_rearrange_op::ColumnToImage>
{
static constexpr bool is_NSpatialGC =
std::is_same_v<ImageLayout, tensor_layout::convolution::NWGC> ||
std::is_same_v<ImageLayout, tensor_layout::convolution::NHWGC> ||
std::is_same_v<ImageLayout, tensor_layout::convolution::NDHWGC>;
static constexpr bool is_GNSpatialC =
std::is_same_v<ImageLayout, tensor_layout::convolution::GNWC> ||
std::is_same_v<ImageLayout, tensor_layout::convolution::GNHWC> ||
std::is_same_v<ImageLayout, tensor_layout::convolution::GNDHWC>;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
......@@ -90,7 +101,7 @@ struct DeviceColumnToImageImpl
const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<index_t, NDimSpatial>& output_spatial_lengths,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, 2>& gemm_m_k_strides,
const std::array<index_t, 3>& gemm_g_m_k_strides,
const std::array<index_t, NDimSpatial>& independent_filters,
const std::array<index_t, NDimSpatial>& effs)
{
......@@ -100,23 +111,23 @@ struct DeviceColumnToImageImpl
C * ck::accumulate_n<index_t>(
filter_spatial_lengths.begin(), NDimSpatial, 1, std::multiplies<>());
const index_t NStride = DoHoWo * gemm_m_k_strides[I0] * gemm_m_k_strides[I1];
const index_t NStride = DoHoWo * gemm_g_m_k_strides[I1] * gemm_g_m_k_strides[I2];
// Calculate the appropriate stride for each set of independent filters
// in each dimension
const index_t WStride =
math::integer_divide_ceil(effs[XIdx], conv_filter_strides[XIdx]) * gemm_m_k_strides[I0];
const index_t WStride = math::integer_divide_ceil(effs[XIdx], conv_filter_strides[XIdx]) *
gemm_g_m_k_strides[I1];
const index_t HStride = math::integer_divide_ceil(effs[YIdx], conv_filter_strides[YIdx]) *
output_spatial_lengths[XIdx] * gemm_m_k_strides[I0];
output_spatial_lengths[XIdx] * gemm_g_m_k_strides[I1];
const index_t DStride = math::integer_divide_ceil(effs[ZIdx], conv_filter_strides[ZIdx]) *
output_spatial_lengths[YIdx] * output_spatial_lengths[XIdx] *
gemm_m_k_strides[I0];
gemm_g_m_k_strides[I1];
// Create descriptor for independent filters in each dimension and
// then merge them into column form
if constexpr(NDimSpatial == 1)
{
const auto desc_gemm_form =
make_naive_tensor_descriptor(make_tuple(N, independent_filters[XIdx], CZYX),
make_tuple(NStride, WStride, gemm_m_k_strides[I1]));
make_tuple(NStride, WStride, gemm_g_m_k_strides[I2]));
const auto desc_gemm_form_merged_filters = transform_tensor_descriptor(
desc_gemm_form,
make_tuple(make_merge_transform(make_tuple(N, independent_filters[XIdx])),
......@@ -130,7 +141,7 @@ struct DeviceColumnToImageImpl
{
const auto desc_gemm_form = make_naive_tensor_descriptor(
make_tuple(N, independent_filters[YIdx], independent_filters[XIdx], CZYX),
make_tuple(NStride, HStride, WStride, gemm_m_k_strides[I1]));
make_tuple(NStride, HStride, WStride, gemm_g_m_k_strides[I2]));
const auto desc_gemm_form_merged_filters = transform_tensor_descriptor(
desc_gemm_form,
make_tuple(make_merge_transform(
......@@ -149,7 +160,7 @@ struct DeviceColumnToImageImpl
independent_filters[YIdx],
independent_filters[XIdx],
CZYX),
make_tuple(NStride, DStride, HStride, WStride, gemm_m_k_strides[I1]));
make_tuple(NStride, DStride, HStride, WStride, gemm_g_m_k_strides[I2]));
const auto desc_gemm_form_merged_filters = transform_tensor_descriptor(
desc_gemm_form,
make_tuple(make_merge_transform(make_tuple(N,
......@@ -252,34 +263,38 @@ struct DeviceColumnToImageImpl
decltype(BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, KPerBlock, InputGridDesc>(
InputGridDesc{}))>;
using GridwiseTensorRearrangeKernel = GridwiseTensorRearrange<InputGridDesc,
InputDataType,
OutputGridDesc,
OutputDataType,
BlockSize,
MPerBlock,
KPerBlock,
ThreadClusterLengths,
ScalarPerVector,
InMemoryDataOperationEnum::Add,
Block2ETileMap>;
using GridwiseTensorRearrangeKernel =
GridwiseTensorRearrange<InputGridDesc,
InputDataType,
OutputGridDesc,
OutputDataType,
BlockSize,
MPerBlock,
KPerBlock,
ThreadClusterLengths,
ScalarPerVector,
InMemoryDataOperationEnum::Add,
Block2ETileMap,
ComputePtrOffsetOfStridedBatch<I0>>;
struct Argument : public BaseArgument
{
Argument(const void* p_in, // input image
void* p_out, // output image
const ck::index_t G,
const ck::index_t N,
const ck::index_t C,
const std::array<index_t, NDimSpatial>& input_spatial_lengths,
const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<index_t, NDimSpatial>& output_spatial_lengths,
const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides,
const std::array<index_t, 2>& gemm_m_k_strides,
const std::array<index_t, 3>& gemm_g_m_k_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads)
: C_(C),
: G_(G),
C_(C),
X_(filter_spatial_lengths[NDimSpatial - I1]),
p_in_{static_cast<const InputDataType*>(p_in)},
p_out_{static_cast<OutputDataType*>(p_out)},
......@@ -289,6 +304,9 @@ struct DeviceColumnToImageImpl
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads}
{
compute_ptr_offset_of_batch_.BatchStrideA_ = gemm_g_m_k_strides[I0];
compute_ptr_offset_of_batch_.BatchStrideC_ = image_g_n_c_wis_strides[I0];
const index_t x_eff =
(filter_spatial_lengths[XIdx] - 1) * conv_filter_dilations[XIdx] + 1;
const index_t y_eff =
......@@ -354,7 +372,7 @@ struct DeviceColumnToImageImpl
filter_spatial_lengths,
output_spatial_lengths,
conv_filter_strides,
gemm_m_k_strides,
gemm_g_m_k_strides,
independent_filters,
effs);
const auto out_grid_desc_m_k =
......@@ -387,10 +405,9 @@ struct DeviceColumnToImageImpl
// Memory offsets to next set of independent filters,
// move to independent filters in each dimension
const index_t in_offset =
x_idx * gemm_m_k_strides[0] +
y_idx * gemm_m_k_strides[0] * output_spatial_lengths[XIdx] +
z_idx * gemm_m_k_strides[0] * output_spatial_lengths[YIdx] *
output_spatial_lengths[XIdx];
(x_idx + y_idx * output_spatial_lengths[XIdx] +
z_idx * output_spatial_lengths[YIdx] * output_spatial_lengths[XIdx]) *
gemm_g_m_k_strides[I1];
// Move to independent filters in appropriate dimensions
const index_t out_offset =
x_offset_with_pad * image_g_n_c_wis_strides[spatial_offset + XIdx] +
......@@ -417,6 +434,7 @@ struct DeviceColumnToImageImpl
}
}
const ck::index_t G_;
const ck::index_t C_;
const ck::index_t X_;
......@@ -434,6 +452,8 @@ struct DeviceColumnToImageImpl
std::vector<const InputDataType*> p_in_container_;
std::vector<OutputDataType*> p_out_container_;
ComputePtrOffsetOfStridedBatch<I0> compute_ptr_offset_of_batch_;
};
struct Invoker : public BaseInvoker
......@@ -451,6 +471,7 @@ struct DeviceColumnToImageImpl
OutputGridDesc,
OutputDataType,
Block2ETileMap,
ComputePtrOffsetOfStridedBatch<I0>,
GridwiseTensorRearrangeKernel>;
// Execute each set of independent filters
......@@ -460,7 +481,7 @@ struct DeviceColumnToImageImpl
BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, KPerBlock, InputGridDesc>(
arg.out_grid_desc_m_k_container_[i]);
const index_t grid_size =
block_2_tile_map.CalculateGridSize(arg.in_grid_desc_m_k_container_[i]);
block_2_tile_map.CalculateGridSize(arg.in_grid_desc_m_k_container_[i]) * arg.G_;
elapsed_time += launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
......@@ -470,7 +491,9 @@ struct DeviceColumnToImageImpl
arg.p_in_container_[i],
arg.out_grid_desc_m_k_container_[i],
arg.p_out_container_[i],
block_2_tile_map);
arg.G_,
block_2_tile_map,
arg.compute_ptr_offset_of_batch_);
}
return elapsed_time;
}
......@@ -485,8 +508,7 @@ struct DeviceColumnToImageImpl
bool IsSupportedArgument(const Argument& arg)
{
using namespace tensor_layout::convolution;
if constexpr(!(std::is_same_v<ImageLayout, GNWC> || std::is_same_v<ImageLayout, GNHWC> ||
std::is_same_v<ImageLayout, GNDHWC>))
if constexpr(!(is_NSpatialGC || is_GNSpatialC))
{
return false;
}
......@@ -534,13 +556,14 @@ struct DeviceColumnToImageImpl
static auto MakeArgument(const void* p_in, // input image
void* p_out, // output image
const ck::index_t G,
const ck::index_t N,
const ck::index_t C,
const std::array<index_t, NDimSpatial>& input_spatial_lengths,
const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<index_t, NDimSpatial>& output_spatial_lengths,
const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides,
const std::array<index_t, 2>& gemm_m_k_strides,
const std::array<index_t, 3>& gemm_g_m_k_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
......@@ -548,13 +571,14 @@ struct DeviceColumnToImageImpl
{
return Argument{static_cast<const InputDataType*>(p_in),
static_cast<OutputDataType*>(p_out),
G,
N,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
image_g_n_c_wis_strides,
gemm_m_k_strides,
gemm_g_m_k_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
......@@ -566,13 +590,14 @@ struct DeviceColumnToImageImpl
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_in, // input image
void* p_out, // output image
const ck::index_t G,
const ck::index_t N,
const ck::index_t C,
const std::array<index_t, NDimSpatial>& input_spatial_lengths,
const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<index_t, NDimSpatial>& output_spatial_lengths,
const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides,
const std::array<index_t, 2>& gemm_m_k_strides,
const std::array<index_t, 3>& gemm_g_m_k_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
......@@ -580,13 +605,14 @@ struct DeviceColumnToImageImpl
{
return std::make_unique<Argument>(static_cast<const InputDataType*>(p_in),
static_cast<OutputDataType*>(p_out),
G,
N,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
image_g_n_c_wis_strides,
gemm_m_k_strides,
gemm_g_m_k_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include <vector>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_contraction_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
template <typename GridwiseGemm,
typename AsPointer,
typename BsPointer,
typename DsPointer,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
typename AsGridDesc_AK0_M_AK1,
typename BsGridDesc_BK0_N_BK1,
typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename Block2ETileMap,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_contraction_multiple_abd_xdl_cshuffle(
AsPointer p_as_grid,
BsPointer p_bs_grid,
DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op,
const AsGridDesc_AK0_M_AK1 as_grid_desc_ak0_m_ak1,
const BsGridDesc_BK0_N_BK1 bs_grid_desc_bk0_n_bk1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2ETileMap block_2_etile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_as_grid,
p_bs_grid,
p_ds_grid,
p_e_grid,
p_shared,
a_element_op,
b_element_op,
cde_element_op,
as_grid_desc_ak0_m_ak1,
bs_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_etile_map);
#else
ignore = p_as_grid;
ignore = p_bs_grid;
ignore = p_ds_grid;
ignore = p_e_grid;
ignore = a_element_op;
ignore = b_element_op;
ignore = cde_element_op;
ignore = as_grid_desc_ak0_m_ak1;
ignore = bs_grid_desc_bk0_n_bk1;
ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = e_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = block_2_etile_map;
#endif
}
} // namespace ck
namespace ck {
namespace tensor_operation {
namespace device {
// GEMM:
// input : A[M, K]
// input : B[N, K]
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// D0, D1, ... and E have the same layout
template <index_t NumDimM,
index_t NumDimN,
index_t NumDimK,
typename AsDataType,
typename BsDataType,
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t AK1,
index_t BK1,
index_t MPerXDL,
index_t NPerXDL,
index_t MXdlPerWave,
index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
index_t BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler(),
PipelineVersion PipelineVer = PipelineVersion::v1>
struct DeviceContractionMultipleABD_Xdl_CShuffle
: public DeviceContractionMultipleABD<NumDimM,
NumDimN,
NumDimK,
AsDataType,
BsDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
{
using DeviceOp = DeviceContractionMultipleABD_Xdl_CShuffle;
static constexpr index_t NumATensor = AsDataType::Size();
static constexpr index_t NumBTensor = BsDataType::Size();
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
using ComputeDataType = EDataType;
// GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleABD_xdl_cshuffle<
AsDataType,
BsDataType,
ComputeDataType,
AccDataType,
CShuffleDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
InMemoryDataOperationEnum::Set,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched,
PipelineVer>;
static constexpr auto matrix_padder =
ck::tensor_operation::device::MatrixPadder<GemmSpec, index_t, index_t, index_t>{
MPerBlock, NPerBlock, KPerBlock};
static auto MakeAGridDescriptor_M_K(const std::vector<index_t>& a_ms_ks_lengths_,
const std::vector<index_t>& a_ms_ks_strides_)
{
assert(a_ms_ks_lengths_.size() == NumDimM + NumDimK &&
a_ms_ks_strides_.size() == NumDimM + NumDimK);
const auto to_tuple = [&](auto& vec, auto num) {
return generate_tuple([&](auto i) { return vec[i]; }, num);
};
const auto a_ms_ks_lengths = to_tuple(a_ms_ks_lengths_, Number<NumDimM + NumDimK>{});
const auto a_ms_ks_strides = to_tuple(a_ms_ks_strides_, Number<NumDimM + NumDimK>{});
// dimension Ids for M0, M1, ...
constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{};
// dimension Ids for K0, K1, ...
constexpr auto kDimIds =
typename arithmetic_sequence_gen<NumDimM, NumDimM + NumDimK, 1>::type{};
// lengths for M0, M1, ...
const auto mLengths = get_container_subset(a_ms_ks_lengths, mDimIds);
// lengths for K0, K1, ...
const auto kLengths = get_container_subset(a_ms_ks_lengths, kDimIds);
// naive tensor A[M0, M1, M2, ..., K0, K1, K2...]
const auto a_grid_desc_ms_ks =
make_naive_tensor_descriptor(a_ms_ks_lengths, a_ms_ks_strides);
// transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...]
const auto a_grid_desc_mraw_kraw = transform_tensor_descriptor(
a_grid_desc_ms_ks,
make_tuple(make_merge_transform(mLengths), make_merge_transform(kLengths)),
make_tuple(mDimIds, kDimIds),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
}
__host__ __device__ static auto
MakeAsGridDescriptor_M_K(const std::array<std::vector<index_t>, NumATensor>& as_ms_ks_lengths,
const std::array<std::vector<index_t>, NumATensor>& as_ms_ks_strides)
{
return generate_tuple(
[&](auto i) {
return MakeAGridDescriptor_M_K(as_ms_ks_lengths[i], as_ms_ks_strides[i]);
},
Number<NumATensor>{});
}
// Assume: B[N0, N1, N2, ..., K0, K1, K2, ...]
static auto MakeBGridDescriptor_N_K(const std::vector<index_t>& b_ns_ks_lengths_,
const std::vector<index_t>& b_ns_ks_strides_)
{
assert(b_ns_ks_lengths_.size() == NumDimN + NumDimK &&
b_ns_ks_strides_.size() == NumDimN + NumDimK);
const auto to_tuple = [&](auto& vec, auto num) {
return generate_tuple([&](auto i) { return vec[i]; }, num);
};
const auto b_ns_ks_lengths = to_tuple(b_ns_ks_lengths_, Number<NumDimN + NumDimK>{});
const auto b_ns_ks_strides = to_tuple(b_ns_ks_strides_, Number<NumDimN + NumDimK>{});
// dimension Ids for N0, N1, ...
constexpr auto nDimIds = typename arithmetic_sequence_gen<0, NumDimN, 1>::type{};
// dimension Ids for K0, K1, ...
constexpr auto kDimIds =
typename arithmetic_sequence_gen<NumDimN, NumDimN + NumDimK, 1>::type{};
// lengths for K0, K1, ...
const auto kLengths = get_container_subset(b_ns_ks_lengths, kDimIds);
// lengths for N0, N1, ...
const auto nLengths = get_container_subset(b_ns_ks_lengths, nDimIds);
// naive tensor B[N0, N1, N2, ..., K0, K1, K2, ...]
const auto b_grid_desc_ns_ks =
make_naive_tensor_descriptor(b_ns_ks_lengths, b_ns_ks_strides);
// transformed tensor B[NRaw = N0 * N1 * N2 * ..., KRaw = K0 * K1 * K2 * ...]
const auto b_grid_desc_nraw_kraw = transform_tensor_descriptor(
b_grid_desc_ns_ks,
make_tuple(make_merge_transform(nLengths), make_merge_transform(kLengths)),
make_tuple(nDimIds, kDimIds),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
}
__host__ __device__ static auto
MakeBsGridDescriptor_N_K(const std::array<std::vector<index_t>, NumBTensor>& bs_ns_ks_lengths,
const std::array<std::vector<index_t>, NumBTensor>& bs_ns_ks_strides)
{
return generate_tuple(
[&](auto i) {
return MakeBGridDescriptor_N_K(bs_ns_ks_lengths[i], bs_ns_ks_strides[i]);
},
Number<NumBTensor>{});
}
// assume E[M0, M1, M2, ..., N0, N1, N2...]
static auto MakeEGridDescriptor_M_N(const std::vector<index_t>& e_ms_ns_lengths_,
const std::vector<index_t>& e_ms_ns_strides_)
{
assert(e_ms_ns_lengths_.size() == NumDimM + NumDimN &&
e_ms_ns_strides_.size() == NumDimM + NumDimN);
const auto to_tuple = [&](auto& vec, auto num) {
return generate_tuple([&](auto i) { return vec[i]; }, num);
};
const auto e_ms_ns_lengths = to_tuple(e_ms_ns_lengths_, Number<NumDimM + NumDimN>{});
const auto e_ms_ns_strides = to_tuple(e_ms_ns_strides_, Number<NumDimM + NumDimN>{});
// dimension Ids for M0, M1, ...
constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{};
// dimension Ids for N0, N1, ...
constexpr auto nDimIds =
typename arithmetic_sequence_gen<NumDimM, NumDimM + NumDimN, 1>::type{};
// lengths for M0, M1, ...
const auto mLengths = get_container_subset(e_ms_ns_lengths, mDimIds);
// lengths for K0, K1, ...
const auto nLengths = get_container_subset(e_ms_ns_lengths, nDimIds);
// naive tensor E[M0, M1, M2, ..., N0, N1, N2...]
const auto e_grid_desc_ms_ns =
make_naive_tensor_descriptor(e_ms_ns_lengths, e_ms_ns_strides);
// transformed tensor E[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * N2 * ...]
const auto e_grid_desc_mraw_nraw = transform_tensor_descriptor(
e_grid_desc_ms_ns,
make_tuple(make_merge_transform(mLengths), make_merge_transform(nLengths)),
make_tuple(mDimIds, nDimIds),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
}
static auto
MakeDsGridDescriptor_M_N(const std::array<std::vector<index_t>, NumDTensor>& ds_ms_ns_lengths,
const std::array<std::vector<index_t>, NumDTensor>& ds_ms_ns_strides)
{
return generate_tuple(
[&](auto i) {
return MakeEGridDescriptor_M_N(ds_ms_ns_lengths[i], ds_ms_ns_strides[i]);
},
Number<NumDTensor>{});
}
// desc for problem definition
using AsGridDesc_M_K = remove_cvref_t<decltype(MakeAsGridDescriptor_M_K({}, {}))>;
using BsGridDesc_N_K = remove_cvref_t<decltype(MakeBsGridDescriptor_N_K({}, {}))>;
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}))>;
using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N({}, {}))>;
// desc for blockwise copy
using AsGridDesc_AK0_M_AK1 =
remove_cvref_t<decltype(GridwiseGemm::MakeAsGridDescriptor_AK0_M_AK1(AsGridDesc_M_K{}))>;
using BsGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(GridwiseGemm::MakeBsGridDescriptor_BK0_N_BK1(BsGridDesc_N_K{}))>;
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
DsGridDesc_M_N{}))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
EGridDesc_M_N{}))>;
// block-to-e-tile map
using Block2ETileMap =
remove_cvref_t<decltype(GridwiseGemm::MakeBlock2ETileMap(EGridDesc_M_N{}))>;
// Argument
struct Argument : public BaseArgument
{
Argument(std::array<const void*, NumATensor> p_as_grid,
std::array<const void*, NumBTensor> p_bs_grid,
std::array<const void*, NumDTensor> p_ds_grid,
void* p_e_grid,
const std::array<std::vector<index_t>, NumATensor>& a_ms_ks_lengths,
const std::array<std::vector<index_t>, NumATensor>& a_ms_ks_strides,
const std::array<std::vector<index_t>, NumBTensor>& b_ns_ks_lengths,
const std::array<std::vector<index_t>, NumBTensor>& b_ns_ks_strides,
const std::array<std::vector<index_t>, NumDTensor>& d_ms_ns_lengths,
const std::array<std::vector<index_t>, NumDTensor>& d_ms_ns_strides,
const std::vector<index_t>& e_ms_ns_length,
const std::vector<index_t>& e_ms_ns_stride,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
: p_as_grid_{},
p_bs_grid_{},
p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e_grid)},
as_grid_desc_m_k_{},
bs_grid_desc_n_k_{},
ds_grid_desc_m_n_{},
e_grid_desc_m_n_{MakeEGridDescriptor_M_N(e_ms_ns_length, e_ms_ns_stride)},
as_grid_desc_ak0_m_ak1_{},
bs_grid_desc_bk0_n_bk1_{},
ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
e_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_etile_map_{GridwiseGemm::MakeBlock2ETileMap(e_grid_desc_m_n_)},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op}
{
// populate pointer, desc for As
static_for<0, NumATensor, 1>{}([&](auto i) {
// using ALayout = remove_cvref_t<tuple_element_t<i.value, AsLayout>>;
using ADataType = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
// A pointer
p_as_grid_(i) = static_cast<const ADataType*>(p_as_grid[i]);
// A desc
as_grid_desc_m_k_(i) =
MakeAGridDescriptor_M_K(a_ms_ks_lengths[i], a_ms_ks_strides[i]);
});
// populate pointer, desc for Bs
static_for<0, NumBTensor, 1>{}([&](auto i) {
// using BLayout = remove_cvref_t<tuple_element_t<i.value, BsLayout>>;
using BDataType = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
// B pointer
p_bs_grid_(i) = static_cast<const BDataType*>(p_bs_grid[i]);
// B desc
bs_grid_desc_n_k_(i) =
MakeBGridDescriptor_N_K(b_ns_ks_lengths[i], b_ns_ks_strides[i]);
});
// populate pointer, desc for Ds
static_for<0, NumDTensor, 1>{}([&](auto i) {
// using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
// D pointer
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]);
// D desc
ds_grid_desc_m_n_(i) =
MakeEGridDescriptor_M_N(d_ms_ns_lengths[i], d_ms_ns_strides[i]);
});
// populate desc for Ds/E
if(GridwiseGemm::CheckValidity(as_grid_desc_m_k_,
bs_grid_desc_n_k_,
ds_grid_desc_m_n_,
e_grid_desc_m_n_,
block_2_etile_map_))
{
as_grid_desc_ak0_m_ak1_ =
GridwiseGemm::MakeAsGridDescriptor_AK0_M_AK1(as_grid_desc_m_k_);
bs_grid_desc_bk0_n_bk1_ =
GridwiseGemm::MakeBsGridDescriptor_BK0_N_BK1(bs_grid_desc_n_k_);
ds_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n_);
e_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n_);
}
// for sanity check of vector memory access
for(index_t i = 0; i < NumATensor; ++i)
{
a_mz_stride_[i] = a_ms_ks_strides[i][NumDimM - 1];
a_kz_stride_[i] = a_ms_ks_strides[i][NumDimM + NumDimK - 1];
}
for(index_t i = 0; i < NumBTensor; ++i)
{
b_nz_stride_[i] = b_ns_ks_strides[i][NumDimN - 1];
b_kz_stride_[i] = b_ns_ks_strides[i][NumDimN + NumDimK - 1];
}
for(index_t i = 0; i < NumDTensor; ++i)
{
ds_nz_stride_[i] = d_ms_ns_strides[i][NumDimM + NumDimN - 1];
}
e_nz_stride_ = e_ms_ns_stride[NumDimM + NumDimN - 1];
}
// pointers
typename GridwiseGemm::AsGridPointer p_as_grid_;
typename GridwiseGemm::BsGridPointer p_bs_grid_;
typename GridwiseGemm::DsGridPointer p_ds_grid_;
EDataType* p_e_grid_;
// tensor descriptors for problem definiton
AsGridDesc_M_K as_grid_desc_m_k_;
BsGridDesc_N_K bs_grid_desc_n_k_;
DsGridDesc_M_N ds_grid_desc_m_n_;
EGridDesc_M_N e_grid_desc_m_n_;
// tensor descriptors for block/thread-wise copy
AsGridDesc_AK0_M_AK1 as_grid_desc_ak0_m_ak1_;
BsGridDesc_BK0_N_BK1 bs_grid_desc_bk0_n_bk1_;
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock_;
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_;
// block-to-e-tile map
Block2ETileMap block_2_etile_map_;
// element-wise op
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
// Strides for the last M/N/K dimensions of A/B/Ds/E
// for sanity check of vector load/store
std::array<index_t, NumATensor> a_mz_stride_;
std::array<index_t, NumATensor> a_kz_stride_;
std::array<index_t, NumBTensor> b_nz_stride_;
std::array<index_t, NumBTensor> b_kz_stride_;
std::array<index_t, NumDTensor> ds_nz_stride_;
index_t e_nz_stride_;
};
// Invoker
struct Invoker : public BaseInvoker
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(!GridwiseGemm::CheckValidity(arg.as_grid_desc_m_k_,
arg.bs_grid_desc_n_k_,
arg.ds_grid_desc_m_n_,
arg.e_grid_desc_m_n_,
arg.block_2_etile_map_))
{
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
}
const index_t grid_size =
arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_);
auto launch_kernel = [&](auto has_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value;
const auto kernel = kernel_contraction_multiple_abd_xdl_cshuffle<
GridwiseGemm,
typename GridwiseGemm::AsGridPointer,
typename GridwiseGemm::BsGridPointer,
typename GridwiseGemm::DsGridPointer,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
DeviceOp::AsGridDesc_AK0_M_AK1,
DeviceOp::BsGridDesc_BK0_N_BK1,
DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
DeviceOp::Block2ETileMap,
has_main_loop>;
return launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_as_grid_,
arg.p_bs_grid_,
arg.p_ds_grid_,
arg.p_e_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
arg.as_grid_desc_ak0_m_ak1_,
arg.bs_grid_desc_bk0_n_bk1_,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_etile_map_);
};
const auto K = arg.as_grid_desc_m_k_[I0].GetLength(I1);
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
return launch_kernel(integral_constant<bool, true>{});
}
else
{
return launch_kernel(integral_constant<bool, false>{});
}
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static bool IsSupportedArgument(const Argument& arg)
{
if(!ck::is_xdl_supported())
{
return false;
}
// check vector load/store
{
bool all_valid = true;
static_for<0, NumATensor, 1>{}([&](auto i) {
// vector memory access of A: could be on M or AK1 dimension
if constexpr(ABlockTransferSrcVectorDim == 1)
{
if(!(arg.a_mz_stride_[i] == 1 && arg.as_grid_desc_ak0_m_ak1_[i].GetLength(I1) %
ABlockTransferSrcScalarPerVector ==
0))
{
all_valid = false;
}
}
else
{
if(!(arg.a_kz_stride_[i] == 1 && arg.as_grid_desc_ak0_m_ak1_[i].GetLength(I2) %
ABlockTransferSrcScalarPerVector ==
0))
{
all_valid = false;
}
}
});
// vector memory access of B: could be on N or BK1 dimension
static_for<0, NumBTensor, 1>{}([&](auto i) {
if constexpr(BBlockTransferSrcVectorDim == 1)
{
if(!(arg.b_nz_stride_[i] == 1 && arg.bs_grid_desc_bk0_n_bk1_[i].GetLength(I1) %
BBlockTransferSrcScalarPerVector ==
0))
{
all_valid = false;
}
}
else
{
if(!(arg.b_kz_stride_[i] == 1 && arg.bs_grid_desc_bk0_n_bk1_[i].GetLength(I2) %
BBlockTransferSrcScalarPerVector ==
0))
{
all_valid = false;
}
}
});
// check vector load of Ds
static_for<0, NumDTensor, 1>{}([&](auto i) {
if(!(arg.ds_nz_stride_[i] == 1 &&
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_[i].GetLength(I3) %
CDEBlockTransferScalarPerVector_NPerBlock ==
0))
{
all_valid = false;
}
});
// vector memory access of E: always on NPerBlock dimension
if(!(arg.e_nz_stride_ == 1 &&
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_.GetLength(I3) %
CDEBlockTransferScalarPerVector_NPerBlock ==
0))
{
all_valid = false;
}
if(!all_valid)
{
return false;
}
}
return GridwiseGemm::CheckValidity(arg.as_grid_desc_m_k_,
arg.bs_grid_desc_n_k_,
arg.ds_grid_desc_m_n_,
arg.e_grid_desc_m_n_,
arg.block_2_etile_map_);
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(std::array<const void*, NumATensor> p_as,
std::array<const void*, NumBTensor> p_bs,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
const std::array<std::vector<index_t>, NumATensor>& a_ms_ks_lengths,
const std::array<std::vector<index_t>, NumATensor>& a_ms_ks_strides,
const std::array<std::vector<index_t>, NumBTensor>& b_ns_ks_lengths,
const std::array<std::vector<index_t>, NumBTensor>& b_ns_ks_strides,
const std::array<std::vector<index_t>, NumDTensor>& d_ms_ns_lengths,
const std::array<std::vector<index_t>, NumDTensor>& d_ms_ns_strides,
const std::vector<index_t>& e_ms_ns_length,
const std::vector<index_t>& e_ms_ns_stride,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
{
return Argument{p_as,
p_bs,
p_ds,
p_e,
a_ms_ks_lengths,
a_ms_ks_strides,
b_ns_ks_lengths,
b_ns_ks_strides,
d_ms_ns_lengths,
d_ms_ns_strides,
e_ms_ns_length,
e_ms_ns_stride,
a_element_op,
b_element_op,
cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument>
MakeArgumentPointer(std::array<const void*, NumATensor> p_as,
std::array<const void*, NumBTensor> p_bs,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
const std::array<std::vector<index_t>, NumATensor>& as_ms_ks_lengths,
const std::array<std::vector<index_t>, NumATensor>& as_ms_ks_strides,
const std::array<std::vector<index_t>, NumBTensor>& bs_ns_ks_lengths,
const std::array<std::vector<index_t>, NumBTensor>& bs_ns_ks_strides,
const std::array<std::vector<index_t>, NumDTensor>& ds_ms_ns_lengths,
const std::array<std::vector<index_t>, NumDTensor>& ds_ms_ns_strides,
const std::vector<index_t>& e_ms_ns_length,
const std::vector<index_t>& e_ms_ns_stride,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) override
{
return std::make_unique<Argument>(p_as,
p_bs,
p_ds,
p_e,
as_ms_ks_lengths,
as_ms_ks_strides,
bs_ns_ks_lengths,
bs_ns_ks_strides,
ds_ms_ns_lengths,
ds_ms_ns_strides,
e_ms_ns_length,
e_ms_ns_stride,
a_element_op,
b_element_op,
cde_element_op);
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
std::map<LoopScheduler, std::string> LoopSchedToString{
{LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}};
std::map<PipelineVersion, std::string> PipelineVersionToString{{PipelineVersion::v1, "v1"},
{PipelineVersion::v2, "v2"}};
// clang-format off
str << "DeviceContractionMultipleABD_Xdl_CShuffle"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< KPerBlock << ", "
<< AK1 << ", "
<< BK1 << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle << ", "
<< getGemmSpecializationString(GemmSpec)
<< ">"
<< " LoopScheduler: "
<< LoopSchedToString[LoopSched] << ", "
<< "PipelineVersion: "
<< PipelineVersionToString[PipelineVer];
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -145,7 +145,8 @@ template <index_t NumDimM,
index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()>
typename ComputeDataType = ADataType,
LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceContractionMultipleD_Xdl_CShuffle
: public DeviceContractionMultipleD<NumDimM,
NumDimN,
......@@ -156,7 +157,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
CDEElementwiseOperation,
ComputeDataType>
{
using DeviceOp = DeviceContractionMultipleD_Xdl_CShuffle;
......@@ -310,8 +312,6 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({{}}, {{}}))>;
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {}));
using ComputeDataType = ADataType;
// GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype
......
......@@ -296,6 +296,28 @@ struct DeviceElementwiseImpl
{
return std::make_unique<Invoker>();
};
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceElementwiseImpl<" ;
str << "NumDim_" << NumDim << ",";
str << "MPerThread_" << MPerThread << ",";
str << "InScalarPerVector";
static_for<0, InScalarPerVectorSeq::Size(), 1>{}([&](auto i) { str << "_" << InScalarPerVectorSeq::At(i).value; });
str << ",";
str << "OutScalarPerVector";
static_for<0, OutScalarPerVectorSeq::Size(), 1>{}([&](auto i) { str << "_" << OutScalarPerVectorSeq::At(i).value; });
str << ">";
// clang-format on
return str.str();
}
}; // namespace device
} // namespace device
......
......@@ -184,7 +184,8 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
return false;
}
}
else if(ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx940")
else if(ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx940" ||
ck::get_device_name() == "gfx941" || ck::get_device_name() == "gfx942")
{
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, float> ||
is_same_v<AccDataType, int32_t> || is_same_v<AccDataType, double>))
......
......@@ -278,6 +278,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
// clang-format off
str << "DeviceGemm_Xdl_CShuffle"
<< "<"
<< getGemmSpecializationString(GemmSpec) << ", "
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
......@@ -296,7 +297,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
<< " LoopScheduler: "
<< LoopSchedToString[LoopSched] << ", "
<< "PipelineVersion: "
<< PipelineVersionToString[PipelineVer];;
<< PipelineVersionToString[PipelineVer];
// clang-format on
return str.str();
......
......@@ -59,7 +59,8 @@ template <typename ADataType,
typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CBlockTransferScalarPerVector_NWaveNPerXDL,
typename ComputeType = CDataType,
PipelineVersion PipelineVer = PipelineVersion::v1>
PipelineVersion PipelineVer = PipelineVersion::v1,
LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
BLayout,
......@@ -79,7 +80,6 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
// TODO: should be exposed as Tparams.
static constexpr index_t NumGemmKPrefetchStage = 1;
static constexpr LoopScheduler LoopSched = make_default_loop_scheduler();
using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2<
BlockSize,
......@@ -127,7 +127,50 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
PipelineVer,
ComputeType>;
using Argument = typename GridwiseGemm::Argument;
struct Argument : public GridwiseGemm::Argument
{
Argument(const ADataType* p_a_grid_,
const BDataType* p_b_grid_,
CDataType* p_c_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
index_t StrideC_,
index_t MPadded_,
index_t NPadded_,
index_t KPadded_,
index_t K0Padded_,
index_t k_batch_,
AElementwiseOperation a_element_op_,
BElementwiseOperation b_element_op_,
CElementwiseOperation c_element_op_)
: GridwiseGemm::Argument(p_a_grid_,
p_b_grid_,
p_c_grid_,
M_,
N_,
K_,
StrideA_,
StrideB_,
StrideC_,
MPadded_,
NPadded_,
KPadded_,
K0Padded_,
k_batch_),
a_element_op(a_element_op_),
b_element_op(b_element_op_),
c_element_op(c_element_op_)
{
}
AElementwiseOperation a_element_op;
BElementwiseOperation b_element_op;
CElementwiseOperation c_element_op;
};
using DefaultBlock2CTileMap = typename GridwiseGemm::DefaultBlock2CTileMap;
// Invoker
......@@ -155,9 +198,9 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
const auto b2c_map = DefaultBlock2CTileMap{};
index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = b2c_map.CalculateGridSize(karg.M, karg.N, karg.k_batch);
const auto K0 = karg.K0;
const auto K0Padded = karg.K0Padded;
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0Padded);
float ave_time = 0;
......@@ -168,8 +211,17 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
karg.M * karg.N * sizeof(CDataType),
stream_config.stream_id_));
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg, b2c_map);
ave_time =
launch_and_time_kernel(stream_config,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
static_cast<typename GridwiseGemm::Argument>(karg),
b2c_map,
karg.a_element_op,
karg.b_element_op,
karg.c_element_op);
};
if(has_main_k0_block_loop)
......@@ -180,7 +232,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
kernel_gemm_xdlops_v2r4r2_simplified<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
DefaultBlock2CTileMap>;
DefaultBlock2CTileMap,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>;
Run(kernel);
}
......@@ -190,7 +245,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
kernel_gemm_xdlops_v2r4r2_simplified<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
DefaultBlock2CTileMap>;
DefaultBlock2CTileMap,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>;
Run(kernel);
}
......@@ -203,7 +261,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
kernel_gemm_xdlops_v2r4r2_simplified<GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
DefaultBlock2CTileMap>;
DefaultBlock2CTileMap,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>;
Run(kernel);
}
......@@ -213,7 +274,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
kernel_gemm_xdlops_v2r4r2_simplified<GridwiseGemm,
false,
InMemoryDataOperationEnum::AtomicAdd,
DefaultBlock2CTileMap>;
DefaultBlock2CTileMap,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>;
Run(kernel);
}
......@@ -261,12 +325,12 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
index_t StrideA,
index_t StrideB,
index_t StrideC,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
index_t KBatch)
{
return Argument{p_a,
return Argument(p_a,
p_b,
p_c,
M,
......@@ -278,8 +342,11 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
GridwiseGemm::CalculateMPadded(M),
GridwiseGemm::CalculateNPadded(N),
GridwiseGemm::CalculateKPadded(K, KBatch),
GridwiseGemm::CalculateK0(K, KBatch),
KBatch};
GridwiseGemm::CalculateK0Padded(K, KBatch),
KBatch,
a_element_op,
b_element_op,
c_element_op);
}
static auto MakeInvoker() { return Invoker{}; }
......@@ -294,9 +361,9 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
index_t StrideA,
index_t StrideB,
index_t StrideC,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
ck::index_t KBatch = 1) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
......@@ -311,8 +378,11 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
GridwiseGemm::CalculateMPadded(M),
GridwiseGemm::CalculateNPadded(N),
GridwiseGemm::CalculateKPadded(K, KBatch),
GridwiseGemm::CalculateK0(K, KBatch),
KBatch);
GridwiseGemm::CalculateK0Padded(K, KBatch),
KBatch,
a_element_op,
b_element_op,
c_element_op);
}
// polymorphic
......@@ -322,7 +392,21 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
}
// polymorphic
std::string GetTypeString() const override { return GridwiseGemm::GetTypeString(); }
std::string GetTypeString() const override
{
auto str = std::stringstream();
std::map<LoopScheduler, std::string> LoopSchedToString{
{LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}};
std::map<PipelineVersion, std::string> PipelineVersionToString{{PipelineVersion::v1, "v1"},
{PipelineVersion::v2, "v2"}};
str << GridwiseGemm::GetTypeString() << " LoopScheduler: " << LoopSchedToString[LoopSched]
<< ", PipelineVersion: " << PipelineVersionToString[PipelineVer];
return str.str();
}
};
} // namespace device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment