Commit 7bf9a377 authored by Jing Zhang's avatar Jing Zhang
Browse files

added an example grouped_gemm_multi_abd

parent 59136091
...@@ -42,3 +42,8 @@ if(USE_BITINT_EXTENSION_INT4) ...@@ -42,3 +42,8 @@ if(USE_BITINT_EXTENSION_INT4)
add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int4) add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int4)
endif() endif()
endif() endif()
add_example_executable(example_grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16 grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp)
if(result EQUAL 0)
add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16)
endif()
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Add = ck::tensor_operation::element_wise::Add;
using A0DataType = F16;
using B0DataType = F16;
using AsDataType = ck::Tuple<A0DataType>;
using BsDataType = ck::Tuple<B0DataType>;
using AccDataType = F32;
using CShuffleDataType = F32;
using D0DataType = F32;
using DsDataType = ck::Tuple<D0DataType>;
using EDataType = F32;
using A0Layout = Row;
using B0Layout = Col;
using AsLayout = ck::Tuple<A0Layout>;
using BsLayout = ck::Tuple<B0Layout>;
using D0Layout = Row;
using DsLayout = ck::Tuple<D0Layout>;
using ELayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = Add;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MPadding;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK
// clang-format off
//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 128, 16, 128, 32, 8, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1, ck::half_t>;
// clang-format on
struct ProblemSize final
{
std::vector<ck::index_t> Ms;
std::vector<ck::index_t> Ns;
std::vector<ck::index_t> Ks;
std::vector<ck::index_t> stride_As;
std::vector<ck::index_t> stride_Bs;
std::vector<ck::index_t> stride_Cs;
ck::index_t group_count;
};
struct ExecutionConfig final
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
int k_batch = 1;
};
bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
{
auto group_count = problem_size.group_count;
// GEMM shape
std::vector<ck::tensor_operation::device::GemmMultiABDDesc> gemm_descs;
gemm_descs.reserve(group_count);
int sum_of_m = 0;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
using namespace ck::literals;
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
std::vector<Tensor<A0DataType>> a_tensors;
std::vector<Tensor<B0DataType>> b_tensors;
std::vector<Tensor<D0DataType>> d0_tensors;
std::vector<Tensor<EDataType>> c_host_tensors;
std::vector<Tensor<EDataType>> c_device_tensors;
a_tensors.reserve(group_count);
b_tensors.reserve(group_count);
d0_tensors.reserve(group_count);
c_host_tensors.reserve(group_count);
c_device_tensors.reserve(group_count);
using DeviceMemPtr = std::unique_ptr<DeviceMem>;
std::vector<DeviceMemPtr> a_tensors_device, b_tensors_device, d0_tensors_device,
c_tensors_device;
a_tensors_device.reserve(group_count);
b_tensors_device.reserve(group_count);
d0_tensors_device.reserve(group_count);
c_tensors_device.reserve(group_count);
std::size_t flop = 0, num_btype = 0;
for(int i = 0; i < group_count; i++)
{
sum_of_m += problem_size.Ms[i];
a_tensors.push_back(Tensor<A0DataType>(f_host_tensor_descriptor(
problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], A0Layout{})));
b_tensors.push_back(Tensor<B0DataType>(f_host_tensor_descriptor(
problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], B0Layout{})));
d0_tensors.push_back(Tensor<D0DataType>(
f_host_tensor_descriptor(problem_size.Ms[i], problem_size.Ns[i], 0, ELayout{})));
c_host_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor(
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{})));
c_device_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor(
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{})));
std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc
<< " b_k_n: " << b_tensors[i].mDesc << " d_m_n: " << d0_tensors[i].mDesc
<< " c_m_n: " << c_device_tensors[i].mDesc << std::endl;
flop += std::size_t(2) * problem_size.Ms[i] * problem_size.Ks[i] * problem_size.Ns[i];
num_btype += sizeof(A0DataType) * a_tensors[i].mDesc.GetElementSize() +
sizeof(B0DataType) * b_tensors[i].mDesc.GetElementSize() +
sizeof(D0DataType) * d0_tensors[i].mDesc.GetElementSize() +
sizeof(EDataType) * c_device_tensors[i].mDesc.GetElementSize();
switch(config.init_method)
{
case 0: break;
case 1:
a_tensors[i].GenerateTensorValue(GeneratorTensor_2<A0DataType>{-5, 5});
b_tensors[i].GenerateTensorValue(GeneratorTensor_2<B0DataType>{-5, 5});
break;
case 2:
a_tensors[i].GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
b_tensors[i].GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
break;
default:
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{});
}
d0_tensors[i].GenerateTensorValue(GeneratorTensor_3<D0DataType>{-0.5, 0.5});
}
using GroupedGemmKernelArgument =
ck::tensor_operation::device::GroupedGemmMultiABDKernelArgument<1, 1, 1>;
std::vector<GroupedGemmKernelArgument> grouped_gemm_kernel_args_;
grouped_gemm_kernel_args_.reserve(group_count);
for(int i = 0; i < group_count; i++)
{
a_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(A0DataType) * sum_of_m * problem_size.Ks[i]));
b_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(B0DataType) * problem_size.Ns[i] * problem_size.Ks[i]));
d0_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(D0DataType) * problem_size.Ns[i]));
c_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(EDataType) * sum_of_m * problem_size.Ns[i]));
a_tensors_device[i]->ToDevice(a_tensors[i].mData.data(),
a_tensors[i].mDesc.GetElementSpaceSize() *
sizeof(A0DataType));
b_tensors_device[i]->ToDevice(b_tensors[i].mData.data(),
b_tensors[i].mDesc.GetElementSpaceSize() *
sizeof(B0DataType));
d0_tensors_device[i]->ToDevice(d0_tensors[i].mData.data());
c_tensors_device[i]->SetZero();
gemm_descs.push_back({sum_of_m,
problem_size.Ns[i],
problem_size.Ks[i],
{1},
{problem_size.stride_Bs[i]},
{0},
1});
grouped_gemm_kernel_args_.push_back(
{std::array<const void*, 1>{a_tensors_device[i]->GetDeviceBuffer()},
std::array<const void*, 1>{b_tensors_device[i]->GetDeviceBuffer()},
std::array<const void*, 1>{d0_tensors_device[i]->GetDeviceBuffer()},
c_tensors_device[i]->GetDeviceBuffer(),
problem_size.Ms[i],
problem_size.Ns[i],
problem_size.Ks[i],
std::array<ck::index_t, 1>{problem_size.stride_As[i]},
std::array<ck::index_t, 1>{problem_size.stride_Bs[i]},
std::array<ck::index_t, 1>{0},
problem_size.stride_Cs[i]});
}
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
std::vector<std::array<const void*, 1>> p_As = {};
std::vector<std::array<const void*, 1>> p_Bs = {};
std::vector<std::array<const void*, 1>> p_Ds = {};
std::vector<void*> p_Cs = {};
// do GEMM
auto argument = gemm.MakeArgument(
p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, cde_element_op);
if(!gemm.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
DeviceMem gemm_workspace_dev(gemm.GetWorkSpaceSize(&argument));
gemm.SetWorkSpacePointer(&argument, gemm_workspace_dev.GetDeviceBuffer());
DeviceMem gemm_kernel_args_dev(gemm.GetDeviceKernelArgSize(&argument));
hip_check_error(hipMemcpy(gemm_kernel_args_dev.GetDeviceBuffer(),
grouped_gemm_kernel_args_.data(),
gemm.GetDeviceKernelArgSize(&argument),
hipMemcpyHostToDevice));
gemm.SetDeviceKernelArgs(argument, gemm_kernel_args_dev.GetDeviceBuffer());
gemm.SetKBatch(argument, config.k_batch);
invoker.Run(argument, StreamConfig{nullptr, false});
if(config.time_kernel)
{
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << gemm.GetTypeString() << std::endl;
}
bool pass = true;
if(config.do_verification)
{
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<A0DataType,
B0DataType,
EDataType,
AccDataType,
AElementOp,
BElementOp,
PassThrough>;
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data(),
c_device_tensors[i].mDesc.GetElementSize() *
sizeof(EDataType));
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(a_tensors[i],
b_tensors[i],
c_host_tensors[i],
a_element_op,
b_element_op,
PassThrough{});
ref_invoker.Run(ref_argument);
for(int m = 0; m < problem_size.Ms[i]; ++m)
{
for(int n = 0; n < problem_size.Ns[i]; ++n)
{
cde_element_op(
c_host_tensors[i](m, n), c_host_tensors[i](m, n), d0_tensors[i](m, n));
}
}
pass &= ck::utils::check_err(c_device_tensors[i], c_host_tensors[i]);
}
}
return pass;
}
int main(int argc, char* argv[])
{
ProblemSize problem_size;
ExecutionConfig config;
problem_size.group_count = 16;
for(int i = 0; i < problem_size.group_count; i++)
{
problem_size.Ms.push_back(16 + 16 * i);
problem_size.Ns.push_back(128);
problem_size.Ks.push_back(64);
problem_size.stride_As.push_back(problem_size.Ks[i]);
problem_size.stride_Bs.push_back(problem_size.Ks[i]);
problem_size.stride_Cs.push_back(problem_size.Ns[i]);
}
if(argc == 5)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
config.k_batch = std::stoi(argv[4]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=n0, 1=yes)\n");
printf("arg4: k_batch (>0)\n");
exit(0);
}
return !run_grouped_gemm(problem_size, config);
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
struct GemmMultiABDDesc
{
ck::index_t M_, N_, K_;
std::vector<ck::index_t> stride_As_;
std::vector<ck::index_t> stride_Bs_;
std::vector<ck::index_t> stride_Ds_;
ck::index_t stride_C_;
};
template <typename AsLayout,
typename BsLayout,
typename DsLayout,
typename ELayout,
typename AsDataType,
typename BsDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct DeviceGroupedGemmMultiABD : public BaseOperator
{
static constexpr index_t NumATensor = AsDataType::Size();
static constexpr index_t NumBTensor = BsDataType::Size();
static constexpr index_t NumDTensor = DsDataType::Size();
static_assert(AsLayout::Size() == AsDataType::Size(), "wrong! inconsistent NumATensor");
static_assert(BsLayout::Size() == BsDataType::Size(), "wrong! inconsistent NumBTensor");
static_assert(DsLayout::Size() == DsDataType::Size(), "wrong! inconsistent NumDTensor");
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(std::vector<std::array<const void*, NumATensor>>& p_as,
std::vector<std::array<const void*, NumBTensor>>& p_bs,
std::vector<std::array<const void*, NumDTensor>>& p_ds,
std::vector<void*>& p_e,
std::vector<GemmMultiABDDesc>& gemm_desc,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <array>
#include "device_grouped_gemm_multi_abd.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <index_t NumATensor = 1, index_t NumBTensor = 1, index_t NumDTensor = 0>
struct GroupedGemmMultiABDKernelArgument
{
std::array<const void*, NumATensor> p_as_grid;
std::array<const void*, NumBTensor> p_bs_grid;
std::array<const void*, NumDTensor> p_ds_grid;
void* p_e_grid;
index_t M;
index_t N;
index_t K;
std::array<index_t, NumATensor> StrideAs;
std::array<index_t, NumBTensor> StrideBs;
std::array<index_t, NumDTensor> StrideDs;
index_t StrideE;
};
template <typename AsLayout,
typename BsLayout,
typename DsLayout,
typename ELayout,
typename AsDataType,
typename BsDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct DeviceGroupedGemmMultiABDFixedNK : DeviceGroupedGemmMultiABD<AsLayout,
BsLayout,
DsLayout,
ELayout,
AsDataType,
BsDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
{
virtual void SetDeviceKernelArgs(BaseArgument* p_arg, const void* kernel_args) const = 0;
virtual size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const = 0;
virtual void SetKBatch(BaseArgument* p_arg, index_t k_batch) const = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -428,14 +428,14 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle ...@@ -428,14 +428,14 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
[&](auto i) { [&](auto i) {
using ALayout = remove_cvref_t<tuple_element_t<i.value, AsLayout>>; using ALayout = remove_cvref_t<tuple_element_t<i.value, AsLayout>>;
return MakeAGridDescriptor_M_N<ALayout, GemmSpec>(MRaws[i], KRaws[i], AsStride[i]); return MakeAGridDescriptor_M_K<ALayout, GemmSpec>(MRaws[i], KRaws[i], AsStride[i]);
}, },
Number<NumATensor>{}); Number<NumATensor>{});
} }
template <typename BLayout, GemmSpecialization GemmSpec> template <typename BLayout, GemmSpecialization GemmSpec>
__host__ __device__ static auto __host__ __device__ static auto
MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB) MakeBGridDescriptor_N_K(index_t NRaw, index_t KRaw, index_t StrideB)
{ {
constexpr auto matrix_padder = constexpr auto matrix_padder =
ck::tensor_operation::device::MatrixPadder<GemmSpec, index_t, index_t, index_t>{ ck::tensor_operation::device::MatrixPadder<GemmSpec, index_t, index_t, index_t>{
...@@ -459,15 +459,15 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle ...@@ -459,15 +459,15 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
template <typename BsLayout, GemmSpecialization GemmSpec> template <typename BsLayout, GemmSpecialization GemmSpec>
__host__ __device__ static auto __host__ __device__ static auto
MakeBsGridDescriptor_N_K(const std::array<index_t, NumBTensor>& KRaws, MakeBsGridDescriptor_N_K(const std::array<index_t, NumBTensor>& NRaws,
const std::array<index_t, NumBTensor>& NRaws, const std::array<index_t, NumBTensor>& KRaws,
const std::array<index_t, NumBTensor>& BsStride) const std::array<index_t, NumBTensor>& BsStride)
{ {
return generate_tuple( return generate_tuple(
[&](auto i) { [&](auto i) {
using BLayout = remove_cvref_t<tuple_element_t<i.value, BsLayout>>; using BLayout = remove_cvref_t<tuple_element_t<i.value, BsLayout>>;
return MakeBGridDescriptor_N_K<BLayout, GemmSpec>(KRaws[i], NRaws[i], BsStride[i]); return MakeBGridDescriptor_N_K<BLayout, GemmSpec>(NRaws[i], KRaws[i], BsStride[i]);
}, },
Number<NumBTensor>{}); Number<NumBTensor>{});
} }
...@@ -571,6 +571,8 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle ...@@ -571,6 +571,8 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
return; return;
} }
// printf("%d %d\n", block_work_idx[I0], block_work_idx[I1]);
// HACK: this force m/n_block_data_idx_on_grid into SGPR // HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid = const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
...@@ -657,6 +659,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle ...@@ -657,6 +659,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize, BlockSize,
ComputeDataType, ComputeDataType,
ComputeDataType,
AccDataType, AccDataType,
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
...@@ -953,22 +956,22 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle ...@@ -953,22 +956,22 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
typename DsLayout, typename DsLayout,
typename ELayout, typename ELayout,
typename Block2ETileMap> typename Block2ETileMap>
__device__ static void Run(AsGridPointer p_as_grid, __device__ static void Run2(AsGridPointer p_as_grid,
BsGridPointer p_bs_grid, BsGridPointer p_bs_grid,
DsGridPointer p_ds_grid, DsGridPointer p_ds_grid,
void* __restrict__ p_e_grid_, void* __restrict__ p_e_grid_,
void* __restrict__ p_shared, void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op, const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op, const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op, const CDEElementwiseOperation& cde_element_op,
const index_t M, const index_t M,
const index_t N, const index_t N,
const index_t K, const index_t K,
const std::array<index_t, NumATensor> StrideAs, const std::array<index_t, NumATensor> StrideAs,
const std::array<index_t, NumBTensor> StrideBs, const std::array<index_t, NumBTensor> StrideBs,
const std::array<index_t, NumDTensor> StrideDs, const std::array<index_t, NumDTensor> StrideDs,
const index_t StrideE, const index_t StrideE,
const Block2ETileMap& block_2_etile_map) const Block2ETileMap& block_2_etile_map)
{ {
using AsGridDesc_M_K = using AsGridDesc_M_K =
remove_cvref_t<decltype(MakeAsGridDescriptor_M_K<AsLayout, GemmSpec>({}, {}, {}))>; remove_cvref_t<decltype(MakeAsGridDescriptor_M_K<AsLayout, GemmSpec>({}, {}, {}))>;
......
...@@ -142,9 +142,9 @@ struct ThreadwiseTensorSliceTransfer_v7r2 ...@@ -142,9 +142,9 @@ struct ThreadwiseTensorSliceTransfer_v7r2
__device__ void RunRead(const SrcDescs& src_descs, const SrcBuffers& src_bufs) __device__ void RunRead(const SrcDescs& src_descs, const SrcBuffers& src_bufs)
{ {
// loop over space-filling curve // loop over space-filling curve
static_for<0, num_access, 1>{}([&](auto iAccess) { static_for<0, src_num_access, 1>{}([&](auto iAccess) {
auto src_vectors = generate_vectors<SrcDatas, SrcScalarPerVector>(); auto src_vectors = generate_vectors<SrcDatas, SrcScalarPerVector>();
auto dst_vectors = generate_vectors<DstDatas, DstScalarPerVector>(); auto dst_vectors = generate_vectors<DstDatas, SrcScalarPerVector>();
// copy data from src_bufs into src_vectors // copy data from src_bufs into src_vectors
static_for<0, nSrc, 1>{}([&](auto i) { static_for<0, nSrc, 1>{}([&](auto i) {
...@@ -251,7 +251,7 @@ struct ThreadwiseTensorSliceTransfer_v7r2 ...@@ -251,7 +251,7 @@ struct ThreadwiseTensorSliceTransfer_v7r2
dst_vectors_tuple_(iAccess) = dst_vectors; dst_vectors_tuple_(iAccess) = dst_vectors;
// move coordinate // move coordinate
if constexpr(iAccess.value != num_access - 1) if constexpr(iAccess.value != src_num_access - 1)
{ {
constexpr auto forward_step = SrcSpaceFillingCurve::GetForwardStep(iAccess); constexpr auto forward_step = SrcSpaceFillingCurve::GetForwardStep(iAccess);
...@@ -282,7 +282,7 @@ struct ThreadwiseTensorSliceTransfer_v7r2 ...@@ -282,7 +282,7 @@ struct ThreadwiseTensorSliceTransfer_v7r2
__device__ void RunWrite(const DstDescs& dst_descs, DstBuffers dst_bufs) __device__ void RunWrite(const DstDescs& dst_descs, DstBuffers dst_bufs)
{ {
// loop over space-filling curve // loop over space-filling curve
static_for<0, num_access, 1>{}([&](auto iAccess) { static_for<0, dst_num_access, 1>{}([&](auto iAccess) {
auto dst_vectors = dst_vectors_tuple_[iAccess]; auto dst_vectors = dst_vectors_tuple_[iAccess];
// copy data from buf_vectors into dst_bufs // copy data from buf_vectors into dst_bufs
...@@ -303,7 +303,7 @@ struct ThreadwiseTensorSliceTransfer_v7r2 ...@@ -303,7 +303,7 @@ struct ThreadwiseTensorSliceTransfer_v7r2
}); });
// move coordinate // move coordinate
if constexpr(iAccess.value != num_access - 1) if constexpr(iAccess.value != dst_num_access - 1)
{ {
constexpr auto forward_step = DstSpaceFillingCurve::GetForwardStep(iAccess); constexpr auto forward_step = DstSpaceFillingCurve::GetForwardStep(iAccess);
...@@ -346,25 +346,25 @@ struct ThreadwiseTensorSliceTransfer_v7r2 ...@@ -346,25 +346,25 @@ struct ThreadwiseTensorSliceTransfer_v7r2
__device__ static constexpr auto GetSrcCoordinateResetStep() __device__ static constexpr auto GetSrcCoordinateResetStep()
{ {
if constexpr(num_access == 0) if constexpr(src_num_access == 0)
{ {
return typename SrcSpaceFillingCurve::Index{}; return typename SrcSpaceFillingCurve::Index{};
} }
else else
{ {
return SrcSpaceFillingCurve::GetStepBetween(Number<num_access - 1>{}, Number<0>{}); return SrcSpaceFillingCurve::GetStepBetween(Number<src_num_access - 1>{}, Number<0>{});
} }
} }
__device__ static constexpr auto GetDstCoordinateResetStep() __device__ static constexpr auto GetDstCoordinateResetStep()
{ {
if constexpr(num_access == 0) if constexpr(dst_num_access == 0)
{ {
return typename DstSpaceFillingCurve::Index{}; return typename DstSpaceFillingCurve::Index{};
} }
else else
{ {
return DstSpaceFillingCurve::GetStepBetween(Number<num_access - 1>{}, Number<0>{}); return DstSpaceFillingCurve::GetStepBetween(Number<dst_num_access - 1>{}, Number<0>{});
} }
} }
...@@ -408,9 +408,10 @@ struct ThreadwiseTensorSliceTransfer_v7r2 ...@@ -408,9 +408,10 @@ struct ThreadwiseTensorSliceTransfer_v7r2
using SrcVectorsType = decltype(generate_vectors<SrcDatas, SrcScalarPerVector>()); using SrcVectorsType = decltype(generate_vectors<SrcDatas, SrcScalarPerVector>());
using DstVectorsType = decltype(generate_vectors<DstDatas, DstScalarPerVector>()); using DstVectorsType = decltype(generate_vectors<DstDatas, DstScalarPerVector>());
static constexpr auto num_access = SrcSpaceFillingCurve::GetNumOfAccess(); static constexpr auto src_num_access = SrcSpaceFillingCurve::GetNumOfAccess();
static constexpr auto dst_num_access = DstSpaceFillingCurve::GetNumOfAccess();
StaticallyIndexedArray<DstVectorsType, num_access> dst_vectors_tuple_; StaticallyIndexedArray<DstVectorsType, dst_num_access> dst_vectors_tuple_;
SrcCoords src_coords_; SrcCoords src_coords_;
DstCoords dst_coords_; DstCoords dst_coords_;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment