Commit 7c36034d authored by Umang Yadav's avatar Umang Yadav
Browse files

Rename float8 to fp8e4m3fnuz

parent 0672c72a
......@@ -45,7 +45,7 @@
m(int64_type, int64_t) \
m(uint32_type, uint32_t) \
m(uint64_type, uint64_t) \
m(float8_type, fp8e4m3fnuz)
m(fp8e4m3fnuz_type, fp8e4m3fnuz)
// clang-format on
#ifdef __cplusplus
......
......@@ -62,7 +62,7 @@ struct MIGRAPHX_EXPORT shape
m(int64_type, int64_t) \
m(uint32_type, uint32_t) \
m(uint64_type, uint64_t) \
m(float8_type, migraphx_fp8::fp8e4m3fnuz)
m(fp8e4m3fnuz_type, migraphx_fp8::fp8e4m3fnuz)
// clang-format on
#define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) x,
......
......@@ -56,7 +56,7 @@ vectorize vectorize::elements(std::size_t axis,
{
// disable vectorization for fp8 types
if(std::any_of(inputs.begin(), inputs.end(), [&](auto ishape) {
return ishape.type() == migraphx::shape::float8_type;
return ishape.type() == migraphx::shape::fp8e4m3fnuz_type;
}))
return {1, axis};
if(std::all_of(
......@@ -93,7 +93,7 @@ vectorize vectorize::elements(context& ctx, std::size_t axis, const std::vector<
{
// disable vectorization for fp8 types
if(std::any_of(inputs.begin(), inputs.end(), [&](auto ishape) {
return ishape.type() == migraphx::shape::float8_type;
return ishape.type() == migraphx::shape::fp8e4m3fnuz_type;
}))
return {1, axis};
if(inputs.empty())
......
......@@ -98,7 +98,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
ctx.set_exhaustive_tune_flag(options.exhaustive_tune);
std::set<shape::type_t> unsupported_types(shape::types().begin(), shape::types().end());
unsupported_types.erase(shape::type_t::float_type);
unsupported_types.erase(shape::type_t::float8_type);
unsupported_types.erase(shape::type_t::fp8e4m3fnuz_type);
unsupported_types.erase(shape::type_t::half_type);
unsupported_types.erase(shape::type_t::bool_type);
unsupported_types.erase(shape::type_t::int8_type);
......
......@@ -354,7 +354,7 @@ TEST_CASE(compile_math)
if(t == migraphx::shape::half_type)
name.insert(0, "migraphx::");
data_types.push_back(name);
if(t != migraphx::shape::float8_type)
if(t != migraphx::shape::fp8e4m3fnuz_type)
{
migraphx::transform(vec_sizes, std::back_inserter(data_types), [&](auto i) {
return "migraphx::vec<" + name + ", " + std::to_string(i) + ">";
......
......@@ -81,7 +81,7 @@ def test_create_dyn_shape():
def test_type_enum():
mgx_types = [
'bool_type', 'double_type', 'float_type', 'half_type', 'float8_type', 'int16_type',
'bool_type', 'double_type', 'float_type', 'half_type', 'fp8e4m3fnuz_type', 'int16_type',
'int32_type', 'int64_type', 'int8_type', 'uint16_type', 'uint32_type',
'uint64_type', 'uint8_type'
]
......
......@@ -40,7 +40,6 @@ struct test_abs : verify_program<test_abs<DType>>
}
};
template struct test_abs<migraphx::shape::float8_type>;
template struct test_abs<migraphx::shape::fp8e4m3fnuz_type>;
template struct test_abs<migraphx::shape::half_type>;
template struct test_abs<migraphx::shape::float_type>;
......@@ -27,7 +27,7 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
template<migraphx::shape::type_t DType>
template <migraphx::shape::type_t DType>
struct test_acos : verify_program<test_acos<DType>>
{
migraphx::program create_program() const
......@@ -41,7 +41,6 @@ struct test_acos : verify_program<test_acos<DType>>
}
};
template struct test_acos<migraphx::shape::float8_type>;
template struct test_acos<migraphx::shape::fp8e4m3fnuz_type>;
template struct test_acos<migraphx::shape::half_type>;
template struct test_acos<migraphx::shape::float_type>;
......@@ -42,6 +42,6 @@ struct test_add : verify_program<test_add<DType>>
}
};
template struct test_add<migraphx::shape::float8_type>;
template struct test_add<migraphx::shape::fp8e4m3fnuz_type>;
template struct test_add<migraphx::shape::half_type>;
template struct test_add<migraphx::shape::float_type>;
......@@ -53,4 +53,4 @@ template struct test_literal_limits<migraphx::shape::double_type, double>;
template struct test_literal_limits<migraphx::shape::half_type, migraphx::half>;
template struct test_literal_limits<migraphx::shape::int32_type, int32_t>;
template struct test_literal_limits<migraphx::shape::int8_type, int8_t>;
template struct test_literal_limits<migraphx::shape::float8_type, migraphx_fp8::fp8e4m3fnuz>;
template struct test_literal_limits<migraphx::shape::fp8e4m3fnuz_type, migraphx_fp8::fp8e4m3fnuz>;
......@@ -45,7 +45,7 @@
m(int64_type, int64_t) \
m(uint32_type, uint32_t) \
m(uint64_type, uint64_t) \
m(float8_type, migraphx_fp8::fp8e4m3fnuz)
m(fp8e4m3fnuz_type, migraphx_fp8::fp8e4m3fnuz)
// clang-format on
#ifdef __cplusplus
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment