Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
0b473ccd
Commit
0b473ccd
authored
Nov 26, 2023
by
Umang Yadav
Browse files
mlir fp8
parent
a6c57726
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
14 additions
and
5 deletions
+14
-5
src/targets/gpu/fuse_ck.cpp
src/targets/gpu/fuse_ck.cpp
+2
-1
src/targets/gpu/fuse_mlir.cpp
src/targets/gpu/fuse_mlir.cpp
+10
-4
src/targets/gpu/mlir.cpp
src/targets/gpu/mlir.cpp
+2
-0
No files found.
src/targets/gpu/fuse_ck.cpp
View file @
0b473ccd
...
...
@@ -69,7 +69,8 @@ struct ck_gemm
static
bool
is_ck_supported_type
(
shape
::
type_t
t
)
{
return
contains
({
shape
::
half_type
,
shape
::
int8_type
,
shape
::
int32_type
},
t
);
return
contains
(
{
shape
::
half_type
,
shape
::
int8_type
,
shape
::
int32_type
,
shape
::
fp8e4m3fnuz_type
},
t
);
}
};
MIGRAPHX_REGISTER_OP
(
ck_gemm
);
...
...
src/targets/gpu/fuse_mlir.cpp
View file @
0b473ccd
...
...
@@ -192,6 +192,8 @@ auto is_mlir_conv(mlir_mode mode)
return
false
;
if
(
ins
->
get_shape
().
type
()
==
shape
::
int8_type
)
return
true
;
if
(
ins
->
get_shape
().
type
()
==
shape
::
fp8e4m3fnuz_type
)
return
true
;
if
(
mode
==
mlir_mode
::
int8
)
return
false
;
if
(
mode
==
mlir_mode
::
all
)
...
...
@@ -246,6 +248,7 @@ struct find_mlir_fused_ops
const
std
::
initializer_list
<
type_t
>
allowed_types
=
{
type_t
::
float_type
,
type_t
::
half_type
,
type_t
::
int8_type
,
type_t
::
fp8e4m3fnuz_type
,
type_t
::
int32_type
,
type_t
::
bool_type
};
// Preliminary type check.
...
...
@@ -284,7 +287,8 @@ struct find_mlir_fused_ops
"softmax"
,
"tanh"
,
};
bool
is_float
=
contains
({
type_t
::
float_type
,
type_t
::
half_type
},
result_type
);
bool
is_float
=
contains
({
type_t
::
float_type
,
type_t
::
half_type
,
type_t
::
fp8e4m3fnuz_type
},
result_type
);
if
(
contains
(
any_type_ops
,
name
))
return
true
;
if
(
result_type
!=
type_t
::
bool_type
and
contains
(
no_bool_ops
,
name
))
...
...
@@ -354,9 +358,11 @@ struct find_mlir_standalone_op
auto
conv_based_op
=
r
.
result
;
// enable only for fp32/fp16/i8 types
if
(
std
::
any_of
(
conv_based_op
->
inputs
().
begin
(),
conv_based_op
->
inputs
().
end
(),
[
&
](
auto
i
)
{
return
not
contains
(
{
shape
::
type_t
::
float_type
,
shape
::
type_t
::
half_type
,
shape
::
type_t
::
int8_type
},
i
->
get_shape
().
type
());
return
not
contains
({
shape
::
type_t
::
float_type
,
shape
::
type_t
::
half_type
,
shape
::
type_t
::
fp8e4m3fnuz_type
,
shape
::
type_t
::
int8_type
},
i
->
get_shape
().
type
());
}))
return
;
...
...
src/targets/gpu/mlir.cpp
View file @
0b473ccd
...
...
@@ -299,6 +299,8 @@ struct mlir_program
result
=
mlirF32TypeGet
(
ctx
.
get
());
else
if
(
as
.
type_enum
()
==
shape
::
half_type
)
result
=
mlirF16TypeGet
(
ctx
.
get
());
else
if
(
as
.
type_enum
()
==
shape
::
fp8e4m3fnuz_type
)
result
=
mlirFloat8E4M3FNUZTypeGet
(
ctx
.
get
());
else
if
(
as
.
type_enum
()
==
shape
::
double_type
)
result
=
mlirF64TypeGet
(
ctx
.
get
());
else
if
(
as
.
is_integral
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment