Commit 0b473ccd authored by Umang Yadav's avatar Umang Yadav
Browse files

mlir fp8

parent a6c57726
......@@ -69,7 +69,8 @@ struct ck_gemm
static bool is_ck_supported_type(shape::type_t t)
{
return contains({shape::half_type, shape::int8_type, shape::int32_type}, t);
return contains(
{shape::half_type, shape::int8_type, shape::int32_type, shape::fp8e4m3fnuz_type}, t);
}
};
MIGRAPHX_REGISTER_OP(ck_gemm);
......
......@@ -192,6 +192,8 @@ auto is_mlir_conv(mlir_mode mode)
return false;
if(ins->get_shape().type() == shape::int8_type)
return true;
if(ins->get_shape().type() == shape::fp8e4m3fnuz_type)
return true;
if(mode == mlir_mode::int8)
return false;
if(mode == mlir_mode::all)
......@@ -246,6 +248,7 @@ struct find_mlir_fused_ops
const std::initializer_list<type_t> allowed_types = {type_t::float_type,
type_t::half_type,
type_t::int8_type,
type_t::fp8e4m3fnuz_type,
type_t::int32_type,
type_t::bool_type};
// Preliminary type check.
......@@ -284,7 +287,8 @@ struct find_mlir_fused_ops
"softmax",
"tanh",
};
bool is_float = contains({type_t::float_type, type_t::half_type}, result_type);
bool is_float = contains({type_t::float_type, type_t::half_type, type_t::fp8e4m3fnuz_type},
result_type);
if(contains(any_type_ops, name))
return true;
if(result_type != type_t::bool_type and contains(no_bool_ops, name))
......@@ -354,9 +358,11 @@ struct find_mlir_standalone_op
auto conv_based_op = r.result;
// enable only for fp32/fp16/i8 types
if(std::any_of(conv_based_op->inputs().begin(), conv_based_op->inputs().end(), [&](auto i) {
return not contains(
{shape::type_t::float_type, shape::type_t::half_type, shape::type_t::int8_type},
i->get_shape().type());
return not contains({shape::type_t::float_type,
shape::type_t::half_type,
shape::type_t::fp8e4m3fnuz_type,
shape::type_t::int8_type},
i->get_shape().type());
}))
return;
......
......@@ -299,6 +299,8 @@ struct mlir_program
result = mlirF32TypeGet(ctx.get());
else if(as.type_enum() == shape::half_type)
result = mlirF16TypeGet(ctx.get());
else if(as.type_enum() == shape::fp8e4m3fnuz_type)
result = mlirFloat8E4M3FNUZTypeGet(ctx.get());
else if(as.type_enum() == shape::double_type)
result = mlirF64TypeGet(ctx.get());
else if(as.is_integral())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment