Commit c9d9e24d authored by Jing Zhang's avatar Jing Zhang
Browse files

add reduce class into gemm deviceOp

parent 0b91a789
......@@ -77,3 +77,6 @@ add_example_executable(example_gemm_xdl_fp16_f8 gemm_xdl_fp16_f8.cpp)
if(result EQUAL 0)
add_dependencies(example_gemm_xdl example_gemm_xdl_fp16_f8)
endif()
add_example_executable(gemm_xdl_input_i16_comp_i8_scale_ab gemm_xdl_input_i16_comp_i8_scale_ab.cpp)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_scale_ab_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using I8 = int8_t;
using I32 = int32_t;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = I32;
using BDataType = I32;
using AccDataType = I32;
using CShuffleDataType = I32;
using EDataType = I32;
using ComputeType = I8;
using ALayout = Row;
using BLayout = Col;
using ELayout = Row;
struct i32_to_i8
{
__host__ __device__ void operator()(I8& y, const I32& x) const
{
y = ck::type_convert<I8>(x) * scale;
}
float scale = 1.0;
};
using AElementOp = i32_to_i8;
using BElementOp = i32_to_i8;
using CDEElementOp = PassThrough;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDScaleAB_Xdl_CShuffle<
ALayout,
BLayout,
ck::Tuple<>,
ELayout,
ADataType,
BDataType,
AccDataType,
CShuffleDataType,
ck::Tuple<>,
EDataType,
AElementOp,
BElementOp,
CDEElementOp,
GemmSpec,
1,
256,
256,
128,
32,
8,
8,
32,
32,
4,
2,
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
4,
8,
1,
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
4,
8,
1,
1,
1,
S<1, 32, 1, 8>,
4,
ck::LoopScheduler::Default,
ck::PipelineVersion::v1,
ComputeType>;
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
// GEMM shape
ck::index_t M = 1024;
ck::index_t N = 1024;
ck::index_t K = 1024;
ck::index_t StrideA = 1024;
ck::index_t StrideB = 1024;
ck::index_t StrideE = 1024;
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 10)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
StrideA = std::stoi(argv[7]);
StrideB = std::stoi(argv[8]);
StrideE = std::stoi(argv[10]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideE\n");
exit(0);
}
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
using namespace ck::literals;
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
break;
default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
}
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
e_device_buf.ToDevice(e_m_n_device_result.mData.data());
auto a_element_op = AElementOp{0.2};
auto b_element_op = BElementOp{0.2};
auto cde_element_op = CDEElementOp{};
// do GEMM
auto device_op = DeviceOpInstance{};
auto invoker = device_op.MakeInvoker();
auto argument = device_op.MakeArgument(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
{},
e_device_buf.GetDeviceBuffer(),
M,
N,
K,
StrideA,
StrideB,
{},
StrideE,
a_element_op,
b_element_op,
cde_element_op);
if(!device_op.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl;
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
if(do_verification)
{
Tensor<CShuffleDataType> c_m_n({M, N});
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CShuffleDataType,
AccDataType,
AElementOp,
BElementOp,
PassThrough,
ComputeType>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument =
ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{});
ref_invoker.Run(ref_argument);
for(int m = 0; m < M; ++m)
{
for(int n = 0; n < N; ++n)
{
cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n));
}
}
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1;
}
return 0;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
// GEMM:
// input : A[M, K], B[K, N],
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// D0, D1, ... and E have the same layout
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation>
struct DeviceGemmMultipleDScaleAB : public BaseOperator
{
static constexpr index_t NumDTensor = DsDataType::Size();
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
std::array<ck::index_t, NumDTensor> StrideDs,
ck::index_t StrideE,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_scale_ab.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
namespace ck {
template <typename GridwiseGemm,
typename ADataType,
typename BDataType,
typename DsPointer,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename Block2ETileMap,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_multiple_d_xdl_cshuffle(const ADataType* __restrict__ p_a_grid,
const BDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2ETileMap block_2_etile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid,
p_ds_grid,
p_e_grid,
p_shared,
a_element_op,
b_element_op,
cde_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_etile_map);
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_ds_grid;
ignore = p_e_grid;
ignore = a_element_op;
ignore = b_element_op;
ignore = cde_element_op;
ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1;
ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = e_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = block_2_etile_map;
#endif
}
} // namespace ck
namespace ck {
namespace tensor_operation {
namespace device {
// GEMM:
// input : A[M, K]
// input : B[N, K]
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// D0, D1, ... and E have the same layout
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t AK1,
index_t BK1,
index_t MPerXDL,
index_t NPerXDL,
index_t MXdlPerWave,
index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
index_t BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler(),
PipelineVersion PipelineVer = PipelineVersion::v1,
typename ComputeDataType = EDataType>
struct DeviceGemmMultipleDScaleAB_Xdl_CShuffle
: public DeviceGemmMultipleDScaleAB<ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
{
using DeviceOp = DeviceGemmMultipleDScaleAB_Xdl_CShuffle;
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA)
{
const auto a_grid_desc_mraw_kraw = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(StrideA, I1));
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(I1, StrideA));
}
}();
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
}
static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
{
const auto b_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(I1, StrideB));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(StrideB, I1));
}
}();
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
}
template <typename ELay>
static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
{
const auto e_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ELay>::value)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(StrideE, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ELay>::value)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(I1, StrideE));
}
}();
return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
}
static auto MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& MRaws,
const std::array<index_t, NumDTensor>& NRaws,
const std::array<index_t, NumDTensor>& DsStride)
{
return generate_tuple(
[&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(MRaws[i], NRaws[i], DsStride[i]);
},
Number<NumDTensor>{});
}
// desc for problem definition
using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1));
using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1));
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<ELayout>(1, 1, 1));
// GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
ADataType,
BDataType,
ComputeDataType,
AccDataType,
CShuffleDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
InMemoryDataOperationEnum::Set,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched,
PipelineVer>;
// desc for blockwise copy
using AGridDesc_AK0_M_AK1 =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(
AGridDesc_M_K{}))>;
using BGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(
BGridDesc_N_K{}))>;
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
DsGridDesc_M_N{}))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
EGridDesc_M_N{}))>;
// block-to-e-tile map
using Block2ETileMap =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
// Argument
struct Argument : public BaseArgument
{
Argument(const void* p_a_grid,
const void* p_b_grid,
std::array<const void*, NumDTensor> p_ds_grid,
void* p_e_grid,
index_t MRaw,
index_t NRaw,
index_t KRaw,
index_t StrideA,
index_t StrideB,
std::array<index_t, NumDTensor> StrideDs,
index_t StrideE,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
: p_a_grid_{static_cast<const ADataType*>(p_a_grid)},
p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e_grid)},
a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(MRaw, KRaw, StrideA)},
b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(KRaw, NRaw, StrideB)},
ds_grid_desc_m_n_{},
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(MRaw, NRaw, StrideE)},
a_grid_desc_ak0_m_ak1_{
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)},
b_grid_desc_bk0_n_bk1_{
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)},
ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
e_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op},
MRaw_{MRaw},
NRaw_{NRaw},
KRaw_{KRaw}
{
// populate pointer, desc for Ds
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
// D pointer
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]);
// D desc
ds_grid_desc_m_n_(i) =
DeviceOp::MakeEGridDescriptor_M_N<DLayout>(MRaw, NRaw, StrideDs[i]);
});
// populate desc for Ds/E
if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_,
b_grid_desc_n_k_,
ds_grid_desc_m_n_,
e_grid_desc_m_n_,
block_2_etile_map_))
{
ds_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n_);
e_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n_);
}
}
void Print() const
{
std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl;
std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl;
static_for<0, NumDTensor, 1>{}(
[&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; });
std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl;
}
// private:
// pointers
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
typename GridwiseGemm::DsGridPointer p_ds_grid_;
EDataType* p_e_grid_;
// tensor descriptors for problem definiton
AGridDesc_M_K a_grid_desc_m_k_;
BGridDesc_N_K b_grid_desc_n_k_;
DsGridDesc_M_N ds_grid_desc_m_n_;
EGridDesc_M_N e_grid_desc_m_n_;
// tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock_;
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_;
// block-to-e-tile map
Block2ETileMap block_2_etile_map_;
// element-wise op
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
// for checking vector load/store
index_t MRaw_;
index_t NRaw_;
index_t KRaw_;
};
template <typename InOutDataType>
struct Reduce2D
{
static constexpr ReduceTensorOp ReduceOpId = ReduceTensorOp::AMAX;
static constexpr bool PropagateNan = true;
static constexpr bool OutputIndex = false;
using ReduceOperation = typename reduce_binary_operator<ReduceOpId>::opType;
using InElementwiseOperation =
typename reduce_unary_operator<ReduceOpId, true, true>::InElementwiseOperation;
using AccElementwiseOperation =
typename reduce_unary_operator<ReduceOpId, true, true>::AccElementwiseOperation;
using PassThroughOp = tensor_operation::element_wise::PassThrough;
using DeviceReduceInstance_1 = DeviceReduceMultiBlock<InOutDataType,
InOutDataType,
InOutDataType,
2, // Rank
1, // NumReduceDim
ReduceOperation,
InElementwiseOperation,
PassThroughOp,
InMemoryDataOperationEnum::Set,
PropagateNan,
OutputIndex,
false, // HaveIndexInputIfOutputIndex
256,
32,
8,
1,
1,
1, // vector dim
1,
1>;
using DeviceReduceInstance_2 = DeviceReduceMultiBlock<InOutDataType,
InOutDataType,
InOutDataType,
1, // Rank
1, // NumReduceDim
ReduceOperation,
PassThroughOp,
AccElementwiseOperation,
InMemoryDataOperationEnum::Set,
PropagateNan,
OutputIndex,
false, // HaveIndexInputIfOutputIndex
256,
128,
2,
1,
1,
1, // vector dim
1,
1>;
float Run(const std::array<index_t, 2> arrInLengths_1,
const std::array<index_t, 2> arrInStrides_1,
const void* p_input,
void* p_workpace,
void* p_output,
const StreamConfig& stream_config = StreamConfig{})
{
const std::array<int, 1> reduceDims_1 = {arrInLengths_1[0] > arrInLengths_1[1] ? 0 : 1};
const std::array<int, 1> reduceDims_2 = {0};
std::array<index_t, 1> arrInLengths_2{arrInLengths_1[reduceDims_1[0]]};
std::array<index_t, 1> arrInStrides_2{1};
std::array<index_t, 1> arrOutLengths{1};
std::array<index_t, 1> arrOutStrides{1};
InElementwiseOperation in_elementwise_op;
AccElementwiseOperation acc_elementwise_op;
auto reduce_1 = DeviceReduceInstance_1{};
auto argument_ptr_1 = reduce_1.MakeArgumentPointer(arrInLengths_1,
arrInStrides_1,
arrInLengths_2,
arrInStrides_2,
reduceDims_1,
1.0,
0.0,
p_input,
nullptr,
p_workpace,
nullptr,
in_elementwise_op,
PassThroughOp{});
if(!reduce_1.IsSupportedArgument(argument_ptr_1.get()))
{
std::cout << "The runtime parameters seems supported by the DeviceReduce instance, "
"exiting!"
<< std::endl;
};
auto invoker_ptr_1 = reduce_1.MakeInvokerPointer();
auto reduce_2 = DeviceReduceInstance_2{};
auto argument_ptr_2 = reduce_2.MakeArgumentPointer(arrInLengths_2,
arrInStrides_2,
arrOutLengths,
arrOutStrides,
reduceDims_2,
1.0,
0.0,
p_workpace,
nullptr,
p_output,
nullptr,
PassThroughOp{},
acc_elementwise_op);
if(!reduce_2.IsSupportedArgument(argument_ptr_2.get()))
{
std::cout << "The runtime parameters seems not supported by the DeviceReduce "
"instance, exiting!"
<< std::endl;
};
auto invoker_ptr_2 = reduce_2.MakeInvokerPointer();
float kern_time = 0;
kern_time += invoker_ptr_1->Run(argument_ptr_1.get(), stream_config);
kern_time += invoker_ptr_2->Run(argument_ptr_2.get(), stream_config);
return kern_time;
}
};
// Invoker
struct Invoker : public BaseInvoker
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
arg.b_grid_desc_n_k_,
arg.ds_grid_desc_m_n_,
arg.e_grid_desc_m_n_,
arg.block_2_etile_map_))
{
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
}
using RowMajor = tensor_layout::gemm::RowMajor;
float kern_time = 0;
auto reduce_a = Reduce2D<ADataType>{};
kern_time += reduce_a.Run({arg.MRaw_, arg.KRaw_},
is_same<RowMajor, ALayout>::value
? std::array<index_t, 2>{arg.KRaw_, I1}
: std::array<index_t, 2>{I1, arg.MRaw_},
arg.p_a_grid_,
arg.p_e_grid_,
arg.p_e_grid_,
stream_config);
auto reduce_b = Reduce2D<BDataType>{};
kern_time += reduce_b.Run({arg.KRaw_, arg.NRaw_},
is_same<RowMajor, BLayout>::value
? std::array<index_t, 2>{arg.NRaw_, I1}
: std::array<index_t, 2>{I1, arg.KRaw_},
arg.p_a_grid_,
arg.p_e_grid_,
arg.p_e_grid_,
stream_config);
const index_t grid_size =
arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_);
auto launch_kernel = [&](auto has_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value;
const auto kernel = kernel_gemm_multiple_d_xdl_cshuffle<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
BDataType, // TODO: distiguish A/B datatype
typename GridwiseGemm::DsGridPointer,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
DeviceOp::Block2ETileMap,
has_main_loop>;
return launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_ds_grid_,
arg.p_e_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_etile_map_);
};
const auto K = arg.a_grid_desc_m_k_.GetLength(I1);
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
kern_time += launch_kernel(integral_constant<bool, true>{});
}
else
{
kern_time += launch_kernel(integral_constant<bool, false>{});
}
return kern_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static bool IsSupportedArgument(const Argument& arg)
{
if(!ck::is_xdl_supported())
{
return false;
}
// check vector load/store
{
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
// check vector load of A
if constexpr(is_same_v<ALayout, Row> && ABlockTransferSrcVectorDim == 2)
{
if(arg.KRaw_ % ABlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else if constexpr(is_same_v<ALayout, Col> && ABlockTransferSrcVectorDim == 1)
{
// FIXME: not rigorous
if(arg.MRaw_ % ABlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else
{
return false;
}
// check vector laod of B
if constexpr(is_same_v<BLayout, Col> && BBlockTransferSrcVectorDim == 2)
{
if(arg.KRaw_ % BBlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else if constexpr(is_same_v<BLayout, Row> && BBlockTransferSrcVectorDim == 1)
{
// FIXME: not rigorous
if(arg.NRaw_ % BBlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else
{
return false;
}
// check vector load of Ds
// only support RowMajor for now
bool all_valid = true;
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
if constexpr(!is_same_v<DLayout, Row>)
{
all_valid = false;
}
});
if(!all_valid)
{
return false;
}
// check vector store of E
// only support RowMajor for now
if constexpr(is_same_v<ELayout, Row>)
{
if(arg.NRaw_ % CDEBlockTransferScalarPerVector_NPerBlock != 0)
{
return false;
}
}
else
{
return false;
}
}
return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
arg.b_grid_desc_n_k_,
arg.ds_grid_desc_m_n_,
arg.e_grid_desc_m_n_,
arg.block_2_etile_map_);
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
index_t MRaw,
index_t NRaw,
index_t KRaw,
index_t StrideA,
index_t StrideB,
std::array<index_t, NumDTensor> StrideDs,
index_t StrideE,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
{
return Argument{p_a,
p_b,
p_ds,
p_e,
MRaw,
NRaw,
KRaw,
StrideA,
StrideB,
StrideDs,
StrideE,
a_element_op,
b_element_op,
cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
index_t MRaw,
index_t NRaw,
index_t KRaw,
index_t StrideA,
index_t StrideB,
std::array<ck::index_t, NumDTensor> StrideDs,
index_t StrideE,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) override
{
return std::make_unique<Argument>(p_a,
p_b,
p_ds,
p_e,
MRaw,
NRaw,
KRaw,
StrideA,
StrideB,
StrideDs,
StrideE,
a_element_op,
b_element_op,
cde_element_op);
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
std::map<LoopScheduler, std::string> LoopSchedToString{
{LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}};
std::map<PipelineVersion, std::string> PipelineVersionToString{{PipelineVersion::v1, "v1"},
{PipelineVersion::v2, "v2"}};
// clang-format off
str << "DeviceGemmMultipleDScaleAB_Xdl_CShuffle"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< KPerBlock << ", "
<< AK1 << ", "
<< BK1 << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle << ", "
<< getGemmSpecializationString(GemmSpec)
<< ">"
<< " LoopScheduler: "
<< LoopSchedToString[LoopSched] << ", "
<< "PipelineVersion: "
<< PipelineVersionToString[PipelineVer];
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment