Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
c1f7d9f2
Commit
c1f7d9f2
authored
Oct 19, 2023
by
Adam Osewski
Browse files
Accumulate partial results in workspace
parent
fee53701
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
166 additions
and
7 deletions
+166
-7
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_tile_loop.hpp
...grouped_gemm_multiple_d_splitk_xdl_cshuffle_tile_loop.hpp
+3
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle_v2.hpp
.../grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle_v2.hpp
+163
-6
No files found.
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_tile_loop.hpp
View file @
c1f7d9f2
...
@@ -177,9 +177,11 @@ __global__ void
...
@@ -177,9 +177,11 @@ __global__ void
// Accumulate partial results. We can have different # of workgroups to reduce, thus we
// Accumulate partial results. We can have different # of workgroups to reduce, thus we
// read actual flag value.
// read actual flag value.
[[
maybe_unused
]]
const
index_t
flag_v
=
__builtin_amdgcn_readfirstlane
(
const
index_t
flag_v
=
__builtin_amdgcn_readfirstlane
(
work_scheduler
.
GetFlagValue
(
k_batch
,
output_tile_idx
,
output_tile_idx_offset
));
work_scheduler
.
GetFlagValue
(
k_batch
,
output_tile_idx
,
output_tile_idx_offset
));
gridwise_gemm
.
AccumulatePartials
(
p_workspace
,
flag_v
);
// TODO: do blockwise reduction from workspace (GMEM) to results_buffer (registers)
// TODO: do blockwise reduction from workspace (GMEM) to results_buffer (registers)
// Signal waiting blocks that they can start use their workspace.
// Signal waiting blocks that they can start use their workspace.
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle_v2.hpp
View file @
c1f7d9f2
...
@@ -4,19 +4,20 @@
...
@@ -4,19 +4,20 @@
#pragma once
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/utility/reduction_functions_accumulate.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -807,7 +808,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
...
@@ -807,7 +808,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
1
,
// DstScalarStrideInVector
1
,
// DstScalarStrideInVector
true
>
{
// DstResetCoordinateAfterRun
true
>
{
// DstResetCoordinateAfterRun
workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
make_multi_index
(
static_cast
<
index_t
>
(
blockIdx
.
x
),
make_multi_index
(
static_cast
<
index_t
>
(
blockIdx
.
x
)
*
MXdlPerWave
,
n_thread_data_on_block_idx
[
I0
],
n_thread_data_on_block_idx
[
I0
],
m_thread_data_on_block_idx
[
I1
],
m_thread_data_on_block_idx
[
I1
],
n_thread_data_on_block_idx
[
I1
],
n_thread_data_on_block_idx
[
I1
],
...
@@ -824,6 +825,162 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
...
@@ -824,6 +825,162 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
w_grid_buf
);
w_grid_buf
);
}
}
__device__
void
AccumulatePartials
(
void
*
__restrict__
p_workspace
,
index_t
reduce_count
)
{
auto
&
c_thread_buf
=
blockwise_gemm_
.
GetCThreadBuffer
();
constexpr
auto
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
BlockwiseGemmT
::
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
// using CThreadBufferT = ck::remove_reference_t<decltype(c_thread_buf)>;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetElementSpaceSize
(),
true
>
acc_buf
{};
// M0 = grid_size
// N0 = 1
// M1 = MPerBlock
// N1 = NPerBlock
const
auto
workspace_grid_desc_m0_n0_m1_n1
=
MakeWorkspaceGridDesc_GridSize_I1_MPerBlock_NPerBlock
(
get_grid_size
());
const
auto
w_grid_m0
=
workspace_grid_desc_m0_n0_m1_n1
.
GetLength
(
I0
);
const
auto
w_grid_n0
=
workspace_grid_desc_m0_n0_m1_n1
.
GetLength
(
I1
);
// if (threadIdx.x == 0)
// {
// printf("w_grid_desc_m0_n0_m1_n1: [%d, %d, %d, %d]\n",
// workspace_grid_desc_m0_n0_m1_n1.GetLength(I0),
// workspace_grid_desc_m0_n0_m1_n1.GetLength(I1),
// workspace_grid_desc_m0_n0_m1_n1.GetLength(I2),
// workspace_grid_desc_m0_n0_m1_n1.GetLength(I3));
// }
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
=
BlockwiseGemmT
::
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
constexpr
auto
M0
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I0
);
constexpr
auto
N0
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I1
);
constexpr
auto
M1
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I2
);
constexpr
auto
N1
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I3
);
constexpr
auto
M2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I4
);
constexpr
auto
M3
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I5
);
constexpr
auto
M4
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I6
);
constexpr
auto
N2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I7
);
// M0 = grid_size -> MRepeats
// N0 = 1 -> NRepeats
const
auto
workspace_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
=
transform_tensor_descriptor
(
workspace_grid_desc_m0_n0_m1_n1
,
make_tuple
(
make_pass_through_transform
(
w_grid_m0
),
make_pass_through_transform
(
w_grid_n0
),
make_unmerge_transform
(
make_tuple
(
M0
,
M1
,
M2
,
M3
,
M4
)),
make_unmerge_transform
(
make_tuple
(
N0
,
N1
,
N2
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
4
,
6
,
7
,
8
>
{},
Sequence
<
3
,
5
,
9
>
{}));
const
auto
workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
transform_tensor_descriptor
(
workspace_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
make_tuple
(
make_merge_transform
(
make_tuple
(
w_grid_m0
,
M0
)),
// MRepeats (grid)
make_merge_transform
(
make_tuple
(
w_grid_n0
,
N0
)),
// NRepeats (grid)
make_pass_through_transform
(
M1
),
// MWave
make_pass_through_transform
(
N1
),
// NWave
make_pass_through_transform
(
M2
),
// mfma_instr.num_groups_per_blk
make_pass_through_transform
(
M3
),
// mfma_instr.num_input_blks
make_pass_through_transform
(
M4
),
// mfma_instr.group_size
make_pass_through_transform
(
N2
)),
// mfma_instr.num_threads_per_blk
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
,
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{},
Sequence
<
6
>
{},
Sequence
<
7
>
{},
Sequence
<
8
>
{},
Sequence
<
9
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{},
Sequence
<
6
>
{},
Sequence
<
7
>
{}));
const
auto
c_thread_mtx_on_block
=
blockwise_gemm_
.
CalculateCThreadOriginDataIndex
(
I0
,
I0
,
I0
,
I0
);
const
index_t
m_thread_data_on_block
=
c_thread_mtx_on_block
[
I0
];
const
index_t
n_thread_data_on_block
=
c_thread_mtx_on_block
[
I1
];
const
auto
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M0
,
M1
,
M2
,
M3
,
M4
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
m_thread_data_on_block_idx
=
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
m_thread_data_on_block
));
const
auto
n_thread_data_on_block_to_n0_n1_n2_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
N0
,
N1
,
N2
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
n_thread_data_on_block_idx
=
n_thread_data_on_block_to_n0_n1_n2_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
n_thread_data_on_block
));
auto
p_workspace_grid
=
reinterpret_cast
<
AccDataType
*>
(
p_workspace
);
auto
w_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_workspace_grid
,
workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetElementSpaceSize
());
auto
acc_load
=
ThreadwiseTensorSliceTransfer_v2
<
AccDataType
,
// SrcData,
AccDataType
,
// DstData,
decltype
(
workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
// SrcDesc,
decltype
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
// DstDesc,
decltype
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLengths
()),
// SliceLengths,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
// DimAccessOrder,
7
,
// SrcVectorDim,
1
,
// SrcScalarPerVector,
1
,
// SrcScalarStrideInVector,
false
// SrcResetCoordinateAfterRun,
>
{
workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
// We do not need to read this workgroup partial results since they're
// already in c_thread_buff
make_multi_index
((
static_cast
<
index_t
>
(
blockIdx
.
x
)
+
1
)
*
MXdlPerWave
,
n_thread_data_on_block_idx
[
I0
],
m_thread_data_on_block_idx
[
I1
],
n_thread_data_on_block_idx
[
I1
],
m_thread_data_on_block_idx
[
I2
],
m_thread_data_on_block_idx
[
I3
],
m_thread_data_on_block_idx
[
I4
],
n_thread_data_on_block_idx
[
I2
])};
using
Accumulation
=
ck
::
detail
::
AccumulateWithNanCheck
<
false
/*PropagateNan*/
,
reduce
::
Add
,
AccDataType
>
;
// We do not need to read this workgroup partial results since they're
// already in c_thread_buff
for
(
int
i_t
=
1
;
i_t
<
reduce_count
;
++
i_t
)
{
acc_buf
.
Clear
();
acc_load
.
Run
(
workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
w_grid_buf
,
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
acc_buf
);
static_for
<
0
,
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetElementSpaceSize
(),
1
>
{}(
[
&
](
auto
i_vec
)
{
Accumulation
::
Calculate
(
c_thread_buf
(
i_vec
),
acc_buf
[
i_vec
]);
});
}
}
// template <typename CThreadBufer,
// template <typename CThreadBufer,
// InMemoryDataOperationEnum EGlobalMemoryDataOperation,
// InMemoryDataOperationEnum EGlobalMemoryDataOperation,
// index_t NumDTensor_,
// index_t NumDTensor_,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment