Commit 575fc04a authored by Umang Yadav's avatar Umang Yadav
Browse files

mlir fp8

parent ce61ea6b
......@@ -69,7 +69,8 @@ struct ck_gemm
static bool is_ck_supported_type(shape::type_t t)
{
return contains({shape::half_type, shape::int8_type, shape::int32_type}, t);
return contains(
{shape::half_type, shape::int8_type, shape::int32_type, shape::fp8e4m3fnuz_type}, t);
}
};
MIGRAPHX_REGISTER_OP(ck_gemm);
......
......@@ -223,6 +223,8 @@ auto is_mlir_conv(mlir_mode mode)
return false;
if(ins->get_shape().type() == shape::int8_type)
return true;
if(ins->get_shape().type() == shape::fp8e4m3fnuz_type)
return true;
if(mode == mlir_mode::int8)
return false;
if(mode == mlir_mode::all)
......@@ -288,6 +290,7 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i)
const auto result_type = i.get_shape().type();
const std::initializer_list<type_t> allowed_types = {type_t::float_type,
type_t::half_type,
type_t::fp8e4m3fnuz_type,
type_t::int8_type,
type_t::int32_type,
type_t::bool_type};
......@@ -327,7 +330,7 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i)
"softmax",
"tanh",
};
bool is_float = contains({type_t::float_type, type_t::half_type}, result_type);
bool is_float = contains({type_t::float_type, type_t::half_type, type_t::fp8e4m3fnuz_type}, result_type);
if(contains(any_type_ops, name))
return true;
if(result_type != type_t::bool_type and contains(no_bool_ops, name))
......@@ -404,7 +407,7 @@ struct find_mlir_standalone_op
// enable only for fp32/fp16/i8 types
if(std::any_of(gemm_based_op->inputs().begin(), gemm_based_op->inputs().end(), [&](auto i) {
return not contains(
{shape::type_t::float_type, shape::type_t::half_type, shape::type_t::int8_type},
{shape::type_t::float_type, shape::type_t::half_type, shape::type_t::int8_type, shape::type_t::fp8e4m3fnuz_type},
i->get_shape().type());
}))
return;
......
......@@ -300,6 +300,8 @@ struct mlir_program
result = mlirF32TypeGet(ctx.get());
else if(as.type_enum() == shape::half_type)
result = mlirF16TypeGet(ctx.get());
else if(as.type_enum() == shape::fp8e4m3fnuz_type)
result = mlirFloat8E4M3FNUZTypeGet(ctx.get());
else if(as.type_enum() == shape::double_type)
result = mlirF64TypeGet(ctx.get());
else if(as.is_integral())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment