Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
b452c70e
Commit
b452c70e
authored
Feb 14, 2023
by
ltqin
Browse files
add bhalf2_t data convert
parent
903c904b
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
35 additions
and
1 deletion
+35
-1
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_v1.hpp
..._batched_multihead_attention_backward_xdl_cshuffle_v1.hpp
+1
-1
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
...tion/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
+17
-0
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r3.hpp
...tion/gpu/thread/threadwise_tensor_slice_transfer_v3r3.hpp
+17
-0
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_v1.hpp
View file @
b452c70e
...
...
@@ -1508,7 +1508,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
n3
,
// NInputNum
n4
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
>
,
9
,
// DstVectorDim
9
,
// DstVectorDim
1
,
// DstScalarPerVector
InMemoryDataOperationEnum
::
Set
,
1
,
// DstScalarStrideInVector
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
View file @
b452c70e
...
...
@@ -347,6 +347,23 @@ struct ThreadwiseTensorSliceTransfer_v3r1
src_vector_refs
,
dst_vector_refs
);
});
}
else
if
constexpr
(
SrcVectorDim
==
DstVectorDim
&&
SrcScalarPerVector
%
2
==
0
&&
DstScalarPerVector
%
2
==
0
&&
is_same
<
half_t
,
remove_cvref_t
<
SrcData
>>::
value
&&
is_same
<
bhalf_t
,
remove_cvref_t
<
DstData
>>::
value
)
{
auto
NewSliceLengths
=
SliceLengths
{}.
template
Modify
(
Number
<
SrcVectorDim
>{},
Number
<
SliceLengths
{}[
SrcVectorDim
]
/
2
>
{});
auto
VectorStep
=
SliceLengths
{}
/
NewSliceLengths
;
static_ford
<
decltype
(
NewSliceLengths
)
>
{}([
&
](
auto
idx
)
{
// convert from SrcData to DstData here
auto
nidx
=
idx
*
VectorStep
;
auto
vhalf
=
src_thread_scratch_tuple_
[
thread_scratch_id
].
template
GetAsType
<
half2_t
>(
nidx
);
dst_thread_scratch_
.
template
SetAsType
<
bhalf2_t
>(
nidx
,
type_convert
<
bhalf2_t
>
(
vhalf
));
});
}
else
{
static_ford
<
SliceLengths
>
{}([
&
](
auto
idx
)
{
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r3.hpp
View file @
b452c70e
...
...
@@ -350,6 +350,23 @@ struct ThreadwiseTensorSliceTransfer_v3r3
src_vector_refs
,
dst_vector_refs
);
});
}
else
if
constexpr
(
SrcVectorDim
==
DstVectorDim
&&
SrcScalarPerVector
%
2
==
0
&&
DstScalarPerVector
%
2
==
0
&&
is_same
<
half_t
,
remove_cvref_t
<
SrcData
>>::
value
&&
is_same
<
bhalf_t
,
remove_cvref_t
<
DstData
>>::
value
)
{
auto
NewSliceLengths
=
SliceLengths
{}.
template
Modify
(
Number
<
SrcVectorDim
>{},
Number
<
SliceLengths
{}[
SrcVectorDim
]
/
2
>
{});
auto
VectorStep
=
SliceLengths
{}
/
NewSliceLengths
;
static_ford
<
decltype
(
NewSliceLengths
)
>
{}([
&
](
auto
idx
)
{
// convert from SrcData to DstData here
auto
nidx
=
idx
*
VectorStep
;
auto
vhalf
=
src_thread_scratch_tuple_
[
thread_scratch_id
].
template
GetAsType
<
half2_t
>(
nidx
);
dst_thread_scratch_
.
template
SetAsType
<
bhalf2_t
>(
nidx
,
type_convert
<
bhalf2_t
>
(
vhalf
));
});
}
else
{
static_ford
<
SliceLengths
>
{}([
&
](
auto
idx
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment