Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
f12064ee
Commit
f12064ee
authored
Aug 25, 2023
by
umangyadav
Browse files
Merge branch 'develop' into resnet50_partition
parents
2c4f70be
6f1c947f
Changes
126
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
373 additions
and
102 deletions
+373
-102
src/include/migraphx/permutation.hpp
src/include/migraphx/permutation.hpp
+4
-0
src/instruction.cpp
src/instruction.cpp
+3
-9
src/memory_coloring.cpp
src/memory_coloring.cpp
+5
-3
src/module.cpp
src/module.cpp
+6
-7
src/normalize_attributes.cpp
src/normalize_attributes.cpp
+21
-0
src/onnx/include/migraphx/onnx/onnx_parser.hpp
src/onnx/include/migraphx/onnx/onnx_parser.hpp
+1
-0
src/onnx/onnx_parser.cpp
src/onnx/onnx_parser.cpp
+12
-9
src/onnx/parse_constant_of_shape.cpp
src/onnx/parse_constant_of_shape.cpp
+2
-3
src/onnx/parse_randomuniform_ops.cpp
src/onnx/parse_randomuniform_ops.cpp
+1
-1
src/onnx/parse_slice.cpp
src/onnx/parse_slice.cpp
+82
-49
src/permutation.cpp
src/permutation.cpp
+10
-0
src/program.cpp
src/program.cpp
+7
-5
src/py/CMakeLists.txt
src/py/CMakeLists.txt
+14
-3
src/py/include/migraphx/py.hpp
src/py/include/migraphx/py.hpp
+38
-0
src/py/py.cpp
src/py/py.cpp
+76
-0
src/py/py_loader.cpp
src/py/py_loader.cpp
+74
-0
src/rewrite_quantization.cpp
src/rewrite_quantization.cpp
+5
-7
src/simplify_algebra.cpp
src/simplify_algebra.cpp
+6
-5
src/sqlite.cpp
src/sqlite.cpp
+1
-0
src/targets/cpu/gemm.cpp
src/targets/cpu/gemm.cpp
+5
-1
No files found.
src/include/migraphx/permutation.hpp
View file @
f12064ee
...
@@ -66,6 +66,10 @@ MIGRAPHX_EXPORT std::vector<int64_t> invert_permutation(const std::vector<int64_
...
@@ -66,6 +66,10 @@ MIGRAPHX_EXPORT std::vector<int64_t> invert_permutation(const std::vector<int64_
MIGRAPHX_EXPORT
std
::
vector
<
int64_t
>
find_permutation
(
const
shape
&
s
);
MIGRAPHX_EXPORT
std
::
vector
<
int64_t
>
find_permutation
(
const
shape
&
s
);
MIGRAPHX_EXPORT
std
::
vector
<
int64_t
>
find_permutation
(
const
std
::
vector
<
shape
>&
shapes
);
MIGRAPHX_EXPORT
std
::
vector
<
int64_t
>
find_permutation
(
const
std
::
vector
<
shape
>&
shapes
);
/// Normalize the shapes so the order of dimensions will be in the order it is
/// in memory as much as possible.
MIGRAPHX_EXPORT
std
::
vector
<
shape
>
normalize_permutation
(
const
std
::
vector
<
shape
>&
shapes
);
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
...
...
src/instruction.cpp
View file @
f12064ee
...
@@ -64,10 +64,7 @@ void instruction::replace(const shape& r)
...
@@ -64,10 +64,7 @@ void instruction::replace(const shape& r)
result
=
r
;
result
=
r
;
for
(
auto
&&
ins
:
output
)
for
(
auto
&&
ins
:
output
)
{
{
if
(
ins
->
name
()
==
"@return"
)
assert
(
ins
->
name
()
==
"@return"
or
ins
->
name
().
front
()
!=
'@'
);
continue
;
assert
(
ins
->
name
().
front
()
!=
'@'
);
ins
->
recompute_shape
();
ins
->
recompute_shape
();
}
}
}
}
...
@@ -122,10 +119,6 @@ bool instruction::valid() const
...
@@ -122,10 +119,6 @@ bool instruction::valid() const
{
{
computed
=
result
;
computed
=
result
;
}
}
else
if
(
op
.
name
()
==
"@return"
)
{
computed
=
{};
}
else
else
{
{
try
try
...
@@ -145,6 +138,7 @@ bool instruction::valid() const
...
@@ -145,6 +138,7 @@ bool instruction::valid() const
}
}
shape
instruction
::
get_shape
()
const
{
return
result
;
}
shape
instruction
::
get_shape
()
const
{
return
result
;
}
const
literal
&
instruction
::
get_literal
()
const
const
literal
&
instruction
::
get_literal
()
const
{
{
assert
(
op
.
name
()
==
"@literal"
);
assert
(
op
.
name
()
==
"@literal"
);
...
@@ -395,7 +389,7 @@ void instruction::print(std::ostream& os,
...
@@ -395,7 +389,7 @@ void instruction::print(std::ostream& os,
if
(
not
ins
->
module_inputs
().
empty
())
if
(
not
ins
->
module_inputs
().
empty
())
{
{
std
::
string
delim
=
", ["
;
std
::
string
delim
=
", ["
;
for
(
auto
&
&
mod_arg
:
ins
->
module_inputs
())
for
(
const
const_module_ref
&
mod_arg
:
ins
->
module_inputs
())
{
{
os
<<
delim
<<
mod_arg
->
name
();
os
<<
delim
<<
mod_arg
->
name
();
delim
=
", "
;
delim
=
", "
;
...
...
src/memory_coloring.cpp
View file @
f12064ee
...
@@ -23,9 +23,9 @@
...
@@ -23,9 +23,9 @@
*/
*/
#include <migraphx/memory_coloring.hpp>
#include <migraphx/memory_coloring.hpp>
#include <migraphx/module.hpp>
#include <migraphx/module.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/algorithm.hpp>
#include <migraphx/algorithm.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/ranges.hpp>
...
@@ -382,7 +382,8 @@ void memory_coloring::apply(module& m) const
...
@@ -382,7 +382,8 @@ void memory_coloring::apply(module& m) const
auto
s
=
ins
->
get_shape
();
auto
s
=
ins
->
get_shape
();
std
::
size_t
offset
=
seg
.
first
*
alignment
;
std
::
size_t
offset
=
seg
.
first
*
alignment
;
assert
(
offset
<
n
);
assert
(
offset
<
n
);
m
.
replace_instruction
(
ins
,
op
::
load
{
s
,
offset
},
mem
);
m
.
replace_instruction
(
ins
,
make_op
(
"load"
,
{{
"shape"
,
to_value
(
s
)},
{
"offset"
,
offset
}}),
mem
);
}
}
// Replace zero allocation
// Replace zero allocation
...
@@ -391,7 +392,8 @@ void memory_coloring::apply(module& m) const
...
@@ -391,7 +392,8 @@ void memory_coloring::apply(module& m) const
if
(
ins
->
name
()
!=
allocation_op
)
if
(
ins
->
name
()
!=
allocation_op
)
continue
;
continue
;
assert
(
ins
->
get_shape
().
bytes
()
==
0
);
assert
(
ins
->
get_shape
().
bytes
()
==
0
);
m
.
replace_instruction
(
ins
,
op
::
load
{
ins
->
get_shape
(),
0
},
mem
);
m
.
replace_instruction
(
ins
,
make_op
(
"load"
,
{{
"shape"
,
to_value
(
ins
->
get_shape
())},
{
"offset"
,
0
}}),
mem
);
}
}
// Remove scratch parameter if its not used
// Remove scratch parameter if its not used
...
...
src/module.cpp
View file @
f12064ee
...
@@ -460,11 +460,11 @@ instruction_ref module::add_parameter(std::string name, shape s)
...
@@ -460,11 +460,11 @@ instruction_ref module::add_parameter(std::string name, shape s)
instruction_ref
module
::
add_return
(
std
::
vector
<
instruction_ref
>
args
)
instruction_ref
module
::
add_return
(
std
::
vector
<
instruction_ref
>
args
)
{
{
impl
->
push_back
({
builtin
::
returns
{},
{},
std
::
move
(
args
)});
shape
instr_shape
=
compute_shape
(
builtin
::
returns
{},
args
);
impl
->
push_back
({
builtin
::
returns
{},
instr_shape
,
std
::
move
(
args
)});
auto
result
=
std
::
prev
(
impl
->
instructions
.
end
());
auto
result
=
std
::
prev
(
impl
->
instructions
.
end
());
instruction
::
backreference
(
result
);
instruction
::
backreference
(
result
);
assert
(
result
->
valid
(
begin
()));
assert
(
result
->
valid
(
begin
()));
return
result
;
return
result
;
}
}
...
@@ -873,12 +873,11 @@ module::print_py(std::ostream& os,
...
@@ -873,12 +873,11 @@ module::print_py(std::ostream& os,
if
(
ins
->
name
()
==
"@literal"
)
if
(
ins
->
name
()
==
"@literal"
)
{
{
os
<<
mname
<<
".add_literal("
;
os
<<
mname
<<
".add_literal("
;
bool
use_abs
=
false
;
const
bool
use_abs
=
false
;
ins
->
get_literal
().
visit
([
&
](
auto
v
)
{
use_abs
=
std
::
none_of
(
v
.
begin
(),
v
.
end
(),
[](
auto
x
)
{
return
x
<
0
;
});
});
// Disable abs for now
// Disable abs for now
use_abs
=
false
;
// ins->get_literal().visit([&](auto v) {
// use_abs = std::none_of(v.begin(), v.end(), [](auto x) { return x < 0; });
// });
if
(
use_abs
)
if
(
use_abs
)
os
<<
"migraphx.abs_literal("
;
os
<<
"migraphx.abs_literal("
;
os
<<
"migraphx.generate_argument("
;
os
<<
"migraphx.generate_argument("
;
...
...
src/normalize_attributes.cpp
View file @
f12064ee
...
@@ -49,6 +49,10 @@ auto tune_attribute(const std::vector<int64_t>& vec,
...
@@ -49,6 +49,10 @@ auto tune_attribute(const std::vector<int64_t>& vec,
Message
m
)
Message
m
)
{
{
std
::
vector
<
int64_t
>
result
(
vec
);
std
::
vector
<
int64_t
>
result
(
vec
);
if
(
result
.
empty
())
{
return
result
;
};
int64_t
n_rank
=
input_shape
.
ndim
();
int64_t
n_rank
=
input_shape
.
ndim
();
std
::
vector
<
op
::
normalize_attribute
>
vec_attrs
=
val
.
to_vector
<
op
::
normalize_attribute
>
();
std
::
vector
<
op
::
normalize_attribute
>
vec_attrs
=
val
.
to_vector
<
op
::
normalize_attribute
>
();
if
(
contains
(
vec_attrs
,
op
::
normalize_attribute
::
use_output
))
if
(
contains
(
vec_attrs
,
op
::
normalize_attribute
::
use_output
))
...
@@ -251,5 +255,22 @@ bool normalize_attributes(operation& op, const shape& input_shape)
...
@@ -251,5 +255,22 @@ bool normalize_attributes(operation& op, const shape& input_shape)
return
tuned
;
return
tuned
;
}
}
std
::
vector
<
int64_t
>
normalize_axes
(
const
std
::
vector
<
int64_t
>&
axes
,
const
shape
&
input_shape
,
const
value
&
attr_val
,
const
std
::
string
&
prefix
)
{
return
tune_attribute
(
axes
,
{},
attr_val
,
input_shape
,
[
&
]
{
return
prefix
;
});
}
std
::
vector
<
int64_t
>
normalize_indices
(
const
std
::
vector
<
int64_t
>&
indices
,
const
std
::
vector
<
int64_t
>&
axes
,
const
shape
&
input_shape
,
const
value
&
attr_val
,
const
std
::
string
&
prefix
)
{
return
tune_attribute
(
indices
,
axes
,
attr_val
,
input_shape
,
[
&
]
{
return
prefix
;
});
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
src/onnx/include/migraphx/onnx/onnx_parser.hpp
View file @
f12064ee
...
@@ -117,6 +117,7 @@ struct onnx_parser
...
@@ -117,6 +117,7 @@ struct onnx_parser
parse_graph
(
module
*
mod
,
const
onnx
::
GraphProto
&
graph
,
bool
inlining
=
false
);
parse_graph
(
module
*
mod
,
const
onnx
::
GraphProto
&
graph
,
bool
inlining
=
false
);
literal
parse_value
(
const
onnx
::
AttributeProto
&
attr
)
const
;
literal
parse_value
(
const
onnx
::
AttributeProto
&
attr
)
const
;
literal
parse_tensor
(
const
onnx
::
TensorProto
&
t
)
const
;
literal
parse_tensor
(
const
onnx
::
TensorProto
&
t
)
const
;
shape
parse_type
(
const
onnx
::
TypeProto
&
t
)
const
;
shape
parse_type
(
const
onnx
::
TypeProto
&
t
,
const
std
::
vector
<
std
::
size_t
>&
input_dims
)
const
;
shape
parse_type
(
const
onnx
::
TypeProto
&
t
,
const
std
::
vector
<
std
::
size_t
>&
input_dims
)
const
;
};
};
...
...
src/onnx/onnx_parser.cpp
View file @
f12064ee
...
@@ -357,10 +357,9 @@ parse_inputs(const onnx_parser& parser,
...
@@ -357,10 +357,9 @@ parse_inputs(const onnx_parser& parser,
}
}
shape
s
;
shape
s
;
std
::
vector
<
std
::
size_t
>
dims
;
if
(
parser
.
map_input_dims
.
count
(
name
)
>
0
)
if
(
parser
.
map_input_dims
.
count
(
name
)
>
0
)
{
{
dims
=
parser
.
map_input_dims
.
at
(
name
);
std
::
vector
<
std
::
size_t
>
dims
=
parser
.
map_input_dims
.
at
(
name
);
s
=
parser
.
parse_type
(
input
.
type
(),
dims
);
s
=
parser
.
parse_type
(
input
.
type
(),
dims
);
}
}
else
if
(
parser
.
map_dyn_input_dims
.
count
(
name
)
>
0
)
else
if
(
parser
.
map_dyn_input_dims
.
count
(
name
)
>
0
)
...
@@ -370,7 +369,7 @@ parse_inputs(const onnx_parser& parser,
...
@@ -370,7 +369,7 @@ parse_inputs(const onnx_parser& parser,
}
}
else
else
{
{
s
=
parser
.
parse_type
(
input
.
type
()
,
dims
);
s
=
parser
.
parse_type
(
input
.
type
());
}
}
mod_insts
[
name
]
=
mod
->
add_parameter
(
name
,
s
);
mod_insts
[
name
]
=
mod
->
add_parameter
(
name
,
s
);
}
}
...
@@ -553,14 +552,9 @@ literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const
...
@@ -553,14 +552,9 @@ literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const
}
}
MIGRAPHX_THROW
(
"PARSE_TENSOR: Invalid tensor type"
);
MIGRAPHX_THROW
(
"PARSE_TENSOR: Invalid tensor type"
);
}
}
shape
onnx_parser
::
parse_type
(
const
onnx
::
TypeProto
&
t
,
shape
onnx_parser
::
parse_type
(
const
onnx
::
TypeProto
&
t
)
const
const
std
::
vector
<
std
::
size_t
>&
input_dims
)
const
{
{
shape
::
type_t
shape_type
=
get_type
(
t
.
tensor_type
().
elem_type
());
shape
::
type_t
shape_type
=
get_type
(
t
.
tensor_type
().
elem_type
());
if
(
not
input_dims
.
empty
())
{
return
{
shape_type
,
input_dims
};
}
std
::
vector
<
shape
::
dynamic_dimension
>
dynamic_dims
;
std
::
vector
<
shape
::
dynamic_dimension
>
dynamic_dims
;
auto
&&
tensor_dims
=
t
.
tensor_type
().
shape
().
dim
();
auto
&&
tensor_dims
=
t
.
tensor_type
().
shape
().
dim
();
...
@@ -590,6 +584,15 @@ shape onnx_parser::parse_type(const onnx::TypeProto& t,
...
@@ -590,6 +584,15 @@ shape onnx_parser::parse_type(const onnx::TypeProto& t,
return
shape_from_dyn_dims
(
shape_type
,
dynamic_dims
);
return
shape_from_dyn_dims
(
shape_type
,
dynamic_dims
);
}
}
shape
onnx_parser
::
parse_type
(
const
onnx
::
TypeProto
&
t
,
const
std
::
vector
<
std
::
size_t
>&
input_dims
)
const
{
shape
::
type_t
shape_type
=
get_type
(
t
.
tensor_type
().
elem_type
());
if
(
input_dims
.
empty
())
return
{
shape_type
};
return
{
shape_type
,
input_dims
};
}
shape
::
type_t
get_type
(
int
dtype
)
shape
::
type_t
get_type
(
int
dtype
)
{
{
switch
(
dtype
)
switch
(
dtype
)
...
...
src/onnx/parse_constant_of_shape.cpp
View file @
f12064ee
...
@@ -55,9 +55,6 @@ struct parse_constant_of_shape : op_parser<parse_constant_of_shape>
...
@@ -55,9 +55,6 @@ struct parse_constant_of_shape : op_parser<parse_constant_of_shape>
l_val
=
literal
({
shape
::
float_type
,
{
1
},
{
0
}},
{
0.0
f
});
l_val
=
literal
({
shape
::
float_type
,
{
1
},
{
0
}},
{
0.0
f
});
}
}
// input is empty, output is a scalar
auto
type
=
l_val
.
get_shape
().
type
();
if
(
args
.
empty
())
if
(
args
.
empty
())
{
{
MIGRAPHX_THROW
(
"ConstantOfShape : must have 1 input!"
);
MIGRAPHX_THROW
(
"ConstantOfShape : must have 1 input!"
);
...
@@ -65,6 +62,8 @@ struct parse_constant_of_shape : op_parser<parse_constant_of_shape>
...
@@ -65,6 +62,8 @@ struct parse_constant_of_shape : op_parser<parse_constant_of_shape>
else
else
{
{
migraphx
::
shape
s
;
migraphx
::
shape
s
;
// input is empty, output is a scalar
auto
type
=
l_val
.
get_shape
().
type
();
// empty input tensor, output is a scalar
// empty input tensor, output is a scalar
if
(
args
[
0
]
->
get_shape
().
elements
()
==
0
)
if
(
args
[
0
]
->
get_shape
().
elements
()
==
0
)
{
{
...
...
src/onnx/parse_randomuniform_ops.cpp
View file @
f12064ee
...
@@ -96,7 +96,7 @@ struct parse_randomuniform_ops : op_parser<parse_randomuniform_ops>
...
@@ -96,7 +96,7 @@ struct parse_randomuniform_ops : op_parser<parse_randomuniform_ops>
if
(
contains
(
info
.
attributes
,
"seed"
))
if
(
contains
(
info
.
attributes
,
"seed"
))
gen
.
seed
(
info
.
attributes
.
at
(
"seed"
).
f
());
gen
.
seed
(
info
.
attributes
.
at
(
"seed"
).
f
());
std
::
uniform_real_distribution
<>
d
(
high
,
low
);
std
::
uniform_real_distribution
<>
d
(
low
,
high
);
std
::
vector
<
double
>
rand_vals
(
out_shape
.
elements
());
std
::
vector
<
double
>
rand_vals
(
out_shape
.
elements
());
std
::
generate
(
rand_vals
.
begin
(),
rand_vals
.
end
(),
[
&
]()
{
return
d
(
gen
);
});
std
::
generate
(
rand_vals
.
begin
(),
rand_vals
.
end
(),
[
&
]()
{
return
d
(
gen
);
});
...
...
src/onnx/parse_slice.cpp
View file @
f12064ee
...
@@ -34,16 +34,65 @@ namespace onnx {
...
@@ -34,16 +34,65 @@ namespace onnx {
struct
parse_slice
:
op_parser
<
parse_slice
>
struct
parse_slice
:
op_parser
<
parse_slice
>
{
{
std
::
vector
<
op_desc
>
operators
()
const
{
return
{{
"Slice"
}};
}
std
::
vector
<
op_desc
>
operators
()
const
{
return
{{
"Slice"
}};
}
struct
slice_desc
{
op
::
slice
op
;
std
::
vector
<
instruction_ref
>
op_args
;
std
::
vector
<
int64_t
>
steps
;
std
::
vector
<
int64_t
>
raxes
;
void
always_insert
(
instruction_ref
arg
)
{
op_args
.
insert
(
op_args
.
begin
(),
arg
);
}
std
::
vector
<
int64_t
>
insert
(
instruction_ref
arg
)
{
std
::
vector
<
int64_t
>
result
;
migraphx
::
argument
arg_value
=
arg
->
eval
();
if
(
arg_value
.
empty
())
{
op_args
.
insert
(
op_args
.
begin
(),
arg
);
}
else
{
arg_value
.
visit
([
&
](
auto
s
)
{
result
.
assign
(
s
.
begin
(),
s
.
end
());
});
}
return
result
;
}
};
instruction_ref
parse
(
const
op_desc
&
/*opd*/
,
instruction_ref
parse
(
const
op_desc
&
/*opd*/
,
const
onnx_parser
&
parser
,
const
onnx_parser
&
parser
,
onnx_parser
::
node_info
info
,
const
onnx_parser
::
node_info
&
info
,
std
::
vector
<
instruction_ref
>
args
)
const
const
std
::
vector
<
instruction_ref
>
&
args
)
const
{
{
op
::
slice
op
;
auto
sd
=
construct_slice_desc
(
parser
,
info
,
args
);
auto
ins
=
info
.
add_instruction
(
sd
.
op
,
sd
.
op_args
);
if
(
not
sd
.
raxes
.
empty
())
{
ins
=
info
.
add_instruction
(
make_op
(
"reverse"
,
{{
"axes"
,
sd
.
raxes
}}),
ins
);
}
// If any steps are other than default 1, add a "steps" op
if
(
std
::
any_of
(
sd
.
steps
.
begin
(),
sd
.
steps
.
end
(),
[](
auto
s
)
{
return
std
::
abs
(
s
)
!=
1
;
}))
{
std
::
vector
<
int64_t
>
nsteps
;
std
::
transform
(
sd
.
steps
.
begin
(),
sd
.
steps
.
end
(),
std
::
back_inserter
(
nsteps
),
[](
auto
s
)
{
return
std
::
abs
(
s
);
});
return
ins
=
info
.
add_instruction
(
make_op
(
"step"
,
{{
"axes"
,
sd
.
op
.
axes
},
{
"steps"
,
nsteps
}}),
ins
);
}
else
return
ins
;
}
std
::
vector
<
int64_t
>
steps
;
slice_desc
construct_slice_desc
(
const
onnx_parser
&
parser
,
onnx_parser
::
node_info
info
,
std
::
vector
<
instruction_ref
>
args
)
const
{
slice_desc
sd
;
// slice can have up to 5 inputs, we first check the 5th one
// slice can have up to 5 inputs, we first check the 5th one
// to decide whether MIGRAPHX can handle this slice.
// to decide whether MIGRAPHX can handle this slice.
...
@@ -51,89 +100,73 @@ struct parse_slice : op_parser<parse_slice>
...
@@ -51,89 +100,73 @@ struct parse_slice : op_parser<parse_slice>
{
{
migraphx
::
argument
step_arg
=
args
.
back
()
->
eval
();
migraphx
::
argument
step_arg
=
args
.
back
()
->
eval
();
check_arg_empty
(
step_arg
,
"PARSE_SLICE: cannot handle variable steps for slice"
);
check_arg_empty
(
step_arg
,
"PARSE_SLICE: cannot handle variable steps for slice"
);
step_arg
.
visit
([
&
](
auto
s
)
{
steps
.
assign
(
s
.
begin
(),
s
.
end
());
});
step_arg
.
visit
([
&
](
auto
s
)
{
sd
.
steps
.
assign
(
s
.
begin
(),
s
.
end
());
});
}
}
if
(
args
.
size
()
>=
4
)
if
(
args
.
size
()
>=
4
)
{
{
migraphx
::
argument
axes_arg
=
args
.
at
(
3
)
->
eval
();
sd
.
op
.
axes
=
sd
.
insert
(
args
.
at
(
3
));
check_arg_empty
(
axes_arg
,
"PARSE_SLICE: cannot handle variable axes for slice"
);
axes_arg
.
visit
([
&
](
auto
s
)
{
op
.
axes
.
assign
(
s
.
begin
(),
s
.
end
());
});
}
}
else
if
(
contains
(
info
.
attributes
,
"axes"
))
else
if
(
contains
(
info
.
attributes
,
"axes"
))
{
{
literal
s
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"axes"
));
literal
s
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"axes"
));
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
op
.
axes
));
});
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
sd
.
op
.
axes
));
});
}
}
if
(
args
.
size
()
>=
3
)
if
(
args
.
size
()
>=
3
)
{
{
migraphx
::
argument
end_arg
=
args
.
at
(
2
)
->
eval
();
sd
.
op
.
ends
=
sd
.
insert
(
args
.
at
(
2
));
check_arg_empty
(
end_arg
,
"PARSE_SLICE: cannot handle variable ends for slice"
);
end_arg
.
visit
([
&
](
auto
s
)
{
op
.
ends
.
assign
(
s
.
begin
(),
s
.
end
());
});
}
}
else
if
(
contains
(
info
.
attributes
,
"ends"
))
else
if
(
contains
(
info
.
attributes
,
"ends"
))
{
{
literal
s
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"ends"
));
literal
s
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"ends"
));
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
op
.
ends
));
});
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
sd
.
op
.
ends
));
});
}
}
if
(
args
.
size
()
>=
2
)
if
(
args
.
size
()
>=
2
)
{
{
migraphx
::
argument
start_arg
=
args
.
at
(
1
)
->
eval
();
sd
.
op
.
starts
=
sd
.
insert
(
args
.
at
(
1
));
check_arg_empty
(
start_arg
,
"PARSE_SLICE: cannot handle variable starts for slice"
);
start_arg
.
visit
([
&
](
auto
s
)
{
op
.
starts
.
assign
(
s
.
begin
(),
s
.
end
());
});
}
}
else
if
(
contains
(
info
.
attributes
,
"starts"
))
else
if
(
contains
(
info
.
attributes
,
"starts"
))
{
{
literal
s
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"starts"
));
literal
s
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"starts"
));
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
op
.
starts
));
});
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
sd
.
op
.
starts
));
});
}
}
// data input argument
sd
.
always_insert
(
args
.
at
(
0
));
// If axes arg is not given, the default is all of them.
// If axes arg is not given, the default is all of them.
if
(
op
.
axes
.
empty
())
if
(
sd
.
op
.
axes
.
empty
()
and
sd
.
op_args
.
size
()
<
3
)
{
{
std
::
vector
<
int64_t
>
axes
(
args
[
0
]
->
get_shape
().
ndim
());
std
::
vector
<
int64_t
>
axes
(
args
[
0
]
->
get_shape
().
ndim
());
std
::
iota
(
axes
.
begin
(),
axes
.
end
(),
int64_t
{
0
});
std
::
iota
(
axes
.
begin
(),
axes
.
end
(),
int64_t
{
0
});
op
.
axes
=
axes
;
sd
.
op
.
axes
=
axes
;
}
}
std
::
vector
<
int64_t
>
raxes
;
if
(
not
sd
.
steps
.
empty
())
{
if
(
sd
.
op
.
starts
.
empty
()
or
sd
.
op
.
ends
.
empty
())
MIGRAPHX_THROW
(
"PARSE_SLICE: steps and variable starts and ends is not supported"
);
if
(
sd
.
op
.
axes
.
empty
())
MIGRAPHX_THROW
(
"PARSE_SLICE: steps and variable axes is not supported"
);
}
assert
(
steps
.
empty
()
or
steps
.
size
()
==
op
.
axes
.
size
());
assert
(
sd
.
steps
.
empty
()
or
sd
.
steps
.
size
()
==
sd
.
op
.
axes
.
size
());
assert
(
op
.
axes
.
size
()
==
op
.
starts
.
size
());
assert
(
op
.
axes
.
size
()
==
op
.
ends
.
size
());
// If any axes have negative step, prepare to add a "reverse" op
// If any axes have negative step, prepare to add a "reverse" op
for
(
auto
i
:
range
(
steps
.
size
()))
for
(
auto
i
:
range
(
sd
.
steps
.
size
()))
{
{
if
(
steps
[
i
]
>=
0
)
if
(
sd
.
steps
[
i
]
>=
0
)
continue
;
continue
;
op
.
starts
[
i
]
+=
1
;
sd
.
op
.
starts
[
i
]
+=
1
;
if
(
op
.
starts
[
i
]
==
0
)
if
(
sd
.
op
.
starts
[
i
]
==
0
)
op
.
starts
[
i
]
=
INT_MAX
;
sd
.
op
.
starts
[
i
]
=
INT_MAX
;
op
.
ends
[
i
]
+=
1
;
sd
.
op
.
ends
[
i
]
+=
1
;
raxes
.
push_back
(
op
.
axes
[
i
]);
sd
.
raxes
.
push_back
(
sd
.
op
.
axes
[
i
]);
std
::
swap
(
op
.
starts
[
i
],
op
.
ends
[
i
]);
std
::
swap
(
sd
.
op
.
starts
[
i
],
sd
.
op
.
ends
[
i
]);
}
auto
ins
=
info
.
add_instruction
(
op
,
args
[
0
]);
if
(
not
raxes
.
empty
())
{
ins
=
info
.
add_instruction
(
make_op
(
"reverse"
,
{{
"axes"
,
raxes
}}),
ins
);
}
}
// If any steps are other than default 1, add a "steps" op
return
sd
;
if
(
std
::
any_of
(
steps
.
begin
(),
steps
.
end
(),
[](
auto
s
)
{
return
std
::
abs
(
s
)
!=
1
;
}))
{
std
::
vector
<
int64_t
>
nsteps
;
std
::
transform
(
steps
.
begin
(),
steps
.
end
(),
std
::
back_inserter
(
nsteps
),
[](
auto
s
)
{
return
std
::
abs
(
s
);
});
return
ins
=
info
.
add_instruction
(
make_op
(
"step"
,
{{
"axes"
,
op
.
axes
},
{
"steps"
,
nsteps
}}),
ins
);
}
else
return
ins
;
}
}
};
};
...
...
src/permutation.cpp
View file @
f12064ee
...
@@ -74,5 +74,15 @@ std::vector<int64_t> find_permutation(const std::vector<shape>& shapes)
...
@@ -74,5 +74,15 @@ std::vector<int64_t> find_permutation(const std::vector<shape>& shapes)
return
it
->
first
;
return
it
->
first
;
}
}
std
::
vector
<
shape
>
normalize_permutation
(
const
std
::
vector
<
shape
>&
shapes
)
{
auto
result
=
shapes
;
auto
perm
=
find_permutation
(
shapes
);
std
::
transform
(
result
.
begin
(),
result
.
end
(),
result
.
begin
(),
[
&
](
auto
s
)
{
return
reorder_shape
(
s
,
perm
);
});
return
result
;
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
src/program.cpp
View file @
f12064ee
...
@@ -40,8 +40,10 @@
...
@@ -40,8 +40,10 @@
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/marker.hpp>
#include <migraphx/marker.hpp>
#include <migraphx/supported_segments.hpp>
#include <migraphx/supported_segments.hpp>
#include <iostream>
#include <iostream>
#include <queue>
#include <queue>
#include <queue>
#include <sstream>
#include <sstream>
#include <algorithm>
#include <algorithm>
#include <set>
#include <set>
...
@@ -229,7 +231,7 @@ void program::compile(const std::vector<target>& targets, std::vector<compile_op
...
@@ -229,7 +231,7 @@ void program::compile(const std::vector<target>& targets, std::vector<compile_op
// Gather all the target roots
// Gather all the target roots
std
::
unordered_multimap
<
std
::
size_t
,
module_ref
>
roots
;
std
::
unordered_multimap
<
std
::
size_t
,
module_ref
>
roots
;
auto
mods
=
this
->
get_modules
();
auto
mods
=
this
->
get_modules
();
for
(
auto
*
mod
:
mods
)
for
(
const
auto
*
mod
:
mods
)
{
{
for
(
const
auto
&
ins
:
*
mod
)
for
(
const
auto
&
ins
:
*
mod
)
{
{
...
@@ -554,7 +556,7 @@ std::vector<argument> program::eval(parameter_map params, execution_environment
...
@@ -554,7 +556,7 @@ std::vector<argument> program::eval(parameter_map params, execution_environment
ins_out
[
x
]
=
ss
.
str
();
ins_out
[
x
]
=
ss
.
str
();
});
});
ret
=
generic_eval
(
*
this
,
contexts
,
std
::
move
(
params
),
[
&
](
instruction_ref
ins
,
auto
f
)
{
ret
=
generic_eval
(
*
this
,
contexts
,
std
::
move
(
params
),
[
&
](
instruction_ref
ins
,
auto
f
)
{
auto
&
ctx
=
contexts
[
ins
->
get_target_id
()];
const
auto
&
ctx
=
contexts
[
ins
->
get_target_id
()];
ctx
.
finish
();
ctx
.
finish
();
std
::
cout
<<
"Run instruction: "
<<
ins_out
.
at
(
ins
)
<<
std
::
endl
;
std
::
cout
<<
"Run instruction: "
<<
ins_out
.
at
(
ins
)
<<
std
::
endl
;
timer
t
{};
timer
t
{};
...
@@ -734,7 +736,7 @@ static void mod_from_val(module_ref mod,
...
@@ -734,7 +736,7 @@ static void mod_from_val(module_ref mod,
std
::
back_inserter
(
module_inputs
),
std
::
back_inserter
(
module_inputs
),
[
&
](
const
value
&
i
)
{
return
map_mods
.
at
(
i
.
to
<
std
::
string
>
());
});
[
&
](
const
value
&
i
)
{
return
map_mods
.
at
(
i
.
to
<
std
::
string
>
());
});
for
(
auto
&
smod
:
module_inputs
)
for
(
const
auto
&
smod
:
module_inputs
)
{
{
mod_from_val
(
smod
,
v
,
instructions
,
map_mods
);
mod_from_val
(
smod
,
v
,
instructions
,
map_mods
);
}
}
...
@@ -1192,7 +1194,7 @@ void program::remove_unused_modules()
...
@@ -1192,7 +1194,7 @@ void program::remove_unused_modules()
std
::
vector
<
module
*>
unused
;
std
::
vector
<
module
*>
unused
;
generic_get_unused_modules
(
generic_get_unused_modules
(
impl
->
modules
,
generic_get_modules
(
this
->
get_main_module
()),
std
::
back_inserter
(
unused
));
impl
->
modules
,
generic_get_modules
(
this
->
get_main_module
()),
std
::
back_inserter
(
unused
));
for
(
auto
*
m
:
unused
)
for
(
const
auto
*
m
:
unused
)
this
->
remove_module
(
m
->
name
());
this
->
remove_module
(
m
->
name
());
}
}
...
@@ -1200,7 +1202,7 @@ program& program::sort()
...
@@ -1200,7 +1202,7 @@ program& program::sort()
{
{
std
::
queue
<
migraphx
::
module_ref
>
mqueue
;
std
::
queue
<
migraphx
::
module_ref
>
mqueue
;
mqueue
.
push
(
get_main_module
());
mqueue
.
push
(
get_main_module
());
while
(
!
mqueue
.
empty
())
while
(
not
mqueue
.
empty
())
{
{
module_ref
current_mod
=
mqueue
.
front
();
module_ref
current_mod
=
mqueue
.
front
();
current_mod
->
sort
();
current_mod
->
sort
();
...
...
src/py/CMakeLists.txt
View file @
f12064ee
...
@@ -23,14 +23,25 @@
...
@@ -23,14 +23,25 @@
#####################################################################################
#####################################################################################
option
(
MIGRAPHX_ENABLE_PYTHON
"Enable python bindings"
ON
)
option
(
MIGRAPHX_ENABLE_PYTHON
"Enable python bindings"
ON
)
add_library
(
migraphx_py py_loader.cpp
)
migraphx_generate_export_header
(
migraphx_py
)
target_include_directories
(
migraphx_py PRIVATE include
)
target_link_libraries
(
migraphx_py PUBLIC migraphx
)
rocm_install_targets
(
TARGETS migraphx_py INCLUDE include
)
if
(
MIGRAPHX_ENABLE_PYTHON
)
if
(
MIGRAPHX_ENABLE_PYTHON
)
include
(
PythonModules
)
include
(
PythonModules
)
add_custom_target
(
migraphx_py
)
foreach
(
PYTHON_VERSION
${
PYTHON_VERSIONS
}
)
foreach
(
PYTHON_VERSION
${
PYTHON_VERSIONS
}
)
py_add_module
(
migraphx_py_
${
PYTHON_VERSION
}
migraphx_py.cpp PYTHON_VERSION
${
PYTHON_VERSION
}
PYTHON_MODULE migraphx
)
py_add_module
(
migraphx_pybind_
${
PYTHON_VERSION
}
migraphx_py.cpp PYTHON_VERSION
${
PYTHON_VERSION
}
PYTHON_MODULE migraphx
)
target_link_libraries
(
migraphx_py_
${
PYTHON_VERSION
}
PRIVATE migraphx migraphx_tf migraphx_onnx migraphx_all_targets
)
target_link_libraries
(
migraphx_pybind_
${
PYTHON_VERSION
}
PRIVATE migraphx migraphx_tf migraphx_onnx migraphx_all_targets
)
rocm_install_targets
(
TARGETS migraphx_pybind_
${
PYTHON_VERSION
}
)
add_dependencies
(
migraphx_py migraphx_pybind_
${
PYTHON_VERSION
}
)
add_library
(
migraphx_py_
${
PYTHON_VERSION
}
py.cpp
)
target_include_directories
(
migraphx_py_
${
PYTHON_VERSION
}
PRIVATE include
)
target_link_libraries
(
migraphx_py_
${
PYTHON_VERSION
}
PUBLIC migraphx
)
target_link_libraries
(
migraphx_py_
${
PYTHON_VERSION
}
PRIVATE pybind11::pybind11 python
${
PYTHON_VERSION
}
::runtime
)
rocm_install_targets
(
TARGETS migraphx_py_
${
PYTHON_VERSION
}
)
rocm_install_targets
(
TARGETS migraphx_py_
${
PYTHON_VERSION
}
)
add_dependencies
(
migraphx_py migraphx_py_
${
PYTHON_VERSION
}
)
add_dependencies
(
migraphx_py migraphx_py_
${
PYTHON_VERSION
}
)
endforeach
()
endforeach
()
...
...
src/py/include/migraphx/py.hpp
0 → 100644
View file @
f12064ee
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHX_PY_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_PY_HPP
#include <migraphx/config.hpp>
#include <migraphx/program.hpp>
#include <migraphx/py/export.h>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
MIGRAPHX_PY_EXPORT
program
load_py
(
const
std
::
string
&
filename
);
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_PY_HPP
src/py/py.cpp
0 → 100644
View file @
f12064ee
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/config.hpp>
#include <migraphx/program.hpp>
#include <migraphx/dynamic_loader.hpp>
#include <migraphx/file_buffer.hpp>
#include <pybind11/embed.h>
namespace
py
=
pybind11
;
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
#ifdef __clang__
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wreturn-type-c-linkage"
#endif
// extern "C" is used to disable name mangling, but the function will still be called from C++
extern
"C"
program
migraphx_load_py
(
const
std
::
string
&
filename
);
#ifdef __clang__
#pragma clang diagnostic pop
#endif
const
std
::
string
&
python_path
()
{
static
const
auto
path
=
dynamic_loader
::
path
(
&
migraphx_load_py
).
parent_path
().
string
();
return
path
;
}
static
py
::
dict
run_file
(
const
std
::
string
&
file
)
{
py
::
object
scope
=
py
::
module_
::
import
(
"__main__"
).
attr
(
"__dict__"
);
std
::
string
buffer
;
buffer
.
append
(
"import sys
\n
"
);
buffer
.
append
(
"sys.path.insert(0, '"
+
python_path
()
+
"')
\n
"
);
buffer
.
append
(
"import migraphx
\n
"
);
buffer
.
append
(
read_string
(
file
));
py
::
exec
(
buffer
,
scope
);
return
scope
.
cast
<
py
::
dict
>
();
}
extern
"C"
program
migraphx_load_py
(
const
std
::
string
&
filename
)
{
py
::
scoped_interpreter
guard
{};
py
::
dict
vars
=
run_file
(
filename
);
auto
it
=
std
::
find_if
(
vars
.
begin
(),
vars
.
end
(),
[](
const
auto
&
p
)
{
return
py
::
isinstance
<
migraphx
::
program
>
(
p
.
second
);
});
if
(
it
==
vars
.
end
())
MIGRAPHX_THROW
(
"No program variable found"
);
return
it
->
second
.
cast
<
migraphx
::
program
>
();
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/py/py_loader.cpp
0 → 100644
View file @
f12064ee
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/py.hpp>
#include <migraphx/dynamic_loader.hpp>
#include <migraphx/process.hpp>
#include <migraphx/ranges.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
static
std
::
vector
<
fs
::
path
>
find_available_python_versions
()
{
std
::
vector
<
fs
::
path
>
result
;
auto
path
=
dynamic_loader
::
path
(
&
load_py
).
parent_path
();
for
(
const
auto
&
entry
:
fs
::
directory_iterator
{
path
})
{
auto
p
=
entry
.
path
();
if
(
not
fs
::
is_regular_file
(
p
))
continue
;
if
(
not
contains
(
p
.
stem
().
string
(),
"migraphx_py_"
))
continue
;
result
.
push_back
(
p
);
}
std
::
sort
(
result
.
begin
(),
result
.
end
(),
std
::
greater
<>
{});
return
result
;
}
static
dynamic_loader
load_py_lib
()
{
auto
libs
=
find_available_python_versions
();
for
(
const
auto
&
lib
:
libs
)
{
auto
result
=
dynamic_loader
::
try_load
(
lib
);
if
(
result
.
has_value
())
return
*
result
;
}
MIGRAPHX_THROW
(
"Cant find a viable version of python"
);
}
static
dynamic_loader
py_lib
()
{
static
dynamic_loader
lib
=
load_py_lib
();
return
lib
;
}
MIGRAPHX_PY_EXPORT
program
load_py
(
const
std
::
string
&
filename
)
{
static
auto
f
=
py_lib
().
get_function
<
program
(
const
std
::
string
&
)
>
(
"migraphx_load_py"
);
return
f
(
filename
);
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/rewrite_quantization.cpp
View file @
f12064ee
...
@@ -28,6 +28,7 @@
...
@@ -28,6 +28,7 @@
#include <migraphx/tune_axis.hpp>
#include <migraphx/tune_axis.hpp>
#include <migraphx/program.hpp>
#include <migraphx/program.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/common.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -61,13 +62,10 @@ void apply_quantizelinear(module& m, instruction_ref ins)
...
@@ -61,13 +62,10 @@ void apply_quantizelinear(module& m, instruction_ref ins)
max_quant
=
qt
.
max
();
max_quant
=
qt
.
max
();
min_quant
=
qt
.
min
();
min_quant
=
qt
.
min
();
});
});
auto
s
=
add_zero_point
->
get_shape
();
auto
s
=
add_zero_point
->
get_shape
();
std
::
vector
<
int
>
min_data
(
s
.
elements
(),
min_quant
);
auto
min_arg
=
m
.
add_literal
(
literal
{
shape
{
s
.
type
()},
{
min_quant
}});
std
::
vector
<
int
>
max_data
(
s
.
elements
(),
max_quant
);
auto
max_arg
=
m
.
add_literal
(
literal
{
shape
{
s
.
type
()},
{
max_quant
}});
auto
min_arg
=
m
.
add_literal
(
literal
(
s
,
min_data
));
auto
saturate
=
insert_common_op
(
m
,
ins
,
make_op
(
"clip"
),
{
add_zero_point
,
min_arg
,
max_arg
});
auto
max_arg
=
m
.
add_literal
(
literal
(
s
,
max_data
));
auto
saturate
=
m
.
insert_instruction
(
ins
,
make_op
(
"clip"
),
add_zero_point
,
min_arg
,
max_arg
);
m
.
replace_instruction
(
m
.
replace_instruction
(
ins
,
make_op
(
"convert"
,
{{
"target_type"
,
ins
->
get_shape
().
type
()}}),
saturate
);
ins
,
make_op
(
"convert"
,
{{
"target_type"
,
ins
->
get_shape
().
type
()}}),
saturate
);
}
}
...
...
src/simplify_algebra.cpp
View file @
f12064ee
...
@@ -1095,8 +1095,9 @@ MIGRAPHX_PRED_MATCHER(horiz_conv_dot, instruction_ref ins)
...
@@ -1095,8 +1095,9 @@ MIGRAPHX_PRED_MATCHER(horiz_conv_dot, instruction_ref ins)
};
};
};
};
auto
dots
=
std
::
count_if
(
ins
->
outputs
().
begin
(),
ins
->
outputs
().
end
(),
pred
(
"dot"
));
auto
dots
=
std
::
count_if
(
ins
->
outputs
().
begin
(),
ins
->
outputs
().
end
(),
pred
(
"dot"
));
auto
qdots
=
std
::
count_if
(
ins
->
outputs
().
begin
(),
ins
->
outputs
().
end
(),
pred
(
"quant_dot"
));
auto
convs
=
std
::
count_if
(
ins
->
outputs
().
begin
(),
ins
->
outputs
().
end
(),
pred
(
"convolution"
));
auto
convs
=
std
::
count_if
(
ins
->
outputs
().
begin
(),
ins
->
outputs
().
end
(),
pred
(
"convolution"
));
return
(
dots
>=
2
or
convs
>=
2
);
return
(
dots
>=
2
or
convs
>=
2
or
qdots
>=
2
);
}
}
struct
find_conv_dot_horiz_fusion
struct
find_conv_dot_horiz_fusion
...
@@ -1110,7 +1111,7 @@ struct find_conv_dot_horiz_fusion
...
@@ -1110,7 +1111,7 @@ struct find_conv_dot_horiz_fusion
auto
pred
=
[](
auto
i
,
auto
j
)
{
auto
pred
=
[](
auto
i
,
auto
j
)
{
if
(
i
->
get_operator
()
!=
j
->
get_operator
())
if
(
i
->
get_operator
()
!=
j
->
get_operator
())
return
false
;
return
false
;
if
(
not
contains
({
"dot"
,
"convolution"
},
i
->
name
()))
if
(
not
contains
({
"quant_dot"
,
"dot"
,
"convolution"
},
i
->
name
()))
return
true
;
return
true
;
auto
x
=
i
->
inputs
()[
1
]
->
get_shape
().
lens
();
auto
x
=
i
->
inputs
()[
1
]
->
get_shape
().
lens
();
auto
y
=
j
->
inputs
()[
1
]
->
get_shape
().
lens
();
auto
y
=
j
->
inputs
()[
1
]
->
get_shape
().
lens
();
...
@@ -1118,7 +1119,7 @@ struct find_conv_dot_horiz_fusion
...
@@ -1118,7 +1119,7 @@ struct find_conv_dot_horiz_fusion
return
false
;
return
false
;
// Check that non-axes match
// Check that non-axes match
int
axis
=
1
;
int
axis
=
1
;
if
(
i
->
name
()
==
"dot"
)
if
(
i
->
name
()
==
"dot"
or
i
->
name
()
==
"quant_dot"
)
{
{
axis
=
x
.
size
()
-
1
;
axis
=
x
.
size
()
-
1
;
}
}
...
@@ -1129,7 +1130,7 @@ struct find_conv_dot_horiz_fusion
...
@@ -1129,7 +1130,7 @@ struct find_conv_dot_horiz_fusion
if
(
std
::
distance
(
start
,
last
)
<
2
)
if
(
std
::
distance
(
start
,
last
)
<
2
)
return
;
return
;
auto
&&
name
=
(
*
start
)
->
name
();
auto
&&
name
=
(
*
start
)
->
name
();
if
(
not
contains
({
"dot"
,
"convolution"
},
name
))
if
(
not
contains
({
"quant_dot"
,
"dot"
,
"convolution"
},
name
))
return
;
return
;
auto
op
=
(
*
start
)
->
get_operator
();
auto
op
=
(
*
start
)
->
get_operator
();
int
group
=
1
;
int
group
=
1
;
...
@@ -1144,7 +1145,7 @@ struct find_conv_dot_horiz_fusion
...
@@ -1144,7 +1145,7 @@ struct find_conv_dot_horiz_fusion
start
,
last
,
std
::
back_inserter
(
args
),
[
&
](
auto
x
)
{
return
x
->
inputs
().
at
(
1
);
});
start
,
last
,
std
::
back_inserter
(
args
),
[
&
](
auto
x
)
{
return
x
->
inputs
().
at
(
1
);
});
int
axis
=
1
;
int
axis
=
1
;
int
concat_axis
=
0
;
int
concat_axis
=
0
;
if
(
name
==
"dot"
)
if
(
name
==
"dot"
or
name
==
"quant_dot"
)
{
{
axis
=
int
(
args
.
front
()
->
get_shape
().
lens
().
size
()
-
1
);
axis
=
int
(
args
.
front
()
->
get_shape
().
lens
().
size
()
-
1
);
concat_axis
=
axis
;
concat_axis
=
axis
;
...
...
src/sqlite.cpp
View file @
f12064ee
...
@@ -48,6 +48,7 @@ struct sqlite_impl
...
@@ -48,6 +48,7 @@ struct sqlite_impl
template
<
class
F
>
template
<
class
F
>
void
exec
(
const
char
*
sql
,
F
f
)
void
exec
(
const
char
*
sql
,
F
f
)
{
{
// cppcheck-suppress constParameterPointer
auto
callback
=
[](
void
*
obj
,
auto
...
xs
)
->
int
{
auto
callback
=
[](
void
*
obj
,
auto
...
xs
)
->
int
{
try
try
{
{
...
...
src/targets/cpu/gemm.cpp
View file @
f12064ee
...
@@ -43,7 +43,11 @@ struct dnnl_gemm : dnnl_extend_op<dnnl_gemm, dnnl::matmul, op::dot>
...
@@ -43,7 +43,11 @@ struct dnnl_gemm : dnnl_extend_op<dnnl_gemm, dnnl::matmul, op::dot>
MIGRAPHX_DNNL_PREFIX
(
ARG_BIAS
)};
MIGRAPHX_DNNL_PREFIX
(
ARG_BIAS
)};
}
}
void
required
(
const
check_shapes
&
cs
)
const
{
cs
.
not_broadcasted
();
}
template
<
class
T
>
void
required
(
const
check_shapes
<
T
>&
cs
)
const
{
cs
.
not_broadcasted
();
}
dnnl
::
matmul
::
desc
get_desc
(
const
std
::
unordered_map
<
int
,
dnnl
::
memory
::
desc
>&
m
)
const
dnnl
::
matmul
::
desc
get_desc
(
const
std
::
unordered_map
<
int
,
dnnl
::
memory
::
desc
>&
m
)
const
{
{
...
...
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment