"sgl-router/vscode:/vscode.git/clone" did not exist on "fd5ce576a428270e2f7ef270e9cbb4ea657ff026"
Commit 0b473ccd authored by Umang Yadav's avatar Umang Yadav
Browse files

mlir fp8

parent a6c57726
...@@ -69,7 +69,8 @@ struct ck_gemm ...@@ -69,7 +69,8 @@ struct ck_gemm
static bool is_ck_supported_type(shape::type_t t) static bool is_ck_supported_type(shape::type_t t)
{ {
return contains({shape::half_type, shape::int8_type, shape::int32_type}, t); return contains(
{shape::half_type, shape::int8_type, shape::int32_type, shape::fp8e4m3fnuz_type}, t);
} }
}; };
MIGRAPHX_REGISTER_OP(ck_gemm); MIGRAPHX_REGISTER_OP(ck_gemm);
......
...@@ -192,6 +192,8 @@ auto is_mlir_conv(mlir_mode mode) ...@@ -192,6 +192,8 @@ auto is_mlir_conv(mlir_mode mode)
return false; return false;
if(ins->get_shape().type() == shape::int8_type) if(ins->get_shape().type() == shape::int8_type)
return true; return true;
if(ins->get_shape().type() == shape::fp8e4m3fnuz_type)
return true;
if(mode == mlir_mode::int8) if(mode == mlir_mode::int8)
return false; return false;
if(mode == mlir_mode::all) if(mode == mlir_mode::all)
...@@ -246,6 +248,7 @@ struct find_mlir_fused_ops ...@@ -246,6 +248,7 @@ struct find_mlir_fused_ops
const std::initializer_list<type_t> allowed_types = {type_t::float_type, const std::initializer_list<type_t> allowed_types = {type_t::float_type,
type_t::half_type, type_t::half_type,
type_t::int8_type, type_t::int8_type,
type_t::fp8e4m3fnuz_type,
type_t::int32_type, type_t::int32_type,
type_t::bool_type}; type_t::bool_type};
// Preliminary type check. // Preliminary type check.
...@@ -284,7 +287,8 @@ struct find_mlir_fused_ops ...@@ -284,7 +287,8 @@ struct find_mlir_fused_ops
"softmax", "softmax",
"tanh", "tanh",
}; };
bool is_float = contains({type_t::float_type, type_t::half_type}, result_type); bool is_float = contains({type_t::float_type, type_t::half_type, type_t::fp8e4m3fnuz_type},
result_type);
if(contains(any_type_ops, name)) if(contains(any_type_ops, name))
return true; return true;
if(result_type != type_t::bool_type and contains(no_bool_ops, name)) if(result_type != type_t::bool_type and contains(no_bool_ops, name))
...@@ -354,8 +358,10 @@ struct find_mlir_standalone_op ...@@ -354,8 +358,10 @@ struct find_mlir_standalone_op
auto conv_based_op = r.result; auto conv_based_op = r.result;
// enable only for fp32/fp16/i8 types // enable only for fp32/fp16/i8 types
if(std::any_of(conv_based_op->inputs().begin(), conv_based_op->inputs().end(), [&](auto i) { if(std::any_of(conv_based_op->inputs().begin(), conv_based_op->inputs().end(), [&](auto i) {
return not contains( return not contains({shape::type_t::float_type,
{shape::type_t::float_type, shape::type_t::half_type, shape::type_t::int8_type}, shape::type_t::half_type,
shape::type_t::fp8e4m3fnuz_type,
shape::type_t::int8_type},
i->get_shape().type()); i->get_shape().type());
})) }))
return; return;
......
...@@ -299,6 +299,8 @@ struct mlir_program ...@@ -299,6 +299,8 @@ struct mlir_program
result = mlirF32TypeGet(ctx.get()); result = mlirF32TypeGet(ctx.get());
else if(as.type_enum() == shape::half_type) else if(as.type_enum() == shape::half_type)
result = mlirF16TypeGet(ctx.get()); result = mlirF16TypeGet(ctx.get());
else if(as.type_enum() == shape::fp8e4m3fnuz_type)
result = mlirFloat8E4M3FNUZTypeGet(ctx.get());
else if(as.type_enum() == shape::double_type) else if(as.type_enum() == shape::double_type)
result = mlirF64TypeGet(ctx.get()); result = mlirF64TypeGet(ctx.get());
else if(as.is_integral()) else if(as.is_integral())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment