Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
8d32c6b8
Commit
8d32c6b8
authored
Oct 17, 2023
by
Paul
Browse files
Merge branch 'develop' into blas_tuning
parents
23cb7917
f25606f9
Changes
386
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
667 additions
and
89 deletions
+667
-89
src/include/migraphx/source_location.hpp
src/include/migraphx/source_location.hpp
+1
-0
src/include/migraphx/stringutils.hpp
src/include/migraphx/stringutils.hpp
+5
-1
src/include/migraphx/tmp_dir.hpp
src/include/migraphx/tmp_dir.hpp
+1
-0
src/include/migraphx/type_name.hpp
src/include/migraphx/type_name.hpp
+1
-1
src/include/migraphx/verify.hpp
src/include/migraphx/verify.hpp
+96
-8
src/include/migraphx/verify_args.hpp
src/include/migraphx/verify_args.hpp
+9
-5
src/load_save.cpp
src/load_save.cpp
+21
-0
src/memory_coloring.cpp
src/memory_coloring.cpp
+5
-3
src/msgpack.cpp
src/msgpack.cpp
+56
-8
src/normalize_attributes.cpp
src/normalize_attributes.cpp
+42
-14
src/onnx/broadcast_qdq.cpp
src/onnx/broadcast_qdq.cpp
+76
-0
src/onnx/include/migraphx/onnx/broadcast_qdq.hpp
src/onnx/include/migraphx/onnx/broadcast_qdq.hpp
+56
-0
src/onnx/onnx_parser.cpp
src/onnx/onnx_parser.cpp
+1
-1
src/onnx/padding.cpp
src/onnx/padding.cpp
+1
-1
src/onnx/parse_castlike.cpp
src/onnx/parse_castlike.cpp
+53
-0
src/onnx/parse_constant.cpp
src/onnx/parse_constant.cpp
+26
-3
src/onnx/parse_constant_of_shape.cpp
src/onnx/parse_constant_of_shape.cpp
+28
-19
src/onnx/parse_depthtospace.cpp
src/onnx/parse_depthtospace.cpp
+1
-2
src/onnx/parse_pooling.cpp
src/onnx/parse_pooling.cpp
+34
-23
src/onnx/parse_qlinearadd.cpp
src/onnx/parse_qlinearadd.cpp
+154
-0
No files found.
src/include/migraphx/source_location.hpp
View file @
8d32c6b8
...
...
@@ -24,6 +24,7 @@
#ifndef MIGRAPHX_GUARD_MIGRAPHX_SOURCE_LOCATION_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_SOURCE_LOCATION_HPP
#include <cstdint>
#include <migraphx/config.hpp>
#if defined(CPPCHECK)
...
...
src/include/migraphx/stringutils.hpp
View file @
8d32c6b8
...
...
@@ -86,7 +86,7 @@ inline std::string join_strings(Strings strings, const std::string& delim)
inline
std
::
vector
<
std
::
string
>
split_string
(
const
std
::
string
&
s
,
char
delim
)
{
std
::
vector
<
std
::
string
>
elems
;
std
::
stringstream
ss
(
s
+
' '
);
std
::
stringstream
ss
(
s
+
delim
);
std
::
string
item
;
while
(
std
::
getline
(
ss
,
item
,
delim
))
{
...
...
@@ -149,6 +149,10 @@ interpolate_string(const std::string& input, F f, std::string start = "${", std:
result
.
append
(
it
,
next_start
);
if
(
next_start
==
input
.
end
())
break
;
if
(
next_end
==
input
.
end
())
{
throw
std
::
runtime_error
(
"Unbalanced brackets"
);
}
auto
r
=
f
(
next_start
+
start
.
size
(),
next_end
);
result
.
append
(
r
.
begin
(),
r
.
end
());
it
=
next_end
+
end
.
size
();
...
...
src/include/migraphx/tmp_dir.hpp
View file @
8d32c6b8
...
...
@@ -34,6 +34,7 @@ struct MIGRAPHX_EXPORT tmp_dir
{
fs
::
path
path
;
tmp_dir
(
const
std
::
string
&
prefix
=
""
);
tmp_dir
(
tmp_dir
&&
)
=
default
;
void
execute
(
const
std
::
string
&
exe
,
const
std
::
string
&
args
)
const
;
...
...
src/include/migraphx/type_name.hpp
View file @
8d32c6b8
...
...
@@ -34,7 +34,7 @@ template <class PrivateMigraphTypeNameProbe>
std
::
string
compute_type_name
()
{
std
::
string
name
;
#ifdef
_MSC_VER
#if
def
ined(
_MSC_VER
) && !defined(__clang__)
name
=
typeid
(
PrivateMigraphTypeNameProbe
).
name
();
name
=
name
.
substr
(
7
);
#else
...
...
src/include/migraphx/verify.hpp
View file @
8d32c6b8
...
...
@@ -29,10 +29,13 @@
#include <functional>
#include <iostream>
#include <numeric>
#include <assert.h>
#include <migraphx/float_equal.hpp>
#include <migraphx/config.hpp>
#include <migraphx/env.hpp>
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_VERIFY_ENABLE_ALLCLOSE
)
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
verify
{
...
...
@@ -87,8 +90,7 @@ struct not_finite_fn
template
<
class
T
>
bool
operator
()(
T
x
)
const
{
using
std
::
isfinite
;
return
not
isfinite
(
x
);
return
not
std
::
isfinite
(
static_cast
<
double
>
(
x
));
}
};
static
constexpr
not_finite_fn
not_finite
{};
...
...
@@ -98,8 +100,7 @@ struct compare_mag_fn
template
<
class
T
,
class
U
>
bool
operator
()(
T
x
,
U
y
)
const
{
using
std
::
fabs
;
return
fabs
(
x
)
<
fabs
(
y
);
return
std
::
fabs
(
x
)
<
std
::
fabs
(
y
);
}
};
static
constexpr
compare_mag_fn
compare_mag
{};
...
...
@@ -187,16 +188,103 @@ double rms_range(const R1& r1, const R2& r2)
return
std
::
numeric_limits
<
range_value
<
R1
>>::
max
();
}
template
<
class
R
>
double
get_rms_tol
(
const
R
&
,
std
::
size_t
tolerance
=
80
)
{
double
threshold
=
std
::
numeric_limits
<
range_value
<
R
>>::
epsilon
()
*
tolerance
;
return
threshold
;
}
/*
C++ doesn't support named arguments, this is just wrapper that helps distinguish between actual
results v/s expected results arguments.
*/
template
<
class
T
>
struct
expected
{
expected
()
=
default
;
explicit
expected
(
const
T
&
input
)
:
x
(
&
input
)
{}
const
T
&
data
()
const
{
assert
(
x
!=
nullptr
);
return
*
x
;
}
private:
const
T
*
x
=
nullptr
;
};
// deduction guide for templated expected class
template
<
class
T
>
expected
(
const
T
&
)
->
expected
<
T
>
;
struct
tolerance
{
double
rms_tol
=
0.001
;
double
atol
=
0.001
;
double
rtol
=
0.001
;
};
/*
MIGraphX implementation of numpy's np.allclose() which checks if elementwise absolute diff is within
tolerance using this formula: abs(a - b) < atol + rtol(abs(b))
*/
template
<
class
R1
,
class
R2
>
bool
allclose
(
const
R1
&
r1
,
const
R2
&
r2
,
tolerance
tols
)
{
std
::
size_t
n
=
range_distance
(
r1
);
if
(
n
==
range_distance
(
r2
))
{
auto
idx
=
mismatch_idx
(
r1
,
r2
,
[
&
](
auto
x
,
auto
y
)
{
return
abs_diff
(
double
(
x
),
double
(
y
))
<
tols
.
atol
+
tols
.
rtol
*
std
::
abs
(
double
(
y
));
});
return
idx
>=
range_distance
(
r1
);
}
return
false
;
}
template
<
class
R1
,
class
R2
>
bool
verify_range
(
const
R1
&
r1
,
const
R2
&
r2
,
double
tolerance
=
80
,
double
*
out_error
=
nullptr
)
bool
verify_rms_range
(
const
R1
&
r1
,
const
R2
&
r2
,
std
::
size_t
tolerance
=
80
,
double
*
out_rms_error
=
nullptr
)
{
double
threshold
=
std
::
numeric_limits
<
range_value
<
R1
>>::
epsilon
()
*
tolerance
;
double
threshold
=
get_rms_tol
(
r1
,
tolerance
)
;
auto
error
=
rms_range
(
r1
,
r2
);
if
(
out_error
!=
nullptr
)
*
out_error
=
error
;
if
(
out_
rms_
error
!=
nullptr
)
*
out_
rms_
error
=
error
;
return
error
<=
threshold
;
}
template
<
class
R1
,
class
R2
>
bool
verify_range_with_tolerance
(
const
R1
&
r1
,
const
expected
<
R2
>&
r2
,
tolerance
tols
=
tolerance
{},
double
*
out_rms_error
=
nullptr
)
{
auto
rms_error
=
rms_range
(
r1
,
r2
.
data
());
// disable ewise_verify by default for now, it requires lot of tests to be fixed
bool
ewise_verify
=
true
;
if
(
enabled
(
MIGRAPHX_VERIFY_ENABLE_ALLCLOSE
{}))
{
ewise_verify
=
allclose
(
r1
,
r2
.
data
(),
tols
);
}
if
(
out_rms_error
!=
nullptr
)
*
out_rms_error
=
rms_error
;
return
rms_error
<=
tols
.
rms_tol
and
ewise_verify
;
}
// expected argument should be passed as second, but if it is passed as the first by mistake then
// flip the order
template
<
class
R1
,
class
R2
>
bool
verify_range_with_tolerance
(
const
expected
<
R1
>&
r1
,
const
R2
&
r2
,
tolerance
tols
=
tolerance
{},
double
*
out_rms_error
=
nullptr
)
{
return
verify_rms_range
(
r2
,
r1
,
tols
,
out_rms_error
);
}
}
// namespace verify
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
...
...
src/include/migraphx/verify_args.hpp
View file @
8d32c6b8
...
...
@@ -31,11 +31,15 @@
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
MIGRAPHX_EXPORT
bool
verify_args
(
const
std
::
string
&
name
,
const
argument
&
ref_arg
,
const
argument
&
target_arg
,
double
tolerance
=
80
);
MIGRAPHX_EXPORT
bool
verify_args
(
const
std
::
string
&
name
,
const
argument
&
target_arg
,
const
verify
::
expected
<
argument
>&
ref_arg
,
verify
::
tolerance
);
MIGRAPHX_EXPORT
bool
verify_args_with_tolerance
(
const
std
::
string
&
name
,
const
argument
&
target_arg
,
const
verify
::
expected
<
argument
>&
ref_arg
,
std
::
size_t
tolerance
=
80
);
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
...
...
src/load_save.cpp
View file @
8d32c6b8
...
...
@@ -21,6 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/instruction.hpp>
#include <migraphx/load_save.hpp>
#include <migraphx/file_buffer.hpp>
#include <migraphx/json.hpp>
...
...
@@ -60,9 +61,29 @@ void save(const program& p, const std::string& filename, const file_options& opt
{
write_buffer
(
filename
,
save_buffer
(
p
,
options
));
}
// MIOpen doesn't support serializing fusion plans with Find-2.0 APIs
void
print_miopen_warning
(
const
program
&
p
)
{
auto
mods
=
p
.
get_modules
();
if
(
std
::
any_of
(
mods
.
begin
(),
mods
.
end
(),
[](
const
auto
*
m
)
{
return
std
::
any_of
(
m
->
begin
(),
m
->
end
(),
[](
const
instruction
&
i
)
{
return
i
.
name
()
==
"gpu::miopen_fusion"
;
});
}))
{
std
::
cout
<<
"[WARNING]: Program has miopen_fusion instructions for which tuned solutions "
"are not stored inside serialized MIGraphX program. Consider serializing with "
"MIGRAPHX_DISABLE_MIOPEN_FUSION=1 flag set."
<<
std
::
endl
;
;
}
}
std
::
vector
<
char
>
save_buffer
(
const
program
&
p
,
const
file_options
&
options
)
{
value
v
=
p
.
to_value
();
print_miopen_warning
(
p
);
std
::
vector
<
char
>
buffer
;
if
(
options
.
format
==
"msgpack"
)
{
...
...
src/memory_coloring.cpp
View file @
8d32c6b8
...
...
@@ -23,9 +23,9 @@
*/
#include <migraphx/memory_coloring.hpp>
#include <migraphx/module.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/algorithm.hpp>
#include <migraphx/ranges.hpp>
...
...
@@ -382,7 +382,8 @@ void memory_coloring::apply(module& m) const
auto
s
=
ins
->
get_shape
();
std
::
size_t
offset
=
seg
.
first
*
alignment
;
assert
(
offset
<
n
);
m
.
replace_instruction
(
ins
,
op
::
load
{
s
,
offset
},
mem
);
m
.
replace_instruction
(
ins
,
make_op
(
"load"
,
{{
"shape"
,
to_value
(
s
)},
{
"offset"
,
offset
}}),
mem
);
}
// Replace zero allocation
...
...
@@ -391,7 +392,8 @@ void memory_coloring::apply(module& m) const
if
(
ins
->
name
()
!=
allocation_op
)
continue
;
assert
(
ins
->
get_shape
().
bytes
()
==
0
);
m
.
replace_instruction
(
ins
,
op
::
load
{
ins
->
get_shape
(),
0
},
mem
);
m
.
replace_instruction
(
ins
,
make_op
(
"load"
,
{{
"shape"
,
to_value
(
ins
->
get_shape
())},
{
"offset"
,
0
}}),
mem
);
}
// Remove scratch parameter if its not used
...
...
src/msgpack.cpp
View file @
8d32c6b8
...
...
@@ -25,6 +25,33 @@
#include <migraphx/serialize.hpp>
#include <msgpack.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
// Leave an extra byte for error checking
constexpr
std
::
size_t
msgpack_size_limit
=
std
::
numeric_limits
<
uint32_t
>::
max
()
-
1
;
template
<
class
Range
>
std
::
size_t
msgpack_chunk_size
(
const
Range
&
r
)
{
return
1
+
(
r
.
size
()
-
1
)
/
msgpack_size_limit
;
}
template
<
class
Iterator
,
class
F
>
void
msgpack_chunk_for_each
(
Iterator
start
,
Iterator
last
,
F
f
)
{
while
(
std
::
distance
(
start
,
last
)
>
msgpack_size_limit
)
{
auto
next
=
std
::
next
(
start
,
msgpack_size_limit
);
f
(
start
,
next
);
start
=
next
;
}
f
(
start
,
last
);
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
namespace
msgpack
{
MSGPACK_API_VERSION_NAMESPACE
(
MSGPACK_DEFAULT_API_NS
)
{
...
...
@@ -63,16 +90,31 @@ MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS)
break
;
}
case
msgpack
::
type
::
BIN
:
{
// For backwards compatibility
v
=
migraphx
::
value
::
binary
{
o
.
via
.
bin
.
ptr
,
o
.
via
.
bin
.
size
};
break
;
}
case
msgpack
::
type
::
ARRAY
:
{
migraphx
::
value
r
=
migraphx
::
value
::
array
{};
std
::
for_each
(
o
.
via
.
array
.
ptr
,
o
.
via
.
array
.
ptr
+
o
.
via
.
array
.
size
,
[
&
](
const
msgpack
::
object
&
so
)
{
r
.
push_back
(
so
.
as
<
migraphx
::
value
>
());
});
v
=
r
;
if
(
o
.
via
.
array
.
size
!=
0
and
o
.
via
.
array
.
ptr
->
type
==
msgpack
::
type
::
BIN
)
{
auto
bin
=
migraphx
::
value
::
binary
{};
std
::
for_each
(
o
.
via
.
array
.
ptr
,
o
.
via
.
array
.
ptr
+
o
.
via
.
array
.
size
,
[
&
](
const
msgpack
::
object
&
so
)
{
bin
.
insert
(
bin
.
end
(),
so
.
via
.
bin
.
ptr
,
so
.
via
.
bin
.
ptr
+
so
.
via
.
bin
.
size
);
});
v
=
bin
;
}
else
{
migraphx
::
value
r
=
migraphx
::
value
::
array
{};
std
::
for_each
(
o
.
via
.
array
.
ptr
,
o
.
via
.
array
.
ptr
+
o
.
via
.
array
.
size
,
[
&
](
const
msgpack
::
object
&
so
)
{
r
.
push_back
(
so
.
as
<
migraphx
::
value
>
());
});
v
=
r
;
}
break
;
}
case
msgpack
::
type
::
MAP
:
{
...
...
@@ -102,8 +144,12 @@ MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS)
{
const
auto
*
data
=
reinterpret_cast
<
const
char
*>
(
x
.
data
());
auto
size
=
x
.
size
();
o
.
pack_bin
(
size
);
o
.
pack_bin_body
(
data
,
size
);
o
.
pack_array
(
migraphx
::
msgpack_chunk_size
(
x
));
migraphx
::
msgpack_chunk_for_each
(
data
,
data
+
size
,
[
&
](
const
char
*
start
,
const
char
*
last
)
{
o
.
pack_bin
(
last
-
start
);
o
.
pack_bin_body
(
start
,
last
-
start
);
});
return
o
;
}
};
...
...
@@ -129,6 +175,8 @@ MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS)
o
.
pack_array
(
0
);
return
;
}
if
(
v
.
size
()
>
migraphx
::
msgpack_size_limit
)
MIGRAPHX_THROW
(
"Size is too large for msgpack"
);
if
(
not
v
.
front
().
get_key
().
empty
())
{
o
.
pack_map
(
v
.
size
());
...
...
src/normalize_attributes.cpp
View file @
8d32c6b8
...
...
@@ -26,7 +26,7 @@
#include <migraphx/normalize_attributes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/op/common.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -49,6 +49,10 @@ auto tune_attribute(const std::vector<int64_t>& vec,
Message
m
)
{
std
::
vector
<
int64_t
>
result
(
vec
);
if
(
result
.
empty
())
{
return
result
;
};
int64_t
n_rank
=
input_shape
.
ndim
();
std
::
vector
<
op
::
normalize_attribute
>
vec_attrs
=
val
.
to_vector
<
op
::
normalize_attribute
>
();
if
(
contains
(
vec_attrs
,
op
::
normalize_attribute
::
use_output
))
...
...
@@ -188,20 +192,27 @@ bool normalize_attributes(operation& op, const shape& input_shape)
auto
val
=
op
.
to_value
();
if
(
attrs
.
contains
(
"normalize_padding"
))
{
auto
padding
=
val
.
at
(
attrs
.
at
(
"normalize_padding"
).
to
<
std
::
string
>
());
auto
padding_size
=
padding
.
size
();
auto
padding_start
=
2
;
if
(
padding_size
==
2
*
(
input_shape
.
ndim
()
-
padding_start
))
tuned
=
true
;
else
if
(
padding_size
!=
(
input_shape
.
ndim
()
-
padding_start
))
MIGRAPHX_THROW
(
"inconsistent padding size"
);
else
bool
use_auto_padding
=
(
val
.
contains
(
"padding_mode"
)
and
(
val
.
at
(
"padding_mode"
).
to
<
int
>
()
!=
migraphx
::
op
::
padding_mode_t
::
default_
));
if
(
not
use_auto_padding
)
{
auto
result
=
tune_pad_attribute
(
padding
);
val
[
"padding"
]
=
result
;
op
.
from_value
(
val
);
tuned
=
true
;
auto
padding
=
val
.
at
(
attrs
.
at
(
"normalize_padding"
).
to
<
std
::
string
>
());
auto
padding_size
=
padding
.
size
();
auto
padding_start
=
2
;
if
(
padding_size
==
2
*
(
input_shape
.
ndim
()
-
padding_start
))
tuned
=
true
;
else
if
(
padding_size
!=
(
input_shape
.
ndim
()
-
padding_start
))
{
MIGRAPHX_THROW
(
"normalize_attributes: inconsistent padding vector size "
);
}
else
{
auto
result
=
tune_pad_attribute
(
padding
);
val
[
"padding"
]
=
result
;
op
.
from_value
(
val
);
tuned
=
true
;
}
}
}
if
(
not
attrs
.
contains
(
"normalize_axes"
))
...
...
@@ -251,5 +262,22 @@ bool normalize_attributes(operation& op, const shape& input_shape)
return
tuned
;
}
std
::
vector
<
int64_t
>
normalize_axes
(
const
std
::
vector
<
int64_t
>&
axes
,
const
shape
&
input_shape
,
const
value
&
attr_val
,
const
std
::
string
&
prefix
)
{
return
tune_attribute
(
axes
,
{},
attr_val
,
input_shape
,
[
&
]
{
return
prefix
;
});
}
std
::
vector
<
int64_t
>
normalize_indices
(
const
std
::
vector
<
int64_t
>&
indices
,
const
std
::
vector
<
int64_t
>&
axes
,
const
shape
&
input_shape
,
const
value
&
attr_val
,
const
std
::
string
&
prefix
)
{
return
tune_attribute
(
indices
,
axes
,
attr_val
,
input_shape
,
[
&
]
{
return
prefix
;
});
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/onnx/broadcast_qdq.cpp
0 → 100644
View file @
8d32c6b8
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/onnx/broadcast_qdq.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
onnx
{
// This method is to prep for quantizelinear or dequantizelinear operation for
// either the broadcasting of weight-scale or zero-points of qlinearadd operator
// outputs: operator op (inputs x, broadcasted: scale (float) & zero_pt (8-bit))
instruction_ref
bcast_qdq_instr
(
const
std
::
string
&
op_name
,
instruction_ref
x_in
,
instruction_ref
arg_fscale
,
instruction_ref
arg_z_pt
,
const
onnx_parser
::
node_info
&
info
)
{
auto
in_lens
=
x_in
->
get_shape
().
lens
();
// prep 1: broadcast scale. it can come as a scalar or a 1-D tensor.
instruction_ref
bcast_scale
;
if
(
arg_fscale
->
get_shape
().
elements
()
>
1
)
bcast_scale
=
info
.
add_instruction
(
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
0
},
{
"out_lens"
,
in_lens
}}),
arg_fscale
);
else
bcast_scale
=
info
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
in_lens
}}),
arg_fscale
);
// prep 2: broadcast zero point. it can come as a scalar or a 1-D tensor.
instruction_ref
bcast_zero_pt
;
if
(
arg_z_pt
->
get_shape
().
elements
()
>
1
)
bcast_zero_pt
=
info
.
add_instruction
(
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
0
},
{
"out_lens"
,
in_lens
}}),
arg_z_pt
);
else
bcast_zero_pt
=
info
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
in_lens
}}),
arg_z_pt
);
// op_name is either quantizelinear or dequantizelinear:
return
info
.
add_instruction
(
migraphx
::
make_op
(
op_name
),
x_in
,
bcast_scale
,
bcast_zero_pt
);
}
// Multibroadcast a scaler..
instruction_ref
bcast_scalar_instr
(
const
migraphx
::
shape
&
shape_out
,
instruction_ref
arg_in
,
const
onnx_parser
::
node_info
&
info
)
{
auto
bcast_instr_out
=
info
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
shape_out
.
lens
()}}),
arg_in
);
return
bcast_instr_out
;
}
}
// namespace onnx
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/onnx/include/migraphx/onnx/broadcast_qdq.hpp
0 → 100644
View file @
8d32c6b8
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_ONNX_BROADCAST_QDQ_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_ONNX_BROADCAST_QDQ_HPP
#include <string>
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
onnx
{
// This method is to prep for quantizelinear or dequantizelinear operation for
// either the broadcasting of weight-scale or zero-points of qlinearadd operator
// outputs: operator op (inputs x, broadcasted: scale (float) & zero_pt (8-bit))
instruction_ref
bcast_qdq_instr
(
const
std
::
string
&
op_name
,
instruction_ref
x_in
,
instruction_ref
arg_fscale
,
instruction_ref
arg_z_pt
,
const
onnx_parser
::
node_info
&
info
);
// Multibroadcast a scaler..
instruction_ref
bcast_scalar_instr
(
const
migraphx
::
shape
&
shape_out
,
instruction_ref
arg_in
,
const
onnx_parser
::
node_info
&
info
);
}
// namespace onnx
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/onnx/onnx_parser.cpp
View file @
8d32c6b8
...
...
@@ -244,7 +244,7 @@ void onnx_parser::parse_from(std::istream& is, std::string name)
this
->
filename
=
std
::
move
(
name
);
auto
parent_path
=
fs
::
path
(
this
->
filename
).
parent_path
();
if
(
not
parent_path
.
empty
())
this
->
path
=
parent_path
;
this
->
path
=
parent_path
.
string
()
;
onnx
::
ModelProto
model
;
if
(
model
.
ParseFromIstream
(
&
is
))
...
...
src/onnx/padding.cpp
View file @
8d32c6b8
...
...
@@ -47,7 +47,7 @@ void cal_auto_padding_size(onnx_parser::node_info info,
return
;
}
auto
auto_pad
=
info
.
attributes
[
"auto_pad"
].
s
();
auto
auto_pad
=
to_upper
(
info
.
attributes
[
"auto_pad"
].
s
()
)
;
if
(
auto_pad
.
find
(
"SAME"
)
!=
std
::
string
::
npos
)
{
bool
is_same_upper
=
(
auto_pad
.
find
(
"SAME_UPPER"
)
!=
std
::
string
::
npos
);
...
...
src/onnx/parse_castlike.cpp
0 → 100644
View file @
8d32c6b8
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
onnx
{
struct
parse_castlike
:
op_parser
<
parse_castlike
>
{
std
::
vector
<
op_desc
>
operators
()
const
{
return
{{
"CastLike"
}};
}
instruction_ref
parse
(
const
op_desc
&
/*opd*/
,
const
onnx_parser
&
/*parser*/
,
const
onnx_parser
::
node_info
&
info
,
const
std
::
vector
<
instruction_ref
>&
args
)
const
{
if
(
not
(
args
.
size
()
==
2
))
{
MIGRAPHX_THROW
(
"PARSE_CASTLIKE: CastLike must have exactly 2 inputs!"
);
}
shape
::
type_t
target_type
=
args
[
1
]
->
get_shape
().
type
();
return
info
.
add_instruction
(
make_op
(
"convert"
,
{{
"target_type"
,
target_type
}}),
args
[
0
]);
}
};
}
// namespace onnx
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/onnx/parse_constant.cpp
View file @
8d32c6b8
...
...
@@ -25,6 +25,7 @@
#include <migraphx/ranges.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/stringutils.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -39,16 +40,38 @@ struct parse_constant : op_parser<parse_constant>
onnx_parser
::
node_info
info
,
const
std
::
vector
<
instruction_ref
>&
/*args*/
)
const
{
literal
v
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"value"
));
static
const
std
::
vector
<
std
::
string
>
attributes
=
{
"value"
,
"value_float"
,
"value_floats"
,
"value_int"
,
"value_ints"
};
std
::
vector
<
std
::
string
>
present_attributes
;
std
::
copy_if
(
attributes
.
begin
(),
attributes
.
end
(),
std
::
back_inserter
(
present_attributes
),
[
&
](
const
std
::
string
&
a
)
{
return
contains
(
info
.
attributes
,
a
);
});
if
(
present_attributes
.
empty
())
{
MIGRAPHX_THROW
(
"Constant node does not contain any supported attribute"
);
}
if
(
present_attributes
.
size
()
>
1
)
{
MIGRAPHX_THROW
(
"Constant contains multiple attributes: "
+
join_strings
(
std
::
move
(
present_attributes
),
", "
));
}
// cppcheck-suppress accessMoved
auto
&&
attr
=
info
.
attributes
[
present_attributes
[
0
]];
literal
v
=
parser
.
parse_value
(
attr
);
// return empty literal
if
(
v
.
get_shape
().
elements
()
==
0
)
{
return
info
.
add_literal
(
literal
{
v
.
get_shape
().
type
()});
}
auto
dim_size
=
info
.
attributes
.
at
(
"value"
).
t
().
dims_size
();
// if dim_size is 0, it is a scalar
if
(
dim_size
==
0
)
if
(
attr
.
has_t
()
and
attr
.
t
().
dim
s
_size
()
==
0
)
{
migraphx
::
shape
scalar_shape
{
v
.
get_shape
().
type
()};
return
info
.
add_literal
(
migraphx
::
literal
{
scalar_shape
,
v
.
data
()});
...
...
src/onnx/parse_constant_of_shape.cpp
View file @
8d32c6b8
...
...
@@ -49,6 +49,8 @@ struct parse_constant_of_shape : op_parser<parse_constant_of_shape>
{
MIGRAPHX_THROW
(
"ConstantOfShape: attribute value can contain only 1 elements!"
);
}
// convert to a scalar literal
l_val
=
literal
(
shape
{
l_val
.
get_shape
().
type
(),
{
1
},
{
0
}},
l_val
.
data
());
}
else
{
...
...
@@ -64,30 +66,37 @@ struct parse_constant_of_shape : op_parser<parse_constant_of_shape>
migraphx
::
shape
s
;
// input is empty, output is a scalar
auto
type
=
l_val
.
get_shape
().
type
();
// empty input tensor, output is a scalar
if
(
args
[
0
]
->
get_shape
().
elements
()
==
0
)
migraphx
::
argument
input
=
args
[
0
]
->
eval
();
if
(
not
input
.
empty
()
)
{
s
=
migraphx
::
shape
{
type
,
{
1
},
{
0
}};
// empty input tensor, output is a scalar
if
(
args
[
0
]
->
get_shape
().
elements
()
==
0
)
{
s
=
migraphx
::
shape
{
type
,
{
1
},
{
0
}};
}
else
{
std
::
vector
<
std
::
size_t
>
dims
;
input
.
visit
([
&
](
auto
ia
)
{
dims
.
assign
(
ia
.
begin
(),
ia
.
end
());
});
s
=
migraphx
::
shape
{
type
,
dims
};
}
literal
l_out
{};
l_val
.
visit
([
&
](
auto
val
)
{
using
val_type
=
std
::
remove_cv_t
<
typename
decltype
(
val
)
::
value_type
>
;
// l_val contains only one element
std
::
vector
<
val_type
>
out_vec
(
s
.
elements
(),
val
.
front
());
l_out
=
literal
(
s
,
out_vec
);
});
return
info
.
add_literal
(
l_out
);
}
// has variable input (dynamic shape buffer)
else
{
migraphx
::
argument
in
=
args
[
0
]
->
eval
();
check_arg_empty
(
in
,
"ConstantOfShape: dynamic shape is not supported"
);
std
::
vector
<
std
::
size_t
>
dims
;
in
.
visit
([
&
](
auto
input
)
{
dims
.
assign
(
input
.
begin
(),
input
.
end
());
});
s
=
migraphx
::
shape
{
type
,
dims
};
auto
dv_lit
=
info
.
add_literal
(
l_val
);
auto
alloc_ins
=
info
.
add_instruction
(
make_op
(
"allocate"
,
{{
"buf_type"
,
type
}}),
args
[
0
]);
return
info
.
add_instruction
(
make_op
(
"fill"
),
dv_lit
,
alloc_ins
);
}
literal
l_out
{};
l_val
.
visit
([
&
](
auto
val
)
{
using
val_type
=
std
::
remove_cv_t
<
typename
decltype
(
val
)
::
value_type
>
;
// l_val contains only one element
std
::
vector
<
val_type
>
out_vec
(
s
.
elements
(),
val
.
front
());
l_out
=
literal
(
s
,
out_vec
);
});
return
info
.
add_literal
(
l_out
);
}
}
};
...
...
src/onnx/parse_depthtospace.cpp
View file @
8d32c6b8
...
...
@@ -87,8 +87,7 @@ struct parse_depthtospace : op_parser<parse_depthtospace>
auto
temp1
=
info
.
add_instruction
(
make_op
(
"reshape"
,
{{
"dims"
,
lens1
}}),
args
[
0
]);
auto
temp2
=
info
.
add_instruction
(
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
temp1
);
return
info
.
add_instruction
(
make_op
(
"reshape"
,
{{
"dims"
,
lens2
}}),
info
.
make_contiguous
(
temp2
));
return
info
.
add_instruction
(
make_op
(
"reshape"
,
{{
"dims"
,
lens2
}}),
temp2
);
}
};
...
...
src/onnx/parse_pooling.cpp
View file @
8d32c6b8
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
...
...
@@ -97,7 +97,7 @@ struct parse_pooling : op_parser<parse_pooling>
values
[
"lp_order"
]
=
info
.
attributes
.
at
(
"p"
).
i
();
}
// ensure pads availabe only when auto_pad is "NOT_SET"
// ensure pads availab
l
e only when auto_pad is "NOT_SET"
check_padding_mode
(
info
,
"POOLING"
);
return
values
;
...
...
@@ -151,26 +151,6 @@ struct parse_pooling : op_parser<parse_pooling>
kdims
,
paddings
.
size
()
/
2
,
"PARSE_POOLING: inconsistent explicit paddings"
);
}
if
(
contains
(
info
.
attributes
,
"auto_pad"
))
{
if
(
in_shape
.
dynamic
())
{
MIGRAPHX_THROW
(
"PARSE_POOLING: Auto padding pooling with dynamic input shape not supported"
);
}
else
{
values
[
"padding"
].
clear
();
// return paddings could be empty, then setting to 0 for no padding
cal_auto_padding_size
(
info
,
values
,
values
[
"lengths"
].
to_vector
<
std
::
size_t
>
(),
{
1
,
1
},
in_shape
.
lens
(),
paddings
);
}
}
if
(
paddings
.
size
()
!=
2
*
kdims
)
{
paddings
.
resize
(
kdims
*
2
);
...
...
@@ -192,6 +172,36 @@ struct parse_pooling : op_parser<parse_pooling>
// used to calculate the supposed output shape
std
::
vector
<
int64_t
>
orig_padding
=
paddings
;
// TODO: add parsing for dilations
if
(
contains
(
info
.
attributes
,
"auto_pad"
)
and
to_upper
(
info
.
attributes
[
"auto_pad"
].
s
())
!=
"NOTSET"
)
{
auto
auto_pad
=
to_upper
(
info
.
attributes
[
"auto_pad"
].
s
());
// don't use the given padding sizes, if any
// values["padding"].clear();
if
(
in_shape
.
dynamic
())
{
// set padding_mode to trigger auto padding at runtime
bool
is_same_upper
=
(
auto_pad
.
find
(
"SAME_UPPER"
)
!=
std
::
string
::
npos
);
values
[
"padding_mode"
]
=
is_same_upper
?
to_value
(
op
::
padding_mode_t
::
same_upper
)
:
to_value
(
op
::
padding_mode_t
::
same_lower
);
}
else
{
// Calculate auto padding
// dilations (argument 4) not supported; default to all 1's
cal_auto_padding_size
(
info
,
values
,
values
[
"lengths"
].
to_vector
<
std
::
size_t
>
(),
std
::
vector
<
size_t
>
(
in_shape
.
ndim
()
-
2
,
1
),
in_shape
.
lens
(),
paddings
);
values
[
"padding"
]
=
paddings
;
// default padding_mode indicates that padding sizes are not calculated dynamically
values
[
"padding_mode"
]
=
migraphx
::
op
::
padding_mode_t
::
default_
;
}
}
std
::
vector
<
int64_t
>
slice_start
;
std
::
vector
<
int64_t
>
slice_end
;
tune_padding_size
(
values
,
paddings
,
count_include_pad
,
slice_start
);
...
...
@@ -208,8 +218,9 @@ struct parse_pooling : op_parser<parse_pooling>
orig_padding
.
insert
(
orig_padding
.
begin
(),
2
,
0
);
op
::
pad
pad
{
orig_padding
,
0.0
f
};
shape
padded_shape
=
pad
.
compute_shape
({
l0
->
get_shape
()});
auto
out_lens
=
make_op
(
"pooling"
,
values
).
compute_shape
({
padded_shape
}).
lens
();
// make an op just to get its output shape
auto
out_lens
=
make_op
(
"pooling"
,
values
).
compute_shape
({
padded_shape
}).
lens
();
// compute slice_end information
slice_end
.
resize
(
slice_start
.
size
());
std
::
transform
(
out_lens
.
begin
()
+
2
,
...
...
src/onnx/parse_qlinearadd.cpp
0 → 100644
View file @
8d32c6b8
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/common.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/onnx/broadcast_qdq.hpp>
#include <migraphx/instruction.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
onnx
{
/*
*********************************************************************************
* Reference: see QLinearAdd in *
* https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md *
*********************************************************************************
com.microsoft.QLinearAdd
Performs element-wise binary addition on 8 bit data types (with Numpy-style broadcasting support).
C = (A_scale * (A - A_zero_point) + B_scale * (B - B_zero_point))/C_scale + C_zero_point
Version
This version of the operator has been available since version 1 of the 'com.microsoft' operator
set.
Inputs (7 - 8)
A : T
First operand.
A_scale : tensor(float)
Input A's scale. It's a scalar, which means a per-tensor/layer quantization.
A_zero_point (optional) : T
Input A zero point. Default value is 0 if it's not specified. It's a scalar, which means a
per-tensor/layer quantization.
B : T
Second operand.
B_scale : tensor(float)
Input B's scale. It's a scalar, which means a per-tensor/layer quantization.
B_zero_point (optional) : T
Input B zero point. Default value is 0 if it's not specified. It's a scalar, which means a
per-tensor/layer quantization.
C_scale : tensor(float)
Output scale. It's a scalar, which means a per-tensor/layer quantization.
C_zero_point (optional) : T
Output zero point. Default value is 0 if it's not specified. It's a scalar, which means a
per-tensor/layer quantization.
Outputs
C : T
Result, has same element type as two inputs
Type Constraints
T : tensor(uint8), tensor(int8)
Constrain input and output types to 8 bit signed and unsigned tensors.
*/
struct
parse_qlinearadd
:
op_parser
<
parse_qlinearadd
>
{
std
::
vector
<
op_desc
>
operators
()
const
{
return
{{
"QLinearAdd"
}};
}
// basic type checking for QLinearAdd Operator
void
check_inputs
(
const
std
::
vector
<
instruction_ref
>&
args
)
const
{
if
(
args
.
size
()
<
7
)
MIGRAPHX_THROW
(
"QLINEARADD: missing inputs"
);
const
auto
&
in_a
=
args
[
0
];
const
auto
&
in_b
=
args
[
3
];
auto
sh_a
=
in_a
->
get_shape
();
auto
sh_b
=
in_b
->
get_shape
();
auto
type_a
=
sh_a
.
type
();
auto
type_b
=
sh_b
.
type
();
if
(
type_a
!=
migraphx
::
shape
::
int8_type
and
type_a
!=
migraphx
::
shape
::
uint8_type
)
MIGRAPHX_THROW
(
"QLINEARADD: unsupported input type"
);
if
(
type_b
!=
migraphx
::
shape
::
int8_type
and
type_b
!=
migraphx
::
shape
::
uint8_type
)
MIGRAPHX_THROW
(
"QLINEARADD: unsupported input type"
);
if
(
type_a
!=
type_b
)
MIGRAPHX_THROW
(
"QLINEARADD: mismatched input types"
);
}
instruction_ref
parse
(
const
op_desc
&
/* opd */
,
const
onnx_parser
&
/*parser*/
,
const
onnx_parser
::
node_info
&
info
,
const
std
::
vector
<
instruction_ref
>&
args
)
const
{
check_inputs
(
args
);
// A
const
auto
&
in_a
=
args
[
0
];
const
auto
&
in_scale_a
=
args
[
1
];
const
auto
&
in_zero_pt_a
=
args
[
2
];
auto
dquant_a
=
bcast_qdq_instr
(
"dequantizelinear"
,
in_a
,
in_scale_a
,
in_zero_pt_a
,
info
);
// B
const
auto
&
in_b
=
args
[
3
];
const
auto
&
in_scale_b
=
args
[
4
];
const
auto
&
in_zero_pt_b
=
args
[
5
];
auto
dquant_b
=
bcast_qdq_instr
(
"dequantizelinear"
,
in_b
,
in_scale_b
,
in_zero_pt_b
,
info
);
// C = A + B
auto
out_c
=
info
.
add_common_op
(
"add"
,
dquant_a
,
dquant_b
);
const
auto
&
in_scale_c
=
args
[
6
];
// zero_pt for C is supplied as the last optional argument..
if
(
args
.
size
()
==
8
)
return
(
bcast_qdq_instr
(
"quantizelinear"
,
out_c
,
in_scale_c
,
args
[
7
],
info
));
// if no zero_pt: just broadcast the scale..
auto
bcast_scale_c
=
bcast_scalar_instr
(
out_c
->
get_shape
(),
in_scale_c
,
info
);
return
(
info
.
add_instruction
(
migraphx
::
make_op
(
"quantizelinear"
),
out_c
,
bcast_scale_c
));
}
};
}
// namespace onnx
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
Prev
1
2
3
4
5
6
7
8
9
…
20
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment