Commit 575fc04a authored by Umang Yadav's avatar Umang Yadav
Browse files

mlir fp8

parent ce61ea6b
...@@ -69,7 +69,8 @@ struct ck_gemm ...@@ -69,7 +69,8 @@ struct ck_gemm
static bool is_ck_supported_type(shape::type_t t) static bool is_ck_supported_type(shape::type_t t)
{ {
return contains({shape::half_type, shape::int8_type, shape::int32_type}, t); return contains(
{shape::half_type, shape::int8_type, shape::int32_type, shape::fp8e4m3fnuz_type}, t);
} }
}; };
MIGRAPHX_REGISTER_OP(ck_gemm); MIGRAPHX_REGISTER_OP(ck_gemm);
......
...@@ -223,6 +223,8 @@ auto is_mlir_conv(mlir_mode mode) ...@@ -223,6 +223,8 @@ auto is_mlir_conv(mlir_mode mode)
return false; return false;
if(ins->get_shape().type() == shape::int8_type) if(ins->get_shape().type() == shape::int8_type)
return true; return true;
if(ins->get_shape().type() == shape::fp8e4m3fnuz_type)
return true;
if(mode == mlir_mode::int8) if(mode == mlir_mode::int8)
return false; return false;
if(mode == mlir_mode::all) if(mode == mlir_mode::all)
...@@ -288,6 +290,7 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i) ...@@ -288,6 +290,7 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i)
const auto result_type = i.get_shape().type(); const auto result_type = i.get_shape().type();
const std::initializer_list<type_t> allowed_types = {type_t::float_type, const std::initializer_list<type_t> allowed_types = {type_t::float_type,
type_t::half_type, type_t::half_type,
type_t::fp8e4m3fnuz_type,
type_t::int8_type, type_t::int8_type,
type_t::int32_type, type_t::int32_type,
type_t::bool_type}; type_t::bool_type};
...@@ -327,7 +330,7 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i) ...@@ -327,7 +330,7 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i)
"softmax", "softmax",
"tanh", "tanh",
}; };
bool is_float = contains({type_t::float_type, type_t::half_type}, result_type); bool is_float = contains({type_t::float_type, type_t::half_type, type_t::fp8e4m3fnuz_type}, result_type);
if(contains(any_type_ops, name)) if(contains(any_type_ops, name))
return true; return true;
if(result_type != type_t::bool_type and contains(no_bool_ops, name)) if(result_type != type_t::bool_type and contains(no_bool_ops, name))
...@@ -404,7 +407,7 @@ struct find_mlir_standalone_op ...@@ -404,7 +407,7 @@ struct find_mlir_standalone_op
// enable only for fp32/fp16/i8 types // enable only for fp32/fp16/i8 types
if(std::any_of(gemm_based_op->inputs().begin(), gemm_based_op->inputs().end(), [&](auto i) { if(std::any_of(gemm_based_op->inputs().begin(), gemm_based_op->inputs().end(), [&](auto i) {
return not contains( return not contains(
{shape::type_t::float_type, shape::type_t::half_type, shape::type_t::int8_type}, {shape::type_t::float_type, shape::type_t::half_type, shape::type_t::int8_type, shape::type_t::fp8e4m3fnuz_type},
i->get_shape().type()); i->get_shape().type());
})) }))
return; return;
......
...@@ -300,6 +300,8 @@ struct mlir_program ...@@ -300,6 +300,8 @@ struct mlir_program
result = mlirF32TypeGet(ctx.get()); result = mlirF32TypeGet(ctx.get());
else if(as.type_enum() == shape::half_type) else if(as.type_enum() == shape::half_type)
result = mlirF16TypeGet(ctx.get()); result = mlirF16TypeGet(ctx.get());
else if(as.type_enum() == shape::fp8e4m3fnuz_type)
result = mlirFloat8E4M3FNUZTypeGet(ctx.get());
else if(as.type_enum() == shape::double_type) else if(as.type_enum() == shape::double_type)
result = mlirF64TypeGet(ctx.get()); result = mlirF64TypeGet(ctx.get());
else if(as.is_integral()) else if(as.is_integral())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment