"git@developer.sourcefind.cn:gaoqiong/composable_kernel.git" did not exist on "edc89778c367006de550beadb8127b5a428f6bad"
Commit 07f07581 authored by letaoqin's avatar letaoqin
Browse files

passthrough add bfp16 to fp16 convert

parent 82347535
...@@ -63,6 +63,12 @@ struct PassThrough ...@@ -63,6 +63,12 @@ struct PassThrough
y = type_convert<bhalf_t>(x); y = type_convert<bhalf_t>(x);
} }
template <>
__host__ __device__ void operator()<half_t, bhalf_t>(half_t& y, const bhalf_t& x) const
{
y = type_convert<half_t>(x);
}
template <> template <>
__host__ __device__ void operator()<int8_t, int8_t>(int8_t& y, const int8_t& x) const __host__ __device__ void operator()<int8_t, int8_t>(int8_t& y, const int8_t& x) const
{ {
......
...@@ -126,6 +126,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad ...@@ -126,6 +126,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
static_assert(ThreadClusterLength_O * ThreadSliceLength_O == BlockSliceLength_O_, ""); static_assert(ThreadClusterLength_O * ThreadSliceLength_O == BlockSliceLength_O_, "");
static_assert(ThreadClusterLength_M * ThreadSliceLength_M == BlockSliceLength_M_, ""); static_assert(ThreadClusterLength_M * ThreadSliceLength_M == BlockSliceLength_M_, "");
static_assert(SrcScalarPerVector % YSrcScalarPerVector == 0, "");
using SrcBufType = StaticBuffer<AddressSpaceEnum::Vgpr, using SrcBufType = StaticBuffer<AddressSpaceEnum::Vgpr,
FloatD, FloatD,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment