Commit b452c70e authored by ltqin's avatar ltqin
Browse files

add bhalf2_t data convert

parent 903c904b
......@@ -1508,7 +1508,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
n3, // NInputNum
n4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
9, // DstVectorDim
9, // DstVectorDim
1, // DstScalarPerVector
InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector
......
......@@ -347,6 +347,23 @@ struct ThreadwiseTensorSliceTransfer_v3r1
src_vector_refs, dst_vector_refs);
});
}
else if constexpr(SrcVectorDim == DstVectorDim && SrcScalarPerVector % 2 == 0 &&
DstScalarPerVector % 2 == 0 &&
is_same<half_t, remove_cvref_t<SrcData>>::value &&
is_same<bhalf_t, remove_cvref_t<DstData>>::value)
{
auto NewSliceLengths = SliceLengths{}.template Modify(
Number<SrcVectorDim>{}, Number<SliceLengths{}[SrcVectorDim] / 2>{});
auto VectorStep = SliceLengths{} / NewSliceLengths;
static_ford<decltype(NewSliceLengths)>{}([&](auto idx) {
// convert from SrcData to DstData here
auto nidx = idx * VectorStep;
auto vhalf =
src_thread_scratch_tuple_[thread_scratch_id].template GetAsType<half2_t>(nidx);
dst_thread_scratch_.template SetAsType<bhalf2_t>(nidx,
type_convert<bhalf2_t>(vhalf));
});
}
else
{
static_ford<SliceLengths>{}([&](auto idx) {
......
......@@ -350,6 +350,23 @@ struct ThreadwiseTensorSliceTransfer_v3r3
src_vector_refs, dst_vector_refs);
});
}
else if constexpr(SrcVectorDim == DstVectorDim && SrcScalarPerVector % 2 == 0 &&
DstScalarPerVector % 2 == 0 &&
is_same<half_t, remove_cvref_t<SrcData>>::value &&
is_same<bhalf_t, remove_cvref_t<DstData>>::value)
{
auto NewSliceLengths = SliceLengths{}.template Modify(
Number<SrcVectorDim>{}, Number<SliceLengths{}[SrcVectorDim] / 2>{});
auto VectorStep = SliceLengths{} / NewSliceLengths;
static_ford<decltype(NewSliceLengths)>{}([&](auto idx) {
// convert from SrcData to DstData here
auto nidx = idx * VectorStep;
auto vhalf =
src_thread_scratch_tuple_[thread_scratch_id].template GetAsType<half2_t>(nidx);
dst_thread_scratch_.template SetAsType<bhalf2_t>(nidx,
type_convert<bhalf2_t>(vhalf));
});
}
else
{
static_ford<SliceLengths>{}([&](auto idx) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment