Commit 2b27d5fc authored by Chao Liu's avatar Chao Liu
Browse files

Merge remote-tracking branch 'origin/develop' into rosenrodt/gemm-layernorm

parents f689a155 fa9a0a5c
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_bias_c_permute_xdl.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/library/host_tensor/device_memory.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/host_tensor/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Add = ck::tensor_operation::element_wise::Add;
using ADataType = F16;
using BDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F32;
using DDataType = F16;
using EDataType = F16;
using ALayout = Row;
using BLayout = Col;
using DLayout = Row;
using ELayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = Add;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmBiasCPermute_Xdl
//######| ALayout| BLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1>;
// clang-format on
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
ck::index_t M0 = 4;
ck::index_t M1 = 32;
ck::index_t M2 = 128;
ck::index_t N0 = 16;
ck::index_t N1 = 256;
// GEMM shape
ck::index_t M = M0 * M1 * M2;
ck::index_t N = N0 * N1;
ck::index_t K = 128;
ck::index_t stride_A = K;
ck::index_t stride_B = K;
#if 1
// E = [M0, N0, M1, N1, M2]
ck::index_t stride_E_M0 = N0 * M1 * N1 * M2;
ck::index_t stride_E_M1 = N1 * M2;
ck::index_t stride_E_M2 = 1;
ck::index_t stride_E_N0 = M1 * N1 * M2;
ck::index_t stride_E_N1 = M2;
// D = [0, N0, 0, N1, 0]
ck::index_t stride_D_M0 = 0;
ck::index_t stride_D_M1 = 0;
ck::index_t stride_D_M2 = 0;
ck::index_t stride_D_N0 = N1;
ck::index_t stride_D_N1 = 1;
#else
// D = [0, 0, 0, N0, N1]
ck::index_t stride_D_M0 = 0;
ck::index_t stride_D_M1 = 0;
ck::index_t stride_D_M2 = 0;
ck::index_t stride_D_N0 = N1;
ck::index_t stride_D_N1 = 1;
// E = [M0, M1, M2, N0, N1]
ck::index_t stride_E_M0 = M1 * M2 * N0 * N1;
ck::index_t stride_E_M1 = M2 * N0 * N1;
ck::index_t stride_E_M2 = N0 * N1;
ck::index_t stride_E_N0 = N1;
ck::index_t stride_E_N1 = 1;
#endif
const ck::tensor_operation::device::DEGridDesc_M0_M1_M2_N0_N1 d_grid_desc{
M0, M1, M2, N0, N1, stride_D_M0, stride_D_M1, stride_D_M2, stride_D_N0, stride_D_N1};
const ck::tensor_operation::device::DEGridDesc_M0_M1_M2_N0_N1 e_grid_desc{
M0, M1, M2, N0, N1, stride_E_M0, stride_E_M1, stride_E_M2, stride_E_N0, stride_E_N1};
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
exit(0);
}
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({stride, 1}));
}
else
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({1, stride}));
}
};
auto f_host_de_tensor_descriptor =
[](ck::tensor_operation::device::DEGridDesc_M0_M1_M2_N0_N1 de_grid_desc) {
std::size_t m0 = de_grid_desc.M0_;
std::size_t m1 = de_grid_desc.M1_;
std::size_t m2 = de_grid_desc.M2_;
std::size_t n0 = de_grid_desc.N0_;
std::size_t n1 = de_grid_desc.N1_;
std::size_t stride_m0 = de_grid_desc.stride_M0_;
std::size_t stride_m1 = de_grid_desc.stride_M1_;
std::size_t stride_m2 = de_grid_desc.stride_M2_;
std::size_t stride_n0 = de_grid_desc.stride_N0_;
std::size_t stride_n1 = de_grid_desc.stride_N1_;
return HostTensorDescriptor(
std::vector<std::size_t>({m0, m1, m2, n0, n1}),
std::vector<std::size_t>({stride_m0, stride_m1, stride_m2, stride_n0, stride_n1}));
};
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, stride_A, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, stride_B, BLayout{}));
Tensor<DDataType> d_m0_m1_m2_n0_n1(f_host_de_tensor_descriptor(d_grid_desc));
Tensor<EDataType> e_m0_m1_m2_n0_n1_host_result(f_host_de_tensor_descriptor(e_grid_desc));
Tensor<EDataType> e_m0_m1_m2_n0_n1_device_result(f_host_de_tensor_descriptor(e_grid_desc));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "d_m0_m1_m2_n0_n1: " << d_m0_m1_m2_n0_n1.mDesc << std::endl;
std::cout << "e_m0_m1_m2_n0_n1: " << e_m0_m1_m2_n0_n1_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
d_m0_m1_m2_n0_n1.GenerateTensorValue(GeneratorTensor_2<DDataType>{-5, 5});
break;
default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
d_m0_m1_m2_n0_n1.GenerateTensorValue(GeneratorTensor_3<DDataType>{0.0, 1.0});
}
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem d_m0_m1_m2_n0_n1_device_buf(sizeof(DDataType) *
d_m0_m1_m2_n0_n1.mDesc.GetElementSpace());
DeviceMem e_m0_m1_m2_n0_n1_device_buf(sizeof(EDataType) *
e_m0_m1_m2_n0_n1_device_result.mDesc.GetElementSpace());
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
d_m0_m1_m2_n0_n1_device_buf.ToDevice(d_m0_m1_m2_n0_n1.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
// do GEMM
auto device_op = DeviceOpInstance{};
auto invoker = device_op.MakeInvoker();
auto argument = device_op.MakeArgument(a_m_k_device_buf.GetDeviceBuffer(),
b_k_n_device_buf.GetDeviceBuffer(),
d_m0_m1_m2_n0_n1_device_buf.GetDeviceBuffer(),
e_m0_m1_m2_n0_n1_device_buf.GetDeviceBuffer(),
M,
N,
K,
stride_A,
stride_B,
d_grid_desc,
e_grid_desc,
a_element_op,
b_element_op,
cde_element_op);
if(!device_op.IsSupportedArgument(argument))
{
throw std::runtime_error("wrong! this device_op instance does not support this problem");
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(DDataType) * N + sizeof(EDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< device_op.GetTypeString() << std::endl;
if(do_verification)
{
Tensor<AccDataType> c_m_n(HostTensorDescriptor(
std::vector<std::size_t>{static_cast<std::size_t>(M), static_cast<std::size_t>(N)}));
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
AccDataType,
AccDataType,
AElementOp,
BElementOp,
PassThrough>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument =
ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{});
ref_invoker.Run(ref_argument);
for(int m0 = 0; m0 < M0; ++m0)
for(int m1 = 0; m1 < M1; ++m1)
for(int m2 = 0; m2 < M2; ++m2)
for(int n0 = 0; n0 < N0; ++n0)
for(int n1 = 0; n1 < N1; ++n1)
{
int m = m0 * M1 * M2 + m1 * M2 + m2;
int n = n0 * N1 + n1;
cde_element_op(e_m0_m1_m2_n0_n1_host_result(m0, m1, m2, n0, n1),
ck::type_convert<EDataType>(c_m_n(m, n)),
d_m0_m1_m2_n0_n1(m0, m1, m2, n0, n1));
}
e_m0_m1_m2_n0_n1_device_buf.FromDevice(e_m0_m1_m2_n0_n1_device_result.mData.data());
return ck::utils::check_err(e_m0_m1_m2_n0_n1_device_result.mData,
e_m0_m1_m2_n0_n1_host_result.mData)
? 0
: 1;
}
return 0;
}
......@@ -42,3 +42,4 @@ add_subdirectory(20_convnd_bwd_weight_xdl)
add_subdirectory(21_gemm_layernorm)
add_subdirectory(22_cgemm)
add_subdirectory(23_softmax)
add_subdirectory(25_gemm_bias_c_permute)
......@@ -10,7 +10,7 @@
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_5ary_Elementwise_1d.hpp"
#include "ck/device_utility/device_prop.hpp"
#include "ck/device_utility/kernel_launch.hpp"
......@@ -35,7 +35,7 @@ template <typename ADataType,
index_t DScalarPerVector,
index_t EScalarPerVector,
index_t FScalarPerVector>
struct Device5AryElementwise : public BaseOperator
struct Device5AryElementwise : public DeviceElementwise<5, 1, NDim, ElementwiseFunctor>
{
static constexpr auto I0 = Number<0>{};
......@@ -268,12 +268,8 @@ struct Device5AryElementwise : public BaseOperator
return true;
};
static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b,
const CDataType* p_c,
const DDataType* p_d,
const EDataType* p_e,
FDataType* p_f,
static auto MakeArgument(std::array<const void*, 5> p_inputs,
std::array<void*, 1> p_outputs,
std::vector<index_t> lengths,
std::vector<index_t> a_strides,
std::vector<index_t> b_strides,
......@@ -283,12 +279,12 @@ struct Device5AryElementwise : public BaseOperator
std::vector<index_t> f_strides,
ElementwiseFunctor functor)
{
return Argument{p_a,
p_b,
p_c,
p_d,
p_e,
p_f,
return Argument{static_cast<const ADataType*>(p_inputs[0]),
static_cast<const BDataType*>(p_inputs[1]),
static_cast<const CDataType*>(p_inputs[2]),
static_cast<const DDataType*>(p_inputs[3]),
static_cast<const EDataType*>(p_inputs[4]),
static_cast<FDataType*>(p_outputs[0]),
lengths,
a_strides,
b_strides,
......@@ -299,40 +295,58 @@ struct Device5AryElementwise : public BaseOperator
functor};
}
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
const void* p_c,
const void* p_d,
const void* p_e,
void* p_f,
std::vector<index_t> lengths,
std::vector<index_t> a_strides,
std::vector<index_t> b_strides,
std::vector<index_t> c_strides,
std::vector<index_t> d_strides,
std::vector<index_t> e_strides,
std::vector<index_t> f_strides,
ElementwiseFunctor functor)
std::unique_ptr<BaseArgument>
MakeArgumentPointer(std::array<const void*, 5> p_inputs,
std::array<void*, 1> p_outputs,
std::vector<index_t> lengths,
std::vector<std::vector<index_t>> input_strides,
std::vector<std::vector<index_t>> output_strides,
ElementwiseFunctor functor) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<const CDataType*>(p_c),
static_cast<const DDataType*>(p_d),
static_cast<const EDataType*>(p_e),
static_cast<FDataType*>(p_f),
return std::make_unique<Argument>(static_cast<const ADataType*>(p_inputs[0]),
static_cast<const BDataType*>(p_inputs[1]),
static_cast<const CDataType*>(p_inputs[2]),
static_cast<const DDataType*>(p_inputs[3]),
static_cast<const EDataType*>(p_inputs[4]),
static_cast<FDataType*>(p_outputs[0]),
lengths,
a_strides,
b_strides,
c_strides,
d_strides,
e_strides,
f_strides,
input_strides[0],
input_strides[1],
input_strides[2],
input_strides[3],
input_strides[4],
output_strides[0],
functor);
}
static auto MakeInvoker() { return Invoker{}; }
std::unique_ptr<BaseInvoker> MakeInvokerPointer() { return std::make_unique<Invoker>(); }
};
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>();
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "Device5aryElementwise"
<< "<"
<< "NDim = " << NDim
<< "MPerThread = " << MPerThread
<< "AScalarPerVector = " << AScalarPerVector
<< "BScalarPerVector = " << BScalarPerVector
<< "CScalarPerVector = " << CScalarPerVector
<< "DScalarPerVector = " << DScalarPerVector
<< "EScalarPerVector = " << EScalarPerVector
<< "FScalarPerVector = " << FScalarPerVector
<< ">";
// clang-format on
return str.str();
}
}; // namespace device
} // namespace device
} // namespace tensor_operation
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct DeviceBatchedGemm : public BaseOperator
{
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
ck::index_t Batch) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
using DeviceBatchedGemmPtr = std::unique_ptr<
DeviceBatchedGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>>;
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -23,16 +23,16 @@ namespace device {
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename DPtrsGlobal,
typename ReducePtrsGlobal,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename DxsInElementwiseOperation,
typename DxsReduceAccElementwiseOperation,
typename ReduceInElementwiseOperations,
typename ReduceAccElementwiseOperations,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename DGridDescriptor_MBlock_MPerBlock,
typename ReduceGridDescriptor_MBlock_MPerBlock,
typename ComputeBasePrtOfBatch,
typename Block2CTileMap,
bool HasMainK0BlockLoop>
......@@ -44,18 +44,18 @@ __global__ void
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
DPtrsGlobal p_ds_grid,
ReducePtrsGlobal p_reduces_grid,
const index_t batch_count,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
const DxsInElementwiseOperation dxs_in_element_op,
const DxsReduceAccElementwiseOperation dxs_out_element_op,
const ReduceInElementwiseOperations reduce_in_element_ops,
const ReduceAccElementwiseOperations reduce_out_element_ops,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const DGridDescriptor_MBlock_MPerBlock d_grid_desc_mblock_mperblock,
const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock,
const ComputeBasePrtOfBatch compute_base_ptr_of_batch_,
const Block2CTileMap block_2_ctile_map)
{
......@@ -71,10 +71,10 @@ __global__ void
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch_.GetCBasePtr(g_idx)));
static_for<0, p_ds_grid.Size(), 1>{}([&](auto In) {
static_for<0, p_reduces_grid.Size(), 1>{}([&](auto In) {
const long_index_t d_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch_.GetDBasePtr(g_idx, In)));
p_ds_grid(In) = p_ds_grid(In) + d_batch_offset;
p_reduces_grid(In) = p_reduces_grid(In) + d_batch_offset;
});
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
......@@ -82,36 +82,36 @@ __global__ void
GridwiseGemm::template Run<HasMainK0BlockLoop>(p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
p_c_grid + c_batch_offset,
p_ds_grid,
p_reduces_grid,
p_shared,
a_element_op,
b_element_op,
c_element_op,
dxs_in_element_op,
dxs_out_element_op,
reduce_in_element_ops,
reduce_out_element_ops,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
d_grid_desc_mblock_mperblock,
reduce_grid_desc_mblock_mperblock,
block_2_ctile_map);
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_c_grid;
ignore = p_ds_grid;
ignore = p_reduces_grid;
ignore = batch_count;
ignore = a_element_op;
ignore = b_element_op;
ignore = c_element_op;
ignore = dxs_in_element_op;
ignore = dxs_out_element_op;
ignore = reduce_in_element_ops;
ignore = reduce_out_element_ops;
ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = d_grid_desc_mblock_mperblock;
ignore = reduce_grid_desc_mblock_mperblock;
ignore = compute_base_ptr_of_batch_;
ignore = block_2_ctile_map;
#endif // end of if defined (defined(__gfx908__) || defined(__gfx90a__))
#endif
}
// Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle
......@@ -126,14 +126,14 @@ template <typename ALayout,
typename GemmAccDataType,
typename CShuffleDataType,
typename ReduceAccDataType,
typename DPtrsGlobal,
typename ReducePtrsGlobal,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename DxsReduceOperation,
typename DxsInElementwiseOperation,
typename DxsReduceAccElementwiseOperation,
typename DGlobalMemoryDataOperation,
typename ReduceOperations,
typename ReduceInElementwiseOperations,
typename ReduceAccElementwiseOperations,
typename ReduceGlobalMemoryDataOperation,
GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
......@@ -168,12 +168,7 @@ template <typename ALayout,
index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceBatchedGemmReduce_Xdl_CShuffle
: public DeviceGemmReduce<AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
DxsInElementwiseOperation,
DxsReduceAccElementwiseOperation>
struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperations::Size()>
{
using DeviceOp = DeviceBatchedGemmReduce_Xdl_CShuffle;
......@@ -446,7 +441,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
}
// assume D is packed tensor
static auto MakeDGridDescriptor_M(index_t MRaw)
static auto MakeReduceGridDescriptor_M(index_t MRaw)
{
const auto d_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw));
......@@ -474,7 +469,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
using DGridDesc_M = decltype(MakeDGridDescriptor_M(1));
using ReduceGridDesc_M = decltype(MakeReduceGridDescriptor_M(1));
struct ComputeBasePtrOfStridedBatch
{
......@@ -527,19 +522,19 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
CShuffleDataType,
CDataType,
ReduceAccDataType,
DPtrsGlobal,
ReducePtrsGlobal,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
DxsReduceOperation,
DxsInElementwiseOperation,
DxsReduceAccElementwiseOperation,
ReduceOperations,
ReduceInElementwiseOperations,
ReduceAccElementwiseOperations,
InMemoryDataOperationEnum::Set,
DGlobalMemoryDataOperation,
ReduceGlobalMemoryDataOperation,
AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
CGridDesc_M_N,
DGridDesc_M,
ReduceGridDesc_M,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
......@@ -582,7 +577,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
Argument(const ADataType* p_a_grid,
const BDataType* p_b_grid,
CDataType* p_c_grid,
DPtrsGlobal p_ds_grid,
ReducePtrsGlobal p_reduces_grid,
index_t MRaw,
index_t NRaw,
index_t KRaw,
......@@ -592,31 +587,31 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
DxsInElementwiseOperation dxs_in_element_op,
DxsReduceAccElementwiseOperation dxs_out_element_op,
index_t BatchCount)
ReduceInElementwiseOperations reduce_in_element_ops,
ReduceAccElementwiseOperations reduce_out_element_ops,
index_t Batch)
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid},
p_ds_grid_{p_ds_grid},
BatchCount_(BatchCount),
p_reduces_grid_{p_reduces_grid},
Batch_(Batch),
a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)},
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)},
c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)},
d_grid_desc_m_{DeviceOp::MakeDGridDescriptor_M(MRaw)},
reduce_grid_desc_m_{DeviceOp::MakeReduceGridDescriptor_M(MRaw)},
c_grid_desc_mblock_mperblock_nblock_nperblock_{},
d_grid_desc_mblock_mperblock_{},
reduce_grid_desc_mblock_mperblock_{},
compute_base_ptr_of_batch_{
type_convert<index_t>(a_grid_desc_ak0_m_ak1_.GetElementSpaceSize()),
type_convert<index_t>(b_grid_desc_bk0_n_bk1_.GetElementSpaceSize()),
type_convert<index_t>(c_grid_desc_m_n_.GetElementSpaceSize()),
type_convert<index_t>(d_grid_desc_m_.GetElementSpaceSize())},
type_convert<index_t>(reduce_grid_desc_m_.GetElementSpaceSize())},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op},
dxs_in_element_op_{dxs_in_element_op},
dxs_out_element_op_{dxs_out_element_op}
reduce_in_element_ops_{reduce_in_element_ops},
reduce_out_element_ops_{reduce_out_element_ops}
{
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
b_grid_desc_bk0_n_bk1_,
......@@ -627,8 +622,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n_);
d_grid_desc_mblock_mperblock_ =
GridwiseGemm::MakeDGridDescriptor_MBlock_MPerBlock(d_grid_desc_m_);
reduce_grid_desc_mblock_mperblock_ =
GridwiseGemm::MakeReduceGridDescriptor_MBlock_MPerBlock(reduce_grid_desc_m_);
}
}
......@@ -636,22 +631,23 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
CDataType* p_c_grid_;
DPtrsGlobal p_ds_grid_;
index_t BatchCount_;
ReducePtrsGlobal p_reduces_grid_;
index_t Batch_;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
CGridDesc_M_N c_grid_desc_m_n_;
DGridDesc_M d_grid_desc_m_;
ReduceGridDesc_M reduce_grid_desc_m_;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock d_grid_desc_mblock_mperblock_;
typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock
reduce_grid_desc_mblock_mperblock_;
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
DxsInElementwiseOperation dxs_in_element_op_;
DxsReduceAccElementwiseOperation dxs_out_element_op_;
ReduceInElementwiseOperations reduce_in_element_ops_;
ReduceAccElementwiseOperations reduce_out_element_ops_;
};
// Invoker
......@@ -663,7 +659,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
{
#if 0
{
std::cout << "arg.BatchCount_ = " << arg.BatchCount_ << std::endl;
std::cout << "arg.Batch_ = " << arg.Batch_ << std::endl;
std::cout << "arg.a_grid_desc_ak0_m_ak1_{"
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", "
......@@ -678,7 +674,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
std::cout << "arg.d_grid_desc_m_{ " << arg.d_grid_desc_m_.GetLength(I0) << "}"
std::cout << "arg.reduce_grid_desc_m_{ " << arg.reduce_grid_desc_m_.GetLength(I0) << "}"
<< std::endl;
}
#endif
......@@ -692,7 +688,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
}
const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.BatchCount_;
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.Batch_;
const auto K =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
......@@ -704,16 +700,16 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
DPtrsGlobal,
ReducePtrsGlobal,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
DxsInElementwiseOperation,
DxsReduceAccElementwiseOperation,
ReduceInElementwiseOperations,
ReduceAccElementwiseOperations,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock,
typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock,
ComputeBasePtrOfStridedBatch,
typename GridwiseGemm::DefaultBlock2CTileMap,
true>;
......@@ -727,17 +723,17 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.p_ds_grid_,
arg.BatchCount_,
arg.p_reduces_grid_,
arg.Batch_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.dxs_in_element_op_,
arg.dxs_out_element_op_,
arg.reduce_in_element_ops_,
arg.reduce_out_element_ops_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.d_grid_desc_mblock_mperblock_,
arg.reduce_grid_desc_mblock_mperblock_,
arg.compute_base_ptr_of_batch_,
arg.block_2_ctile_map_);
}
......@@ -747,16 +743,16 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
DPtrsGlobal,
ReducePtrsGlobal,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
DxsInElementwiseOperation,
DxsReduceAccElementwiseOperation,
ReduceInElementwiseOperations,
ReduceAccElementwiseOperations,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock,
typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock,
ComputeBasePtrOfStridedBatch,
typename GridwiseGemm::DefaultBlock2CTileMap,
false>;
......@@ -770,17 +766,17 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.p_ds_grid_,
arg.BatchCount_,
arg.p_reduces_grid_,
arg.Batch_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.dxs_in_element_op_,
arg.dxs_out_element_op_,
arg.reduce_in_element_ops_,
arg.reduce_out_element_ops_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.d_grid_desc_mblock_mperblock_,
arg.reduce_grid_desc_mblock_mperblock_,
arg.compute_base_ptr_of_batch_,
arg.block_2_ctile_map_);
}
......@@ -824,39 +820,77 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
}
}
static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b,
CDataType* p_c,
DPtrsGlobal p_dxs,
index_t MRaw,
index_t NRaw,
index_t KRaw,
index_t StrideA,
index_t StrideB,
index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
DxsInElementwiseOperation dxs_in_element_op,
DxsReduceAccElementwiseOperation dxs_out_element_op,
index_t BatchCount)
static constexpr int NumReduce = ReduceOperations::Size();
static auto MakeArgument(const void* p_a,
const void* p_b,
const void* p_bias,
std::array<const void*, 0> p_ds,
void* p_c,
std::array<void*, NumReduce> p_reduces,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
std::array<ck::index_t, 0> StrideDs,
std::array<void*, 3> gemm_element_ops,
std::array<void*, 0> d_element_ops,
std::array<void*, NumReduce> reduce_in_element_op,
std::array<void*, NumReduce> reduce_out_element_op,
index_t Batch)
{
return Argument{p_a,
p_b,
p_c,
p_dxs,
MRaw,
NRaw,
KRaw,
(void)p_bias;
(void)p_ds;
(void)StrideDs;
(void)d_element_ops;
ReducePtrsGlobal reduce_tuple = generate_tuple(
[&](auto I) {
auto tmp = ReducePtrsGlobal{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return static_cast<T*>(p_reduces[I]);
},
Number<NumReduce>{});
ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple(
[&](auto I) {
auto tmp = ReduceInElementwiseOperations{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return *(static_cast<T*>(reduce_in_element_op[I]));
},
Number<NumReduce>{});
ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple(
[&](auto I) {
auto tmp = ReduceAccElementwiseOperations{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return *(static_cast<T*>(reduce_out_element_op[I]));
},
Number<NumReduce>{});
AElementwiseOperation a_element_op =
*(static_cast<AElementwiseOperation*>(gemm_element_ops[0]));
BElementwiseOperation b_element_op =
*(static_cast<BElementwiseOperation*>(gemm_element_ops[1]));
CElementwiseOperation c_element_op =
*(static_cast<CElementwiseOperation*>(gemm_element_ops[2]));
return Argument{static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c),
reduce_tuple,
M,
N,
K,
StrideA,
StrideB,
StrideC,
a_element_op,
b_element_op,
c_element_op,
dxs_in_element_op,
dxs_out_element_op,
BatchCount};
reduce_in_element_ops,
reduce_out_element_ops,
Batch};
}
static auto MakeInvoker() { return Invoker{}; }
......@@ -865,38 +899,74 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
const void* p_bias,
std::array<const void*, 0> p_ds,
void* p_c,
void* p_dxs,
index_t MRaw,
index_t NRaw,
index_t KRaw,
index_t StrideA,
index_t StrideB,
index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
DxsInElementwiseOperation dxs_in_element_op,
DxsReduceAccElementwiseOperation dxs_out_element_op,
index_t BatchCount) override
std::array<void*, NumReduce> p_reduces,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
std::array<ck::index_t, 0> StrideDs,
std::array<void*, 3> gemm_element_ops,
std::array<void*, 0> d_element_ops,
std::array<void*, NumReduce> reduce_in_element_op,
std::array<void*, NumReduce> reduce_out_element_op,
index_t Batch = 1) override
{
DPtrsGlobal dxs_tuple = *(static_cast<DPtrsGlobal*>(p_dxs));
(void)p_bias;
(void)p_ds;
(void)StrideDs;
(void)d_element_ops;
ReducePtrsGlobal reduce_tuple = generate_tuple(
[&](auto I) {
auto tmp = ReducePtrsGlobal{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return static_cast<T*>(p_reduces[I]);
},
Number<NumReduce>{});
ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple(
[&](auto I) {
auto tmp = ReduceInElementwiseOperations{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return *(static_cast<T*>(reduce_in_element_op[I]));
},
Number<NumReduce>{});
ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple(
[&](auto I) {
auto tmp = ReduceAccElementwiseOperations{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return *(static_cast<T*>(reduce_out_element_op[I]));
},
Number<NumReduce>{});
AElementwiseOperation a_element_op =
*(static_cast<AElementwiseOperation*>(gemm_element_ops[0]));
BElementwiseOperation b_element_op =
*(static_cast<BElementwiseOperation*>(gemm_element_ops[1]));
CElementwiseOperation c_element_op =
*(static_cast<CElementwiseOperation*>(gemm_element_ops[2]));
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c),
dxs_tuple,
MRaw,
NRaw,
KRaw,
reduce_tuple,
M,
N,
K,
StrideA,
StrideB,
StrideC,
a_element_op,
b_element_op,
c_element_op,
dxs_in_element_op,
dxs_out_element_op,
BatchCount);
reduce_in_element_ops,
reduce_out_element_ops,
Batch);
}
// polymorphic
......
......@@ -10,7 +10,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp"
#include "ck/device_utility/device_prop.hpp"
......@@ -152,7 +152,7 @@ template <typename ADataType,
ck::index_t CThreadTransferSrcDstVectorDim,
ck::index_t CThreadTransferDstScalarPerVector>
struct DeviceBatchedGemmXdl
: public DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>
: public DeviceBatchedGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
......@@ -339,11 +339,11 @@ struct DeviceBatchedGemmXdl
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
index_t BatchCount)
index_t Batch)
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid},
BatchCount_(BatchCount),
Batch_(Batch),
a_grid_desc_k0_m_k1_{
DeviceBatchedGemmXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA)},
b_grid_desc_k0_n_k1_{
......@@ -376,7 +376,7 @@ struct DeviceBatchedGemmXdl
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
CDataType* p_c_grid_;
index_t BatchCount_;
index_t Batch_;
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_;
......@@ -420,7 +420,7 @@ struct DeviceBatchedGemmXdl
}
const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.BatchCount_;
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.Batch_;
const auto K =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
......@@ -451,7 +451,7 @@ struct DeviceBatchedGemmXdl
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.BatchCount_,
arg.Batch_,
arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
......@@ -485,7 +485,7 @@ struct DeviceBatchedGemmXdl
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.BatchCount_,
arg.Batch_,
arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
......@@ -539,7 +539,7 @@ struct DeviceBatchedGemmXdl
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
index_t BatchCount)
index_t Batch)
{
return Argument{p_a,
p_b,
......@@ -555,7 +555,7 @@ struct DeviceBatchedGemmXdl
a_element_op,
b_element_op,
c_element_op,
BatchCount};
Batch};
}
static auto MakeInvoker() { return Invoker{}; }
......@@ -573,7 +573,7 @@ struct DeviceBatchedGemmXdl
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
index_t BatchCount) override
index_t Batch) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
......@@ -589,7 +589,7 @@ struct DeviceBatchedGemmXdl
a_element_op,
b_element_op,
c_element_op,
BatchCount);
Batch);
}
// polymorphic
......
......@@ -9,6 +9,7 @@
#include "ck/device_utility/device_prop.hpp"
#include "ck/device_utility/kernel_launch.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp"
namespace ck {
......@@ -25,7 +26,7 @@ template <typename ADataType,
index_t AScalarPerVector,
index_t BScalarPerVector,
index_t CScalarPerVector>
struct DeviceBinaryElementwise : public BaseOperator
struct DeviceBinaryElementwise : public DeviceElementwise<2, 1, NDim, ElementwiseFunctor>
{
static constexpr auto I0 = Number<0>{};
......@@ -198,27 +199,30 @@ struct DeviceBinaryElementwise : public BaseOperator
return true;
};
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
std::vector<index_t> lengths,
std::vector<index_t> a_strides,
std::vector<index_t> b_strides,
std::vector<index_t> c_strides,
ElementwiseFunctor functor)
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(std::array<const void*, 2> p_inputs,
std::array<void*, 1> p_outputs,
std::vector<index_t> lengths,
std::vector<std::vector<index_t>> input_strides,
std::vector<std::vector<index_t>> output_strides,
ElementwiseFunctor functor) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c),
return std::make_unique<Argument>(static_cast<const ADataType*>(p_inputs[0]),
static_cast<const BDataType*>(p_inputs[1]),
static_cast<CDataType*>(p_outputs[0]),
lengths,
a_strides,
b_strides,
c_strides,
input_strides[0],
input_strides[1],
output_strides[0],
functor);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() { return std::make_unique<Invoker>(); }
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>();
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
......@@ -226,7 +230,11 @@ struct DeviceBinaryElementwise : public BaseOperator
// clang-format off
str << "DeviceBinaryElementwise"
<< "<"
<< "NDim = " << NDim
<< "MPerThread = " << MPerThread
<< "AScalarPerVector = " << AScalarPerVector
<< "BScalarPerVector = " << BScalarPerVector
<< "CScalarPerVector = " << CScalarPerVector
<< ">";
// clang-format on
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <ck::index_t NumInputTensor,
ck::index_t NumOutputTensor,
index_t NDim,
typename ElementwiseFunctor>
struct DeviceElementwise : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(std::array<const void*, NumInputTensor> p_inputs,
std::array<void*, NumOutputTensor> p_outputs,
std::vector<index_t> lengths,
std::vector<std::vector<index_t>> input_strides,
std::vector<std::vector<index_t>> output_strides,
ElementwiseFunctor functor) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <ck::index_t NumInputTensor,
ck::index_t NumOutputTensor,
index_t NDim,
typename ElementwiseFunctor>
using DeviceElementwisePtr =
std::unique_ptr<DeviceElementwise<NumInputTensor, NumOutputTensor, NDim, ElementwiseFunctor>>;
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -29,20 +29,20 @@ template <typename ALayout,
typename ADataType,
typename BDataType,
typename CDataType,
typename C0DataType,
typename C1DataType,
typename BiasDataType,
typename D0DataType,
typename GemmAccDataType,
typename CShuffleDataType,
typename ReduceAccDataType,
typename DPtrsGlobal,
typename ReducePtrsGlobal,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename C1ElementwiseOperation,
typename DxsReduceOperation,
typename DxsInElementwiseOperation,
typename DxsReduceAccElementwiseOperation,
typename DGlobalMemoryDataOperation,
typename D0ElementwiseOperation,
typename ReduceOperations,
typename ReduceInElementwiseOperations,
typename ReduceAccElementwiseOperations,
typename ReduceGlobalMemoryDataOperation,
GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
......@@ -77,13 +77,7 @@ template <typename ALayout,
index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceGemmBiasAddReduce_Xdl_CShuffle
: public DeviceGemmBiasAddReduce<AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
C1ElementwiseOperation,
DxsInElementwiseOperation,
DxsReduceAccElementwiseOperation>
struct DeviceGemmBiasAddReduce_Xdl_CShuffle : public DeviceGemmReduce<1, ReduceOperations::Size()>
{
using DeviceOp = DeviceGemmBiasAddReduce_Xdl_CShuffle;
......@@ -356,7 +350,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
}
// assume D is packed tensor
static auto MakeDGridDescriptor_M(index_t MRaw)
static auto MakeReduceGridDescriptor_M(index_t MRaw)
{
const auto d_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw));
......@@ -386,7 +380,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
using C0GridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 0));
using C1GridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
using DGridDesc_M = decltype(MakeDGridDescriptor_M(1));
using ReduceGridDesc_M = decltype(MakeReduceGridDescriptor_M(1));
// GridwiseGemm
using GridwiseGemm = GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1<
......@@ -394,25 +388,25 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
GemmAccDataType,
CShuffleDataType,
CDataType,
C0DataType,
C1DataType,
BiasDataType,
D0DataType,
ReduceAccDataType,
DPtrsGlobal,
ReducePtrsGlobal,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
C1ElementwiseOperation,
DxsReduceOperation,
DxsInElementwiseOperation,
DxsReduceAccElementwiseOperation,
D0ElementwiseOperation,
ReduceOperations,
ReduceInElementwiseOperations,
ReduceAccElementwiseOperations,
InMemoryDataOperationEnum::Set,
DGlobalMemoryDataOperation,
ReduceGlobalMemoryDataOperation,
AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
CGridDesc_M_N,
C0GridDesc_M_N,
C1GridDesc_M_N,
DGridDesc_M,
ReduceGridDesc_M,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
......@@ -455,9 +449,9 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
Argument(const ADataType* p_a_grid,
const BDataType* p_b_grid,
CDataType* p_c_grid,
const C0DataType* p_c0_grid,
const C1DataType* p_c1_grid,
DPtrsGlobal p_ds_grid,
const BiasDataType* p_bias_grid,
const D0DataType* p_d0_grid,
ReducePtrsGlobal p_reduces_grid,
index_t MRaw,
index_t NRaw,
index_t KRaw,
......@@ -468,32 +462,32 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
C1ElementwiseOperation c1_element_op,
DxsInElementwiseOperation dxs_in_element_op,
DxsReduceAccElementwiseOperation dxs_out_element_op)
D0ElementwiseOperation d0_element_op,
ReduceInElementwiseOperations reduce_in_element_ops,
ReduceAccElementwiseOperations reduce_out_element_ops)
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid},
p_c0_grid_{p_c0_grid},
p_c1_grid_{p_c1_grid},
p_ds_grid_{p_ds_grid},
p_bias_grid_{p_bias_grid},
p_d0_grid_{p_d0_grid},
p_reduces_grid_{p_reduces_grid},
a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)},
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)},
c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)},
c0_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, 0)},
c1_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC1)},
d_grid_desc_m_{DeviceOp::MakeDGridDescriptor_M(MRaw)},
reduce_grid_desc_m_{DeviceOp::MakeReduceGridDescriptor_M(MRaw)},
c_grid_desc_mblock_mperblock_nblock_nperblock_{},
c0_grid_desc_mblock_mperblock_nblock_nperblock_{},
c1_grid_desc_mblock_mperblock_nblock_nperblock_{},
d_grid_desc_mblock_mperblock_{},
reduce_grid_desc_mblock_mperblock_{},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op},
c1_element_op_{c1_element_op},
dxs_in_element_op_{dxs_in_element_op},
dxs_out_element_op_{dxs_out_element_op}
d0_element_op_{d0_element_op},
reduce_in_element_ops_{reduce_in_element_ops},
reduce_out_element_ops_{reduce_out_element_ops}
{
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
b_grid_desc_bk0_n_bk1_,
......@@ -512,8 +506,8 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c1_grid_desc_m_n_);
d_grid_desc_mblock_mperblock_ =
GridwiseGemm::MakeDGridDescriptor_MBlock_MPerBlock(d_grid_desc_m_);
reduce_grid_desc_mblock_mperblock_ =
GridwiseGemm::MakeReduceGridDescriptor_MBlock_MPerBlock(reduce_grid_desc_m_);
}
}
......@@ -521,29 +515,30 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
CDataType* p_c_grid_;
const C0DataType* p_c0_grid_;
const C1DataType* p_c1_grid_;
DPtrsGlobal p_ds_grid_;
const BiasDataType* p_bias_grid_;
const D0DataType* p_d0_grid_;
ReducePtrsGlobal p_reduces_grid_;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
CGridDesc_M_N c_grid_desc_m_n_;
C0GridDesc_M_N c0_grid_desc_m_n_;
C1GridDesc_M_N c1_grid_desc_m_n_;
DGridDesc_M d_grid_desc_m_;
ReduceGridDesc_M reduce_grid_desc_m_;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c0_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c1_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock d_grid_desc_mblock_mperblock_;
typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock
reduce_grid_desc_mblock_mperblock_;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
C1ElementwiseOperation c1_element_op_;
DxsInElementwiseOperation dxs_in_element_op_;
DxsReduceAccElementwiseOperation dxs_out_element_op_;
D0ElementwiseOperation d0_element_op_;
ReduceInElementwiseOperations reduce_in_element_ops_;
ReduceAccElementwiseOperations reduce_out_element_ops_;
};
// Invoker
......@@ -574,21 +569,21 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
C0DataType,
C1DataType,
DPtrsGlobal,
BiasDataType,
D0DataType,
ReducePtrsGlobal,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
C1ElementwiseOperation,
DxsInElementwiseOperation,
DxsReduceAccElementwiseOperation,
D0ElementwiseOperation,
ReduceInElementwiseOperations,
ReduceAccElementwiseOperations,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock,
typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock,
typename GridwiseGemm::DefaultBlock2CTileMap,
true>;
......@@ -601,21 +596,21 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.p_c0_grid_,
arg.p_c1_grid_,
arg.p_ds_grid_,
arg.p_bias_grid_,
arg.p_d0_grid_,
arg.p_reduces_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.c1_element_op_,
arg.dxs_in_element_op_,
arg.dxs_out_element_op_,
arg.d0_element_op_,
arg.reduce_in_element_ops_,
arg.reduce_out_element_ops_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.c0_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.c1_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.d_grid_desc_mblock_mperblock_,
arg.reduce_grid_desc_mblock_mperblock_,
arg.block_2_ctile_map_);
}
else
......@@ -624,21 +619,21 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
C0DataType,
C1DataType,
DPtrsGlobal,
BiasDataType,
D0DataType,
ReducePtrsGlobal,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
C1ElementwiseOperation,
DxsInElementwiseOperation,
DxsReduceAccElementwiseOperation,
D0ElementwiseOperation,
ReduceInElementwiseOperations,
ReduceAccElementwiseOperations,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock,
typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock,
typename GridwiseGemm::DefaultBlock2CTileMap,
false>;
......@@ -651,21 +646,21 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.p_c0_grid_,
arg.p_c1_grid_,
arg.p_ds_grid_,
arg.p_bias_grid_,
arg.p_d0_grid_,
arg.p_reduces_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.c1_element_op_,
arg.dxs_in_element_op_,
arg.dxs_out_element_op_,
arg.d0_element_op_,
arg.reduce_in_element_ops_,
arg.reduce_out_element_ops_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.c0_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.c1_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.d_grid_desc_mblock_mperblock_,
arg.reduce_grid_desc_mblock_mperblock_,
arg.block_2_ctile_map_);
}
......@@ -700,45 +695,76 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b,
CDataType* p_c,
const C0DataType* p_c0,
const C1DataType* p_c1,
DPtrsGlobal p_dxs,
index_t MRaw,
index_t NRaw,
index_t KRaw,
index_t StrideA,
index_t StrideB,
index_t StrideC,
index_t StrideC1,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
C1ElementwiseOperation c1_element_op,
DxsInElementwiseOperation dxs_in_element_op,
DxsReduceAccElementwiseOperation dxs_out_element_op)
static constexpr int NumReduce = ReduceOperations::Size();
static auto MakeArgument(const void* p_a,
const void* p_b,
const void* p_bias,
std::array<const void*, 1> p_ds,
void* p_c,
std::array<void*, NumReduce> p_reduces,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
std::array<ck::index_t, 1> StrideDs,
std::array<void*, 3> gemm_element_ops,
std::array<void*, 1> d_element_ops,
std::array<void*, NumReduce> reduce_in_element_op,
std::array<void*, NumReduce> reduce_out_element_op)
{
return Argument{p_a,
p_b,
p_c,
p_c0,
p_c1,
p_dxs,
MRaw,
NRaw,
KRaw,
ReducePtrsGlobal reduce_tuple = generate_tuple(
[&](auto I) {
auto tmp = ReducePtrsGlobal{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return static_cast<T*>(p_reduces[I]);
},
Number<NumReduce>{});
ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple(
[&](auto I) {
auto tmp = ReduceInElementwiseOperations{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return *(static_cast<T*>(reduce_in_element_op[I]));
},
Number<NumReduce>{});
ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple(
[&](auto I) {
auto tmp = ReduceAccElementwiseOperations{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return *(static_cast<T*>(reduce_out_element_op[I]));
},
Number<NumReduce>{});
AElementwiseOperation a_element_op =
*(static_cast<AElementwiseOperation*>(gemm_element_ops[0]));
BElementwiseOperation b_element_op =
*(static_cast<BElementwiseOperation*>(gemm_element_ops[1]));
CElementwiseOperation c_element_op =
*(static_cast<CElementwiseOperation*>(gemm_element_ops[2]));
D0ElementwiseOperation d_element_op =
*(static_cast<D0ElementwiseOperation*>(d_element_ops[0]));
return Argument{static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c),
static_cast<const BiasDataType*>(p_bias),
static_cast<const D0DataType*>(p_ds[0]),
reduce_tuple,
M,
N,
K,
StrideA,
StrideB,
StrideC,
StrideC1,
StrideDs[0],
a_element_op,
b_element_op,
c_element_op,
c1_element_op,
dxs_in_element_op,
dxs_out_element_op};
d_element_op,
reduce_in_element_ops,
reduce_out_element_ops};
}
static auto MakeInvoker() { return Invoker{}; }
......@@ -747,45 +773,74 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
const void* p_bias,
std::array<const void*, 1> p_ds,
void* p_c,
const void* p_c0,
const void* p_c1,
void* p_dxs,
index_t MRaw,
index_t NRaw,
index_t KRaw,
index_t StrideA,
index_t StrideB,
index_t StrideC,
index_t StrideC1,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
C1ElementwiseOperation c1_element_op,
DxsInElementwiseOperation dxs_in_element_op,
DxsReduceAccElementwiseOperation dxs_out_element_op,
std::array<void*, NumReduce> p_reduces,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
std::array<ck::index_t, 1> StrideDs,
std::array<void*, 3> gemm_element_ops,
std::array<void*, 1> d_element_ops,
std::array<void*, NumReduce> reduce_in_element_op,
std::array<void*, NumReduce> reduce_out_element_op,
index_t /* KBatch */ = 1) override
{
DPtrsGlobal dxs_tuple = *(static_cast<DPtrsGlobal*>(p_dxs));
ReducePtrsGlobal reduce_tuple = generate_tuple(
[&](auto I) {
auto tmp = ReducePtrsGlobal{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return static_cast<T*>(p_reduces[I]);
},
Number<NumReduce>{});
ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple(
[&](auto I) {
auto tmp = ReduceInElementwiseOperations{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return *(static_cast<T*>(reduce_in_element_op[I]));
},
Number<NumReduce>{});
ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple(
[&](auto I) {
auto tmp = ReduceAccElementwiseOperations{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return *(static_cast<T*>(reduce_out_element_op[I]));
},
Number<NumReduce>{});
AElementwiseOperation a_element_op =
*(static_cast<AElementwiseOperation*>(gemm_element_ops[0]));
BElementwiseOperation b_element_op =
*(static_cast<BElementwiseOperation*>(gemm_element_ops[1]));
CElementwiseOperation c_element_op =
*(static_cast<CElementwiseOperation*>(gemm_element_ops[2]));
D0ElementwiseOperation d_element_op =
*(static_cast<D0ElementwiseOperation*>(d_element_ops[0]));
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c),
static_cast<const C0DataType*>(p_c0),
static_cast<const C1DataType*>(p_c1),
dxs_tuple,
MRaw,
NRaw,
KRaw,
static_cast<const BiasDataType*>(p_bias),
static_cast<const D0DataType*>(p_ds[0]),
reduce_tuple,
M,
N,
K,
StrideA,
StrideB,
StrideC,
StrideC1,
StrideDs[0],
a_element_op,
b_element_op,
c_element_op,
c1_element_op,
dxs_in_element_op,
dxs_out_element_op);
d_element_op,
reduce_in_element_ops,
reduce_out_element_ops);
}
// polymorphic
......@@ -800,7 +855,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
auto str = std::stringstream();
// clang-format off
str << "DeviceGemmReduce_Xdl_CShuffle"
str << "DeviceGemmBiasAddReduce_Xdl_CShuffle"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include "device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
struct DEGridDesc_M0_M1_M2_N0_N1
{
ck::index_t M0_, M1_, M2_, N0_, N1_;
ck::index_t stride_M0_, stride_M1_, stride_M2_, stride_N0_, stride_N1_;
};
// input : A[M, K], B[K, N],
// input : D[M, N], ...
// output : E[M, N]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D)
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation>
struct DeviceGemmBiasCPermute : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
const void* p_d,
void* p_e,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
DEGridDesc_M0_M1_M2_N0_N1 d_gride_desc,
DEGridDesc_M0_M1_M2_N0_N1 e_gride_desc,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
using DeviceGemmBiasCPermutePtr = std::unique_ptr<
DeviceGemmBiasCPermute<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>>;
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_bias_c_permute.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/device_utility/device_prop.hpp"
#include "ck/device_utility/kernel_launch.hpp"
namespace ck {
template <typename GridwiseGemm,
typename FloatAB,
typename FloatDsPointer,
typename FloatE,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename Block2ETileMap,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_bias_c_permute(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatDsPointer p_ds_grid,
FloatE* __restrict__ p_e_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2ETileMap block_2_etile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid,
p_ds_grid,
p_e_grid,
p_shared,
a_element_op,
b_element_op,
cde_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_etile_map);
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_ds_grid;
ignore = p_e_grid;
ignore = a_element_op;
ignore = b_element_op;
ignore = cde_element_op;
ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1;
ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = e_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = block_2_etile_map;
#endif
}
} // namespace ck
namespace ck {
namespace tensor_operation {
namespace device {
// input : A[M, K], or A[K, N]
// input : B[K, N], or A[N, K]
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
template <typename ALayout,
typename BLayout,
typename CDELayout,
typename ADataType,
typename BDataType,
typename GemmAccDataType,
typename CShuffleDataType,
typename DDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t AK1,
index_t BK1,
index_t MPerXDL,
index_t NPerXDL,
index_t MXdlPerWave,
index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
{
using DeviceOp = DeviceGemmBiasCPermute_Xdl;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr index_t NumDTensor = I1;
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
{
const auto a_grid_desc_mraw_kraw = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(StrideA, I1));
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(I1, StrideA));
}
}();
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
const auto MPad = M - MRaw;
const auto KPad = K - KRaw;
if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad both M and K
assert(K % AK1 == 0);
const auto AK0 = K / AK1;
const auto a_grid_desc_m_k =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_right_pad_transform(MRaw, MPad),
make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MNPadding)
{
// pad M, but not K
assert(KRaw % AK1 == 0);
const auto AK0 = KRaw / AK1;
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_right_pad_transform(MRaw, MPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
// pad K, but not M
assert(K % AK1 == 0);
const auto AK0 = K / AK1;
const auto a_grid_desc_m_k = transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(MRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else
{
// not pad M or K
assert(KRaw % AK1 == 0);
const auto AK0 = KRaw / AK1;
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(MRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
}
static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
{
const auto b_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(I1, StrideB));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(StrideB, I1));
}
}();
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
const auto NPad = N - NRaw;
const auto KPad = K - KRaw;
if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad both N and K
assert(K % BK1 == 0);
const auto BK0 = K / BK1;
const auto b_grid_desc_n_k =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_right_pad_transform(NRaw, NPad),
make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::MNPadding)
{
// pad N, but not K
assert(KRaw % BK1 == 0);
const auto BK0 = KRaw / BK1;
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
// pad K, but not N
assert(K % BK1 == 0);
const auto BK0 = K / BK1;
const auto b_grid_desc_n_k = transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(NRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else
{
// not pad N or K
assert(KRaw % BK1 == 0);
const auto BK0 = KRaw / BK1;
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(NRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
}
static auto MakeEGridDescriptor_M_N(DEGridDesc_M0_M1_M2_N0_N1 d_e_grid_desc)
{
index_t M0 = d_e_grid_desc.M0_;
index_t M1 = d_e_grid_desc.M1_;
index_t M2 = d_e_grid_desc.M2_;
index_t N0 = d_e_grid_desc.N0_;
index_t N1 = d_e_grid_desc.N1_;
index_t stride_M0 = d_e_grid_desc.stride_M0_;
index_t stride_M1 = d_e_grid_desc.stride_M1_;
index_t stride_M2 = d_e_grid_desc.stride_M2_;
index_t stride_N0 = d_e_grid_desc.stride_N0_;
index_t stride_N1 = d_e_grid_desc.stride_N1_;
const auto MRaw = M0 * M1 * M2;
const auto NRaw = N0 * N1;
const auto c_grid_desc_mraw_nraw = [&]() {
const auto c_grid_desc_m0_m1_m2_n0_n1 = make_naive_tensor_descriptor(
make_tuple(M0, M1, M2, N0, N1),
make_tuple(stride_M0, stride_M1, stride_M2, stride_N0, stride_N1));
return transform_tensor_descriptor(
c_grid_desc_m0_m1_m2_n0_n1,
make_tuple(make_merge_transform(make_tuple(M0, M1, M2)),
make_merge_transform(make_tuple(N0, N1))),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}();
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
const auto MPad = M - MRaw;
const auto NPad = N - NRaw;
if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad M and N
return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(MRaw, MPad),
make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
// pad M, but not N
return transform_tensor_descriptor(
c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
// pad N, but not M
return transform_tensor_descriptor(
c_grid_desc_mraw_nraw,
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
// not pad M or N
return c_grid_desc_mraw_nraw;
}
}
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1));
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(DEGridDesc_M0_M1_M2_N0_N1{}));
// GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype
GemmAccDataType,
CShuffleDataType,
ck::Tuple<DDataType>,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
InMemoryDataOperationEnum::Set,
AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
EGridDesc_M_N,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched>;
// Argument
struct Argument : public BaseArgument
{
Argument(const void* p_a_grid,
const void* p_b_grid,
const void* p_d_grid,
void* p_e_grid,
index_t MRaw,
index_t NRaw,
index_t KRaw,
index_t StrideA,
index_t StrideB,
DEGridDesc_M0_M1_M2_N0_N1 d_grid_desc,
DEGridDesc_M0_M1_M2_N0_N1 e_grid_desc,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
: p_a_grid_{static_cast<const ADataType*>(p_a_grid)},
p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
p_ds_grid_{}, // FIXME
p_e_grid_{static_cast<EDataType*>(p_e_grid)},
a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)},
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)},
ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(e_grid_desc)},
e_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op}
{
if(MRaw != d_grid_desc.M0_ * d_grid_desc.M1_ * d_grid_desc.M2_)
{
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
}
if(NRaw != d_grid_desc.N0_ * d_grid_desc.N1_)
{
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
}
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
b_grid_desc_bk0_n_bk1_,
e_grid_desc_m_n_,
block_2_etile_map_))
{
e_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n_);
p_ds_grid_(I0) = static_cast<const DDataType*>(p_d_grid);
const auto d_grid_desc_m_n = DeviceOp::MakeEGridDescriptor_M_N(d_grid_desc);
ds_grid_desc_mblock_mperblock_nblock_nperblock_(I0) =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
d_grid_desc_m_n);
}
}
// private:
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
typename GridwiseGemm::DsGridPointer p_ds_grid_;
EDataType* p_e_grid_;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
StaticallyIndexedArray<
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
NumDTensor>
ds_grid_desc_mblock_mperblock_nblock_nperblock_; // FIXME: Ds desc may be of different
// type from E
EGridDesc_M_N e_grid_desc_m_n_;
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::DefaultBlock2ETileMap block_2_etile_map_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
};
// Invoker
struct Invoker : public BaseInvoker
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.e_grid_desc_m_n_,
arg.block_2_etile_map_))
{
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
}
const index_t grid_size =
arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_);
const auto K =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
auto launch_kernel = [&](auto has_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value;
const auto kernel = kernel_gemm_bias_c_permute<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
typename GridwiseGemm::DsGridPointer,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
ck::StaticallyIndexedArray<
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
NumDTensor>,
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DefaultBlock2ETileMap,
has_main_loop>;
return launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_ds_grid_,
arg.p_e_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_etile_map_);
};
float ave_time = 0;
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
ave_time = launch_kernel(integral_constant<bool, true>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, false>{});
}
return ave_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static bool IsSupportedArgument(const Argument& arg)
{
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a"))
{
return false;
}
return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.e_grid_desc_m_n_,
arg.block_2_etile_map_);
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const void* p_a,
const void* p_b,
const void* p_d,
void* p_e,
index_t MRaw,
index_t NRaw,
index_t KRaw,
index_t StrideA,
index_t StrideB,
DEGridDesc_M0_M1_M2_N0_N1 d_grid_desc,
DEGridDesc_M0_M1_M2_N0_N1 e_grid_desc,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
{
return Argument{p_a,
p_b,
p_d,
p_e,
MRaw,
NRaw,
KRaw,
StrideA,
StrideB,
d_grid_desc,
e_grid_desc,
a_element_op,
b_element_op,
cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
const void* p_d,
void* p_e,
index_t MRaw,
index_t NRaw,
index_t KRaw,
index_t StrideA,
index_t StrideB,
DEGridDesc_M0_M1_M2_N0_N1 d_grid_desc,
DEGridDesc_M0_M1_M2_N0_N1 e_grid_desc,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) override
{
return std::make_unique<Argument>(p_a,
p_b,
p_d,
p_e,
MRaw,
NRaw,
KRaw,
StrideA,
StrideB,
d_grid_desc,
e_grid_desc,
a_element_op,
b_element_op,
cde_element_op);
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceGemmBiasCPermute_Xdl"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< KPerBlock << ", "
<< AK1 << ", "
<< BK1
<< ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -9,91 +9,34 @@ namespace ck {
namespace tensor_operation {
namespace device {
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename DxsInElementwiseOperation,
typename DxsReduceAccElementwiseOperation>
template <ck::index_t NumDTensor, ck::index_t NumReduce>
struct DeviceGemmReduce : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
const void* p_bias,
std::array<const void*, NumDTensor> p_ds,
void* p_c,
void* p_dxs,
std::array<void*, NumReduce> p_reduces,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
DxsInElementwiseOperation dxs_in_element_op,
DxsReduceAccElementwiseOperation dxs_out_element_op,
std::array<ck::index_t, NumDTensor> StrideDs,
std::array<void*, 3> gemm_element_ops,
std::array<void*, NumDTensor> d_element_ops,
std::array<void*, NumReduce> reduce_in_element_ops,
std::array<void*, NumReduce> reduce_out_element_ops,
ck::index_t BatchCount = 1) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename DxsInElementwiseOperation,
typename DxsReduceAccElementwiseOperation>
using DeviceGemmReducePtr = std::unique_ptr<DeviceGemmReduce<AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
DxsInElementwiseOperation,
DxsReduceAccElementwiseOperation>>;
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename C1ElementwiseOperation,
typename DxsInElementwiseOperation,
typename DxsReduceAccElementwiseOperation>
struct DeviceGemmBiasAddReduce : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
const void* p_c0,
const void* p_c1,
void* p_dxs,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
ck::index_t StrideC1,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
C1ElementwiseOperation c1_element_op,
DxsInElementwiseOperation dxs_in_element_op,
DxsReduceAccElementwiseOperation dxs_out_element_op,
ck::index_t BatchCount = 1) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename C1ElementwiseOperation,
typename DxsInElementwiseOperation,
typename DxsReduceAccElementwiseOperation>
using DeviceGemmBiasAddReducePtr =
std::unique_ptr<DeviceGemmBiasAddReduce<AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
C1ElementwiseOperation,
DxsInElementwiseOperation,
DxsReduceAccElementwiseOperation>>;
template <ck::index_t NumDTensor, ck::index_t NumReduce>
using DeviceGemmReducePtr = std::unique_ptr<DeviceGemmReduce<NumDTensor, NumReduce>>;
} // namespace device
} // namespace tensor_operation
......
......@@ -32,14 +32,14 @@ template <typename ALayout,
typename GemmAccDataType,
typename CShuffleDataType,
typename ReduceAccDataType,
typename DPtrsGlobal,
typename ReducePtrsGlobal,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename DxsReduceOperation,
typename DxsInElementwiseOperation,
typename DxsReduceAccElementwiseOperation,
typename DGlobalMemoryDataOperation,
typename ReduceOperations,
typename ReduceInElementwiseOperations,
typename ReduceAccElementwiseOperations,
typename ReduceGlobalMemoryDataOperation,
GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
......@@ -74,11 +74,7 @@ template <typename ALayout,
index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
DxsInElementwiseOperation,
DxsReduceAccElementwiseOperation>
struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperations::Size()>
{
using DeviceOp = DeviceGemmReduce_Xdl_CShuffle;
......@@ -350,8 +346,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
}
}
// assume D is packed tensor
static auto MakeDGridDescriptor_M(index_t MRaw)
// assume Reduce is packed tensor
static auto MakeReduceGridDescriptor_M(index_t MRaw)
{
const auto d_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw));
......@@ -379,7 +375,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
using DGridDesc_M = decltype(MakeDGridDescriptor_M(1));
using ReduceGridDesc_M = decltype(MakeReduceGridDescriptor_M(1));
// GridwiseGemm
using GridwiseGemm = GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1<
......@@ -388,19 +384,19 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
CShuffleDataType,
CDataType,
ReduceAccDataType,
DPtrsGlobal,
ReducePtrsGlobal,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
DxsReduceOperation,
DxsInElementwiseOperation,
DxsReduceAccElementwiseOperation,
ReduceOperations,
ReduceInElementwiseOperations,
ReduceAccElementwiseOperations,
InMemoryDataOperationEnum::Set,
DGlobalMemoryDataOperation,
ReduceGlobalMemoryDataOperation,
AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
CGridDesc_M_N,
DGridDesc_M,
ReduceGridDesc_M,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
......@@ -443,7 +439,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
Argument(const ADataType* p_a_grid,
const BDataType* p_b_grid,
CDataType* p_c_grid,
DPtrsGlobal p_ds_grid,
ReducePtrsGlobal p_reduces_grid,
index_t MRaw,
index_t NRaw,
index_t KRaw,
......@@ -453,24 +449,24 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
DxsInElementwiseOperation dxs_in_element_op,
DxsReduceAccElementwiseOperation dxs_out_element_op)
ReduceInElementwiseOperations reduce_in_element_ops,
ReduceAccElementwiseOperations reduce_out_element_ops)
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid},
p_ds_grid_{p_ds_grid},
p_reduces_grid_{p_reduces_grid},
a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)},
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)},
c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)},
d_grid_desc_m_{DeviceOp::MakeDGridDescriptor_M(MRaw)},
reduce_grid_desc_m_{DeviceOp::MakeReduceGridDescriptor_M(MRaw)},
c_grid_desc_mblock_mperblock_nblock_nperblock_{},
d_grid_desc_mblock_mperblock_{},
reduce_grid_desc_mblock_mperblock_{},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op},
dxs_in_element_op_{dxs_in_element_op},
dxs_out_element_op_{dxs_out_element_op}
reduce_in_element_ops_{reduce_in_element_ops},
reduce_out_element_ops_{reduce_out_element_ops}
{
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
b_grid_desc_bk0_n_bk1_,
......@@ -481,8 +477,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n_);
d_grid_desc_mblock_mperblock_ =
GridwiseGemm::MakeDGridDescriptor_MBlock_MPerBlock(d_grid_desc_m_);
reduce_grid_desc_mblock_mperblock_ =
GridwiseGemm::MakeReduceGridDescriptor_MBlock_MPerBlock(reduce_grid_desc_m_);
}
}
......@@ -490,20 +486,21 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
CDataType* p_c_grid_;
DPtrsGlobal p_ds_grid_;
ReducePtrsGlobal p_reduces_grid_;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
CGridDesc_M_N c_grid_desc_m_n_;
DGridDesc_M d_grid_desc_m_;
ReduceGridDesc_M reduce_grid_desc_m_;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock d_grid_desc_mblock_mperblock_;
typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock
reduce_grid_desc_mblock_mperblock_;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
DxsInElementwiseOperation dxs_in_element_op_;
DxsReduceAccElementwiseOperation dxs_out_element_op_;
ReduceInElementwiseOperations reduce_in_element_ops_;
ReduceAccElementwiseOperations reduce_out_element_ops_;
};
// Invoker
......@@ -528,7 +525,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
std::cout << "arg.d_grid_desc_m_{ " << arg.d_grid_desc_m_.GetLength(I0) << "}"
std::cout << "arg.reduce_grid_desc_m_{ " << arg.reduce_grid_desc_m_.GetLength(I0) << "}"
<< std::endl;
}
#endif
......@@ -554,16 +551,16 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
DPtrsGlobal,
ReducePtrsGlobal,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
DxsInElementwiseOperation,
DxsReduceAccElementwiseOperation,
ReduceInElementwiseOperations,
ReduceAccElementwiseOperations,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock,
typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock,
typename GridwiseGemm::DefaultBlock2CTileMap,
true>;
......@@ -576,16 +573,16 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.p_ds_grid_,
arg.p_reduces_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.dxs_in_element_op_,
arg.dxs_out_element_op_,
arg.reduce_in_element_ops_,
arg.reduce_out_element_ops_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.d_grid_desc_mblock_mperblock_,
arg.reduce_grid_desc_mblock_mperblock_,
arg.block_2_ctile_map_);
}
else
......@@ -594,16 +591,16 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
DPtrsGlobal,
ReducePtrsGlobal,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
DxsInElementwiseOperation,
DxsReduceAccElementwiseOperation,
ReduceInElementwiseOperations,
ReduceAccElementwiseOperations,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock,
typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock,
typename GridwiseGemm::DefaultBlock2CTileMap,
false>;
......@@ -616,16 +613,16 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.p_ds_grid_,
arg.p_reduces_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.dxs_in_element_op_,
arg.dxs_out_element_op_,
arg.reduce_in_element_ops_,
arg.reduce_out_element_ops_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.d_grid_desc_mblock_mperblock_,
arg.reduce_grid_desc_mblock_mperblock_,
arg.block_2_ctile_map_);
}
......@@ -660,37 +657,75 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b,
CDataType* p_c,
DPtrsGlobal p_dxs,
index_t MRaw,
index_t NRaw,
index_t KRaw,
index_t StrideA,
index_t StrideB,
index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
DxsInElementwiseOperation dxs_in_element_op,
DxsReduceAccElementwiseOperation dxs_out_element_op)
static constexpr int NumReduce = ReduceOperations::Size();
static auto MakeArgument(const void* p_a,
const void* p_b,
const void* p_bias,
std::array<const void*, 0> p_ds,
void* p_c,
std::array<void*, NumReduce> p_reduces,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
std::array<ck::index_t, 0> StrideDs,
std::array<void*, 3> gemm_element_ops,
std::array<void*, 0> d_element_ops,
std::array<void*, NumReduce> reduce_in_element_op,
std::array<void*, NumReduce> reduce_out_element_op)
{
return Argument{p_a,
p_b,
p_c,
p_dxs,
MRaw,
NRaw,
KRaw,
(void)p_bias;
(void)p_ds;
(void)StrideDs;
(void)d_element_ops;
ReducePtrsGlobal reduce_tuple = generate_tuple(
[&](auto I) {
auto tmp = ReducePtrsGlobal{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return static_cast<T*>(p_reduces[I]);
},
Number<NumReduce>{});
ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple(
[&](auto I) {
auto tmp = ReduceInElementwiseOperations{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return *(static_cast<T*>(reduce_in_element_op[I]));
},
Number<NumReduce>{});
ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple(
[&](auto I) {
auto tmp = ReduceAccElementwiseOperations{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return *(static_cast<T*>(reduce_out_element_op[I]));
},
Number<NumReduce>{});
AElementwiseOperation a_element_op =
*(static_cast<AElementwiseOperation*>(gemm_element_ops[0]));
BElementwiseOperation b_element_op =
*(static_cast<BElementwiseOperation*>(gemm_element_ops[1]));
CElementwiseOperation c_element_op =
*(static_cast<CElementwiseOperation*>(gemm_element_ops[2]));
return Argument{static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c),
reduce_tuple,
M,
N,
K,
StrideA,
StrideB,
StrideC,
a_element_op,
b_element_op,
c_element_op,
dxs_in_element_op,
dxs_out_element_op};
reduce_in_element_ops,
reduce_out_element_ops};
}
static auto MakeInvoker() { return Invoker{}; }
......@@ -699,37 +734,73 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
const void* p_bias,
std::array<const void*, 0> p_ds,
void* p_c,
void* p_dxs,
index_t MRaw,
index_t NRaw,
index_t KRaw,
index_t StrideA,
index_t StrideB,
index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
DxsInElementwiseOperation dxs_in_element_op,
DxsReduceAccElementwiseOperation dxs_out_element_op,
index_t /* KBatch */ = 1) override
std::array<void*, NumReduce> p_reduces,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
std::array<ck::index_t, 0> StrideDs,
std::array<void*, 3> gemm_element_ops,
std::array<void*, 0> d_element_ops,
std::array<void*, NumReduce> reduce_in_element_op,
std::array<void*, NumReduce> reduce_out_element_op,
ck::index_t = 1) override
{
DPtrsGlobal dxs_tuple = *(static_cast<DPtrsGlobal*>(p_dxs));
(void)p_bias;
(void)p_ds;
(void)StrideDs;
(void)d_element_ops;
ReducePtrsGlobal reduce_tuple = generate_tuple(
[&](auto I) {
auto tmp = ReducePtrsGlobal{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return static_cast<T*>(p_reduces[I]);
},
Number<NumReduce>{});
ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple(
[&](auto I) {
auto tmp = ReduceInElementwiseOperations{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return *(static_cast<T*>(reduce_in_element_op[I]));
},
Number<NumReduce>{});
ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple(
[&](auto I) {
auto tmp = ReduceAccElementwiseOperations{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return *(static_cast<T*>(reduce_out_element_op[I]));
},
Number<NumReduce>{});
AElementwiseOperation a_element_op =
*(static_cast<AElementwiseOperation*>(gemm_element_ops[0]));
BElementwiseOperation b_element_op =
*(static_cast<BElementwiseOperation*>(gemm_element_ops[1]));
CElementwiseOperation c_element_op =
*(static_cast<CElementwiseOperation*>(gemm_element_ops[2]));
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c),
dxs_tuple,
MRaw,
NRaw,
KRaw,
reduce_tuple,
M,
N,
K,
StrideA,
StrideB,
StrideC,
a_element_op,
b_element_op,
c_element_op,
dxs_in_element_op,
dxs_out_element_op);
reduce_in_element_ops,
reduce_out_element_ops);
}
// polymorphic
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct DeviceGemmSplitK : public BaseOperator
{
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
ck::index_t KBatch) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
using DeviceGemmSplitKPtr = std::unique_ptr<
DeviceGemmSplitK<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>>;
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -10,7 +10,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_splitk.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp"
#include "ck/device_utility/device_prop.hpp"
......@@ -57,7 +57,7 @@ template <typename ADataType,
ck::index_t CThreadTransferSrcDstVectorDim,
ck::index_t CThreadTransferDstScalarPerVector>
struct DeviceGemmXdlSplitK
: public DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>
: public DeviceGemmSplitK<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
......
......@@ -10,7 +10,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_splitk.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp"
#include "ck/device_utility/device_prop.hpp"
......@@ -59,7 +59,7 @@ template <typename ADataType,
typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CBlockTransferScalarPerVector_NWaveNPerXDL>
struct DeviceGemmXdlSplitKCShuffle
: public DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>
: public DeviceGemmSplitK<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
......@@ -420,21 +420,22 @@ struct DeviceGemmXdlSplitKCShuffle
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() *
sizeof(CDataType)));
launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.block_2_ctile_map_);
ave_time =
launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.block_2_ctile_map_);
};
if(has_main_k0_block_loop)
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
struct DeviceNormalization : public BaseOperator
{
// inLengths: input tensor extent(s) from high to low dimension
// inStrides: input tensor stride(s) from high to low dimension
// reduceDims: the dimension(s) the normalization operation is applied
// alpha: typeless pointer in host memory storing the alpha scaling value of type AccDataType
// beta: typeless pointer in host memory storing the beta scaling value of type AccDataType
// in_dev: typeless const pointer in device memory storing the input tensor
// out_dev: typeless pointer in device memory storing the output tensor
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(const std::vector<index_t> inLengths,
const std::vector<index_t> inStrides,
const std::vector<int> reduceDims,
const void* alpha,
const void* beta,
const void* in_dev,
void* out_dev) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
virtual index_t GetRank() const = 0;
virtual index_t GetNumReduceDim() const = 0;
};
using DeviceNormalizationPtr = std::unique_ptr<DeviceNormalization>;
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -9,6 +9,7 @@
#include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp"
#include "ck/tensor_operation/gpu/device/device_normalization.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce_multiblock.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_softmax.hpp"
......@@ -33,8 +34,15 @@ template <typename InDataType,
index_t InSrcVectorDim,
index_t InSrcVectorSize,
index_t OutDstVectorSize>
struct DeviceSoftmax : public BaseOperator
struct DeviceSoftmax : public DeviceNormalization
{
static constexpr index_t kRank = Rank;
static constexpr index_t kNumReduceDim = NumReduceDim;
virtual index_t GetRank() const override { return kRank; }
virtual index_t GetNumReduceDim() const override { return kNumReduceDim; }
using PassThrough = tensor_operation::element_wise::PassThrough;
// Used for freeloading of some handy functions from DeviceReduceMultiBlock
......@@ -61,18 +69,33 @@ struct DeviceSoftmax : public BaseOperator
using GridDesc_M_K = decltype(Reduction::MakeSrc2dDescriptor({1}, {1}, 1, 1));
using GridwiseReduce = GridwiseSoftmax_mk_to_mk<InDataType,
OutDataType,
AccDataType,
GridDesc_M_K,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
InSrcVectorDim,
InSrcVectorSize,
OutDstVectorSize>;
using GridwiseSoftmaxGeneric = GridwiseSoftmax_mk_to_mk<InDataType,
OutDataType,
AccDataType,
GridDesc_M_K,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
InSrcVectorDim,
InSrcVectorSize,
OutDstVectorSize,
false>;
using GridwiseSoftmaxSweepOnce = GridwiseSoftmax_mk_to_mk<InDataType,
OutDataType,
AccDataType,
GridDesc_M_K,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
InSrcVectorDim,
InSrcVectorSize,
OutDstVectorSize,
true>;
struct Argument : public Reduction::Argument
{
......@@ -121,8 +144,19 @@ struct DeviceSoftmax : public BaseOperator
const auto out_grid_desc_m_k = Reduction::MakeSrc2dDescriptor(
arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration);
const auto kernel_main =
kernel_softmax<GridwiseReduce, InDataType, OutDataType, AccDataType, GridDesc_M_K>;
bool sweep_once =
in_grid_desc_m_k.GetLength(Number<1>{}) <= KThreadClusterSize * KThreadSliceSize;
const auto kernel_main = sweep_once ? kernel_softmax<GridwiseSoftmaxSweepOnce,
InDataType,
OutDataType,
AccDataType,
GridDesc_M_K>
: kernel_softmax<GridwiseSoftmaxGeneric,
InDataType,
OutDataType,
AccDataType,
GridDesc_M_K>;
float avg_time = 0;
......@@ -167,24 +201,34 @@ struct DeviceSoftmax : public BaseOperator
return true;
};
// inLengths: input tensor extent(s) from high to low dimension
// inStrides: input tensor stride(s) from high to low dimension
// reduceDims: the dimension(s) the softmax normalization operate on
// alpha: typeless pointer in host memory storing the alpha scaling value as type AccDataType
// beta: typeless pointer in host memory storing the beta scaling value as type AccDataType
// in_dev: typeless const pointer in device memory storing the input tensor
// out_dev: typeless pointer in device memory storing the output tensor
std::unique_ptr<BaseArgument> MakeArgumentPointer(const std::vector<index_t> inLengths,
const std::vector<index_t> inStrides,
const std::vector<int> reduceDims,
AccDataType alpha,
AccDataType beta,
const void* alpha,
const void* beta,
const void* in_dev,
void* out_dev)
void* out_dev) override
{
return std::make_unique<Argument>(inLengths,
inStrides,
reduceDims,
alpha,
beta,
*static_cast<const AccDataType*>(alpha),
*static_cast<const AccDataType*>(beta),
static_cast<const InDataType*>(in_dev),
static_cast<OutDataType*>(out_dev));
};
std::unique_ptr<BaseInvoker> MakeInvokerPointer() { return std::make_unique<Invoker>(); };
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>();
};
std::string GetTypeString() const override
{
......
......@@ -11,8 +11,8 @@ namespace element_wise {
struct Add
{
template <typename T>
__host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1) const;
template <typename Y, typename X0, typename X1>
__host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const;
template <>
__host__ __device__ constexpr void
......@@ -28,6 +28,13 @@ struct Add
y = x0 + x1;
};
template <>
__host__ __device__ constexpr void
operator()<half_t>(half_t& y, const float& x0, const half_t& x1) const
{
y = type_convert<half_t>(x0) + x1;
};
// Question: should half_t be supported ?
template <>
__host__ __device__ constexpr void
......
......@@ -23,19 +23,19 @@ template <typename GridwiseGemm,
typename FloatC,
typename FloatC0,
typename FloatC1,
typename DPtrsGlobal,
typename ReducePtrsGlobal,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename C1ElementwiseOperation,
typename DxsInElementwiseOperation,
typename DxsReduceAccElementwiseOperation,
typename ReduceInElementwiseOperations,
typename ReduceAccElementwiseOperations,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename DGridDescriptor_MBlock_MPerBlock,
typename ReduceGridDescriptor_MBlock_MPerBlock,
typename Block2CTileMap,
bool HasMainKBlockLoop>
__global__ void
......@@ -46,15 +46,15 @@ __global__ void
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const FloatC0* __restrict__ p_c0_grid,
const FloatC1* __restrict__ p_c1_grid,
DPtrsGlobal p_ds_grid,
const FloatC0* __restrict__ p_bias_grid,
const FloatC1* __restrict__ p_d0_grid,
ReducePtrsGlobal p_reduces_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
const C1ElementwiseOperation c1_element_op,
const DxsInElementwiseOperation dxs_in_element_op,
const DxsReduceAccElementwiseOperation dxs_out_element_op,
const ReduceInElementwiseOperations reduce_in_element_ops,
const ReduceAccElementwiseOperations reduce_out_element_ops,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
......@@ -63,7 +63,7 @@ __global__ void
c0_grid_desc_mblock_mperblock_nblock_nperblock,
const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c1_grid_desc_mblock_mperblock_nblock_nperblock,
const DGridDescriptor_MBlock_MPerBlock d_grid_desc_mblock_mperblock,
const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock,
const Block2CTileMap block_2_ctile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
......@@ -72,42 +72,42 @@ __global__ void
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid,
p_c_grid,
p_c0_grid,
p_c1_grid,
p_ds_grid,
p_bias_grid,
p_d0_grid,
p_reduces_grid,
p_shared,
a_element_op,
b_element_op,
c_element_op,
c1_element_op,
dxs_in_element_op,
dxs_out_element_op,
reduce_in_element_ops,
reduce_out_element_ops,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c0_grid_desc_mblock_mperblock_nblock_nperblock,
c1_grid_desc_mblock_mperblock_nblock_nperblock,
d_grid_desc_mblock_mperblock,
reduce_grid_desc_mblock_mperblock,
block_2_ctile_map);
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_c_grid;
ignore = p_c0_grid;
ignore = p_c1_grid;
ignore = p_ds_grid;
ignore = p_bias_grid;
ignore = p_d0_grid;
ignore = p_reduces_grid;
ignore = a_element_op;
ignore = b_element_op;
ignore = c_element_op;
ignore = c1_element_op;
ignore = dxs_in_element_op;
ignore = dxs_out_element_op;
ignore = reduce_in_element_ops;
ignore = reduce_out_element_ops;
ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = c0_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = c1_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = d_grid_desc_mblock_mperblock;
ignore = reduce_grid_desc_mblock_mperblock;
ignore = block_2_ctile_map;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
......@@ -119,22 +119,22 @@ template <typename FloatAB,
typename FloatC0,
typename FloatC1,
typename FloatReduceAcc,
typename DPtrsGlobal,
typename ReducePtrsGlobal,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename C1ElementwiseOperation,
typename DxsReduceOperation,
typename DxsInElementwiseOperation,
typename DxsReduceAccElementwiseOperation,
typename ReduceOperations,
typename ReduceInElementwiseOperations,
typename ReduceAccElementwiseOperations,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename DGlobalMemoryDataOperation,
typename ReduceGlobalMemoryDataOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename CGridDesc_M_N,
typename C0GridDesc_M_N,
typename C1GridDesc_M_N,
typename DGridDesc_M,
typename ReduceGridDesc_M,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
......@@ -321,18 +321,18 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
}
__host__ __device__ static constexpr auto
MakeDGridDescriptor_MBlock_MPerBlock(const DGridDesc_M& d_grid_desc_m)
MakeReduceGridDescriptor_MBlock_MPerBlock(const ReduceGridDesc_M& d_grid_desc_m)
{
const auto M = d_grid_desc_m.GetLength(I0);
const auto MBlock = M / MPerBlock;
const auto d_grid_desc_mblock_mperblock = transform_tensor_descriptor(
const auto reduce_grid_desc_mblock_mperblock = transform_tensor_descriptor(
d_grid_desc_m,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{}))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1>{}));
return d_grid_desc_mblock_mperblock;
return reduce_grid_desc_mblock_mperblock;
}
// return block_id to C matrix tile idx (m0, n0) mapping
......@@ -352,36 +352,37 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(C1GridDesc_M_N{}))>;
using DGridDescriptor_MBlock_MPerBlock =
remove_cvref_t<decltype(MakeDGridDescriptor_MBlock_MPerBlock(DGridDesc_M{}))>;
using ReduceGridDescriptor_MBlock_MPerBlock =
remove_cvref_t<decltype(MakeReduceGridDescriptor_MBlock_MPerBlock(ReduceGridDesc_M{}))>;
using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>;
template <bool HasMainKBlockLoop, typename Block2CTileMap>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const FloatC0* __restrict__ p_c0_grid,
const FloatC1* __restrict__ p_c1_grid,
DPtrsGlobal p_ds_grid,
void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op,
const C1ElementwiseOperation& c1_element_op,
const DxsInElementwiseOperation& dxs_in_element_op,
const DxsReduceAccElementwiseOperation& dxs_out_element_op,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock,
const C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c0_grid_desc_mblock_mperblock_nblock_nperblock,
const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c1_grid_desc_mblock_mperblock_nblock_nperblock,
const DGridDescriptor_MBlock_MPerBlock& d_grid_desc_mblock_mperblock,
const Block2CTileMap& block_2_ctile_map)
__device__ static void
Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const FloatC0* __restrict__ p_bias_grid,
const FloatC1* __restrict__ p_d0_grid,
ReducePtrsGlobal p_reduces_grid,
void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op,
const C1ElementwiseOperation& c1_element_op,
const ReduceInElementwiseOperations& reduce_in_element_ops,
const ReduceAccElementwiseOperations& reduce_out_element_ops,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock,
const C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c0_grid_desc_mblock_mperblock_nblock_nperblock,
const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c1_grid_desc_mblock_mperblock_nblock_nperblock,
const ReduceGridDescriptor_MBlock_MPerBlock& reduce_grid_desc_mblock_mperblock,
const Block2CTileMap& block_2_ctile_map)
{
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
......@@ -390,9 +391,9 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
auto c0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c0_grid, c0_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
p_bias_grid, c0_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
auto c1_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c1_grid, c1_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
p_d0_grid, c1_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// divide block work by [M, N]
const auto block_work_idx =
......@@ -725,12 +726,12 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
make_naive_tensor_descriptor_packed(
make_tuple(Number<mreduce_per_thread>{}, Number<nreduce_per_thread>{}));
// VGPR d_reduce_thread_desc_mperblock
constexpr auto d_reduce_thread_desc_mperblock =
// VGPR reduce_thread_desc_mperblock
constexpr auto reduce_thread_desc_mperblock =
make_naive_tensor_descriptor_packed(make_tuple(Number<mreduce_per_thread>{}));
// VGPR d_reduce_thread_desc_mblock_mperblock
constexpr auto d_reduce_thread_desc_mblock_mperblock =
// VGPR reduce_thread_desc_mblock_mperblock
constexpr auto reduce_thread_desc_mblock_mperblock =
make_naive_tensor_descriptor_packed(make_tuple(I1, Number<mreduce_per_thread>{}));
auto c_reduce_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatReduceAcc>(
......@@ -759,29 +760,29 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
1,
true>{c_reduce_block_desc_mperblock_nperblock, c_reduce_thread_data_idx_begin};
auto dxs_reduce_thread_copy_vgpr_to_global = generate_tuple(
auto reduce_tuple_thread_copy_vgpr_to_global = generate_tuple(
[&](auto I) {
auto p_d_grid = p_ds_grid[I];
auto d_out_element_op = dxs_out_element_op[I];
auto p_reduce_grid = p_reduces_grid[I];
auto reduce_acc_element_op = reduce_out_element_ops[I];
return ThreadwiseTensorSliceTransfer_v1r3<
FloatReduceAcc,
remove_pointer_t<decltype(p_d_grid)>,
decltype(d_reduce_thread_desc_mblock_mperblock),
decltype(d_grid_desc_mblock_mperblock),
decltype(d_out_element_op),
remove_pointer_t<decltype(p_reduce_grid)>,
decltype(reduce_thread_desc_mblock_mperblock),
decltype(reduce_grid_desc_mblock_mperblock),
decltype(reduce_acc_element_op),
Sequence<1, mreduce_per_thread>,
Sequence<0, 1>,
1,
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
DGlobalMemoryDataOperation::At(I),
ReduceGlobalMemoryDataOperation::At(I),
1,
false>{d_grid_desc_mblock_mperblock,
false>{reduce_grid_desc_mblock_mperblock,
make_multi_index(block_work_idx[I0], // mblock
c_reduce_thread_data_idx_begin[I0]), // mperblock
d_out_element_op};
reduce_acc_element_op};
},
Number<p_ds_grid.Size()>{});
Number<p_reduces_grid.Size()>{});
// c0 and c1
constexpr auto c0_reduce_thread_desc_mblock_mperblock_nblock_nperblock =
......@@ -909,35 +910,35 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
static_for<0, p_ds_grid.Size(), 1>{}([&](auto In) {
auto& p_d_grid = p_ds_grid[In];
static_for<0, p_reduces_grid.Size(), 1>{}([&](auto In) {
auto& p_reduce_grid = p_reduces_grid[In];
auto d_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d_grid, d_grid_desc_mblock_mperblock.GetElementSpaceSize());
auto reduce_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_reduce_grid, reduce_grid_desc_mblock_mperblock.GetElementSpaceSize());
auto d_thread_buf =
auto reduce_thread_buf =
make_static_buffer<AddressSpaceEnum::Vgpr, FloatReduceAcc>(
d_reduce_thread_desc_mperblock.GetElementSpaceSize());
reduce_thread_desc_mperblock.GetElementSpaceSize());
auto& d_in_element_op = dxs_in_element_op[In];
auto& reduce_in_element_op = reduce_in_element_ops[In];
auto& d_reduce_thread_copy_vgpr_to_global =
dxs_reduce_thread_copy_vgpr_to_global(In);
auto& reduce_thread_copy_vgpr_to_global =
reduce_tuple_thread_copy_vgpr_to_global(In);
using DReduceOperation = remove_cvref_t<decltype(DxsReduceOperation{}[In])>;
using ReduceOperation = remove_cvref_t<decltype(ReduceOperations{}[In])>;
using ThreadwiseReduce =
ThreadwiseReduction<FloatReduceAcc,
decltype(c_reduce_thread_desc_mperblock_nperblock),
decltype(d_reduce_thread_desc_mperblock),
DReduceOperation,
decltype(reduce_thread_desc_mperblock),
ReduceOperation,
false>;
// Global write Gemm shuffle + reduction
const auto d_zeroVal =
DReduceOperation::template GetIdentityValue<FloatReduceAcc>();
const auto reduce_identityVal =
ReduceOperation::template GetIdentityValue<FloatReduceAcc>();
static_for<0, mreduce_per_thread, 1>{}(
[&](auto I) { d_thread_buf(I) = d_zeroVal; });
[&](auto I) { reduce_thread_buf(I) = reduce_identityVal; });
// reduce in VGPR
static_for<0, mreduce_per_thread, 1>{}([&](auto im) {
......@@ -946,26 +947,25 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
Number<c_reduce_thread_desc_mperblock_nperblock.CalculateOffset(
make_tuple(im, in))>{};
d_in_element_op(c_reduce_thread_buf(offset),
c_reduce_thread_buf(offset));
reduce_in_element_op(c_reduce_thread_buf(offset),
c_reduce_thread_buf(offset));
});
});
ThreadwiseReduce::Reduce(c_reduce_thread_buf, d_thread_buf);
ThreadwiseReduce::Reduce(c_reduce_thread_buf, reduce_thread_buf);
// copy from VGPR to Global
d_reduce_thread_copy_vgpr_to_global.Run(
d_reduce_thread_desc_mblock_mperblock,
make_tuple(I0, I0),
d_thread_buf,
d_grid_desc_mblock_mperblock,
d_grid_buf);
reduce_thread_copy_vgpr_to_global.Run(reduce_thread_desc_mblock_mperblock,
make_tuple(I0, I0),
reduce_thread_buf,
reduce_grid_desc_mblock_mperblock,
reduce_grid_buf);
if constexpr(access_id < num_access - 1)
{
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
d_reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow(
d_grid_desc_mblock_mperblock,
reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow(
reduce_grid_desc_mblock_mperblock,
make_tuple(c_global_step[I0], c_global_step[I1]));
}
});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment