Commit 6fb61ded authored by Umang Yadav's avatar Umang Yadav
Browse files

working for add operation

parent 90c6a6c5
......@@ -357,6 +357,12 @@ struct alignas(1) fp8e4m3fnuz
}
};
MIGRAPHX_HIP_HOST_DEVICE inline migraphx::fp8e4m3fnuz operator+(migraphx::fp8e4m3fnuz x,
migraphx::fp8e4m3fnuz y)
{
return migraphx::fp8e4m3fnuz(float(x) + float(y));
}
inline std::ostream& operator<<(std::ostream& out, const fp8e4m3fnuz& value)
{
out << (float)(value);
......
......@@ -78,6 +78,8 @@ __device__ __host__ auto as_vec(T x, Axis axis)
{
if constexpr(N < 2)
return x;
else if constexpr(is_same<decltype(x), migraphx::fp8e4m3fnuz>{})
return x;
else
return make_tensor_view(as_vec<N>(remove_bool(x.data())),
shape_step<N>(x.get_shape(), axis));
......
......@@ -345,7 +345,7 @@ TEST_CASE(compile_math)
// clang-format on
};
std::vector<std::string> data_types;
// auto vec_sizes = {2, 4, 6};
auto vec_sizes = {2, 4, 6};
for(auto&& t : migraphx::shape::types())
{
if(contains({migraphx::shape::bool_type, migraphx::shape::tuple_type}, t))
......@@ -354,9 +354,12 @@ TEST_CASE(compile_math)
if(t == migraphx::shape::half_type or t == migraphx::shape::float8_type)
name.insert(0, "migraphx::");
data_types.push_back(name);
// migraphx::transform(vec_sizes, std::back_inserter(data_types), [&](auto i) {
// return "migraphx::vec<" + name + ", " + std::to_string(i) + ">";
// });
if(t != migraphx::shape::float8_type)
{
migraphx::transform(vec_sizes, std::back_inserter(data_types), [&](auto i) {
return "migraphx::vec<" + name + ", " + std::to_string(i) + ">";
});
}
}
migraphx::shape input{migraphx::shape::float_type, {5, 2}};
migraphx::gpu::hip_compile_options options;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment