Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
f9abcf80
Commit
f9abcf80
authored
Feb 07, 2025
by
coderfeli
Browse files
use offsets in transfer ok
parent
e947d11e
Changes
3
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
921 additions
and
10 deletions
+921
-10
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_mod8.hpp
...pu/block/thread_group_tensor_slice_transfer_v4r1_mod8.hpp
+10
-6
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp
...id/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp
+9
-4
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_gather.hpp
...u/thread/threadwise_tensor_slice_transfer_v3r1_gather.hpp
+902
-0
No files found.
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_mod8.hpp
View file @
f9abcf80
...
...
@@ -7,7 +7,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1
_gather
.hpp"
namespace
ck
{
...
...
@@ -41,14 +41,15 @@ template <typename ThreadGroup,
index_t
DstScalarStrideInVector
,
bool
ThreadTransferSrcResetCoordinateAfterRun
,
bool
ThreadTransferDstResetCoordinateAfterRun
,
index_t
GatherDim
=
1
,
index_t
NumThreadScratch
=
1
>
struct
ThreadGroupTensorSliceTransfer_v4r1_mod8
{
static
constexpr
index_t
nDim
=
remove_reference_t
<
SrcDesc
>::
GetNumOfDimension
();
static
constexpr
auto
thread_slice_lengths
=
BlockSliceLengths
{}
/
ThreadClusterLengths
{};
static
constexpr
index_t
gather_num
=
thread_slice_lengths
.
At
(
Number
<
GatherDim
>
{});
using
Index
=
MultiIndex
<
nDim
>
;
// using GatherIndex = MultiIndex<gather_num>;
__device__
constexpr
ThreadGroupTensorSliceTransfer_v4r1_mod8
(
const
SrcDesc
&
src_desc
,
...
...
@@ -56,13 +57,15 @@ struct ThreadGroupTensorSliceTransfer_v4r1_mod8
const
SrcElementwiseOperation
&
src_element_op
,
const
DstDesc
&
dst_desc
,
const
Index
&
dst_block_slice_origin
,
const
DstElementwiseOperation
&
dst_element_op
)
const
DstElementwiseOperation
&
dst_element_op
,
const
StaticallyIndexedArray
<
index_t
,
gather_num
>
&
gather_offsets
)
:
threadwise_transfer_
(
src_desc
,
make_zero_multi_index
<
nDim
>
(),
src_element_op
,
dst_desc
,
make_zero_multi_index
<
nDim
>
(),
dst_element_op
)
dst_element_op
,
gather_offsets
)
{
static_assert
(
nDim
==
remove_cvref_t
<
SrcDesc
>::
GetNumOfDimension
()
&&
...
...
@@ -173,7 +176,7 @@ struct ThreadGroupTensorSliceTransfer_v4r1_mod8
make_cluster_descriptor
(
ThreadClusterLengths
{},
ThreadClusterArrangeOrder
{});
using
ThreadwiseTransfer
=
ThreadwiseTensorSliceTransfer_v3r1
<
decltype
(
thread_slice_lengths
),
ThreadwiseTensorSliceTransfer_v3r1
_gather
<
decltype
(
thread_slice_lengths
),
SrcElementwiseOperation
,
DstElementwiseOperation
,
DstInMemOp
,
...
...
@@ -191,6 +194,7 @@ struct ThreadGroupTensorSliceTransfer_v4r1_mod8
DstScalarStrideInVector
,
ThreadTransferSrcResetCoordinateAfterRun
,
ThreadTransferDstResetCoordinateAfterRun
,
GatherDim
,
NumThreadScratch
>
;
ThreadwiseTransfer
threadwise_transfer_
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp
View file @
f9abcf80
...
...
@@ -1132,8 +1132,11 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
constexpr
auto
MLoadRepeats
=
MPerBlock
/
MLoadThreads
;
static_assert
(
MLoadRepeats
==
1
,
"only support 1 line per thread now!"
);
const
index_t
token_pos
=
block_m_id
*
MPerBlock
+
threadIdx
.
x
/
KLoadThreads
;
index_t
token_offset
=
p_sorted_token_ids
[
token_pos
];
StaticallyIndexedArray
<
index_t
,
MLoadRepeats
>
token_offsets
;
//= p_sorted_token_ids[token_pos];
static_for
<
0
,
MLoadRepeats
,
1
>
{}([
&
](
auto
m0
)
{
token_offsets
(
m0
)
=
p_sorted_token_ids
[
token_pos
+
MLoadThreads
*
m0
]
*
problem
.
K
;
});
printf
(
"threadIdx.x %d off %d
\n
"
,
threadIdx
.
x
,
token_offsets
(
I0
));
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_m_id
*
MPerBlock
);
const
index_t
expert_stride
=
__builtin_amdgcn_readfirstlane
(
problem
.
N
*
problem
.
K
);
...
...
@@ -1183,13 +1186,15 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
true
,
1
,
BlockwiseGemmPipe
::
GlobalBufferNum
>
(
a_grid_desc_ak0_m_ak1
,
make_multi_index
(
0
,
token_offset
,
0
),
make_multi_index
(
0
,
0
,
0
),
a_element_op
,
a_block_desc_ak0_m_ak1
,
make_multi_index
(
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
token_offsets
);
// Thread-wise copy
// K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_gather.hpp
0 → 100644
View file @
f9abcf80
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment