Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
a6b2f1c1
"include/ck/utility/magic_division.hpp" did not exist on "c254e5abd2b01b9d5a2ba3fe4531e178623396d0"
Commit
a6b2f1c1
authored
Feb 09, 2023
by
aska-0096
Browse files
Add Inter-Row thread transfer
parent
a0a469e4
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
122 additions
and
8 deletions
+122
-8
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
...grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
+6
-8
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
...operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
+116
-0
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
View file @
a6b2f1c1
...
...
@@ -533,7 +533,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
constexpr
auto
b1_block_desc_l0perblock_nperblock_l1
=
GetB1BlockDescriptor_L0PerBlock_NPerBlock_L1
();
// A1 matrix blockwise copy
auto
a1_blockwise_copy
=
ThreadwiseTensorSliceTransfer_StaticToStatic
<
auto
a1_blockwise_copy
=
ThreadwiseTensorSliceTransfer_StaticToStatic
_InterRow
<
FloatAcc
,
FloatA
,
decltype
(
acc_thread_desc_k0_m_k1
),
...
...
@@ -542,7 +542,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
Sequence
<
A1ThreadSliceK0
,
A1ThreadSliceM
,
A1ThreadSliceK1
>
,
Sequence
<
1
,
0
,
2
>
,
2
,
n4
>
{
tensor_operation
::
element_wise
::
PassThrough
{}};
n4
,
// dst Rowlane
// 0x76543210 0xfedcba98
// src Rowlane
0x76543210
,
0xfedcba98
>
{
tensor_operation
::
element_wise
::
PassThrough
{}};
// B1 matrix blockwise copy
auto
b1_blockwise_copy
=
...
...
@@ -700,12 +704,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
running_sum_new
=
mathext
::
exp
(
running_max
-
running_max_new
)
*
running_sum
+
mathext
::
exp
(
max
-
running_max_new
)
*
sum
;
// Intra-Row data permutation, make swizzled A input for WMMA
__builtin_amdgcn_permlane16
(
0xeca86420
,
0xfdb97531
);
// Low/high row move data to low/high half of thread buffer
/* thread copy*/
// Inter-Row data permutation, fullfill data duplication requirement
__builtin_amdgcn_permlanex16
(
0x76543210
,
0xfedcba98
);
// gemm1
{
// TODO: explore using dynamic buffer for a1 thread buffer
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
View file @
a6b2f1c1
...
...
@@ -1298,4 +1298,120 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic
ElementwiseOperation
element_op_
;
};
// Specilized for WMMA
// A single Wave32 is composed by double row
// Data exchange allowed between these two rows
// This RowLane Dst buf will be filled from two Src buf
// SrcA: From specific thread buffer hold by This RowLane on This Row
// SrcB: From specific thread buffer hold by This RowLane on The other Row
template
<
typename
SrcData
,
typename
DstData
,
typename
SrcDesc
,
typename
DstDesc
,
typename
ElementwiseOperation
,
typename
SliceLengths
,
typename
DimAccessOrder
,
index_t
DstVectorDim
,
index_t
DstScalarPerVector
,
index_t
LowEightRowlaneIdx
,
index_t
HighEightRowLaneIdx
,
typename
enable_if
<
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
struct
ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
{
static
constexpr
index_t
nDim
=
SliceLengths
::
Size
();
using
Index
=
MultiIndex
<
nDim
>
;
__device__
constexpr
ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
(
const
ElementwiseOperation
&
element_op
)
:
element_op_
{
element_op
}
{
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc need to known at compile-time"
);
static_assert
(
SliceLengths
::
At
(
Number
<
DstVectorDim
>
{})
%
DstScalarPerVector
==
0
,
"wrong! Not divisible"
);
}
template
<
typename
SrcSliceOriginIdx
,
typename
DstSliceOriginIdx
,
typename
SrcBuffer
,
typename
DstBuffer
>
__device__
void
Run
(
const
SrcDesc
&
,
const
SrcSliceOriginIdx
&
,
const
SrcBuffer
&
src_buf
,
const
DstDesc
&
,
const
DstSliceOriginIdx
&
,
DstBuffer
&
dst_buf
)
{
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc need to known at compile-time"
);
static_assert
(
is_known_at_compile_time
<
remove_cvref_t
<
SrcSliceOriginIdx
>>::
value
&&
is_known_at_compile_time
<
remove_cvref_t
<
DstSliceOriginIdx
>>::
value
,
"wrong! SliceOrigin need to known at compile-time"
);
static_assert
(
SrcBuffer
::
IsStaticBuffer
()
&&
DstBuffer
::
IsStaticBuffer
(),
"wrong! Buffer need to be StaticBuffer"
);
// SrcDesc and src_slice_origin_idx are known at compile-time
constexpr
auto
src_desc
=
remove_cvref_t
<
SrcDesc
>
{};
constexpr
auto
dst_desc
=
remove_cvref_t
<
DstDesc
>
{};
constexpr
auto
src_slice_origin_idx
=
to_multi_index
(
SrcSliceOriginIdx
{});
constexpr
auto
dst_slice_origin_idx
=
to_multi_index
(
DstSliceOriginIdx
{});
// scalar per access on each dim
constexpr
auto
dst_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
DstVectorDim
,
DstScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
dst_scalar_step_in_vector
=
generate_sequence
(
detail
::
lambda_scalar_step_in_vector
<
DstVectorDim
>
{},
Number
<
nDim
>
{});
using
SpaceFillingCurve
=
SpaceFillingCurve
<
SliceLengths
,
DimAccessOrder
,
remove_cv_t
<
decltype
(
dst_scalar_per_access
)
>>
;
static_assert
(
DstScalarPerVector
==
SpaceFillingCurve
::
ScalarPerVector
,
"wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector"
);
constexpr
auto
num_access
=
SpaceFillingCurve
::
GetNumOfAccess
();
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
idx_1d
)
{
constexpr
auto
idx_md
=
SpaceFillingCurve
::
GetIndex
(
idx_1d
);
// copy data from src_buf into dst_vector
static_for
<
0
,
DstScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
src_offset
=
src_desc
.
CalculateOffset
(
src_slice_origin_idx
+
idx_md
+
i
*
dst_scalar_step_in_vector
);
constexpr
index_t
dst_offset
=
dst_desc
.
CalculateOffset
(
dst_slice_origin_idx
+
idx_md
+
i
*
dst_scalar_step_in_vector
);
SrcData
v
;
// apply element-wise operation
element_op_
(
v
,
src_buf
[
Number
<
src_offset
>
{}]);
if
(
get_thread_local_1d_id
()
%
32
>
16
){
// apply type convert
dst_buf
(
Number
<
dst_offset
>
{})
=
type_convert
<
DstData
>
(
v
);
dst_buf
(
Number
<
dst_offset
+
dst_buf
.
size
()
/
2
>
{})
=
__builtin_amdgcn_permlanex16
(
type_convert
<
DstData
>
(
dst_buf
(
Number
<
dst_offset
+
dst_buf
.
size
()
/
2
>
{})),
type_convert
<
DstData
>
(
v
),
LowEightRowlaneIdx
,
HighEightRowLaneIdx
,
1
,
0
);
}
else
{
// apply type convert
dst_buf
(
Number
<
dst_offset
+
dst_buf
.
size
()
/
2
>
{})
=
type_convert
<
DstData
>
(
v
);
dst_buf
(
Number
<
dst_offset
>
{})
=
__builtin_amdgcn_permlanex16
(
type_convert
<
DstData
>
(
dst_buf
(
Number
<
dst_offset
>
{})),
type_convert
<
DstData
>
(
v
),
LowEightRowlaneIdx
,
HighEightRowLaneIdx
,
1
,
0
);
}
});
});
}
ElementwiseOperation
element_op_
;
};
}
// namespace ck
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment