Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
fee53701
"git@developer.sourcefind.cn:gaoqiong/composable_kernel.git" did not exist on "00899f191b1ffbe24bdaf982d8ebf6a4e29697c3"
Commit
fee53701
authored
Oct 18, 2023
by
Adam Osewski
Browse files
Store partials from VGPR to GMEM
parent
e114409b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
152 additions
and
33 deletions
+152
-33
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_tile_loop.hpp
...grouped_gemm_multiple_d_splitk_xdl_cshuffle_tile_loop.hpp
+1
-33
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle_v2.hpp
.../grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle_v2.hpp
+151
-0
No files found.
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_tile_loop.hpp
View file @
fee53701
...
@@ -160,24 +160,7 @@ __global__ void
...
@@ -160,24 +160,7 @@ __global__ void
if
(
!
b2c_tile_map
.
IsFirstKSplitBlock
())
if
(
!
b2c_tile_map
.
IsFirstKSplitBlock
())
{
{
// Store partial results to auxilliary workspace.
// Store partial results to auxilliary workspace.
gridwise_gemm
.
StorePartials
(
p_workspace
);
// make results buffer tensor descriptor (registers).
// make workspace gmem tensor descriptor
// create ThreadGroupTransform and run copy
// if (threadIdx.x == 0)
// {
// // using CThreadBuffer = decltype(results_buffer);
// // constexpr index_t n_scalars = CThreadBuffer::s_per_buf.value;
// constexpr index_t n_scalars = 4;
// static_for<0, n_scalars, 1>{}([&](auto i) {
// printf("[kernel] bid: %d; c_thread_buff[%d]: %f\n",
// static_cast<index_t>(blockIdx.x),
// i.value,
// static_cast<float>(results_buffer[i]));
// });
// }
}
}
const
index_t
output_tile_idx
=
const
index_t
output_tile_idx
=
...
@@ -197,21 +180,6 @@ __global__ void
...
@@ -197,21 +180,6 @@ __global__ void
[[
maybe_unused
]]
const
index_t
flag_v
=
__builtin_amdgcn_readfirstlane
(
[[
maybe_unused
]]
const
index_t
flag_v
=
__builtin_amdgcn_readfirstlane
(
work_scheduler
.
GetFlagValue
(
k_batch
,
output_tile_idx
,
output_tile_idx_offset
));
work_scheduler
.
GetFlagValue
(
k_batch
,
output_tile_idx
,
output_tile_idx_offset
));
// if(threadIdx.x == 0)
// {
// // using CThreadBuffer = decltype(results_buffer);
// // constexpr index_t n_scalars = CThreadBuffer::s_per_buf.value;
// constexpr index_t n_scalars = 4;
// static_for<0, n_scalars, 1>{}([&](auto i) {
// printf("[kernel] bid: %d; c_thread_buff[%d]: %f\n",
// static_cast<index_t>(blockIdx.x),
// i.value,
// static_cast<float>(results_buffer[i]));
// });
// }
// TODO: do blockwise reduction from workspace (GMEM) to results_buffer (registers)
// TODO: do blockwise reduction from workspace (GMEM) to results_buffer (registers)
// Signal waiting blocks that they can start use their workspace.
// Signal waiting blocks that they can start use their workspace.
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle_v2.hpp
View file @
fee53701
...
@@ -519,10 +519,32 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
...
@@ -519,10 +519,32 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
Number
<
NumDTensor
>
{});
Number
<
NumDTensor
>
{});
}
}
__host__
__device__
static
auto
MakeWorkspaceGridDesc_GridSize_I1_MPerBlock_NPerBlock
(
index_t
grid_size
)
{
const
auto
w_desc_grid_i1_mperb_nperb
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ELayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
grid_size
,
I1
.
value
,
MPerBlock
,
NPerBlock
),
make_tuple
(
MPerBlock
*
NPerBlock
,
MPerBlock
*
NPerBlock
,
NPerBlock
,
I1
.
value
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
ELayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
grid_size
,
I1
.
value
,
MPerBlock
,
NPerBlock
),
make_tuple
(
MPerBlock
*
NPerBlock
,
MPerBlock
*
NPerBlock
,
I1
.
value
,
MPerBlock
));
}
}();
return
w_desc_grid_i1_mperb_nperb
;
}
// TODO: we should refactor out all those common Make... descriptors to sth like
// TODO: we should refactor out all those common Make... descriptors to sth like
// gridwise_gemm_utils.hpp
// gridwise_gemm_utils.hpp
__device__
__host__
static
constexpr
auto
GetMPerBlock
()
{
return
MPerBlock
;
}
__device__
__host__
static
constexpr
auto
GetMPerBlock
()
{
return
MPerBlock
;
}
__device__
__host__
static
constexpr
auto
GetNPerBlock
()
{
return
NPerBlock
;
}
__device__
__host__
constexpr
auto
&
GetCThreadBuffer
()
__device__
__host__
constexpr
auto
&
GetCThreadBuffer
()
{
{
...
@@ -673,6 +695,135 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
...
@@ -673,6 +695,135 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
block_2_etile_map
);
block_2_etile_map
);
}
}
__device__
void
StorePartials
(
void
*
__restrict__
p_workspace
)
{
// M0 = grid_size
// N0 = 1
// M1 = MPerBlock
// N1 = NPerBlock
const
auto
workspace_grid_desc_m0_n0_m1_n1
=
MakeWorkspaceGridDesc_GridSize_I1_MPerBlock_NPerBlock
(
get_grid_size
());
const
auto
w_grid_m0
=
workspace_grid_desc_m0_n0_m1_n1
.
GetLength
(
I0
);
const
auto
w_grid_n0
=
workspace_grid_desc_m0_n0_m1_n1
.
GetLength
(
I1
);
auto
p_workspace_grid
=
reinterpret_cast
<
AccDataType
*>
(
p_workspace
);
auto
w_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_workspace_grid
,
workspace_grid_desc_m0_n0_m1_n1
.
GetElementSpaceSize
());
const
auto
&
c_thread_buf
=
blockwise_gemm_
.
GetCThreadBuffer
();
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
=
BlockwiseGemmT
::
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
constexpr
auto
M0
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I0
);
constexpr
auto
N0
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I1
);
constexpr
auto
M1
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I2
);
constexpr
auto
N1
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I3
);
constexpr
auto
M2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I4
);
constexpr
auto
M3
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I5
);
constexpr
auto
M4
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I6
);
constexpr
auto
N2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I7
);
// M0 = grid_size -> MRepeats
// N0 = 1 -> NRepeats
const
auto
workspace_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
=
transform_tensor_descriptor
(
workspace_grid_desc_m0_n0_m1_n1
,
make_tuple
(
make_pass_through_transform
(
w_grid_m0
),
make_pass_through_transform
(
w_grid_n0
),
make_unmerge_transform
(
make_tuple
(
M0
,
M1
,
M2
,
M3
,
M4
)),
make_unmerge_transform
(
make_tuple
(
N0
,
N1
,
N2
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
4
,
6
,
7
,
8
>
{},
Sequence
<
3
,
5
,
9
>
{}));
const
auto
workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
transform_tensor_descriptor
(
workspace_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
make_tuple
(
make_merge_transform
(
make_tuple
(
w_grid_m0
,
M0
)),
// MRepeats (grid)
make_merge_transform
(
make_tuple
(
w_grid_n0
,
N0
)),
// NRepeats (grid)
make_pass_through_transform
(
M1
),
// MWave
make_pass_through_transform
(
N1
),
// NWave
make_pass_through_transform
(
M2
),
// mfma_instr.num_groups_per_blk
make_pass_through_transform
(
M3
),
// mfma_instr.num_input_blks
make_pass_through_transform
(
M4
),
// mfma_instr.group_size
make_pass_through_transform
(
N2
)),
// mfma_instr.num_threads_per_blk
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
,
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{},
Sequence
<
6
>
{},
Sequence
<
7
>
{},
Sequence
<
8
>
{},
Sequence
<
9
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{},
Sequence
<
6
>
{},
Sequence
<
7
>
{}));
constexpr
auto
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
BlockwiseGemmT
::
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
const
auto
c_thread_mtx_on_block
=
blockwise_gemm_
.
CalculateCThreadOriginDataIndex
(
I0
,
I0
,
I0
,
I0
);
const
index_t
m_thread_data_on_block
=
c_thread_mtx_on_block
[
I0
];
const
index_t
n_thread_data_on_block
=
c_thread_mtx_on_block
[
I1
];
const
auto
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M0
,
M1
,
M2
,
M3
,
M4
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
m_thread_data_on_block_idx
=
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
m_thread_data_on_block
));
const
auto
n_thread_data_on_block_to_n0_n1_n2_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
N0
,
N1
,
N2
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
n_thread_data_on_block_idx
=
n_thread_data_on_block_to_n0_n1_n2_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
n_thread_data_on_block
));
auto
c_thread_copy_vgpr_to_gmem
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
AccDataType
,
decltype
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
decltype
(
workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
decltype
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLengths
()),
// SliceLengths
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
// DimAccessOrder
7
,
// DstVectorDim,
1
,
// DstScalarPerVector
InMemoryDataOperationEnum
::
Set
,
1
,
// DstScalarStrideInVector
true
>
{
// DstResetCoordinateAfterRun
workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
make_multi_index
(
static_cast
<
index_t
>
(
blockIdx
.
x
),
n_thread_data_on_block_idx
[
I0
],
m_thread_data_on_block_idx
[
I1
],
n_thread_data_on_block_idx
[
I1
],
m_thread_data_on_block_idx
[
I2
],
m_thread_data_on_block_idx
[
I3
],
m_thread_data_on_block_idx
[
I4
],
n_thread_data_on_block_idx
[
I2
]),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
c_thread_copy_vgpr_to_gmem
.
Run
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
w_grid_buf
);
}
// template <typename CThreadBufer,
// template <typename CThreadBufer,
// InMemoryDataOperationEnum EGlobalMemoryDataOperation,
// InMemoryDataOperationEnum EGlobalMemoryDataOperation,
// index_t NumDTensor_,
// index_t NumDTensor_,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment