Unverified Commit ffabd70a authored by Bartłomiej Kocot's avatar Bartłomiej Kocot Committed by GitHub
Browse files

Add support for half_t and bfloat to reduction operations (#1395)

* Add support for half_t and bfloat to reduction operations

* Fix bhalf convert

* Next fix bf16
parent 33b2a2bd
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -52,11 +52,19 @@ struct Add ...@@ -52,11 +52,19 @@ struct Add
__host__ __device__ inline constexpr void operator()(T& a, T b) const __host__ __device__ inline constexpr void operator()(T& a, T b) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, int32_t>::value, is_same<T, int32_t>::value || is_same<T, half_t>::value,
"The data type is not supported by the Add accumulator!"); "The data type is not supported by the Add accumulator!");
a = a + b; a = a + b;
} }
__host__ __device__ inline constexpr void operator()(bhalf_t& a, bhalf_t b) const
{
float a_ = type_convert<float>(a);
float b_ = type_convert<float>(b);
a = type_convert<bhalf_t>(a_ + b_);
}
}; };
struct SquaredAdd struct SquaredAdd
...@@ -104,11 +112,19 @@ struct Mul ...@@ -104,11 +112,19 @@ struct Mul
__host__ __device__ inline constexpr void operator()(T& a, T b) const __host__ __device__ inline constexpr void operator()(T& a, T b) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, int32_t>::value, is_same<T, int32_t>::value || is_same<T, half_t>::value,
"The data type is not supported by the Mul accumulator!"); "The data type is not supported by the Mul accumulator!");
a = a * b; a = a * b;
} }
__host__ __device__ inline constexpr void operator()(bhalf_t& a, bhalf_t b) const
{
float a_ = type_convert<float>(a);
float b_ = type_convert<float>(b);
a = type_convert<bhalf_t>(a_ * b_);
}
}; };
struct Max struct Max
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment