"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "296b01e1a15a4feecac267050543d62e921d5875"
Commit 4933b582 authored by Jing Zhang's avatar Jing Zhang
Browse files

optimized global_barrier

parent e7f10bf4
...@@ -41,7 +41,7 @@ using DsDataType = ck::Tuple<D0DataType>; ...@@ -41,7 +41,7 @@ using DsDataType = ck::Tuple<D0DataType>;
using EDataType = F32; using EDataType = F32;
using ALayout = Row; using ALayout = Row;
using BLayout = Col; using BLayout = Row;
using D0Layout = Row; using D0Layout = Row;
using DsLayout = ck::Tuple<D0Layout>; using DsLayout = ck::Tuple<D0Layout>;
using ELayout = Row; using ELayout = Row;
...@@ -51,7 +51,7 @@ using BElementOp = PassThrough; ...@@ -51,7 +51,7 @@ using BElementOp = PassThrough;
using CDEElementOp = AddBias; using CDEElementOp = AddBias;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNPadding; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MPadding;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl_Fixed_NK using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl_Fixed_NK
// clang-format off // clang-format off
...@@ -60,9 +60,12 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl_F ...@@ -60,9 +60,12 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl_F
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 4>; //< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 4>;
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 128, 16, 128, 32, 8, 8, 16, 16, 1, 4, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 4>;
//< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 128, 16, 192, 32, 8, 8, 16, 16, 1, 6, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 8>, 4>;
//< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 128, 16, 96, 32, 8, 8, 16, 16, 1, 3, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 4>;
//< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 4>; //< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 4>;
//< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 128, 16, 128, 32, 8, 8, 16, 16, 1, 4, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 4>; //< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 128, 16, 128, 32, 8, 8, 16, 16, 1, 4, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 4>;
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 4>; //< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 4>;
//< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 32, 128, 32, 4, 4, 32, 32, 1, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>; //< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 32, 128, 32, 4, 4, 32, 32, 1, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>;
// clang-format on // clang-format on
...@@ -93,7 +96,6 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -93,7 +96,6 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
// GEMM shape // GEMM shape
std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs; std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs;
std::vector<void*> p_Cs;
gemm_descs.reserve(group_count); gemm_descs.reserve(group_count);
...@@ -205,8 +207,6 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -205,8 +207,6 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
d0_tensors_device[i]->ToDevice(d0_tensors[i].mData.data()); d0_tensors_device[i]->ToDevice(d0_tensors[i].mData.data());
c_tensors_device[i]->SetZero(); c_tensors_device[i]->SetZero();
p_Cs.push_back(c_tensors_device[i]->GetDeviceBuffer());
gemm_descs.push_back({sum_of_m, gemm_descs.push_back({sum_of_m,
problem_size.Ns[i], problem_size.Ns[i],
problem_size.Ks[i], problem_size.Ks[i],
...@@ -239,28 +239,28 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -239,28 +239,28 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
std::vector<const void*> p_As = {}; std::vector<const void*> p_As = {};
std::vector<const void*> p_Bs = {}; std::vector<const void*> p_Bs = {};
std::vector<std::array<const void*, 1>> p_Ds = {}; std::vector<std::array<const void*, 1>> p_Ds = {};
std::vector<void*> p_Cs = {};
// do GEMM // do GEMM
auto argument = gemm.MakeArgument( auto argument = gemm.MakeArgument(
p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, cde_element_op); p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, cde_element_op);
DeviceMem gemm_kernel_args_dev(gemm.GetDeviceKernelArgSize(&argument)); if(!gemm.IsSupportedArgument(argument))
DeviceMem gemm_workspace_dev(gemm.GetWorkSpaceSize(&argument)); {
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
DeviceMem gemm_workspace_dev(gemm.GetWorkSpaceSize(&argument));
gemm.SetWorkSpacePointer(&argument, gemm_workspace_dev.GetDeviceBuffer()); gemm.SetWorkSpacePointer(&argument, gemm_workspace_dev.GetDeviceBuffer());
DeviceMem gemm_kernel_args_dev(gemm.GetDeviceKernelArgSize(&argument));
hip_check_error(hipMemcpy(gemm_kernel_args_dev.GetDeviceBuffer(), hip_check_error(hipMemcpy(gemm_kernel_args_dev.GetDeviceBuffer(),
grouped_gemm_kernel_args_.data(), grouped_gemm_kernel_args_.data(),
gemm.GetDeviceKernelArgSize(&argument), gemm.GetDeviceKernelArgSize(&argument),
hipMemcpyHostToDevice)); hipMemcpyHostToDevice));
if(!gemm.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
gemm.SetDeviceKernelArgs(argument, gemm_kernel_args_dev.GetDeviceBuffer()); gemm.SetDeviceKernelArgs(argument, gemm_kernel_args_dev.GetDeviceBuffer());
gemm.SetKBatch(argument, config.k_batch); gemm.SetKBatch(argument, config.k_batch);
...@@ -328,8 +328,7 @@ int main(int argc, char* argv[]) ...@@ -328,8 +328,7 @@ int main(int argc, char* argv[])
problem_size.group_count = 16; problem_size.group_count = 16;
problem_size.Ms = { problem_size.Ms = {0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0};
167, 183, 177, 181, 153, 139, 156, 173, 163, 150, 204, 184, 168, 156, 168, 148};
for(int i = 0; i < problem_size.group_count; i++) for(int i = 0; i < problem_size.group_count; i++)
{ {
...@@ -337,7 +336,7 @@ int main(int argc, char* argv[]) ...@@ -337,7 +336,7 @@ int main(int argc, char* argv[])
problem_size.Ks.push_back(4608); problem_size.Ks.push_back(4608);
problem_size.stride_As.push_back(problem_size.Ks[i]); problem_size.stride_As.push_back(problem_size.Ks[i]);
problem_size.stride_Bs.push_back(problem_size.Ks[i]); problem_size.stride_Bs.push_back(problem_size.Ns[i]);
problem_size.stride_Cs.push_back(problem_size.Ns[i]); problem_size.stride_Cs.push_back(problem_size.Ns[i]);
} }
......
...@@ -67,7 +67,7 @@ __global__ void ...@@ -67,7 +67,7 @@ __global__ void
const index_t N = gemm_desc_ptr[group_id].N; const index_t N = gemm_desc_ptr[group_id].N;
const index_t K = gemm_desc_ptr[group_id].K; const index_t K = gemm_desc_ptr[group_id].K;
if(M == 0 || N == 0 || K == 0) if(M * N * K == 0)
return; return;
const auto StrideA = gemm_desc_ptr[group_id].StrideA; const auto StrideA = gemm_desc_ptr[group_id].StrideA;
...@@ -112,10 +112,8 @@ __global__ void ...@@ -112,10 +112,8 @@ __global__ void
const auto block_2_etile_map = const auto block_2_etile_map =
GroupedGemmBlock2ETileMap(local_b2e_tile_map, BlockStart, id_off); GroupedGemmBlock2ETileMap(local_b2e_tile_map, BlockStart, id_off);
auto barrier_count_start = auto barrier_count_finished =
barrier_count + group_id * barrier_size_grp * 2 + id_local % mn_blocks; barrier_count + group_id * barrier_size_grp + id_local % mn_blocks;
auto barrier_count_finished = barrier_count + group_id * barrier_size_grp * 2 +
barrier_size_grp + id_local % mn_blocks;
GridwiseGemm::template Run<HasMainKBlockLoop, GridwiseGemm::template Run<HasMainKBlockLoop,
EGlobalMemoryDataOperation, EGlobalMemoryDataOperation,
...@@ -128,7 +126,6 @@ __global__ void ...@@ -128,7 +126,6 @@ __global__ void
p_ds_grid_, p_ds_grid_,
gemm_desc_ptr[group_id].p_e_grid, gemm_desc_ptr[group_id].p_e_grid,
p_shared, p_shared,
barrier_count_start,
barrier_count_finished, barrier_count_finished,
a_element_op, a_element_op,
b_element_op, b_element_op,
...@@ -448,13 +445,14 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -448,13 +445,14 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_m_n, k_batch_}; const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_m_n, k_batch_};
grid_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_m_n); grid_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_m_n);
grid_size_ = grid_size_grp_ * group_count_;
grid_size_ = grid_size_grp_ * group_count_;
} }
Argument(std::vector<const void*>& p_As, Argument(std::vector<const void*>&,
std::vector<const void*>& p_Bs, std::vector<const void*>&,
std::vector<std::array<const void*, NumDTensor>>& p_Ds, std::vector<std::array<const void*, NumDTensor>>&,
std::vector<void*>& p_Es, std::vector<void*>&,
std::vector<GemmDesc>& gemm_descs, std::vector<GemmDesc>& gemm_descs,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -469,29 +467,6 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -469,29 +467,6 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
group_count_ = ck::type_convert<ck::index_t>(gemm_descs.size()); group_count_ = ck::type_convert<ck::index_t>(gemm_descs.size());
if(!(group_count_ == ck::type_convert<ck::index_t>(p_As.size()) ||
0 == ck::type_convert<ck::index_t>(p_As.size())))
{
throw std::runtime_error("wrong! group_count_ != p_As || 0 != p_As.size");
}
if(!(group_count_ == ck::type_convert<ck::index_t>(p_Bs.size()) ||
0 == ck::type_convert<ck::index_t>(p_Bs.size())))
{
throw std::runtime_error("wrong! group_count_ != p_Bs || 0 != p_Bs.size");
}
if(!(group_count_ == ck::type_convert<ck::index_t>(p_Ds.size()) ||
0 == ck::type_convert<ck::index_t>(p_Ds.size())))
{
throw std::runtime_error("wrong! group_count_ != p_Ds || 0 != p_Ds.size");
}
if(!(group_count_ == ck::type_convert<ck::index_t>(p_Es.size())))
{
throw std::runtime_error("wrong! group_count_ != p_Es");
}
gemm_desc_kernel_arg_.reserve(group_count_); gemm_desc_kernel_arg_.reserve(group_count_);
index_t group_id = 0; index_t group_id = 0;
...@@ -518,12 +493,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -518,12 +493,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
// pointer // pointer
std::array<const void*, NumDTensor> p_ds_grid; std::array<const void*, NumDTensor> p_ds_grid;
static_for<0, NumDTensor, 1>{}([&](auto j) { static_for<0, NumDTensor, 1>{}([&](auto j) { p_ds_grid[j] = nullptr; });
using DDataType = remove_cvref_t<tuple_element_t<j.value, DsDataType>>;
p_ds_grid[j] =
static_cast<const DDataType*>(p_Ds.size() == 0 ? nullptr : p_Ds[i][j]);
});
std::array<index_t, NumDTensor> StrideDs; std::array<index_t, NumDTensor> StrideDs;
...@@ -570,10 +540,10 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -570,10 +540,10 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
} }
gemm_desc_kernel_arg_.push_back(GemmBiasTransKernelArg{ gemm_desc_kernel_arg_.push_back(GemmBiasTransKernelArg{
p_As.size() == 0 ? nullptr : p_As[i], nullptr,
p_Bs.size() == 0 ? nullptr : p_Bs[i], nullptr,
p_ds_grid, p_ds_grid,
p_Es[i], nullptr,
AverM, AverM,
N, N,
K, K,
...@@ -838,7 +808,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -838,7 +808,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
{ {
auto arg = *dynamic_cast<const Argument*>(p_arg); auto arg = *dynamic_cast<const Argument*>(p_arg);
return arg.group_count_ * (arg.barrier_size_grp_ * 2) * sizeof(uint32_t); return arg.group_count_ * arg.barrier_size_grp_ * sizeof(uint32_t);
} }
size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override
......
...@@ -475,7 +475,6 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -475,7 +475,6 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
DsGridPointer p_ds_grid, DsGridPointer p_ds_grid,
EDataType* __restrict__ p_e_grid, EDataType* __restrict__ p_e_grid,
void* __restrict__ p_shared, void* __restrict__ p_shared,
uint32_t* barrier_count_start,
uint32_t* barrier_count_finished, uint32_t* barrier_count_finished,
const index_t KBatch, const index_t KBatch,
const AElementwiseOperation& a_element_op, const AElementwiseOperation& a_element_op,
...@@ -495,6 +494,17 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -495,6 +494,17 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_kbatch_bk0_n_bk1.GetElementSpaceSize()); p_b_grid, b_grid_desc_kbatch_bk0_n_bk1.GetElementSpaceSize());
const auto ds_grid_buf = generate_tuple(
[&](auto i) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ds_grid[i],
ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
},
Number<NumDTensor_>{});
auto e_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// divide block work by [M, N] // divide block work by [M, N]
const auto block_work_idx = const auto block_work_idx =
block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
...@@ -611,6 +621,61 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -611,6 +621,61 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
KPack, KPack,
LoopSched>(); LoopSched>();
#if 1
if(block_work_idx[I0] == 0)
{
const index_t nThreadSize = CDEShuffleBlockTransferScalarPerVector_NPerBlock;
const index_t numNThreads = NPerBlock / nThreadSize;
const index_t numMThreads = BlockSize / numNThreads;
const index_t mThreadSize = MPerBlock / numMThreads;
const index_t m_tid = get_thread_local_1d_id() / numNThreads;
const index_t n_tid = get_thread_local_1d_id() % numNThreads;
auto c_thread_desc_mblock_mperblock_nblock_nperblock =
make_naive_tensor_descriptor_packed(
make_tuple(I1, Number<mThreadSize>{}, I1, Number<nThreadSize>{}));
StaticBuffer<AddressSpaceEnum::Vgpr,
EDataType,
c_thread_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(),
true>
e_thread_zero_buf;
auto c_thread_copy = ThreadwiseTensorSliceTransfer_v1r3<
EDataType,
EDataType,
decltype(c_thread_desc_mblock_mperblock_nblock_nperblock),
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
ck::tensor_operation::element_wise::PassThrough,
Sequence<1, mThreadSize, 1, nThreadSize>,
Sequence<0, 1, 2, 3>,
3,
CDEShuffleBlockTransferScalarPerVector_NPerBlock,
InMemoryDataOperationEnum::Set,
1,
true>{e_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(block_work_idx[I1],
m_tid * mThreadSize,
block_work_idx[I2],
n_tid * nThreadSize),
ck::tensor_operation::element_wise::PassThrough{}};
c_thread_copy.Run(c_thread_desc_mblock_mperblock_nblock_nperblock,
make_tuple(I0, I0, I0, I0),
e_thread_zero_buf,
e_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_buf);
__syncthreads();
if(threadIdx.x == 0)
{
atomicAdd(barrier_count_finished, 1);
}
}
#endif
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
...@@ -653,37 +718,14 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -653,37 +718,14 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
c_thread_buf, c_thread_buf,
num_k_block_main_loop); num_k_block_main_loop);
ignore = barrier_count_start; // shuffle C and write out
ignore = barrier_count_finished;
ignore = KBatch;
__shared__ index_t k_id_start_shared;
if(threadIdx.x == 0)
{ {
const auto k_id_start_t = atomicAdd(barrier_count_start, 1); if(threadIdx.x == 0)
k_id_start_shared = k_id_start_t;
if(k_id_start_t > 0)
{ {
while(__atomic_load_n(barrier_count_finished, __ATOMIC_RELAXED) == 0) {} while(__atomic_load_n(barrier_count_finished, __ATOMIC_RELAXED) == 0) {}
} }
}
__syncthreads();
// shuffle C and write out
{
const auto ds_grid_buf = generate_tuple(
[&](auto i) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ds_grid[i],
ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
},
Number<NumDTensor_>{});
auto e_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( __syncthreads();
p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
...@@ -847,163 +889,83 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -847,163 +889,83 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
if(k_id_start_shared == 0) // blockwise copy C/D/E between LDS and global
{ auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7<
// blockwise copy C/D/E between LDS and global ThisThreadBlock,
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7< decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType_{})),
ThisThreadBlock, Tuple<EDataType>,
decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType_{})), decltype(c_ds_desc_refs),
Tuple<EDataType>, decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
decltype(c_ds_desc_refs), CDEElementwiseOperation_,
decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make
CDEElementwiseOperation_, // Sequence support
Sequence<static_cast<index_t>(InMemoryDataOperationEnum::Set)>, // FIXME: make // arbitray type
// Sequence Sequence<1,
// support CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
// arbitray type 1,
Sequence<1, CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
1, Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, 3, // index_t VectorDim,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, CDEShuffleBlockTransferScalarPerVector_NPerBlock,
Sequence<0, 1, 2, 3>, // typename DimAccessOrder, sequence_merge_t<
3, // index_t VectorDim, Sequence<true>,
CDEShuffleBlockTransferScalarPerVector_NPerBlock, uniform_sequence_gen_t<NumDTensor_,
sequence_merge_t<Sequence<true>, false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
uniform_sequence_gen_t< Sequence<false>> // ThreadTransferDstResetCoordinateAfterRunFlags
NumDTensor_, {c_ds_desc_refs,
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags idx_c_ds_block_begin,
Sequence<false>> // ThreadTransferDstResetCoordinateAfterRunFlags tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
{c_ds_desc_refs, make_tuple(make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0)),
idx_c_ds_block_begin, cde_element_op};
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
make_tuple(make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0)), static_for<0, num_access, 1>{}([&](auto access_id) {
cde_element_op}; // make sure it's safe to write to LDS
block_sync_lds();
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS // each thread write its data from VGPR to LDS
block_sync_lds(); c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
// each thread write its data from VGPR to LDS c_thread_buf,
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id), c_shuffle_block_buf);
c_thread_buf,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, // make sure it's safe to read from LDS
c_shuffle_block_buf); block_sync_lds();
// make sure it's safe to read from LDS // each block copy its data from LDS to global
block_sync_lds(); cde_block_copy_lds_and_global.Run(
c_ds_desc_refs,
// each block copy its data from LDS to global c_ds_buf_refs,
cde_block_copy_lds_and_global.Run( tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
c_ds_desc_refs, tie(e_grid_buf));
c_ds_buf_refs,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock), if constexpr(access_id < num_access - 1)
tie(e_grid_buf)); {
constexpr auto cde_lds_and_global_step =
if constexpr(access_id < num_access - 1) sfc_cde_block.GetForwardStep(access_id);
{
constexpr auto cde_lds_and_global_step =
sfc_cde_block.GetForwardStep(access_id);
// move on Ds
static_for<0, NumDTensor_, 1>{}([&](auto i) {
cde_block_copy_lds_and_global.MoveSrcSliceWindow(
c_ds_desc_refs, i + I1, cde_lds_and_global_step);
});
// move on E
cde_block_copy_lds_and_global.MoveDstSliceWindow(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
I0,
cde_lds_and_global_step);
}
});
}
else
{
// blockwise copy C/D/E between LDS and global
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7<
ThisThreadBlock,
decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType_{})),
Tuple<EDataType>,
decltype(c_ds_desc_refs),
decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
CDEElementwiseOperation_,
Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make
// Sequence support
// arbitray type
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CDEShuffleBlockTransferScalarPerVector_NPerBlock,
sequence_merge_t<Sequence<true>,
uniform_sequence_gen_t<
NumDTensor_,
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
Sequence<false>> // ThreadTransferDstResetCoordinateAfterRunFlags
{c_ds_desc_refs,
idx_c_ds_block_begin,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
make_tuple(make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0)),
cde_element_op};
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
block_sync_lds();
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
c_thread_buf,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_shuffle_block_buf);
// make sure it's safe to read from LDS
block_sync_lds();
// each block copy its data from LDS to global
cde_block_copy_lds_and_global.Run(
c_ds_desc_refs,
c_ds_buf_refs,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
tie(e_grid_buf));
if constexpr(access_id < num_access - 1)
{
constexpr auto cde_lds_and_global_step =
sfc_cde_block.GetForwardStep(access_id);
// move on Ds
static_for<0, NumDTensor_, 1>{}([&](auto i) {
cde_block_copy_lds_and_global.MoveSrcSliceWindow(
c_ds_desc_refs, i + I1, cde_lds_and_global_step);
});
// move on E
cde_block_copy_lds_and_global.MoveDstSliceWindow(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
I0,
cde_lds_and_global_step);
}
});
}
__syncthreads(); // move on Ds
static_for<0, NumDTensor_, 1>{}([&](auto i) {
cde_block_copy_lds_and_global.MoveSrcSliceWindow(
c_ds_desc_refs, i + I1, cde_lds_and_global_step);
});
// move on E
cde_block_copy_lds_and_global.MoveDstSliceWindow(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
I0,
cde_lds_and_global_step);
}
});
if(threadIdx.x == 0) if(threadIdx.x == 0)
{ {
index_t k_id_finished_t = atomicAdd(barrier_count_finished, 1); index_t k_id_finished_t = atomicAdd(barrier_count_finished, 1);
if(k_id_finished_t == KBatch - 1) if(k_id_finished_t == KBatch)
{ {
*barrier_count_start = 0;
*barrier_count_finished = 0; *barrier_count_finished = 0;
} }
} }
...@@ -1023,7 +985,6 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -1023,7 +985,6 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
DsGridPointer p_ds_grid, DsGridPointer p_ds_grid,
void* __restrict__ p_e_grid_, void* __restrict__ p_e_grid_,
void* __restrict__ p_shared, void* __restrict__ p_shared,
uint32_t* barrier_count_start,
uint32_t* barrier_count_finished, uint32_t* barrier_count_finished,
const AElementwiseOperation& a_element_op, const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op, const BElementwiseOperation& b_element_op,
...@@ -1042,10 +1003,6 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -1042,10 +1003,6 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
const auto p_b_grid = reinterpret_cast<const ABDataType*>(p_b_grid_); const auto p_b_grid = reinterpret_cast<const ABDataType*>(p_b_grid_);
const auto p_e_grid = reinterpret_cast<EDataType*>(p_e_grid_); const auto p_e_grid = reinterpret_cast<EDataType*>(p_e_grid_);
// tensor descriptors for problem definiton
// const auto a_grid_desc_m_k = MakeAGridDescriptor_M_K<ALayout, GemmSpec>(M, K, StrideA);
// const auto b_grid_desc_n_k = MakeBGridDescriptor_N_K<BLayout, GemmSpec>(K, N, StrideB);
using DsGridDesc_M_N = using DsGridDesc_M_N =
remove_cvref_t<decltype(MakeDsGridDescriptor_M_N<DsLayout, GemmSpec>({}, {}, {}))>; remove_cvref_t<decltype(MakeDsGridDescriptor_M_N<DsLayout, GemmSpec>({}, {}, {}))>;
...@@ -1084,7 +1041,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -1084,7 +1041,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
const index_t kbatch_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); const index_t kbatch_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
if(kbatch_id == 0) if(kbatch_id == KBatch - 1)
{ {
Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, NumDTensor, DsDataType>( Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, NumDTensor, DsDataType>(
p_a_grid, p_a_grid,
...@@ -1092,7 +1049,6 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -1092,7 +1049,6 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
p_ds_grid, p_ds_grid,
p_e_grid, p_e_grid,
p_shared, p_shared,
barrier_count_start,
barrier_count_finished, barrier_count_finished,
KBatch, KBatch,
a_element_op, a_element_op,
...@@ -1112,7 +1068,6 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -1112,7 +1068,6 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
p_ds_grid, p_ds_grid,
p_e_grid, p_e_grid,
p_shared, p_shared,
barrier_count_start,
barrier_count_finished, barrier_count_finished,
KBatch, KBatch,
a_element_op, a_element_op,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment