Commit 63ed48e4 authored by Umang Yadav's avatar Umang Yadav
Browse files

Fixes

parent 711ff872
...@@ -474,10 +474,12 @@ template <> ...@@ -474,10 +474,12 @@ template <>
class NumericLimits<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>> class NumericLimits<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>>
{ {
public: public:
// TODO :figure out epsilon in Hex to make it constexpr
static constexpr MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8> static constexpr MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>
epsilon() epsilon()
{ {
return static_cast<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>>(float(0.0625)); return migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>(
0x28, migraphx_fp8::hip_f8<>::from_bits());
} }
static constexpr MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8> static constexpr MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>
...@@ -493,13 +495,6 @@ class NumericLimits<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>> ...@@ -493,13 +495,6 @@ class NumericLimits<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>>
return migraphx_fp8::F8_Max<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>>(); return migraphx_fp8::F8_Max<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>>();
} }
static constexpr MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>
min()
{
return static_cast<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>>(-1.0f) *
migraphx_fp8::F8_Max<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>>();
}
static constexpr MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8> static constexpr MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>
lowest() lowest()
{ {
...@@ -521,7 +516,8 @@ class NumericLimits<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>> ...@@ -521,7 +516,8 @@ class NumericLimits<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>
static constexpr MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8> static constexpr MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>
epsilon() epsilon()
{ {
return static_cast<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>(float(0.125)); return migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>(
0x34, migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>::from_bits());
} }
static constexpr MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8> static constexpr MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>
...@@ -538,12 +534,6 @@ class NumericLimits<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>> ...@@ -538,12 +534,6 @@ class NumericLimits<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>
return static_cast<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>( return static_cast<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>(
migraphx_fp8::F8_Max<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>()); migraphx_fp8::F8_Max<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>());
} }
static constexpr MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>
min()
{
return static_cast<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>(float(-1.0f)) *
migraphx_fp8::F8_Max<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>();
}
static constexpr MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8> static constexpr MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>
lowest() lowest()
......
...@@ -38,9 +38,6 @@ struct and_ : std::is_same<and_<Bs...>, and_<(Bs or true)...>> // NOLINT ...@@ -38,9 +38,6 @@ struct and_ : std::is_same<and_<Bs...>, and_<(Bs or true)...>> // NOLINT
template <bool B> template <bool B>
using bool_c = std::integral_constant<bool, B>; using bool_c = std::integral_constant<bool, B>;
template <class From, class To>
using is_convertible = std::is_convertible<From, To>;
#define MIGRAPHX_REQUIRES_PRIMITIVE_CAT(x, y) x##y #define MIGRAPHX_REQUIRES_PRIMITIVE_CAT(x, y) x##y
#define MIGRAPHX_REQUIRES_CAT(x, y) MIGRAPHX_REQUIRES_PRIMITIVE_CAT(x, y) #define MIGRAPHX_REQUIRES_CAT(x, y) MIGRAPHX_REQUIRES_PRIMITIVE_CAT(x, y)
......
...@@ -36,7 +36,8 @@ namespace migraphx { ...@@ -36,7 +36,8 @@ namespace migraphx {
template <class F, class T, class... Ts> template <class F, class T, class... Ts>
__device__ void pointwise_tensor(index idx, F f, T out, Ts... xs) __device__ void pointwise_tensor(index idx, F f, T out, Ts... xs)
{ {
idx.global_stride(out.get_shape().elements(), [&](auto i) { out[i] = f(xs[i]...); }); idx.global_stride(out.get_shape().elements(),
[&](auto i) { out[i] = implicit_conversion(f(xs[i]...)); });
} }
template <class... Transforms> template <class... Transforms>
......
...@@ -244,8 +244,9 @@ struct reducer_base ...@@ -244,8 +244,9 @@ struct reducer_base
{ {
auto&& derived = static_cast<const Derived&>(*this); auto&& derived = static_cast<const Derived&>(*this);
auto t = derived.slice(x); auto t = derived.slice(x);
return make_storage_access<typename decltype(t)::type>( return make_storage_access<typename decltype(t)::type>([=](auto i, auto...) -> auto& {
[=](auto i, auto...) -> auto& { return t[i]; }); return t[i];
});
} }
} }
......
...@@ -81,7 +81,7 @@ def test_create_dyn_shape(): ...@@ -81,7 +81,7 @@ def test_create_dyn_shape():
def test_type_enum(): def test_type_enum():
mgx_types = [ mgx_types = [
'bool_type', 'double_type', 'float_type', 'half_type', 'float_type', 'int16_type', 'bool_type', 'double_type', 'float_type', 'half_type', 'float8_type', 'int16_type',
'int32_type', 'int64_type', 'int8_type', 'uint16_type', 'uint32_type', 'int32_type', 'int64_type', 'int8_type', 'uint16_type', 'uint32_type',
'uint64_type', 'uint8_type' 'uint64_type', 'uint8_type'
] ]
......
...@@ -22,10 +22,10 @@ ...@@ -22,10 +22,10 @@
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include "migraphx/migraphx_float8.hpp"
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/migraphx_float8.hpp>
#include <limits> #include <limits>
template <migraphx::shape::type_t Q, typename T> template <migraphx::shape::type_t Q, typename T>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment