"...git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "5e5b822bb0b4ec02639d3e8e750944a6024dc29f"
Commit 134fc2e7 authored by Adam Osewski's avatar Adam Osewski
Browse files

Fix StorePartials.

Pass pointer to whole workspace not the shifted one.
parent 88a4fbfb
......@@ -319,6 +319,7 @@ int main(int argc, char* argv[])
if(argc < 11)
{
std::vector<ck::index_t> Ms{64, 127, 255, 129, 260, 190, 77};
problem_size.group_count = Ms.size();
for(int i = 0; i < problem_size.group_count; i++)
{
......
......@@ -162,11 +162,7 @@ __global__ void
// if (changed group_id || next [M,N] tile)
if(!b2c_tile_map.IsFirstKSplitBlock())
{
void* __restrict__ p_block_workspace = reinterpret_cast<void* __restrict__>(
reinterpret_cast<char*>(p_workspace) + blockIdx.x * GridwiseGemm::GetMPerBlock() *
GridwiseGemm::GetNPerBlock() *
sizeof(typename GridwiseGemm::AccType));
gridwise_gemm.StorePartials(p_block_workspace);
gridwise_gemm.StorePartials(p_workspace);
}
work_scheduler.FlagFinished(k_batch, output_tile_idx, output_tile_idx_offset);
......
......@@ -814,26 +814,71 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
{
const auto& c_thread_buf = blockwise_gemm_.GetCThreadBuffer();
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
// M0 = grid_size
// N0 = 1
// M1 = MPerBlock
// N1 = NPerBlock
const auto workspace_grid_desc_m0_n0_m1_n1 =
MakeWorkspaceGridDesc_GridSize_I1_MPerBlock_NPerBlock(get_grid_size());
const auto w_grid_m0 = workspace_grid_desc_m0_n0_m1_n1.GetLength(I0);
const auto w_grid_n0 = workspace_grid_desc_m0_n0_m1_n1.GetLength(I1);
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
BlockwiseGemmT::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2;
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
// M0 = grid_size -> MRepeats (MXdlPerWave)
// N0 = 1 -> NRepeats (NXdlPerWave)
const auto workspace_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3 = transform_tensor_descriptor(
workspace_grid_desc_m0_n0_m1_n1,
make_tuple(make_pass_through_transform(w_grid_m0),
make_pass_through_transform(w_grid_n0),
make_unmerge_transform(make_tuple(M0, M1, M2, M3, M4)),
make_unmerge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2, 4, 6, 7, 8>{}, Sequence<3, 5, 9>{}));
const auto workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
workspace_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_tuple(make_merge_transform(make_tuple(w_grid_m0, M0)), // MRepeats (grid)
make_merge_transform(make_tuple(w_grid_n0, N0)), // NRepeats (grid)
make_pass_through_transform(M1), // MWave
make_pass_through_transform(N1), // NWave
make_pass_through_transform(M2), // mfma_instr.num_groups_per_blk
make_pass_through_transform(M3), // mfma_instr.num_input_blks
make_pass_through_transform(M4), // mfma_instr.group_size
make_pass_through_transform(N2)), // mfma_instr.num_threads_per_blk
make_tuple(Sequence<0, 2>{},
Sequence<1, 3>{},
Sequence<4>{},
Sequence<5>{},
Sequence<6>{},
Sequence<7>{},
Sequence<8>{},
Sequence<9>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{},
Sequence<6>{},
Sequence<7>{}));
auto p_workspace_grid = reinterpret_cast<AccDataType*>(p_workspace);
auto w_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_workspace_grid, workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize());
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0);
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1);
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I2);
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I3);
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I4);
constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I5);
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6);
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I7);
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
BlockwiseGemmT::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
......@@ -869,14 +914,15 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
decltype(workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2),
ck::tensor_operation::element_wise::PassThrough,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLengths()), // SliceLengths
Sequence<0, 1, 2, 3, 4, 5, 6, 7>, // DimAccessOrder
7, // DstVectorDim,
1, // DstScalarPerVector
// N -> then M dims
Sequence<0, 2, 4, 5, 6, 1, 3, 7>, // DimAccessOrder
7, // DstVectorDim,
1, // DstScalarPerVector
InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector
true>{// DstResetCoordinateAfterRun
workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(m_thread_data_on_block_idx[I0],
make_multi_index((static_cast<index_t>(blockIdx.x)) * MXdlPerWave,
n_thread_data_on_block_idx[I0],
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1],
......@@ -916,7 +962,6 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
const auto w_grid_m0 = workspace_grid_desc_m0_n0_m1_n1.GetLength(I0);
const auto w_grid_n0 = workspace_grid_desc_m0_n0_m1_n1.GetLength(I1);
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
BlockwiseGemmT::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
......@@ -929,8 +974,8 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
// M0 = grid_size -> MRepeats
// N0 = 1 -> NRepeats
// M0 = grid_size -> MRepeats (MXdlPerWave)
// N0 = 1 -> NRepeats (NXdlPerWave)
const auto workspace_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3 = transform_tensor_descriptor(
workspace_grid_desc_m0_n0_m1_n1,
make_tuple(make_pass_through_transform(w_grid_m0),
......@@ -1003,7 +1048,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
decltype(workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2), // SrcDesc,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), // DstDesc,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLengths()), // SliceLengths,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>, // DimAccessOrder,
Sequence<0, 2, 4, 5, 6, 1, 3, 7>, // DimAccessOrder,
7, // SrcVectorDim,
1, // SrcScalarPerVector,
1, // SrcScalarStrideInVector,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment