Commit 7c36034d authored by Umang Yadav's avatar Umang Yadav
Browse files

Rename float8 to fp8e4m3fnuz

parent 0672c72a
...@@ -45,7 +45,7 @@ ...@@ -45,7 +45,7 @@
m(int64_type, int64_t) \ m(int64_type, int64_t) \
m(uint32_type, uint32_t) \ m(uint32_type, uint32_t) \
m(uint64_type, uint64_t) \ m(uint64_type, uint64_t) \
m(float8_type, fp8e4m3fnuz) m(fp8e4m3fnuz_type, fp8e4m3fnuz)
// clang-format on // clang-format on
#ifdef __cplusplus #ifdef __cplusplus
......
...@@ -62,7 +62,7 @@ struct MIGRAPHX_EXPORT shape ...@@ -62,7 +62,7 @@ struct MIGRAPHX_EXPORT shape
m(int64_type, int64_t) \ m(int64_type, int64_t) \
m(uint32_type, uint32_t) \ m(uint32_type, uint32_t) \
m(uint64_type, uint64_t) \ m(uint64_type, uint64_t) \
m(float8_type, migraphx_fp8::fp8e4m3fnuz) m(fp8e4m3fnuz_type, migraphx_fp8::fp8e4m3fnuz)
// clang-format on // clang-format on
#define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) x, #define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) x,
......
...@@ -56,7 +56,7 @@ vectorize vectorize::elements(std::size_t axis, ...@@ -56,7 +56,7 @@ vectorize vectorize::elements(std::size_t axis,
{ {
// disable vectorization for fp8 types // disable vectorization for fp8 types
if(std::any_of(inputs.begin(), inputs.end(), [&](auto ishape) { if(std::any_of(inputs.begin(), inputs.end(), [&](auto ishape) {
return ishape.type() == migraphx::shape::float8_type; return ishape.type() == migraphx::shape::fp8e4m3fnuz_type;
})) }))
return {1, axis}; return {1, axis};
if(std::all_of( if(std::all_of(
...@@ -93,7 +93,7 @@ vectorize vectorize::elements(context& ctx, std::size_t axis, const std::vector< ...@@ -93,7 +93,7 @@ vectorize vectorize::elements(context& ctx, std::size_t axis, const std::vector<
{ {
// disable vectorization for fp8 types // disable vectorization for fp8 types
if(std::any_of(inputs.begin(), inputs.end(), [&](auto ishape) { if(std::any_of(inputs.begin(), inputs.end(), [&](auto ishape) {
return ishape.type() == migraphx::shape::float8_type; return ishape.type() == migraphx::shape::fp8e4m3fnuz_type;
})) }))
return {1, axis}; return {1, axis};
if(inputs.empty()) if(inputs.empty())
......
...@@ -98,7 +98,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -98,7 +98,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
ctx.set_exhaustive_tune_flag(options.exhaustive_tune); ctx.set_exhaustive_tune_flag(options.exhaustive_tune);
std::set<shape::type_t> unsupported_types(shape::types().begin(), shape::types().end()); std::set<shape::type_t> unsupported_types(shape::types().begin(), shape::types().end());
unsupported_types.erase(shape::type_t::float_type); unsupported_types.erase(shape::type_t::float_type);
unsupported_types.erase(shape::type_t::float8_type); unsupported_types.erase(shape::type_t::fp8e4m3fnuz_type);
unsupported_types.erase(shape::type_t::half_type); unsupported_types.erase(shape::type_t::half_type);
unsupported_types.erase(shape::type_t::bool_type); unsupported_types.erase(shape::type_t::bool_type);
unsupported_types.erase(shape::type_t::int8_type); unsupported_types.erase(shape::type_t::int8_type);
......
...@@ -354,7 +354,7 @@ TEST_CASE(compile_math) ...@@ -354,7 +354,7 @@ TEST_CASE(compile_math)
if(t == migraphx::shape::half_type) if(t == migraphx::shape::half_type)
name.insert(0, "migraphx::"); name.insert(0, "migraphx::");
data_types.push_back(name); data_types.push_back(name);
if(t != migraphx::shape::float8_type) if(t != migraphx::shape::fp8e4m3fnuz_type)
{ {
migraphx::transform(vec_sizes, std::back_inserter(data_types), [&](auto i) { migraphx::transform(vec_sizes, std::back_inserter(data_types), [&](auto i) {
return "migraphx::vec<" + name + ", " + std::to_string(i) + ">"; return "migraphx::vec<" + name + ", " + std::to_string(i) + ">";
......
...@@ -81,7 +81,7 @@ def test_create_dyn_shape(): ...@@ -81,7 +81,7 @@ def test_create_dyn_shape():
def test_type_enum(): def test_type_enum():
mgx_types = [ mgx_types = [
'bool_type', 'double_type', 'float_type', 'half_type', 'float8_type', 'int16_type', 'bool_type', 'double_type', 'float_type', 'half_type', 'fp8e4m3fnuz_type', 'int16_type',
'int32_type', 'int64_type', 'int8_type', 'uint16_type', 'uint32_type', 'int32_type', 'int64_type', 'int8_type', 'uint16_type', 'uint32_type',
'uint64_type', 'uint8_type' 'uint64_type', 'uint8_type'
] ]
......
...@@ -40,7 +40,6 @@ struct test_abs : verify_program<test_abs<DType>> ...@@ -40,7 +40,6 @@ struct test_abs : verify_program<test_abs<DType>>
} }
}; };
template struct test_abs<migraphx::shape::float8_type>; template struct test_abs<migraphx::shape::fp8e4m3fnuz_type>;
template struct test_abs<migraphx::shape::half_type>; template struct test_abs<migraphx::shape::half_type>;
template struct test_abs<migraphx::shape::float_type>; template struct test_abs<migraphx::shape::float_type>;
...@@ -27,7 +27,7 @@ ...@@ -27,7 +27,7 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
template<migraphx::shape::type_t DType> template <migraphx::shape::type_t DType>
struct test_acos : verify_program<test_acos<DType>> struct test_acos : verify_program<test_acos<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
...@@ -41,7 +41,6 @@ struct test_acos : verify_program<test_acos<DType>> ...@@ -41,7 +41,6 @@ struct test_acos : verify_program<test_acos<DType>>
} }
}; };
template struct test_acos<migraphx::shape::float8_type>; template struct test_acos<migraphx::shape::fp8e4m3fnuz_type>;
template struct test_acos<migraphx::shape::half_type>; template struct test_acos<migraphx::shape::half_type>;
template struct test_acos<migraphx::shape::float_type>; template struct test_acos<migraphx::shape::float_type>;
...@@ -42,6 +42,6 @@ struct test_add : verify_program<test_add<DType>> ...@@ -42,6 +42,6 @@ struct test_add : verify_program<test_add<DType>>
} }
}; };
template struct test_add<migraphx::shape::float8_type>; template struct test_add<migraphx::shape::fp8e4m3fnuz_type>;
template struct test_add<migraphx::shape::half_type>; template struct test_add<migraphx::shape::half_type>;
template struct test_add<migraphx::shape::float_type>; template struct test_add<migraphx::shape::float_type>;
...@@ -53,4 +53,4 @@ template struct test_literal_limits<migraphx::shape::double_type, double>; ...@@ -53,4 +53,4 @@ template struct test_literal_limits<migraphx::shape::double_type, double>;
template struct test_literal_limits<migraphx::shape::half_type, migraphx::half>; template struct test_literal_limits<migraphx::shape::half_type, migraphx::half>;
template struct test_literal_limits<migraphx::shape::int32_type, int32_t>; template struct test_literal_limits<migraphx::shape::int32_type, int32_t>;
template struct test_literal_limits<migraphx::shape::int8_type, int8_t>; template struct test_literal_limits<migraphx::shape::int8_type, int8_t>;
template struct test_literal_limits<migraphx::shape::float8_type, migraphx_fp8::fp8e4m3fnuz>; template struct test_literal_limits<migraphx::shape::fp8e4m3fnuz_type, migraphx_fp8::fp8e4m3fnuz>;
...@@ -45,7 +45,7 @@ ...@@ -45,7 +45,7 @@
m(int64_type, int64_t) \ m(int64_type, int64_t) \
m(uint32_type, uint32_t) \ m(uint32_type, uint32_t) \
m(uint64_type, uint64_t) \ m(uint64_type, uint64_t) \
m(float8_type, migraphx_fp8::fp8e4m3fnuz) m(fp8e4m3fnuz_type, migraphx_fp8::fp8e4m3fnuz)
// clang-format on // clang-format on
#ifdef __cplusplus #ifdef __cplusplus
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment