Commit 07f07581 authored by letaoqin's avatar letaoqin
Browse files

passthrough add bfp16 to fp16 convert

parent 82347535
......@@ -63,6 +63,12 @@ struct PassThrough
y = type_convert<bhalf_t>(x);
}
template <>
__host__ __device__ void operator()<half_t, bhalf_t>(half_t& y, const bhalf_t& x) const
{
y = type_convert<half_t>(x);
}
template <>
__host__ __device__ void operator()<int8_t, int8_t>(int8_t& y, const int8_t& x) const
{
......
......@@ -126,6 +126,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
static_assert(ThreadClusterLength_O * ThreadSliceLength_O == BlockSliceLength_O_, "");
static_assert(ThreadClusterLength_M * ThreadSliceLength_M == BlockSliceLength_M_, "");
static_assert(SrcScalarPerVector % YSrcScalarPerVector == 0, "");
using SrcBufType = StaticBuffer<AddressSpaceEnum::Vgpr,
FloatD,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment