"docs/source/en/api/pipelines/stable_diffusion/upscale.mdx" did not exist on "da31075700eb5f7aae1eb974a1c185e53b74f316"
Commit 621a459f authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Update type_convert_precision -> bf16_convert_rtn

parent 5195dbbb
......@@ -97,22 +97,25 @@ struct ConvertBF16RTN
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const;
// convert fp16->bf16 using rounding to nearest (rtn) via fp32
template <>
__host__ __device__ void operator()<bhalf_t, half_t>(bhalf_t& y, const half_t& x) const
{
y = type_convert_precision<bhalf_t>(x);
y = bf16_convert_rtn<bhalf_t>(x);
}
// convert fp32->bf16 using rounding to nearest (rtn)
template <>
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
{
y = type_convert_precision<bhalf_t>(x);
y = bf16_convert_rtn<bhalf_t>(x);
}
// need to keep this specialization for fp16->fp16 ops
template <>
__host__ __device__ void operator()<half_t, half_t>(half_t& y, const half_t& x) const
{
y = type_convert_precision<half_t>(x);
y = type_convert<half_t>(x);
}
};
......
......@@ -339,7 +339,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
}
static_ford<SliceLengths>{}([&](auto idx) {
// apply the src elementwise op and convert under the hood if needed
// apply the src elementwise op and convert to DstData under the hood if needed
DstData dst_v;
src_element_op_(dst_v, src_thread_scratch_tuple_[thread_scratch_id][idx]);
dst_thread_scratch_(idx) = dst_v;
......
......@@ -1033,16 +1033,16 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_
// Convert X to Y with highest possible precision
template <typename Y, typename X>
__host__ __device__ constexpr Y type_convert_precision(X x)
{
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
__host__ __device__ constexpr Y bf16_convert_rtn(X x);
// {
// static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
return static_cast<Y>(x);
}
// return static_cast<Y>(x);
// }
// Convert fp32 to bf16 with RTN if higher precision is needed
template <>
inline __host__ __device__ constexpr bhalf_t type_convert_precision<bhalf_t, float>(float x)
inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, float>(float x)
{
union
{
......@@ -1086,11 +1086,11 @@ inline __host__ __device__ constexpr bhalf_t type_convert_precision<bhalf_t, flo
// convert fp16 to bfp16 via fp32 with RTN if higher precision is needed
template <>
inline __host__ __device__ constexpr bhalf_t type_convert_precision<bhalf_t, half_t>(half_t x)
inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, half_t>(half_t x)
{
float x_fp32 = static_cast<float>(x);
return type_convert_precision<bhalf_t>(x_fp32);
return bf16_convert_rtn<bhalf_t>(x_fp32);
}
template <typename T>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment