Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
68dd3bb4
Commit
68dd3bb4
authored
Dec 08, 2023
by
Artur Wojcik
Browse files
Merge branch 'develop' into uif2-initial
parents
8d7a8a6c
7e53592e
Changes
38
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
178 additions
and
55 deletions
+178
-55
docs/.sphinx/requirements.txt
docs/.sphinx/requirements.txt
+1
-1
src/include/migraphx/op/dequantizelinear.hpp
src/include/migraphx/op/dequantizelinear.hpp
+2
-2
src/include/migraphx/op/quant_convolution.hpp
src/include/migraphx/op/quant_convolution.hpp
+11
-5
src/include/migraphx/op/quantizelinear.hpp
src/include/migraphx/op/quantizelinear.hpp
+4
-4
src/module.cpp
src/module.cpp
+9
-0
src/onnx/onnx_parser.cpp
src/onnx/onnx_parser.cpp
+5
-1
src/rewrite_quantization.cpp
src/rewrite_quantization.cpp
+4
-4
src/simplify_qdq.cpp
src/simplify_qdq.cpp
+20
-16
src/targets/gpu/device_name.cpp
src/targets/gpu/device_name.cpp
+6
-0
src/targets/gpu/fuse_mlir.cpp
src/targets/gpu/fuse_mlir.cpp
+18
-6
src/targets/gpu/include/migraphx/gpu/device_name.hpp
src/targets/gpu/include/migraphx/gpu/device_name.hpp
+2
-0
src/targets/gpu/mlir.cpp
src/targets/gpu/mlir.cpp
+2
-0
src/targets/gpu/rocblas.cpp
src/targets/gpu/rocblas.cpp
+1
-2
src/targets/gpu/target.cpp
src/targets/gpu/target.cpp
+8
-0
test/simplify_qdq_test.cpp
test/simplify_qdq_test.cpp
+56
-1
test/verify/main.cpp
test/verify/main.cpp
+0
-1
test/verify/quant_conv.cpp
test/verify/quant_conv.cpp
+7
-3
test/verify/quant_conv_1.cpp
test/verify/quant_conv_1.cpp
+7
-3
test/verify/quant_conv_1d.cpp
test/verify/quant_conv_1d.cpp
+8
-3
test/verify/quant_conv_2.cpp
test/verify/quant_conv_2.cpp
+7
-3
No files found.
docs/.sphinx/requirements.txt
View file @
68dd3bb4
...
@@ -89,7 +89,7 @@ requests==2.28.2
...
@@ -89,7 +89,7 @@ requests==2.28.2
# via
# via
# pygithub
# pygithub
# sphinx
# sphinx
rocm-docs-core==0.30.
0
rocm-docs-core==0.30.
1
# via -r requirements.in
# via -r requirements.in
smmap==5.0.0
smmap==5.0.0
# via gitdb
# via gitdb
...
...
src/include/migraphx/op/dequantizelinear.hpp
View file @
68dd3bb4
...
@@ -72,8 +72,8 @@ struct dequantizelinear
...
@@ -72,8 +72,8 @@ struct dequantizelinear
visit_all
(
x
,
x_zero_point
)([
&
](
auto
input
,
auto
zero_pts
)
{
visit_all
(
x
,
x_zero_point
)([
&
](
auto
input
,
auto
zero_pts
)
{
visit_all
(
result
,
x_scale
)([
&
](
auto
output
,
auto
scales
)
{
visit_all
(
result
,
x_scale
)([
&
](
auto
output
,
auto
scales
)
{
par_for
(
output_shape
.
elements
(),
[
&
](
auto
i
)
{
par_for
(
output_shape
.
elements
(),
[
&
](
auto
i
)
{
output
[
i
]
=
static_cast
<
double
>
(
static_cast
<
int64_t
>
(
input
[
i
])
-
output
[
i
]
=
static_cast
<
double
>
(
static_cast
<
double
>
(
input
[
i
])
-
static_cast
<
int64_t
>
(
zero_pts
[
i
]))
*
static_cast
<
double
>
(
zero_pts
[
i
]))
*
scales
[
i
];
scales
[
i
];
});
});
});
});
...
...
src/include/migraphx/op/quant_convolution.hpp
View file @
68dd3bb4
...
@@ -27,6 +27,7 @@
...
@@ -27,6 +27,7 @@
#include <migraphx/op/common.hpp>
#include <migraphx/op/common.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/convolution.hpp>
#include <migraphx/convolution.hpp>
#include <migraphx/value.hpp>
#include <migraphx/value.hpp>
...
@@ -87,11 +88,13 @@ struct quant_convolution
...
@@ -87,11 +88,13 @@ struct quant_convolution
}
}
// all input type must be int8_type and output is float_type
// all input type must be int8_type and output is float_type
if
(
t
!=
shape
::
int8_type
)
std
::
set
<
migraphx
::
shape
::
type_t
>
supported_types
=
{
shape
::
int8_type
,
shape
::
fp8e4m3fnuz_type
};
if
(
not
contains
(
supported_types
,
t
))
{
{
MIGRAPHX_THROW
(
"QUANT_CONVOLUTION: only accept input and weights of type int8_t"
);
MIGRAPHX_THROW
(
"QUANT_CONVOLUTION: only accept input and weights of type int8_t or "
"fp8e4m3fnuz_type"
);
}
}
t
=
shape
::
int32_type
;
std
::
vector
<
size_t
>
output_lens
{
input
.
lens
()[
0
],
weights
.
lens
()[
0
]};
std
::
vector
<
size_t
>
output_lens
{
input
.
lens
()[
0
],
weights
.
lens
()[
0
]};
auto
padding_size
=
padding
.
size
();
auto
padding_size
=
padding
.
size
();
...
@@ -107,8 +110,11 @@ struct quant_convolution
...
@@ -107,8 +110,11 @@ struct quant_convolution
stride
[
i
]
+
stride
[
i
]
+
1
)));
1
)));
}
}
if
(
t
==
shape
::
int8_type
)
return
inputs
[
0
].
with_lens
(
t
,
output_lens
);
{
return
inputs
[
0
].
with_lens
(
shape
::
int32_type
,
output_lens
);
}
// else fp8 conv
return
inputs
[
0
].
with_lens
(
shape
::
float_type
,
output_lens
);
}
}
size_t
kdims
()
const
size_t
kdims
()
const
...
...
src/include/migraphx/op/quantizelinear.hpp
View file @
68dd3bb4
...
@@ -80,10 +80,10 @@ struct quantizelinear
...
@@ -80,10 +80,10 @@ struct quantizelinear
auto
min_value
=
std
::
numeric_limits
<
quant_type
>::
min
();
auto
min_value
=
std
::
numeric_limits
<
quant_type
>::
min
();
auto
max_value
=
std
::
numeric_limits
<
quant_type
>::
max
();
auto
max_value
=
std
::
numeric_limits
<
quant_type
>::
max
();
par_for
(
output_shape
.
elements
(),
[
&
](
auto
i
)
{
par_for
(
output_shape
.
elements
(),
[
&
](
auto
i
)
{
int64_t
quantized
=
static_cast
<
int64_t
>
(
std
::
nearbyint
(
input
[
i
]
/
scales
[
i
]))
+
double
quantized
=
static_cast
<
double
>
(
std
::
nearbyint
(
input
[
i
]
/
scales
[
i
]))
+
static_cast
<
int64_t
>
(
zero_pts
[
i
]);
static_cast
<
double
>
(
zero_pts
[
i
]);
output
[
i
]
=
std
::
max
(
static_cast
<
int64_t
>
(
min_value
),
output
[
i
]
=
std
::
max
(
static_cast
<
double
>
(
min_value
),
std
::
min
(
static_cast
<
int64_t
>
(
max_value
),
quantized
));
std
::
min
(
static_cast
<
double
>
(
max_value
),
quantized
));
});
});
});
});
});
});
...
...
src/module.cpp
View file @
68dd3bb4
...
@@ -669,6 +669,15 @@ void module::finalize(std::vector<context>& contexts)
...
@@ -669,6 +669,15 @@ void module::finalize(std::vector<context>& contexts)
smod
->
finalize
(
contexts
);
smod
->
finalize
(
contexts
);
}
}
}
}
#ifndef BUILD_DEV
if
(
std
::
any_of
(
this
->
begin
(),
this
->
end
(),
[](
const
auto
i
)
{
return
i
.
get_shape
().
type
()
==
migraphx
::
shape
::
fp8e4m3fnuz_type
;
}))
{
std
::
cout
<<
"[Warning] : MIGraphX has BETA support for FP8. Using FP8 may result in "
"incorrect final outputs
\n
"
;
}
#endif
// Warn when an instruction is not normalized
// Warn when an instruction is not normalized
auto
ins
=
std
::
find_if
(
begin
(),
end
(),
[](
auto
&
i
)
{
return
i
.
need_normalization
();
});
auto
ins
=
std
::
find_if
(
begin
(),
end
(),
[](
auto
&
i
)
{
return
i
.
need_normalization
();
});
...
...
src/onnx/onnx_parser.cpp
View file @
68dd3bb4
...
@@ -625,7 +625,11 @@ shape::type_t get_type(int dtype)
...
@@ -625,7 +625,11 @@ shape::type_t get_type(int dtype)
case
11
:
return
shape
::
double_type
;
case
11
:
return
shape
::
double_type
;
case
12
:
return
shape
::
uint32_type
;
case
12
:
return
shape
::
uint32_type
;
case
13
:
return
shape
::
uint64_type
;
case
13
:
return
shape
::
uint64_type
;
case
18
:
return
shape
::
fp8e4m3fnuz_type
;
case
18
:
{
std
::
cout
<<
"[Warning] : MIGraphX has BETA support for FP8. Using FP8 may result in "
"incorrect final outputs
\n
"
;
return
shape
::
fp8e4m3fnuz_type
;
}
case
14
:
case
14
:
case
15
:
case
15
:
case
16
:
case
16
:
...
...
src/rewrite_quantization.cpp
View file @
68dd3bb4
...
@@ -58,8 +58,8 @@ void apply_quantizelinear(module& m, instruction_ref ins)
...
@@ -58,8 +58,8 @@ void apply_quantizelinear(module& m, instruction_ref ins)
add_zero_point
=
m
.
insert_instruction
(
ins
,
make_op
(
"add"
),
add_zero_point
,
zero_point
);
add_zero_point
=
m
.
insert_instruction
(
ins
,
make_op
(
"add"
),
add_zero_point
,
zero_point
);
}
}
int64_t
max_quant
=
0
;
double
max_quant
=
0
;
int64_t
min_quant
=
0
;
double
min_quant
=
0
;
ins
->
get_shape
().
visit_type
([
&
](
auto
qt
)
{
ins
->
get_shape
().
visit_type
([
&
](
auto
qt
)
{
max_quant
=
qt
.
max
();
max_quant
=
qt
.
max
();
min_quant
=
qt
.
min
();
min_quant
=
qt
.
min
();
...
@@ -70,8 +70,8 @@ void apply_quantizelinear(module& m, instruction_ref ins)
...
@@ -70,8 +70,8 @@ void apply_quantizelinear(module& m, instruction_ref ins)
if
(
enabled
(
MIGRAPHX_ENABLE_CK_WORKAROUNDS
{}))
if
(
enabled
(
MIGRAPHX_ENABLE_CK_WORKAROUNDS
{}))
{
{
std
::
vector
<
int
>
min_data
(
s
.
elements
(),
min_quant
);
std
::
vector
<
double
>
min_data
(
s
.
elements
(),
min_quant
);
std
::
vector
<
int
>
max_data
(
s
.
elements
(),
max_quant
);
std
::
vector
<
double
>
max_data
(
s
.
elements
(),
max_quant
);
min_arg
=
m
.
add_literal
(
literal
(
s
,
min_data
));
min_arg
=
m
.
add_literal
(
literal
(
s
,
min_data
));
max_arg
=
m
.
add_literal
(
literal
(
s
,
max_data
));
max_arg
=
m
.
add_literal
(
literal
(
s
,
max_data
));
}
}
...
...
src/simplify_qdq.cpp
View file @
68dd3bb4
...
@@ -82,18 +82,21 @@ struct match_find_quantizable_ops
...
@@ -82,18 +82,21 @@ struct match_find_quantizable_ops
// Helper function to insert quantized versions of any broadcasts and transpose ops that
// Helper function to insert quantized versions of any broadcasts and transpose ops that
// occur between dequantizelinear and the quantized op
// occur between dequantizelinear and the quantized op
static
auto
static
auto
propagate_quantized_ins
(
module
&
m
,
const
instruction_ref
dqins
,
const
instruction_ref
qop
)
propagate_quantized_ins
(
module
&
m
,
const
instruction_ref
dqins
,
const
instruction_ref
qop
_arg
)
{
{
auto
qinp
=
dqins
->
inputs
().
front
();
auto
prev_ins
=
qop_arg
;
auto
next_ins
=
dqins
;
std
::
vector
<
instruction_ref
>
ins_inbetween
;
// matcher skips continguous, multi/broadcasts and transposes, collect all those
while
(
next_ins
!=
qop
)
// instructions
{
while
(
prev_ins
!=
dqins
)
if
(
next_ins
->
name
()
!=
"dequantizelinear"
)
{
{
qinp
=
m
.
insert_instruction
(
qop
,
next_ins
->
get_operator
(),
qinp
);
ins_inbetween
.
push_back
(
prev_ins
);
prev_ins
=
prev_ins
->
inputs
().
front
();
}
}
next_ins
=
next_ins
->
outputs
().
front
();
auto
qinp
=
dqins
->
inputs
().
front
();
for
(
auto
ins
:
reverse_iterator_for
(
ins_inbetween
))
{
qinp
=
m
.
insert_instruction
(
dqins
,
(
*
ins
)
->
get_operator
(),
{
qinp
});
}
}
return
qinp
;
return
qinp
;
}
}
...
@@ -124,10 +127,11 @@ struct match_find_quantizable_ops
...
@@ -124,10 +127,11 @@ struct match_find_quantizable_ops
auto
scale2
=
r
.
instructions
[
"scale2"
];
auto
scale2
=
r
.
instructions
[
"scale2"
];
auto
zp1
=
r
.
instructions
[
"zp1"
];
auto
zp1
=
r
.
instructions
[
"zp1"
];
auto
zp2
=
r
.
instructions
[
"zp2"
];
auto
zp2
=
r
.
instructions
[
"zp2"
];
// Only INT8 or FP8 type currently supported
// Only INT8 type currently supported
std
::
set
<
migraphx
::
shape
::
type_t
>
supported_types
=
{
migraphx
::
shape
::
fp8e4m3fnuz_type
,
if
(
dq1
->
inputs
().
front
()
->
get_shape
().
type
()
!=
migraphx
::
shape
::
int8_type
or
migraphx
::
shape
::
int8_type
};
dq2
->
inputs
().
front
()
->
get_shape
().
type
()
!=
migraphx
::
shape
::
int8_type
)
if
(
not
contains
(
supported_types
,
dq1
->
inputs
().
front
()
->
get_shape
().
type
())
or
not
contains
(
supported_types
,
dq2
->
inputs
().
front
()
->
get_shape
().
type
()))
return
;
return
;
// Only symmetric quantization supported (ie. non-zero zero_points not allowed)
// Only symmetric quantization supported (ie. non-zero zero_points not allowed)
...
@@ -140,8 +144,8 @@ struct match_find_quantizable_ops
...
@@ -140,8 +144,8 @@ struct match_find_quantizable_ops
// Propagate q1 and q2 through any broadcasts and transposes before qop
// Propagate q1 and q2 through any broadcasts and transposes before qop
auto
qop_args
=
qop
->
inputs
();
auto
qop_args
=
qop
->
inputs
();
qop_args
.
at
(
0
)
=
propagate_quantized_ins
(
m
,
dq1
,
qop
);
qop_args
.
at
(
0
)
=
propagate_quantized_ins
(
m
,
dq1
,
qop
_args
[
0
]
);
qop_args
.
at
(
1
)
=
propagate_quantized_ins
(
m
,
dq2
,
qop
);
qop_args
.
at
(
1
)
=
propagate_quantized_ins
(
m
,
dq2
,
qop
_args
[
1
]
);
instruction_ref
dq
;
instruction_ref
dq
;
instruction_ref
out_scale
;
instruction_ref
out_scale
;
instruction_ref
zero_point
;
instruction_ref
zero_point
;
...
...
src/targets/gpu/device_name.cpp
View file @
68dd3bb4
...
@@ -49,6 +49,12 @@ std::string get_device_name()
...
@@ -49,6 +49,12 @@ std::string get_device_name()
return
props
.
gcnArchName
;
return
props
.
gcnArchName
;
}
}
bool
gfx_has_fp8_intrinsics
()
{
const
auto
device_name
=
trim
(
split_string
(
get_device_name
(),
':'
).
front
());
return
(
starts_with
(
device_name
,
"gfx9"
)
and
device_name
>=
"gfx940"
);
}
}
// namespace gpu
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
src/targets/gpu/fuse_mlir.cpp
View file @
68dd3bb4
...
@@ -218,6 +218,7 @@ auto is_mlir_conv(mlir_mode mode)
...
@@ -218,6 +218,7 @@ auto is_mlir_conv(mlir_mode mode)
return
false
;
return
false
;
if
(
ins
->
name
()
!=
"convolution"
and
ins
->
name
()
!=
"quant_convolution"
)
if
(
ins
->
name
()
!=
"convolution"
and
ins
->
name
()
!=
"quant_convolution"
)
return
false
;
return
false
;
auto
input_arg_t
=
ins
->
inputs
().
front
()
->
get_shape
().
type
();
value
v
=
ins
->
get_operator
().
to_value
();
value
v
=
ins
->
get_operator
().
to_value
();
auto
group
=
v
.
at
(
"group"
).
to
<
int
>
();
auto
group
=
v
.
at
(
"group"
).
to
<
int
>
();
if
(
group
!=
1
)
if
(
group
!=
1
)
...
@@ -225,6 +226,10 @@ auto is_mlir_conv(mlir_mode mode)
...
@@ -225,6 +226,10 @@ auto is_mlir_conv(mlir_mode mode)
// Avoid MLIR assertion: Index < Length && "Invalid index!"
// Avoid MLIR assertion: Index < Length && "Invalid index!"
if
(
ins
->
get_shape
().
lens
().
size
()
!=
4
)
if
(
ins
->
get_shape
().
lens
().
size
()
!=
4
)
return
false
;
return
false
;
if
(
ins
->
get_shape
().
type
()
==
shape
::
fp8e4m3fnuz_type
)
return
true
;
if
(
ins
->
get_shape
().
type
()
==
shape
::
float_type
and
input_arg_t
==
shape
::
fp8e4m3fnuz_type
)
return
true
;
if
(
ins
->
get_shape
().
type
()
==
shape
::
int8_type
)
if
(
ins
->
get_shape
().
type
()
==
shape
::
int8_type
)
return
true
;
return
true
;
if
(
mode
==
mlir_mode
::
int8
)
if
(
mode
==
mlir_mode
::
int8
)
...
@@ -292,6 +297,7 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i)
...
@@ -292,6 +297,7 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i)
const
auto
result_type
=
i
.
get_shape
().
type
();
const
auto
result_type
=
i
.
get_shape
().
type
();
const
std
::
initializer_list
<
type_t
>
allowed_types
=
{
type_t
::
float_type
,
const
std
::
initializer_list
<
type_t
>
allowed_types
=
{
type_t
::
float_type
,
type_t
::
half_type
,
type_t
::
half_type
,
type_t
::
fp8e4m3fnuz_type
,
type_t
::
int8_type
,
type_t
::
int8_type
,
type_t
::
int32_type
,
type_t
::
int32_type
,
type_t
::
bool_type
};
type_t
::
bool_type
};
...
@@ -331,7 +337,8 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i)
...
@@ -331,7 +337,8 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i)
"softmax"
,
"softmax"
,
"tanh"
,
"tanh"
,
};
};
bool
is_float
=
contains
({
type_t
::
float_type
,
type_t
::
half_type
},
result_type
);
bool
is_float
=
contains
({
type_t
::
float_type
,
type_t
::
half_type
,
type_t
::
fp8e4m3fnuz_type
},
result_type
);
if
(
contains
(
any_type_ops
,
name
))
if
(
contains
(
any_type_ops
,
name
))
return
true
;
return
true
;
if
(
result_type
!=
type_t
::
bool_type
and
contains
(
no_bool_ops
,
name
))
if
(
result_type
!=
type_t
::
bool_type
and
contains
(
no_bool_ops
,
name
))
...
@@ -342,6 +349,10 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i)
...
@@ -342,6 +349,10 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i)
// supported.
// supported.
if
(
is_float
and
name
==
"convert"
)
if
(
is_float
and
name
==
"convert"
)
{
{
if
(
result_type
==
shape
::
fp8e4m3fnuz_type
)
{
return
false
;
}
// else
return
std
::
all_of
(
i
.
inputs
().
begin
(),
i
.
inputs
().
end
(),
[](
const
auto
&
arg
)
{
return
std
::
all_of
(
i
.
inputs
().
begin
(),
i
.
inputs
().
end
(),
[](
const
auto
&
arg
)
{
return
contains
({
type_t
::
float_type
,
type_t
::
half_type
},
arg
->
get_shape
().
type
());
return
contains
({
type_t
::
float_type
,
type_t
::
half_type
},
arg
->
get_shape
().
type
());
});
});
...
@@ -404,11 +415,12 @@ struct find_mlir_standalone_op
...
@@ -404,11 +415,12 @@ struct find_mlir_standalone_op
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
{
auto
gemm_based_op
=
r
.
result
;
auto
gemm_based_op
=
r
.
result
;
//
// enable only for fp32/fp16/i8/fp8 types
// enable only for fp32/fp16/i8 types
if
(
std
::
any_of
(
gemm_based_op
->
inputs
().
begin
(),
gemm_based_op
->
inputs
().
end
(),
[
&
](
auto
i
)
{
if
(
std
::
any_of
(
gemm_based_op
->
inputs
().
begin
(),
gemm_based_op
->
inputs
().
end
(),
[
&
](
auto
i
)
{
return
not
contains
(
return
not
contains
({
shape
::
type_t
::
float_type
,
{
shape
::
type_t
::
float_type
,
shape
::
type_t
::
half_type
,
shape
::
type_t
::
int8_type
},
shape
::
type_t
::
half_type
,
shape
::
type_t
::
int8_type
,
shape
::
type_t
::
fp8e4m3fnuz_type
},
i
->
get_shape
().
type
());
i
->
get_shape
().
type
());
}))
}))
return
;
return
;
...
...
src/targets/gpu/include/migraphx/gpu/device_name.hpp
View file @
68dd3bb4
...
@@ -37,6 +37,8 @@ MIGRAPHX_GPU_EXPORT std::string get_device_name();
...
@@ -37,6 +37,8 @@ MIGRAPHX_GPU_EXPORT std::string get_device_name();
MIGRAPHX_GPU_EXPORT
int
get_device_id
();
MIGRAPHX_GPU_EXPORT
int
get_device_id
();
MIGRAPHX_GPU_EXPORT
bool
gfx_has_fp8_intrinsics
();
}
// namespace gpu
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
...
...
src/targets/gpu/mlir.cpp
View file @
68dd3bb4
...
@@ -300,6 +300,8 @@ struct mlir_program
...
@@ -300,6 +300,8 @@ struct mlir_program
result
=
mlirF32TypeGet
(
ctx
.
get
());
result
=
mlirF32TypeGet
(
ctx
.
get
());
else
if
(
as
.
type_enum
()
==
shape
::
half_type
)
else
if
(
as
.
type_enum
()
==
shape
::
half_type
)
result
=
mlirF16TypeGet
(
ctx
.
get
());
result
=
mlirF16TypeGet
(
ctx
.
get
());
else
if
(
as
.
type_enum
()
==
shape
::
fp8e4m3fnuz_type
)
result
=
mlirFloat8E4M3FNUZTypeGet
(
ctx
.
get
());
else
if
(
as
.
type_enum
()
==
shape
::
double_type
)
else
if
(
as
.
type_enum
()
==
shape
::
double_type
)
result
=
mlirF64TypeGet
(
ctx
.
get
());
result
=
mlirF64TypeGet
(
ctx
.
get
());
else
if
(
as
.
is_integral
())
else
if
(
as
.
is_integral
())
...
...
src/targets/gpu/rocblas.cpp
View file @
68dd3bb4
...
@@ -58,8 +58,7 @@ bool rocblas_fp8_available()
...
@@ -58,8 +58,7 @@ bool rocblas_fp8_available()
#ifndef MIGRAPHX_USE_ROCBLAS_FP8_API
#ifndef MIGRAPHX_USE_ROCBLAS_FP8_API
return
false
;
return
false
;
#else
#else
const
auto
device_name
=
trim
(
split_string
(
get_device_name
(),
':'
).
front
());
return
gfx_has_fp8_intrinsics
();
return
(
starts_with
(
device_name
,
"gfx9"
)
and
device_name
>=
"gfx940"
);
#endif
#endif
}
}
...
...
src/targets/gpu/target.cpp
View file @
68dd3bb4
...
@@ -105,11 +105,19 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
...
@@ -105,11 +105,19 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
unsupported_types
.
erase
(
shape
::
type_t
::
uint8_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
uint8_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
int32_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
int32_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
tuple_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
tuple_type
);
// whiltelist supported Ops for the FP8
std
::
set
<
std
::
string
>
unsupported_fp8_ops
=
{};
std
::
set
<
std
::
string
>
unsupported_fp8_ops
=
{};
if
(
not
gpu
::
rocblas_fp8_available
())
if
(
not
gpu
::
rocblas_fp8_available
())
{
{
unsupported_fp8_ops
.
insert
(
"dot"
);
unsupported_fp8_ops
.
insert
(
"dot"
);
}
}
// MIOpen doesn't have support for fp8 pooling yet.
unsupported_fp8_ops
.
insert
(
"pooling"
);
if
(
not
gpu
::
gfx_has_fp8_intrinsics
())
{
unsupported_fp8_ops
.
insert
(
"convolution"
);
unsupported_fp8_ops
.
insert
(
"quant_convolution"
);
}
// add all device kernels
// add all device kernels
unsupported_fp8_ops
.
insert
(
"logsoftmax"
);
unsupported_fp8_ops
.
insert
(
"logsoftmax"
);
unsupported_fp8_ops
.
insert
(
"nonzero"
);
unsupported_fp8_ops
.
insert
(
"nonzero"
);
...
...
test/simplify_qdq_test.cpp
View file @
68dd3bb4
...
@@ -527,6 +527,62 @@ TEST_CASE(dot_add)
...
@@ -527,6 +527,62 @@ TEST_CASE(dot_add)
EXPECT
(
m1
==
m2
);
EXPECT
(
m1
==
m2
);
}
}
TEST_CASE
(
dot_add_multiple_dq_use
)
{
migraphx
::
shape
sh1
{
migraphx
::
shape
::
float_type
,
{
32
,
1
}};
migraphx
::
shape
sh2
{
migraphx
::
shape
::
float_type
,
{
32
,
32
}};
migraphx
::
module
m1
;
{
auto
t1
=
m1
.
add_parameter
(
"t1"
,
sh1
);
auto
t2
=
m1
.
add_parameter
(
"t2"
,
sh2
);
auto
scale
=
m1
.
add_literal
(
0.5
f
);
auto
zero
=
m1
.
add_literal
(
std
::
int8_t
{
0
});
auto
q1
=
add_quantize_op
(
m1
,
"quantizelinear"
,
t1
,
scale
,
zero
);
auto
d1
=
add_quantize_op
(
m1
,
"dequantizelinear"
,
q1
,
scale
,
zero
);
auto
d1_t
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
d1
);
auto
d1_tmb
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
32
,
32
}}}),
d1_t
);
auto
d1_tmbc
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
d1_tmb
);
auto
q2
=
add_quantize_op
(
m1
,
"quantizelinear"
,
t2
,
scale
,
zero
);
auto
d2
=
add_quantize_op
(
m1
,
"dequantizelinear"
,
q2
,
scale
,
zero
);
auto
dot_1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
d1_tmbc
,
d2
);
auto
q3
=
add_quantize_op
(
m1
,
"quantizelinear"
,
dot_1
,
scale
,
zero
);
auto
d3
=
add_quantize_op
(
m1
,
"dequantizelinear"
,
q3
,
scale
,
zero
);
auto
dot_2
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
d3
,
d1
);
auto
add
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
{
dot_2
,
d1
});
m1
.
add_return
({
add
});
}
migraphx
::
module
m2
;
{
auto
t1
=
m2
.
add_parameter
(
"t1"
,
sh1
);
auto
t2
=
m2
.
add_parameter
(
"t2"
,
sh2
);
auto
scale
=
m2
.
add_literal
(
0.5
f
);
auto
zero
=
m2
.
add_literal
(
std
::
int8_t
{
0
});
auto
q1
=
add_quantize_op
(
m2
,
"quantizelinear"
,
t1
,
scale
,
zero
);
auto
q1_t
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
q1
);
auto
q1_tmb
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
32
,
32
}}}),
q1_t
);
auto
q1_tmbc
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
q1_tmb
);
auto
q2
=
add_quantize_op
(
m2
,
"quantizelinear"
,
t2
,
scale
,
zero
);
auto
dot_1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"quant_dot"
),
q1_tmbc
,
q2
);
auto
out_scale
=
add_scale_mul
(
m2
,
scale
,
scale
,
1
,
1
,
dot_1
->
get_shape
().
lens
());
auto
d3
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
dot_1
,
out_scale
);
auto
d3_q
=
add_quantize_op
(
m2
,
"quantizelinear"
,
d3
,
scale
,
zero
);
auto
dot_2
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"quant_dot"
),
d3_q
,
q1
);
auto
out_scale_2
=
add_scale_mul
(
m2
,
scale
,
scale
,
1
,
1
,
dot_2
->
get_shape
().
lens
());
auto
d4
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
dot_2
,
out_scale_2
);
auto
add
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
d4
,
t1
);
m2
.
add_return
({
add
});
}
run_pass
(
m1
);
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
conv
)
TEST_CASE
(
conv
)
{
{
migraphx
::
shape
s4
{
migraphx
::
shape
::
int8_type
,
{
1280
,
320
,
1
,
1
}};
migraphx
::
shape
s4
{
migraphx
::
shape
::
int8_type
,
{
1280
,
320
,
1
,
1
}};
...
@@ -919,7 +975,6 @@ TEST_CASE(mobilenet_snippet)
...
@@ -919,7 +975,6 @@ TEST_CASE(mobilenet_snippet)
auto
mod1
=
create_module
();
auto
mod1
=
create_module
();
auto
mod2
=
create_module
();
auto
mod2
=
create_module
();
run_pass
(
mod2
);
run_pass
(
mod2
);
auto
match_qdq
=
migraphx
::
match
::
name
(
"dequantizelinear"
)(
auto
match_qdq
=
migraphx
::
match
::
name
(
"dequantizelinear"
)(
...
...
test/verify/main.cpp
View file @
68dd3bb4
...
@@ -77,6 +77,5 @@ int main(int argc, const char* argv[])
...
@@ -77,6 +77,5 @@ int main(int argc, const char* argv[])
"test_split_single_dyn_dim"
,
"test_split_single_dyn_dim"
,
"test_instancenorm_large_3d<migraphx::shape::float_type>"
,
"test_instancenorm_large_3d<migraphx::shape::float_type>"
,
"test_instancenorm_large_3d<migraphx::shape::half_type>"
});
"test_instancenorm_large_3d<migraphx::shape::half_type>"
});
rv
.
disable_test_for
(
"gpu"
,
{
"test_conv_bn_add"
});
rv
.
run
(
argc
,
argv
);
rv
.
run
(
argc
,
argv
);
}
}
test/verify/quant_conv.cpp
View file @
68dd3bb4
...
@@ -27,17 +27,21 @@
...
@@ -27,17 +27,21 @@
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
struct
quant_conv
:
verify_program
<
quant_conv
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
quant_conv
:
verify_program
<
quant_conv
<
DType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
int8_t
ype
,
{
2
,
3
,
4
,
4
}};
migraphx
::
shape
a_shape
{
DT
ype
,
{
2
,
3
,
4
,
4
}};
auto
pa
=
mm
->
add_parameter
(
"a"
,
a_shape
);
auto
pa
=
mm
->
add_parameter
(
"a"
,
a_shape
);
migraphx
::
shape
c_shape
{
migraphx
::
shape
::
int8_t
ype
,
{
2
,
3
,
3
,
3
}};
migraphx
::
shape
c_shape
{
DT
ype
,
{
2
,
3
,
3
,
3
}};
auto
pc
=
mm
->
add_parameter
(
"c"
,
c_shape
);
auto
pc
=
mm
->
add_parameter
(
"c"
,
c_shape
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"quant_convolution"
),
pa
,
pc
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"quant_convolution"
),
pa
,
pc
);
return
p
;
return
p
;
}
}
};
};
template
struct
quant_conv
<
migraphx
::
shape
::
int8_type
>;
template
struct
quant_conv
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
test/verify/quant_conv_1.cpp
View file @
68dd3bb4
...
@@ -27,17 +27,21 @@
...
@@ -27,17 +27,21 @@
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/op/quant_convolution.hpp>
#include <migraphx/op/quant_convolution.hpp>
struct
quant_conv_1
:
verify_program
<
quant_conv_1
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
quant_conv_1
:
verify_program
<
quant_conv_1
<
DType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
int8_t
ype
,
{
2
,
3
,
4
,
4
}};
migraphx
::
shape
a_shape
{
DT
ype
,
{
2
,
3
,
4
,
4
}};
auto
pa
=
mm
->
add_parameter
(
"a"
,
a_shape
);
auto
pa
=
mm
->
add_parameter
(
"a"
,
a_shape
);
migraphx
::
shape
c_shape
{
migraphx
::
shape
::
int8_t
ype
,
{
2
,
3
,
3
,
3
}};
migraphx
::
shape
c_shape
{
DT
ype
,
{
2
,
3
,
3
,
3
}};
auto
pc
=
mm
->
add_parameter
(
"c"
,
c_shape
);
auto
pc
=
mm
->
add_parameter
(
"c"
,
c_shape
);
mm
->
add_instruction
(
migraphx
::
op
::
quant_convolution
{{{
0
,
0
}},
{{
1
,
1
}},
{{
1
,
1
}}},
pa
,
pc
);
mm
->
add_instruction
(
migraphx
::
op
::
quant_convolution
{{{
0
,
0
}},
{{
1
,
1
}},
{{
1
,
1
}}},
pa
,
pc
);
return
p
;
return
p
;
}
}
};
};
template
struct
quant_conv_1
<
migraphx
::
shape
::
int8_type
>;
template
struct
quant_conv_1
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
test/verify/quant_conv_1d.cpp
View file @
68dd3bb4
...
@@ -27,15 +27,16 @@
...
@@ -27,15 +27,16 @@
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
struct
quant_conv_1d
:
verify_program
<
quant_conv_1d
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
quant_conv_1d
:
verify_program
<
quant_conv_1d
<
DType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
int8_t
ype
,
{
2
,
3
,
4
}};
migraphx
::
shape
a_shape
{
DT
ype
,
{
2
,
3
,
4
}};
auto
pa
=
mm
->
add_parameter
(
"a"
,
a_shape
);
auto
pa
=
mm
->
add_parameter
(
"a"
,
a_shape
);
migraphx
::
shape
c_shape
{
migraphx
::
shape
::
int8_t
ype
,
{
2
,
3
,
3
}};
migraphx
::
shape
c_shape
{
DT
ype
,
{
2
,
3
,
3
}};
auto
pc
=
mm
->
add_parameter
(
"c"
,
c_shape
);
auto
pc
=
mm
->
add_parameter
(
"c"
,
c_shape
);
mm
->
add_instruction
(
mm
->
add_instruction
(
migraphx
::
make_op
(
"quant_convolution"
,
migraphx
::
make_op
(
"quant_convolution"
,
...
@@ -45,3 +46,7 @@ struct quant_conv_1d : verify_program<quant_conv_1d>
...
@@ -45,3 +46,7 @@ struct quant_conv_1d : verify_program<quant_conv_1d>
return
p
;
return
p
;
}
}
};
};
template
struct
quant_conv_1d
<
migraphx
::
shape
::
int8_type
>;
// MLIR 1D convolution is not supported in MIGraphX yet. Enable this through MIOpen route later.
// template struct quant_conv_1d<migraphx::shape::fp8e4m3fnuz_type>;
test/verify/quant_conv_2.cpp
View file @
68dd3bb4
...
@@ -27,17 +27,21 @@
...
@@ -27,17 +27,21 @@
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/op/quant_convolution.hpp>
#include <migraphx/op/quant_convolution.hpp>
struct
quant_conv_2
:
verify_program
<
quant_conv_2
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
quant_conv_2
:
verify_program
<
quant_conv_2
<
DType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
int8_t
ype
,
{
16
,
16
,
4
,
4
}};
migraphx
::
shape
a_shape
{
DT
ype
,
{
16
,
16
,
4
,
4
}};
auto
pa
=
mm
->
add_parameter
(
"a"
,
a_shape
);
auto
pa
=
mm
->
add_parameter
(
"a"
,
a_shape
);
migraphx
::
shape
c_shape
{
migraphx
::
shape
::
int8_t
ype
,
{
16
,
16
,
3
,
3
}};
migraphx
::
shape
c_shape
{
DT
ype
,
{
16
,
16
,
3
,
3
}};
auto
pc
=
mm
->
add_parameter
(
"c"
,
c_shape
);
auto
pc
=
mm
->
add_parameter
(
"c"
,
c_shape
);
mm
->
add_instruction
(
migraphx
::
op
::
quant_convolution
{{{
0
,
0
}},
{{
1
,
1
}},
{{
1
,
1
}}},
pa
,
pc
);
mm
->
add_instruction
(
migraphx
::
op
::
quant_convolution
{{{
0
,
0
}},
{{
1
,
1
}},
{{
1
,
1
}}},
pa
,
pc
);
return
p
;
return
p
;
}
}
};
};
template
struct
quant_conv_2
<
migraphx
::
shape
::
int8_type
>;
template
struct
quant_conv_2
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment