Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
4cc5393d
Commit
4cc5393d
authored
Dec 07, 2023
by
Paul
Browse files
Merge branch 'develop' into subwave-reduce
parents
f7d97e53
fe61d940
Changes
40
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
281 additions
and
65 deletions
+281
-65
docs/driver/read.rst
docs/driver/read.rst
+4
-0
src/driver/CMakeLists.txt
src/driver/CMakeLists.txt
+1
-0
src/driver/main.cpp
src/driver/main.cpp
+5
-0
src/driver/passes.cpp
src/driver/passes.cpp
+109
-0
src/driver/passes.hpp
src/driver/passes.hpp
+15
-19
src/include/migraphx/op/dequantizelinear.hpp
src/include/migraphx/op/dequantizelinear.hpp
+2
-2
src/include/migraphx/op/quant_convolution.hpp
src/include/migraphx/op/quant_convolution.hpp
+11
-5
src/include/migraphx/op/quantizelinear.hpp
src/include/migraphx/op/quantizelinear.hpp
+4
-4
src/onnx/onnx_parser.cpp
src/onnx/onnx_parser.cpp
+5
-1
src/rewrite_quantization.cpp
src/rewrite_quantization.cpp
+4
-4
src/simplify_qdq.cpp
src/simplify_qdq.cpp
+20
-16
src/targets/gpu/device_name.cpp
src/targets/gpu/device_name.cpp
+6
-0
src/targets/gpu/fuse_mlir.cpp
src/targets/gpu/fuse_mlir.cpp
+19
-7
src/targets/gpu/include/migraphx/gpu/device_name.hpp
src/targets/gpu/include/migraphx/gpu/device_name.hpp
+2
-0
src/targets/gpu/mlir.cpp
src/targets/gpu/mlir.cpp
+2
-0
src/targets/gpu/rocblas.cpp
src/targets/gpu/rocblas.cpp
+1
-2
src/targets/gpu/target.cpp
src/targets/gpu/target.cpp
+8
-0
test/simplify_qdq_test.cpp
test/simplify_qdq_test.cpp
+56
-1
test/verify/main.cpp
test/verify/main.cpp
+0
-1
test/verify/quant_conv.cpp
test/verify/quant_conv.cpp
+7
-3
No files found.
docs/driver/read.rst
View file @
4cc5393d
...
...
@@ -58,6 +58,10 @@ Set the default dynamic dimension (format {min:x, max:y, optimals:[o1,o2,...]})
Optimize when reading
.. option:: --apply-pass, -p
Passes to apply to model
.. option:: --graphviz, -g
Print out a graphviz representation.
...
...
src/driver/CMakeLists.txt
View file @
4cc5393d
...
...
@@ -25,6 +25,7 @@
add_executable
(
driver
main.cpp
verify.cpp
passes.cpp
perf.cpp
resnet50.cpp
inceptionv3.cpp
...
...
src/driver/main.cpp
View file @
4cc5393d
...
...
@@ -26,6 +26,7 @@
#include "argument_parser.hpp"
#include "command.hpp"
#include "precision.hpp"
#include "passes.hpp"
#include "perf.hpp"
#include "models.hpp"
#include "marker_roctx.hpp"
...
...
@@ -83,6 +84,7 @@ struct loader
std
::
vector
<
std
::
string
>
param_dims
;
std
::
vector
<
std
::
string
>
dyn_param_dims
;
std
::
vector
<
std
::
string
>
output_names
;
std
::
vector
<
std
::
string
>
passes
;
void
parse
(
argument_parser
&
ap
)
{
...
...
@@ -130,6 +132,7 @@ struct loader
ap
.
append
(),
ap
.
nargs
(
2
));
ap
(
optimize
,
{
"--optimize"
,
"-O"
},
ap
.
help
(
"Optimize when reading"
),
ap
.
set_value
(
true
));
ap
(
passes
,
{
"--apply-pass"
,
"-p"
},
ap
.
help
(
"Passes to apply to model"
),
ap
.
append
());
ap
(
output_type
,
{
"--graphviz"
,
"-g"
},
ap
.
help
(
"Print out a graphviz representation."
),
...
...
@@ -337,6 +340,8 @@ struct loader
migraphx
::
dead_code_elimination
{},
});
}
if
(
not
passes
.
empty
())
migraphx
::
run_passes
(
*
p
.
get_main_module
(),
get_passes
(
passes
));
return
p
;
}
...
...
src/driver/passes.cpp
0 → 100644
View file @
4cc5393d
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "passes.hpp"
#include <migraphx/auto_contiguous.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_allocation.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/eliminate_concat.hpp>
#include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/eliminate_data_type.hpp>
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/eliminate_pad.hpp>
#include <migraphx/inline_module.hpp>
#include <migraphx/insert_pad.hpp>
#include <migraphx/normalize_ops.hpp>
#include <migraphx/optimize_module.hpp>
#include <migraphx/promote_literals.hpp>
#include <migraphx/propagate_constant.hpp>
#include <migraphx/rewrite_gelu.hpp>
#include <migraphx/rewrite_pooling.hpp>
#include <migraphx/rewrite_quantization.hpp>
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/simplify_dyn_ops.hpp>
#include <migraphx/simplify_qdq.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/ranges.hpp>
#include <unordered_map>
namespace
migraphx
{
namespace
driver
{
inline
namespace
MIGRAPHX_INLINE_NS
{
std
::
unordered_map
<
std
::
string
,
pass
>
create_passes_lookup
()
{
std
::
unordered_map
<
std
::
string
,
pass
>
result
;
// clang-format off
std
::
initializer_list
<
pass
>
passes
=
{
auto_contiguous
{},
dead_code_elimination
{},
eliminate_allocation
{},
eliminate_common_subexpression
{},
eliminate_concat
{},
eliminate_contiguous
{},
eliminate_data_type
{},
eliminate_identity
{},
eliminate_pad
{},
inline_module
{},
insert_pad
{},
normalize_ops
{},
optimize_module
{},
promote_literals
{},
propagate_constant
{},
rewrite_gelu
{},
rewrite_pooling
{},
rewrite_quantization
{},
rewrite_rnn
{},
simplify_algebra
{},
simplify_dyn_ops
{},
simplify_qdq
{},
simplify_reshapes
{},
};
// clang-format on
for
(
const
auto
&
pass
:
passes
)
result
[
pass
.
name
()]
=
pass
;
result
[
"eliminate_dead_code"
]
=
dead_code_elimination
{};
return
result
;
}
std
::
vector
<
pass
>
get_passes
(
const
std
::
vector
<
std
::
string
>&
names
)
{
std
::
vector
<
pass
>
result
;
static
const
std
::
unordered_map
<
std
::
string
,
pass
>
lookup
=
create_passes_lookup
();
std
::
transform
(
names
.
begin
(),
names
.
end
(),
std
::
back_inserter
(
result
),
[](
const
std
::
string
&
name
)
{
if
(
not
contains
(
lookup
,
name
))
MIGRAPHX_THROW
(
"Unknown pass: "
+
name
);
return
lookup
.
at
(
name
);
});
return
result
;
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace driver
}
// namespace migraphx
test/verify/test_conv_relu_half.c
pp
→
src/driver/passes.h
pp
View file @
4cc5393d
...
...
@@ -21,24 +21,20 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_DRIVER_PASSES_HPP
#define MIGRAPHX_GUARD_DRIVER_PASSES_HPP
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/pass.hpp>
#include <vector>
struct
test_conv_relu_half
:
verify_program
<
test_conv_relu_half
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
input
=
mm
->
add_parameter
(
"x"
,
migraphx
::
shape
{
migraphx
::
shape
::
half_type
,
{
4
,
3
,
3
,
3
}});
auto
weights
=
mm
->
add_parameter
(
"w"
,
migraphx
::
shape
{
migraphx
::
shape
::
half_type
,
{
4
,
3
,
3
,
3
}});
auto
conv
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"convolution"
),
input
,
weights
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"relu"
),
conv
);
return
p
;
}
};
namespace
migraphx
{
namespace
driver
{
inline
namespace
MIGRAPHX_INLINE_NS
{
std
::
vector
<
pass
>
get_passes
(
const
std
::
vector
<
std
::
string
>&
names
);
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace driver
}
// namespace migraphx
#endif
src/include/migraphx/op/dequantizelinear.hpp
View file @
4cc5393d
...
...
@@ -72,8 +72,8 @@ struct dequantizelinear
visit_all
(
x
,
x_zero_point
)([
&
](
auto
input
,
auto
zero_pts
)
{
visit_all
(
result
,
x_scale
)([
&
](
auto
output
,
auto
scales
)
{
par_for
(
output_shape
.
elements
(),
[
&
](
auto
i
)
{
output
[
i
]
=
static_cast
<
double
>
(
static_cast
<
int64_t
>
(
input
[
i
])
-
static_cast
<
int64_t
>
(
zero_pts
[
i
]))
*
output
[
i
]
=
static_cast
<
double
>
(
static_cast
<
double
>
(
input
[
i
])
-
static_cast
<
double
>
(
zero_pts
[
i
]))
*
scales
[
i
];
});
});
...
...
src/include/migraphx/op/quant_convolution.hpp
View file @
4cc5393d
...
...
@@ -27,6 +27,7 @@
#include <migraphx/op/common.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/config.hpp>
#include <migraphx/convolution.hpp>
#include <migraphx/value.hpp>
...
...
@@ -87,11 +88,13 @@ struct quant_convolution
}
// all input type must be int8_type and output is float_type
if
(
t
!=
shape
::
int8_type
)
std
::
set
<
migraphx
::
shape
::
type_t
>
supported_types
=
{
shape
::
int8_type
,
shape
::
fp8e4m3fnuz_type
};
if
(
not
contains
(
supported_types
,
t
))
{
MIGRAPHX_THROW
(
"QUANT_CONVOLUTION: only accept input and weights of type int8_t"
);
MIGRAPHX_THROW
(
"QUANT_CONVOLUTION: only accept input and weights of type int8_t or "
"fp8e4m3fnuz_type"
);
}
t
=
shape
::
int32_type
;
std
::
vector
<
size_t
>
output_lens
{
input
.
lens
()[
0
],
weights
.
lens
()[
0
]};
auto
padding_size
=
padding
.
size
();
...
...
@@ -107,8 +110,11 @@ struct quant_convolution
stride
[
i
]
+
1
)));
}
return
inputs
[
0
].
with_lens
(
t
,
output_lens
);
if
(
t
==
shape
::
int8_type
)
{
return
inputs
[
0
].
with_lens
(
shape
::
int32_type
,
output_lens
);
}
// else fp8 conv
return
inputs
[
0
].
with_lens
(
shape
::
float_type
,
output_lens
);
}
size_t
kdims
()
const
...
...
src/include/migraphx/op/quantizelinear.hpp
View file @
4cc5393d
...
...
@@ -80,10 +80,10 @@ struct quantizelinear
auto
min_value
=
std
::
numeric_limits
<
quant_type
>::
min
();
auto
max_value
=
std
::
numeric_limits
<
quant_type
>::
max
();
par_for
(
output_shape
.
elements
(),
[
&
](
auto
i
)
{
int64_t
quantized
=
static_cast
<
int64_t
>
(
std
::
nearbyint
(
input
[
i
]
/
scales
[
i
]))
+
static_cast
<
int64_t
>
(
zero_pts
[
i
]);
output
[
i
]
=
std
::
max
(
static_cast
<
int64_t
>
(
min_value
),
std
::
min
(
static_cast
<
int64_t
>
(
max_value
),
quantized
));
double
quantized
=
static_cast
<
double
>
(
std
::
nearbyint
(
input
[
i
]
/
scales
[
i
]))
+
static_cast
<
double
>
(
zero_pts
[
i
]);
output
[
i
]
=
std
::
max
(
static_cast
<
double
>
(
min_value
),
std
::
min
(
static_cast
<
double
>
(
max_value
),
quantized
));
});
});
});
...
...
src/onnx/onnx_parser.cpp
View file @
4cc5393d
...
...
@@ -625,7 +625,11 @@ shape::type_t get_type(int dtype)
case
11
:
return
shape
::
double_type
;
case
12
:
return
shape
::
uint32_type
;
case
13
:
return
shape
::
uint64_type
;
case
18
:
return
shape
::
fp8e4m3fnuz_type
;
case
18
:
{
std
::
cout
<<
"[Warning] : MIGraphX has BETA support for FP8. Using FP8 may result in "
"incorrect final outputs
\n
"
;
return
shape
::
fp8e4m3fnuz_type
;
}
case
14
:
case
15
:
case
16
:
...
...
src/rewrite_quantization.cpp
View file @
4cc5393d
...
...
@@ -58,8 +58,8 @@ void apply_quantizelinear(module& m, instruction_ref ins)
add_zero_point
=
m
.
insert_instruction
(
ins
,
make_op
(
"add"
),
add_zero_point
,
zero_point
);
}
int64_t
max_quant
=
0
;
int64_t
min_quant
=
0
;
double
max_quant
=
0
;
double
min_quant
=
0
;
ins
->
get_shape
().
visit_type
([
&
](
auto
qt
)
{
max_quant
=
qt
.
max
();
min_quant
=
qt
.
min
();
...
...
@@ -70,8 +70,8 @@ void apply_quantizelinear(module& m, instruction_ref ins)
if
(
enabled
(
MIGRAPHX_ENABLE_CK_WORKAROUNDS
{}))
{
std
::
vector
<
int
>
min_data
(
s
.
elements
(),
min_quant
);
std
::
vector
<
int
>
max_data
(
s
.
elements
(),
max_quant
);
std
::
vector
<
double
>
min_data
(
s
.
elements
(),
min_quant
);
std
::
vector
<
double
>
max_data
(
s
.
elements
(),
max_quant
);
min_arg
=
m
.
add_literal
(
literal
(
s
,
min_data
));
max_arg
=
m
.
add_literal
(
literal
(
s
,
max_data
));
}
...
...
src/simplify_qdq.cpp
View file @
4cc5393d
...
...
@@ -82,18 +82,21 @@ struct match_find_quantizable_ops
// Helper function to insert quantized versions of any broadcasts and transpose ops that
// occur between dequantizelinear and the quantized op
static
auto
propagate_quantized_ins
(
module
&
m
,
const
instruction_ref
dqins
,
const
instruction_ref
qop
)
propagate_quantized_ins
(
module
&
m
,
const
instruction_ref
dqins
,
const
instruction_ref
qop
_arg
)
{
auto
qinp
=
dqins
->
inputs
().
front
();
auto
next_ins
=
dqins
;
while
(
next_ins
!=
qop
)
auto
prev_ins
=
qop_arg
;
std
::
vector
<
instruction_ref
>
ins_inbetween
;
// matcher skips continguous, multi/broadcasts and transposes, collect all those
// instructions
while
(
prev_ins
!=
dqins
)
{
if
(
next_ins
->
name
()
!=
"dequantizelinear"
)
{
qinp
=
m
.
insert_instruction
(
qop
,
next_ins
->
get_operator
(),
qinp
);
}
next_ins
=
next_ins
->
outputs
().
front
();
ins_inbetween
.
push_back
(
prev_ins
);
prev_ins
=
prev_ins
->
inputs
().
front
();
}
auto
qinp
=
dqins
->
inputs
().
front
();
for
(
auto
ins
:
reverse_iterator_for
(
ins_inbetween
))
{
qinp
=
m
.
insert_instruction
(
dqins
,
(
*
ins
)
->
get_operator
(),
{
qinp
});
}
return
qinp
;
}
...
...
@@ -124,10 +127,11 @@ struct match_find_quantizable_ops
auto
scale2
=
r
.
instructions
[
"scale2"
];
auto
zp1
=
r
.
instructions
[
"zp1"
];
auto
zp2
=
r
.
instructions
[
"zp2"
];
// Only INT8 type currently supported
if
(
dq1
->
inputs
().
front
()
->
get_shape
().
type
()
!=
migraphx
::
shape
::
int8_type
or
dq2
->
inputs
().
front
()
->
get_shape
().
type
()
!=
migraphx
::
shape
::
int8_type
)
// Only INT8 or FP8 type currently supported
std
::
set
<
migraphx
::
shape
::
type_t
>
supported_types
=
{
migraphx
::
shape
::
fp8e4m3fnuz_type
,
migraphx
::
shape
::
int8_type
};
if
(
not
contains
(
supported_types
,
dq1
->
inputs
().
front
()
->
get_shape
().
type
())
or
not
contains
(
supported_types
,
dq2
->
inputs
().
front
()
->
get_shape
().
type
()))
return
;
// Only symmetric quantization supported (ie. non-zero zero_points not allowed)
...
...
@@ -140,8 +144,8 @@ struct match_find_quantizable_ops
// Propagate q1 and q2 through any broadcasts and transposes before qop
auto
qop_args
=
qop
->
inputs
();
qop_args
.
at
(
0
)
=
propagate_quantized_ins
(
m
,
dq1
,
qop
);
qop_args
.
at
(
1
)
=
propagate_quantized_ins
(
m
,
dq2
,
qop
);
qop_args
.
at
(
0
)
=
propagate_quantized_ins
(
m
,
dq1
,
qop
_args
[
0
]
);
qop_args
.
at
(
1
)
=
propagate_quantized_ins
(
m
,
dq2
,
qop
_args
[
1
]
);
instruction_ref
dq
;
instruction_ref
out_scale
;
instruction_ref
zero_point
;
...
...
src/targets/gpu/device_name.cpp
View file @
4cc5393d
...
...
@@ -49,6 +49,12 @@ std::string get_device_name()
return
props
.
gcnArchName
;
}
bool
gfx_has_fp8_intrinsics
()
{
const
auto
device_name
=
trim
(
split_string
(
get_device_name
(),
':'
).
front
());
return
(
starts_with
(
device_name
,
"gfx9"
)
and
device_name
>=
"gfx940"
);
}
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/fuse_mlir.cpp
View file @
4cc5393d
...
...
@@ -218,6 +218,7 @@ auto is_mlir_conv(mlir_mode mode)
return
false
;
if
(
ins
->
name
()
!=
"convolution"
and
ins
->
name
()
!=
"quant_convolution"
)
return
false
;
auto
input_arg_t
=
ins
->
inputs
().
front
()
->
get_shape
().
type
();
value
v
=
ins
->
get_operator
().
to_value
();
auto
group
=
v
.
at
(
"group"
).
to
<
int
>
();
if
(
group
!=
1
)
...
...
@@ -225,6 +226,10 @@ auto is_mlir_conv(mlir_mode mode)
// Avoid MLIR assertion: Index < Length && "Invalid index!"
if
(
ins
->
get_shape
().
lens
().
size
()
!=
4
)
return
false
;
if
(
ins
->
get_shape
().
type
()
==
shape
::
fp8e4m3fnuz_type
)
return
true
;
if
(
ins
->
get_shape
().
type
()
==
shape
::
float_type
and
input_arg_t
==
shape
::
fp8e4m3fnuz_type
)
return
true
;
if
(
ins
->
get_shape
().
type
()
==
shape
::
int8_type
)
return
true
;
if
(
mode
==
mlir_mode
::
int8
)
...
...
@@ -292,6 +297,7 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i)
const
auto
result_type
=
i
.
get_shape
().
type
();
const
std
::
initializer_list
<
type_t
>
allowed_types
=
{
type_t
::
float_type
,
type_t
::
half_type
,
type_t
::
fp8e4m3fnuz_type
,
type_t
::
int8_type
,
type_t
::
int32_type
,
type_t
::
bool_type
};
...
...
@@ -331,7 +337,8 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i)
"softmax"
,
"tanh"
,
};
bool
is_float
=
contains
({
type_t
::
float_type
,
type_t
::
half_type
},
result_type
);
bool
is_float
=
contains
({
type_t
::
float_type
,
type_t
::
half_type
,
type_t
::
fp8e4m3fnuz_type
},
result_type
);
if
(
contains
(
any_type_ops
,
name
))
return
true
;
if
(
result_type
!=
type_t
::
bool_type
and
contains
(
no_bool_ops
,
name
))
...
...
@@ -342,6 +349,10 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i)
// supported.
if
(
is_float
and
name
==
"convert"
)
{
if
(
result_type
==
shape
::
fp8e4m3fnuz_type
)
{
return
false
;
}
// else
return
std
::
all_of
(
i
.
inputs
().
begin
(),
i
.
inputs
().
end
(),
[](
const
auto
&
arg
)
{
return
contains
({
type_t
::
float_type
,
type_t
::
half_type
},
arg
->
get_shape
().
type
());
});
...
...
@@ -404,12 +415,13 @@ struct find_mlir_standalone_op
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
auto
gemm_based_op
=
r
.
result
;
//
// enable only for fp32/fp16/i8 types
// enable only for fp32/fp16/i8/fp8 types
if
(
std
::
any_of
(
gemm_based_op
->
inputs
().
begin
(),
gemm_based_op
->
inputs
().
end
(),
[
&
](
auto
i
)
{
return
not
contains
(
{
shape
::
type_t
::
float_type
,
shape
::
type_t
::
half_type
,
shape
::
type_t
::
int8_type
},
i
->
get_shape
().
type
());
return
not
contains
({
shape
::
type_t
::
float_type
,
shape
::
type_t
::
half_type
,
shape
::
type_t
::
int8_type
,
shape
::
type_t
::
fp8e4m3fnuz_type
},
i
->
get_shape
().
type
());
}))
return
;
static
size_t
counter
=
0
;
...
...
@@ -531,7 +543,7 @@ void fuse_mlir::apply(module_pass_manager& mpm) const
match
::
find_matches
(
mpm
,
find_mlir_standalone_convolution_op
{
get_mode
(
"convolution"
,
mlir_mode
::
int8
)},
find_mlir_standalone_convolution_op
{
get_mode
(
"convolution"
,
mlir_mode
::
fast
)},
find_mlir_standalone_dot_op
{
get_mode
(
"dot"
,
mlir_mode
::
none
)});
#else
(
void
)
mpm
;
...
...
src/targets/gpu/include/migraphx/gpu/device_name.hpp
View file @
4cc5393d
...
...
@@ -37,6 +37,8 @@ MIGRAPHX_GPU_EXPORT std::string get_device_name();
MIGRAPHX_GPU_EXPORT
int
get_device_id
();
MIGRAPHX_GPU_EXPORT
bool
gfx_has_fp8_intrinsics
();
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
...
...
src/targets/gpu/mlir.cpp
View file @
4cc5393d
...
...
@@ -300,6 +300,8 @@ struct mlir_program
result
=
mlirF32TypeGet
(
ctx
.
get
());
else
if
(
as
.
type_enum
()
==
shape
::
half_type
)
result
=
mlirF16TypeGet
(
ctx
.
get
());
else
if
(
as
.
type_enum
()
==
shape
::
fp8e4m3fnuz_type
)
result
=
mlirFloat8E4M3FNUZTypeGet
(
ctx
.
get
());
else
if
(
as
.
type_enum
()
==
shape
::
double_type
)
result
=
mlirF64TypeGet
(
ctx
.
get
());
else
if
(
as
.
is_integral
())
...
...
src/targets/gpu/rocblas.cpp
View file @
4cc5393d
...
...
@@ -58,8 +58,7 @@ bool rocblas_fp8_available()
#ifndef MIGRAPHX_USE_ROCBLAS_FP8_API
return
false
;
#else
const
auto
device_name
=
trim
(
split_string
(
get_device_name
(),
':'
).
front
());
return
(
starts_with
(
device_name
,
"gfx9"
)
and
device_name
>=
"gfx940"
);
return
gfx_has_fp8_intrinsics
();
#endif
}
...
...
src/targets/gpu/target.cpp
View file @
4cc5393d
...
...
@@ -105,11 +105,19 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
unsupported_types
.
erase
(
shape
::
type_t
::
uint8_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
int32_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
tuple_type
);
// whiltelist supported Ops for the FP8
std
::
set
<
std
::
string
>
unsupported_fp8_ops
=
{};
if
(
not
gpu
::
rocblas_fp8_available
())
{
unsupported_fp8_ops
.
insert
(
"dot"
);
}
// MIOpen doesn't have support for fp8 pooling yet.
unsupported_fp8_ops
.
insert
(
"pooling"
);
if
(
not
gpu
::
gfx_has_fp8_intrinsics
())
{
unsupported_fp8_ops
.
insert
(
"convolution"
);
unsupported_fp8_ops
.
insert
(
"quant_convolution"
);
}
// add all device kernels
unsupported_fp8_ops
.
insert
(
"logsoftmax"
);
unsupported_fp8_ops
.
insert
(
"nonzero"
);
...
...
test/simplify_qdq_test.cpp
View file @
4cc5393d
...
...
@@ -527,6 +527,62 @@ TEST_CASE(dot_add)
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
dot_add_multiple_dq_use
)
{
migraphx
::
shape
sh1
{
migraphx
::
shape
::
float_type
,
{
32
,
1
}};
migraphx
::
shape
sh2
{
migraphx
::
shape
::
float_type
,
{
32
,
32
}};
migraphx
::
module
m1
;
{
auto
t1
=
m1
.
add_parameter
(
"t1"
,
sh1
);
auto
t2
=
m1
.
add_parameter
(
"t2"
,
sh2
);
auto
scale
=
m1
.
add_literal
(
0.5
f
);
auto
zero
=
m1
.
add_literal
(
std
::
int8_t
{
0
});
auto
q1
=
add_quantize_op
(
m1
,
"quantizelinear"
,
t1
,
scale
,
zero
);
auto
d1
=
add_quantize_op
(
m1
,
"dequantizelinear"
,
q1
,
scale
,
zero
);
auto
d1_t
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
d1
);
auto
d1_tmb
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
32
,
32
}}}),
d1_t
);
auto
d1_tmbc
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
d1_tmb
);
auto
q2
=
add_quantize_op
(
m1
,
"quantizelinear"
,
t2
,
scale
,
zero
);
auto
d2
=
add_quantize_op
(
m1
,
"dequantizelinear"
,
q2
,
scale
,
zero
);
auto
dot_1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
d1_tmbc
,
d2
);
auto
q3
=
add_quantize_op
(
m1
,
"quantizelinear"
,
dot_1
,
scale
,
zero
);
auto
d3
=
add_quantize_op
(
m1
,
"dequantizelinear"
,
q3
,
scale
,
zero
);
auto
dot_2
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
d3
,
d1
);
auto
add
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
{
dot_2
,
d1
});
m1
.
add_return
({
add
});
}
migraphx
::
module
m2
;
{
auto
t1
=
m2
.
add_parameter
(
"t1"
,
sh1
);
auto
t2
=
m2
.
add_parameter
(
"t2"
,
sh2
);
auto
scale
=
m2
.
add_literal
(
0.5
f
);
auto
zero
=
m2
.
add_literal
(
std
::
int8_t
{
0
});
auto
q1
=
add_quantize_op
(
m2
,
"quantizelinear"
,
t1
,
scale
,
zero
);
auto
q1_t
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
q1
);
auto
q1_tmb
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
32
,
32
}}}),
q1_t
);
auto
q1_tmbc
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
q1_tmb
);
auto
q2
=
add_quantize_op
(
m2
,
"quantizelinear"
,
t2
,
scale
,
zero
);
auto
dot_1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"quant_dot"
),
q1_tmbc
,
q2
);
auto
out_scale
=
add_scale_mul
(
m2
,
scale
,
scale
,
1
,
1
,
dot_1
->
get_shape
().
lens
());
auto
d3
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
dot_1
,
out_scale
);
auto
d3_q
=
add_quantize_op
(
m2
,
"quantizelinear"
,
d3
,
scale
,
zero
);
auto
dot_2
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"quant_dot"
),
d3_q
,
q1
);
auto
out_scale_2
=
add_scale_mul
(
m2
,
scale
,
scale
,
1
,
1
,
dot_2
->
get_shape
().
lens
());
auto
d4
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
dot_2
,
out_scale_2
);
auto
add
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
d4
,
t1
);
m2
.
add_return
({
add
});
}
run_pass
(
m1
);
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
conv
)
{
migraphx
::
shape
s4
{
migraphx
::
shape
::
int8_type
,
{
1280
,
320
,
1
,
1
}};
...
...
@@ -919,7 +975,6 @@ TEST_CASE(mobilenet_snippet)
auto
mod1
=
create_module
();
auto
mod2
=
create_module
();
run_pass
(
mod2
);
auto
match_qdq
=
migraphx
::
match
::
name
(
"dequantizelinear"
)(
...
...
test/verify/main.cpp
View file @
4cc5393d
...
...
@@ -77,6 +77,5 @@ int main(int argc, const char* argv[])
"test_split_single_dyn_dim"
,
"test_instancenorm_large_3d<migraphx::shape::float_type>"
,
"test_instancenorm_large_3d<migraphx::shape::half_type>"
});
rv
.
disable_test_for
(
"gpu"
,
{
"test_conv_bn_add"
});
rv
.
run
(
argc
,
argv
);
}
test/verify/quant_conv.cpp
View file @
4cc5393d
...
...
@@ -27,17 +27,21 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
quant_conv
:
verify_program
<
quant_conv
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
quant_conv
:
verify_program
<
quant_conv
<
DType
>>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
int8_t
ype
,
{
2
,
3
,
4
,
4
}};
migraphx
::
shape
a_shape
{
DT
ype
,
{
2
,
3
,
4
,
4
}};
auto
pa
=
mm
->
add_parameter
(
"a"
,
a_shape
);
migraphx
::
shape
c_shape
{
migraphx
::
shape
::
int8_t
ype
,
{
2
,
3
,
3
,
3
}};
migraphx
::
shape
c_shape
{
DT
ype
,
{
2
,
3
,
3
,
3
}};
auto
pc
=
mm
->
add_parameter
(
"c"
,
c_shape
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"quant_convolution"
),
pa
,
pc
);
return
p
;
}
};
template
struct
quant_conv
<
migraphx
::
shape
::
int8_type
>;
template
struct
quant_conv
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment