Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
23cb7917
Unverified
Commit
23cb7917
authored
Aug 16, 2023
by
Brian Pickrell
Committed by
GitHub
Aug 16, 2023
Browse files
Merge branch 'develop' into blas_tuning
parents
b5fcc0bc
ea32ca70
Changes
458
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
999 additions
and
276 deletions
+999
-276
src/onnx/parse_randomuniform_ops.cpp
src/onnx/parse_randomuniform_ops.cpp
+2
-3
src/onnx/parse_shape.cpp
src/onnx/parse_shape.cpp
+53
-9
src/onnx/parse_where.cpp
src/onnx/parse_where.cpp
+1
-0
src/pass_manager.cpp
src/pass_manager.cpp
+29
-15
src/permutation.cpp
src/permutation.cpp
+10
-0
src/program.cpp
src/program.cpp
+283
-128
src/promote_literals.cpp
src/promote_literals.cpp
+1
-1
src/py/CMakeLists.txt
src/py/CMakeLists.txt
+14
-3
src/py/include/migraphx/py.hpp
src/py/include/migraphx/py.hpp
+38
-0
src/py/migraphx_py.cpp
src/py/migraphx_py.cpp
+58
-9
src/py/py.cpp
src/py/py.cpp
+76
-0
src/py/py_loader.cpp
src/py/py_loader.cpp
+74
-0
src/quantization.cpp
src/quantization.cpp
+5
-11
src/quantize_fp16.cpp
src/quantize_fp16.cpp
+11
-10
src/replace_allocate.cpp
src/replace_allocate.cpp
+5
-3
src/rewrite_quantization.cpp
src/rewrite_quantization.cpp
+17
-14
src/shape.cpp
src/shape.cpp
+45
-18
src/simplify_algebra.cpp
src/simplify_algebra.cpp
+254
-20
src/simplify_reshapes.cpp
src/simplify_reshapes.cpp
+13
-28
src/split_single_dyn_dim.cpp
src/split_single_dyn_dim.cpp
+10
-4
No files found.
src/onnx/parse_randomuniform_ops.cpp
View file @
23cb7917
...
...
@@ -35,8 +35,7 @@ namespace onnx {
struct
parse_randomuniform_ops
:
op_parser
<
parse_randomuniform_ops
>
{
const
std
::
set
<
shape
::
type_t
>
valid_types
=
{
shape
::
float_type
,
shape
::
half_type
,
shape
::
double_type
};
std
::
set
<
shape
::
type_t
>
valid_types
=
{
shape
::
float_type
,
shape
::
half_type
,
shape
::
double_type
};
std
::
vector
<
op_desc
>
operators
()
const
{
return
{{
"RandomUniform"
},
{
"RandomUniformLike"
}};
}
...
...
@@ -97,7 +96,7 @@ struct parse_randomuniform_ops : op_parser<parse_randomuniform_ops>
if
(
contains
(
info
.
attributes
,
"seed"
))
gen
.
seed
(
info
.
attributes
.
at
(
"seed"
).
f
());
std
::
uniform_real_distribution
<>
d
(
high
,
low
);
std
::
uniform_real_distribution
<>
d
(
low
,
high
);
std
::
vector
<
double
>
rand_vals
(
out_shape
.
elements
());
std
::
generate
(
rand_vals
.
begin
(),
rand_vals
.
end
(),
[
&
]()
{
return
d
(
gen
);
});
...
...
src/onnx/parse_shape.cpp
View file @
23cb7917
...
...
@@ -30,8 +30,11 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
onnx
{
// Use a literal instruction to replace the shape since, output of
// shape operator are literals in migraphx
/**
* If static shape input, creates a literal in migraphx.
* If dynamic shape input, creates a dimensions_of operator in migraphx (runtime evaluation of
* shape).
*/
struct
parse_shape
:
op_parser
<
parse_shape
>
{
std
::
vector
<
op_desc
>
operators
()
const
{
return
{{
"Shape"
}};
}
...
...
@@ -43,14 +46,55 @@ struct parse_shape : op_parser<parse_shape>
{
if
(
args
.
size
()
!=
1
)
MIGRAPHX_THROW
(
"Shape: operator should have 1 operand"
);
std
::
vector
<
std
::
size_t
>
arg_shape
=
args
[
0
]
->
get_shape
().
lens
();
std
::
vector
<
int64_t
>
vec_shape
(
arg_shape
.
size
());
migraphx
::
shape
s
(
migraphx
::
shape
::
int64_type
,
{
arg_shape
.
size
()});
std
::
transform
(
arg_shape
.
begin
(),
arg_shape
.
end
(),
vec_shape
.
begin
(),
[](
auto
i
)
{
return
int64_t
(
i
);
});
auto
input_shape
=
args
[
0
]
->
get_shape
();
int
input_ndim
=
input_shape
.
ndim
();
std
::
size_t
start
=
0
;
std
::
size_t
end
=
input_ndim
;
// Normalizing the start and end is handled here because of how the static shape version
// works. Clamping to [-r, r], where r is ndim of input and then making positive.
auto
normalize_ind
=
[
&
](
int64_t
ind
)
{
if
(
ind
<
(
-
1
*
input_ndim
))
{
ind
=
-
1
*
input_ndim
;
}
if
(
ind
>
input_ndim
)
{
ind
=
input_ndim
;
}
return
(
ind
>=
0
)
?
ind
:
input_ndim
+
ind
;
};
if
(
contains
(
info
.
attributes
,
"end"
))
{
end
=
normalize_ind
(
info
.
attributes
.
at
(
"end"
).
i
());
}
if
(
contains
(
info
.
attributes
,
"start"
))
{
start
=
normalize_ind
(
info
.
attributes
.
at
(
"start"
).
i
());
}
if
(
end
<=
start
)
{
MIGRAPHX_THROW
(
"PARSE_SHAPE: ending axis <= starting axis, end: "
+
std
::
to_string
(
end
)
+
" start: "
+
std
::
to_string
(
start
));
}
if
(
input_shape
.
dynamic
())
{
return
info
.
add_instruction
(
make_op
(
"dimensions_of"
,
{{
"start"
,
start
},
{
"end"
,
end
}}),
args
[
0
]);
}
else
{
std
::
size_t
output_ndim
=
end
-
start
;
std
::
vector
<
int64_t
>
vec_shape
(
output_ndim
);
migraphx
::
shape
s
(
migraphx
::
shape
::
int64_type
,
{
output_ndim
});
std
::
vector
<
std
::
size_t
>
input_lens
=
input_shape
.
lens
();
std
::
transform
(
input_lens
.
begin
()
+
start
,
input_lens
.
begin
()
+
end
,
vec_shape
.
begin
(),
[](
auto
i
)
{
return
int64_t
(
i
);
});
return
info
.
add_literal
(
migraphx
::
literal
{
s
,
vec_shape
});
}
}
};
}
// namespace onnx
...
...
src/onnx/parse_where.cpp
View file @
23cb7917
...
...
@@ -56,6 +56,7 @@ struct parse_where : op_parser<parse_where>
auto
lens
=
compute_broadcasted_lens
(
args
[
0
]
->
get_shape
().
lens
(),
args
[
1
]
->
get_shape
().
lens
());
lens
=
compute_broadcasted_lens
(
lens
,
args
[
2
]
->
get_shape
().
lens
());
if
(
args
[
0
]
->
get_shape
().
lens
()
!=
lens
)
{
args
[
0
]
=
...
...
src/pass_manager.cpp
View file @
23cb7917
...
...
@@ -68,12 +68,18 @@ void run_pass(program& prog, const pass& p, tracer trace)
struct
module_pm
:
module_pass_manager
{
module
*
mod
=
nullptr
;
module
*
root_mod
=
nullptr
;
tracer
*
t
=
nullptr
;
module
*
common_parent
=
nullptr
;
program
*
prog
=
nullptr
;
module_pm
(
module
*
pmod
=
nullptr
,
tracer
*
pt
=
nullptr
)
:
mod
(
pmod
),
t
(
pt
)
{}
module_pm
(
module
*
pmod
=
nullptr
,
module
*
rmod
=
nullptr
,
tracer
*
pt
=
nullptr
)
:
mod
(
pmod
),
root_mod
(
rmod
),
t
(
pt
)
{
}
template
<
class
...
Ts
>
void
trace
(
Ts
&&
...
xs
)
const
{
...
...
@@ -97,6 +103,8 @@ struct module_pm : module_pass_manager
virtual
module
*
get_root_module
()
override
{
if
(
root_mod
!=
nullptr
)
return
root_mod
;
assert
(
prog
);
return
prog
->
get_main_module
();
}
...
...
@@ -123,33 +131,24 @@ struct module_pm : module_pass_manager
module
&
get_module
(
module_pass_manager
&
mpm
)
{
return
mpm
.
get_module
();
}
void
run_passes
(
module
&
mod
,
const
std
::
vector
<
pass
>&
passes
,
tracer
trace
)
{
if
(
enabled
(
MIGRAPHX_TRACE_PASSES
{}))
trace
=
tracer
{
std
::
cout
};
for
(
const
auto
&
p
:
passes
)
{
module_pm
{
&
mod
,
&
trace
}.
run_pass
(
p
);
}
}
void
run_passes
(
program
&
prog
,
const
std
::
vector
<
pass
>&
passes
,
tracer
trace
)
void
run_passes
(
program
&
prog
,
module_ref
root_mod
,
const
std
::
vector
<
pass
>&
passes
,
tracer
trace
)
{
if
(
enabled
(
MIGRAPHX_TRACE_PASSES
{}))
trace
=
tracer
{
std
::
cout
};
std
::
unordered_set
<
module_ref
>
visited
;
for
(
const
auto
&
p
:
passes
)
{
auto
mods
=
prog
.
get_modules
();
auto
tree
=
prog
.
get_module_tree
();
std
::
vector
<
module_ref
>
sub_mods
=
root_mod
->
get_sub_modules
();
sub_mods
.
insert
(
sub_mods
.
begin
(),
root_mod
);
visited
.
clear
();
for
(
const
auto
&
mod
:
reverse
(
mods
))
for
(
const
auto
&
mod
:
reverse
(
sub_
mods
))
{
if
(
mod
->
bypass
())
continue
;
if
(
not
visited
.
insert
(
mod
).
second
)
continue
;
module_pm
mpm
{
mod
,
&
trace
};
module_pm
mpm
{
mod
,
root_mod
,
&
trace
};
mpm
.
prog
=
&
prog
;
auto
parents
=
range
(
tree
.
equal_range
(
mod
));
auto
nparents
=
distance
(
parents
);
...
...
@@ -167,5 +166,20 @@ void run_passes(program& prog, const std::vector<pass>& passes, tracer trace)
}
}
void
run_passes
(
module
&
mod
,
const
std
::
vector
<
pass
>&
passes
,
tracer
trace
)
{
if
(
enabled
(
MIGRAPHX_TRACE_PASSES
{}))
trace
=
tracer
{
std
::
cout
};
for
(
const
auto
&
p
:
passes
)
{
module_pm
{
&
mod
,
&
mod
,
&
trace
}.
run_pass
(
p
);
}
}
void
run_passes
(
program
&
prog
,
const
std
::
vector
<
pass
>&
passes
,
tracer
trace
)
{
run_passes
(
prog
,
prog
.
get_main_module
(),
passes
,
trace
);
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/permutation.cpp
View file @
23cb7917
...
...
@@ -74,5 +74,15 @@ std::vector<int64_t> find_permutation(const std::vector<shape>& shapes)
return
it
->
first
;
}
std
::
vector
<
shape
>
normalize_permutation
(
const
std
::
vector
<
shape
>&
shapes
)
{
auto
result
=
shapes
;
auto
perm
=
find_permutation
(
shapes
);
std
::
transform
(
result
.
begin
(),
result
.
end
(),
result
.
begin
(),
[
&
](
auto
s
)
{
return
reorder_shape
(
s
,
perm
);
});
return
result
;
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/program.cpp
View file @
23cb7917
...
...
@@ -21,6 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/version.h>
#include <migraphx/compile_options.hpp>
#include <migraphx/program.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp>
...
...
@@ -38,12 +40,14 @@
#include <migraphx/make_op.hpp>
#include <migraphx/marker.hpp>
#include <migraphx/supported_segments.hpp>
#include <iostream>
#include <queue>
#include <sstream>
#include <algorithm>
#include <set>
#include <unordered_map>
#include <utility>
#include <unordered_set>
#include <map>
#include <cassert>
...
...
@@ -53,12 +57,23 @@ inline namespace MIGRAPHX_INLINE_NS {
using
milliseconds
=
std
::
chrono
::
duration
<
double
,
std
::
milli
>
;
struct
mark_instruction_target
{
std
::
size_t
target_id
=
0
;
std
::
string
name
()
const
{
return
"mark_instruction_target"
;
}
void
apply
(
module
&
m
)
const
{
for
(
auto
&
ins
:
m
)
ins
.
set_target_id
(
target_id
);
}
};
struct
program_impl
{
// A map is used to keep references to modules of the program
std
::
unordered_map
<
std
::
string
,
module
>
modules
;
context
c
tx
;
std
::
string
target_name
;
std
::
vector
<
context
>
c
ontexts
;
std
::
vector
<
target
>
targets
;
};
program
::
program
()
:
impl
(
std
::
make_unique
<
program_impl
>
())
{
this
->
create_module
(
"main"
);
}
...
...
@@ -82,14 +97,8 @@ void program::assign(const program& p)
{
impl
=
std
::
make_unique
<
program_impl
>
();
}
else
if
(
not
impl
->
modules
.
empty
())
{
impl
->
modules
.
clear
();
}
impl
->
ctx
=
p
.
impl
->
ctx
;
impl
->
target_name
=
p
.
impl
->
target_name
;
impl
->
modules
=
p
.
impl
->
modules
;
*
impl
=
*
p
.
impl
;
// build a map from old ins to new ins
// Build a map from old module to new module
...
...
@@ -152,7 +161,11 @@ std::vector<shape> program::get_output_shapes() const
return
mm
->
get_output_shapes
();
}
context
&
program
::
get_context
()
const
{
return
impl
->
ctx
;
}
context
&
program
::
get_context
()
const
{
assert
(
impl
->
contexts
.
size
()
==
1
);
return
impl
->
contexts
.
front
();
}
instruction_ref
program
::
validate
()
const
{
...
...
@@ -203,20 +216,106 @@ target_assignments program::get_target_assignments(const std::vector<target>& ta
return
p
;
}
bool
program
::
is_compiled
()
const
{
return
not
this
->
impl
->
target_name
.
empty
();
}
bool
program
::
is_compiled
()
const
{
return
not
this
->
impl
->
contexts
.
empty
();
}
void
program
::
compile
(
const
std
::
vector
<
target
>&
targets
,
std
::
vector
<
compile_options
>
compile_opts
)
{
// Gather all the target roots
std
::
unordered_multimap
<
std
::
size_t
,
module_ref
>
roots
;
auto
mods
=
this
->
get_modules
();
for
(
const
auto
*
mod
:
mods
)
{
for
(
const
auto
&
ins
:
*
mod
)
{
if
(
ins
.
name
()
!=
"run_on_target"
)
continue
;
auto
v
=
ins
.
get_operator
().
to_value
();
module_ref
root
=
ins
.
module_inputs
().
front
();
std
::
size_t
root_target_id
=
v
.
at
(
"target_id"
).
to
<
std
::
size_t
>
();
assert
(
root_target_id
<
targets
.
size
());
roots
.
insert
({
root_target_id
,
root
});
}
}
auto
trace
=
tracer
{};
// TODO: Add tracer based on compile options
if
(
enabled
(
MIGRAPHX_TRACE_COMPILE
{}))
trace
=
tracer
{
std
::
cout
};
trace
(
*
this
);
trace
();
// It is assumed that all instructions outside of any root module would run on "ref" target
// Ref target may or may not be passed as one of the target for the "compile()".
// If it is not passed, Create one and add context of it into the map.
auto
target_idx
=
[
&
](
const
std
::
string
&
t_name
)
{
return
static_cast
<
std
::
size_t
>
(
std
::
find_if
(
targets
.
begin
(),
targets
.
end
(),
[
&
](
const
auto
&
t
)
{
return
t
.
name
()
==
t_name
;
})
-
targets
.
begin
());
};
std
::
size_t
ref_target_id
=
target_idx
(
"ref"
);
if
(
ref_target_id
==
targets
.
size
())
{
this
->
impl
->
contexts
.
resize
(
targets
.
size
()
+
1
);
this
->
impl
->
contexts
[
ref_target_id
]
=
migraphx
::
make_target
(
"ref"
).
get_context
();
// users could pass lessers compile_ops than targets, in that case use default compile_opts
compile_opts
.
resize
(
targets
.
size
()
+
1
,
migraphx
::
compile_options
{});
}
else
{
this
->
impl
->
contexts
.
resize
(
targets
.
size
());
compile_opts
.
resize
(
targets
.
size
(),
migraphx
::
compile_options
{});
}
// mark all the instruction as ref target first, later change target_id based on root-target
run_passes
(
*
this
,
{
mark_instruction_target
{
ref_target_id
}});
// Run passes on each root target
for
(
const
auto
i
:
range
(
targets
.
size
()))
{
const
auto
&
root_target
=
targets
.
at
(
i
);
auto
root_target_id
=
i
;
auto
root_modules_range
=
roots
.
equal_range
(
root_target_id
);
this
->
impl
->
contexts
[
root_target_id
]
=
root_target
.
get_context
();
for
(
const
auto
&
[
id
,
current_mod
]
:
range
(
root_modules_range
))
{
auto
passes
=
root_target
.
get_passes
(
this
->
impl
->
contexts
[
root_target_id
],
compile_opts
[
root_target_id
]);
passes
.
push_back
(
mark_instruction_target
{
static_cast
<
size_t
>
(
root_target_id
)});
run_passes
(
*
this
,
current_mod
,
passes
,
trace
);
auto
invalid
=
current_mod
->
validate
();
if
(
invalid
!=
current_mod
->
end
())
{
MIGRAPHX_THROW
(
"Invalid module "
+
current_mod
->
name
()
+
" from compilation at instruction "
+
std
::
to_string
(
std
::
distance
(
current_mod
->
begin
(),
invalid
)));
}
auto
dangling
=
current_mod
->
find_dangling_reference
();
if
(
dangling
!=
current_mod
->
end
())
{
auto
index
=
std
::
distance
(
current_mod
->
begin
(),
dangling
);
MIGRAPHX_THROW
(
"Dangling reference in module "
+
current_mod
->
name
()
+
" from instruction "
+
std
::
to_string
(
index
));
}
}
}
this
->
finalize
();
}
void
program
::
compile
(
const
target
&
t
,
compile_options
options
)
{
// todo: combine with multi-target compile method
assert
(
not
this
->
is_compiled
());
this
->
impl
->
target
_name
=
t
.
name
()
;
this
->
impl
->
c
tx
=
t
.
get_context
();
this
->
impl
->
target
s
=
{
t
}
;
this
->
impl
->
c
ontexts
=
{
t
.
get_context
()
}
;
if
(
enabled
(
MIGRAPHX_TRACE_COMPILE
{}))
options
.
trace
=
tracer
{
std
::
cout
};
options
.
trace
(
*
this
);
options
.
trace
();
auto
&&
passes
=
t
.
get_passes
(
this
->
impl
->
c
tx
,
options
);
auto
&&
passes
=
t
.
get_passes
(
this
->
impl
->
c
ontexts
.
front
()
,
options
);
run_passes
(
*
this
,
passes
,
options
.
trace
);
auto
mods
=
this
->
get_modules
();
// Validate and finalize
...
...
@@ -235,14 +334,14 @@ void program::compile(const target& t, compile_options options)
MIGRAPHX_THROW
(
"Dangling reference in module "
+
mod
->
name
()
+
" from instruction "
+
std
::
to_string
(
index
));
}
mod
->
finalize
(
this
->
impl
->
c
tx
);
mod
->
finalize
(
this
->
impl
->
c
ontexts
);
}
}
void
program
::
finalize
()
{
auto
*
mm
=
this
->
get_main_module
();
mm
->
finalize
(
this
->
impl
->
c
tx
);
mm
->
finalize
(
this
->
impl
->
c
ontexts
);
}
template
<
class
T
>
...
...
@@ -259,6 +358,31 @@ std::string classify(T x)
}
}
void
print_statistics
(
std
::
ostream
&
os
,
const
argument
&
a
)
{
a
.
visit
(
[
&
](
auto
t
)
{
os
<<
"Min value: "
<<
*
std
::
min_element
(
t
.
begin
(),
t
.
end
())
<<
", "
;
os
<<
"Max value: "
<<
*
std
::
max_element
(
t
.
begin
(),
t
.
end
())
<<
", "
;
double
num_elements
=
t
.
size
();
auto
mean
=
std
::
accumulate
(
t
.
begin
(),
t
.
end
(),
0.0
)
/
num_elements
;
auto
stddev
=
std
::
sqrt
(
std
::
accumulate
(
t
.
begin
(),
t
.
end
(),
0.0
,
[
&
](
auto
r
,
auto
v
)
{
return
r
+
std
::
pow
((
v
-
mean
),
2.0
);
})
/
num_elements
);
os
<<
"Mean: "
<<
mean
<<
", "
;
os
<<
"StdDev: "
<<
stddev
<<
"
\n
"
;
},
[
&
](
const
auto
&
xs
)
{
for
(
const
auto
&
x
:
xs
)
{
print_statistics
(
os
,
x
);
}
});
}
std
::
unordered_set
<
std
::
string
>
classify_argument
(
const
argument
&
a
)
{
std
::
unordered_set
<
std
::
string
>
result
;
...
...
@@ -304,16 +428,15 @@ void preview_argument(std::ostream& os, const argument& a)
template
<
class
F
>
std
::
vector
<
argument
>
generic_eval
(
const
module
*
mod
,
context
&
ctx
,
std
::
vector
<
context
>
&
ctx
,
std
::
unordered_map
<
std
::
string
,
argument
>
params
,
std
::
unordered_map
<
instruction_ref
,
argument
>
results
,
F
make_
trace
)
F
trace
)
{
assert
(
mod
->
validate
()
==
mod
->
end
());
results
.
reserve
(
mod
->
size
()
*
2
);
std
::
vector
<
argument
>
values
;
values
.
reserve
(
16
);
auto
trace
=
make_trace
(
mod
);
for
(
auto
ins
:
iterator_for
(
*
mod
))
{
assert
(
results
.
find
(
ins
)
==
results
.
end
());
...
...
@@ -366,17 +489,21 @@ std::vector<argument> generic_eval(const module* mod,
assert
(
results
.
find
(
i
)
!=
results
.
end
());
return
results
[
i
];
});
const
auto
&
mod_args
=
ins
->
module_inputs
();
auto
module_eval
=
[
&
](
module_ref
smod
,
const
std
::
unordered_map
<
std
::
string
,
argument
>&
inputs
)
{
auto
ssctx
=
ctx
;
return
generic_eval
(
smod
,
ssctx
,
inputs
,
results
,
make_trace
);
return
generic_eval
(
smod
,
ctx
,
inputs
,
results
,
trace
);
};
results
.
emplace
(
ins
,
trace
(
ins
,
[
&
]
{
return
ins
->
normalized_operator
().
compute
(
ctx
,
ins
->
get_shape
(),
values
,
mod_args
,
module_eval
);
results
.
emplace
(
ins
,
trace
(
ins
,
[
&
]
{
auto
op
=
ins
->
normalized_operator
();
if
(
op
.
is_context_free
())
return
op
.
compute
(
ins
->
get_shape
(),
values
,
mod_args
,
module_eval
);
if
(
ins
->
get_target_id
()
>=
ctx
.
size
())
MIGRAPHX_THROW
(
"No context available for "
+
op
.
name
());
return
op
.
compute
(
ctx
[
ins
->
get_target_id
()],
ins
->
get_shape
(),
values
,
mod_args
,
module_eval
);
}));
}
assert
(
results
.
find
(
ins
)
!=
results
.
end
());
...
...
@@ -390,44 +517,25 @@ std::vector<argument> generic_eval(const module* mod,
template
<
class
F
>
std
::
vector
<
argument
>
generic_eval
(
const
program
&
p
,
context
&
ctx
,
std
::
vector
<
context
>
&
ctx
,
std
::
unordered_map
<
std
::
string
,
argument
>
params
,
F
make_
trace
)
F
trace
)
{
const
module
*
mm
=
p
.
get_main_module
();
return
generic_eval
(
mm
,
ctx
,
params
,
{},
make_
trace
);
return
generic_eval
(
mm
,
ctx
,
params
,
{},
trace
);
}
std
::
vector
<
argument
>
program
::
eval
(
parameter_map
params
,
execution_environment
exec_env
)
const
{
auto
&
ctx
=
this
->
impl
->
ctx
;
#ifndef NDEBUG
auto
with_check_context
=
[
&
](
auto
f
)
{
return
[
=
,
&
ctx
](
auto
&&
)
{
auto
sctx
=
std
::
make_shared
<
context
>
(
ctx
);
auto
check_context
=
[
=
,
&
ctx
](
auto
g
)
{
assert
(
is_shared
(
ctx
,
*
sctx
));
auto
x
=
g
();
*
sctx
=
ctx
;
return
x
;
};
return
[
=
](
auto
&&
...
xs
)
{
return
f
(
xs
...,
check_context
);
};
};
};
#else
auto
with_check_context
=
[](
auto
f
)
{
return
[
=
](
auto
&&
)
{
return
[
=
](
auto
&&
...
xs
)
{
return
f
(
xs
...,
[](
auto
g
)
{
return
g
();
});
};
};
};
#endif
auto
&
contexts
=
this
->
impl
->
contexts
;
auto
trace_level
=
value_of
(
MIGRAPHX_TRACE_EVAL
{});
std
::
vector
<
argument
>
ret
;
if
(
exec_env
.
async
)
{
ctx
.
wait_for
(
exec_env
.
queue
);
assert
(
contexts
.
size
()
==
1
);
contexts
.
front
().
wait_for
(
exec_env
.
queue
);
}
if
(
trace_level
>
0
)
...
...
@@ -439,32 +547,42 @@ std::vector<argument> program::eval(parameter_map params, execution_environment
instruction
::
print
(
ss
,
x
,
ins_names
);
ins_out
[
x
]
=
ss
.
str
();
});
ret
=
generic_eval
(
*
this
,
ctx
,
std
::
move
(
params
),
with_check_context
([
&
](
auto
&
ins
,
auto
f
,
auto
&&
check_context
)
{
ret
=
generic_eval
(
*
this
,
contexts
,
std
::
move
(
params
),
[
&
](
instruction_ref
ins
,
auto
f
)
{
const
auto
&
ctx
=
contexts
[
ins
->
get_target_id
()];
ctx
.
finish
();
std
::
cout
<<
"Run instruction: "
<<
ins_out
.
at
(
ins
)
<<
std
::
endl
;
timer
t
{};
auto
result
=
check_context
(
f
);
auto
result
=
f
(
);
double
t1
=
t
.
record
<
milliseconds
>
();
ctx
.
finish
();
double
t2
=
t
.
record
<
milliseconds
>
();
std
::
cout
<<
"Time: "
<<
t1
<<
"ms, "
<<
t2
<<
"ms"
<<
std
::
endl
;
if
(
trace_level
>
1
and
ins
->
name
().
front
()
!=
'@'
and
ins
->
name
()
!=
"load"
and
not
result
.
empty
())
if
(
trace_level
>
1
and
ins
->
name
().
front
()
!=
'@'
and
ins
->
name
()
!=
"load"
and
not
result
.
empty
())
{
migraphx
::
argument
buffer
;
try
{
const
target
&
tgt
=
this
->
impl
->
targets
.
at
(
ins
->
get_target_id
());
buffer
=
tgt
.
copy_from
(
result
);
}
catch
(
const
migraphx
::
exception
&
)
{
target
tgt
=
make_target
(
this
->
impl
->
target_name
);
auto
buffer
=
tgt
.
copy_from
(
result
);
// instruction was run on host then no need to copy buffer from target
buffer
=
result
;
}
catch
(...)
{
MIGRAPHX_THROW
(
"MIGraphX program execution with MIGRAPHX_TRACE_EVAL failed.
\n
"
);
}
if
(
trace_level
==
2
)
{
std
::
cout
<<
"Output has "
<<
to_string_range
(
classify_argument
(
buffer
))
std
::
cout
<<
"Output has "
<<
to_string_range
(
classify_argument
(
buffer
))
<<
std
::
endl
;
std
::
cout
<<
"Output: "
;
preview_argument
(
std
::
cout
,
buffer
);
std
::
cout
<<
std
::
endl
;
print_statistics
(
std
::
cout
,
buffer
);
}
else
{
...
...
@@ -472,36 +590,49 @@ std::vector<argument> program::eval(parameter_map params, execution_environment
}
}
return
result
;
})
)
;
});
}
else
{
ret
=
generic_eval
(
*
this
,
ctx
,
std
::
move
(
params
),
with_check_context
([
&
](
auto
&
,
auto
f
,
auto
&&
check_context
)
{
return
check_context
(
f
);
}));
ret
=
generic_eval
(
*
this
,
contexts
,
std
::
move
(
params
),
[
&
](
auto
&&
,
auto
f
)
{
return
f
();
});
}
if
(
exec_env
.
async
)
{
ctx
.
finish_on
(
exec_env
.
queue
);
assert
(
contexts
.
size
()
==
1
);
contexts
.
front
().
finish_on
(
exec_env
.
queue
);
}
return
ret
;
}
const
int
program_file_version
=
5
;
void
program
::
finish
()
const
{
for
(
const
auto
&
ctx
:
this
->
impl
->
contexts
)
ctx
.
finish
();
}
std
::
string
get_migraphx_version
()
{
std
::
stringstream
ss
;
ss
<<
std
::
to_string
(
MIGRAPHX_VERSION_MAJOR
)
<<
"."
<<
std
::
to_string
(
MIGRAPHX_VERSION_MINOR
)
<<
"."
<<
std
::
to_string
(
MIGRAPHX_VERSION_PATCH
);
return
ss
.
str
();
}
/*
program file version is for the data structure or format of the MXR file. Version should be bumped
if any changes occur to the format of the MXR file.
*/
const
int
program_file_version
=
6
;
value
program
::
to_value
()
const
{
value
result
;
result
[
"version"
]
=
program_file_version
;
result
[
"target"
]
=
this
->
impl
->
target_name
;
if
(
not
this
->
impl
->
target_name
.
empty
())
result
[
"context"
]
=
this
->
impl
->
ctx
.
to_value
();
result
[
"migraphx_version"
]
=
get_migraphx_version
();
result
[
"targets"
]
=
migraphx
::
to_value
(
this
->
impl
->
targets
);
result
[
"contexts"
]
=
migraphx
::
to_value
(
this
->
impl
->
contexts
);
value
module_vals
=
value
::
object
{};
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
names
;
for
(
auto
&
mod
:
this
->
get_modules
())
...
...
@@ -597,7 +728,7 @@ static void mod_from_val(module_ref mod,
std
::
back_inserter
(
module_inputs
),
[
&
](
const
value
&
i
)
{
return
map_mods
.
at
(
i
.
to
<
std
::
string
>
());
});
for
(
auto
&
smod
:
module_inputs
)
for
(
const
auto
&
smod
:
module_inputs
)
{
mod_from_val
(
smod
,
v
,
instructions
,
map_mods
);
}
...
...
@@ -626,15 +757,27 @@ void program::from_value(const value& v)
auto
version
=
v
.
at
(
"version"
).
to
<
int
>
();
if
(
version
!=
program_file_version
)
{
MIGRAPHX_THROW
(
"Warning: Program version mismatch"
);
MIGRAPHX_THROW
(
"Error: Program version mismatch. MXR file was created using program file version: "
+
std
::
to_string
(
version
)
+
", while installed MIGraphX is using program file version: "
+
std
::
to_string
(
program_file_version
)
+
", Try regenerating MXR file using installed MIGraphX and running again."
);
}
this
->
impl
->
target_name
=
v
.
at
(
"target
"
).
to
<
std
::
string
>
();
if
(
not
this
->
impl
->
target_name
.
empty
())
auto
migx_version
=
v
.
at
(
"migraphx_version
"
).
to
<
std
::
string
>
();
if
(
migx_version
!=
get_migraphx_version
())
{
target
t
=
make_target
(
this
->
impl
->
target_name
);
this
->
impl
->
ctx
=
t
.
get_context
();
this
->
impl
->
ctx
.
from_value
(
v
.
at
(
"context"
));
std
::
cout
<<
"WARNING: MXR File was created using MIGraphX version: "
<<
migx_version
<<
", while installed MIGraphX is at version: "
<<
get_migraphx_version
()
<<
", operators implementation could be mismatched."
;
}
migraphx
::
from_value
(
v
.
at
(
"targets"
),
this
->
impl
->
targets
);
for
(
auto
i
:
range
(
this
->
impl
->
targets
.
size
()))
{
this
->
impl
->
contexts
.
push_back
(
this
->
impl
->
targets
[
i
].
get_context
());
this
->
impl
->
contexts
.
back
().
from_value
(
v
.
at
(
"contexts"
)[
i
]);
}
auto
module_vals
=
v
.
at
(
"modules"
);
...
...
@@ -655,6 +798,8 @@ void program::from_value(const value& v)
auto
*
mm
=
get_main_module
();
mod_from_val
(
mm
,
module_vals
,
map_insts
,
map_mods
);
// Finalize a compiled model
if
(
not
this
->
impl
->
contexts
.
empty
())
this
->
finalize
();
}
...
...
@@ -675,19 +820,19 @@ std::string perf_group(const operation& op)
void
program
::
mark
(
const
parameter_map
&
params
,
marker
&&
m
)
{
auto
&
ctx
=
this
->
impl
->
c
tx
;
auto
&
ctx
=
this
->
impl
->
c
ontexts
;
// Run once by itself
eval
(
params
);
ctx
.
finish
();
this
->
finish
();
// Start marking
m
.
mark_start
(
*
this
);
generic_eval
(
*
this
,
ctx
,
params
,
always
(
[
&
](
auto
ins
,
auto
f
)
{
generic_eval
(
*
this
,
ctx
,
params
,
[
&
](
auto
ins
,
auto
f
)
{
argument
result
;
m
.
mark_start
(
ins
);
result
=
f
();
m
.
mark_stop
(
ins
);
return
result
;
})
)
;
});
m
.
mark_stop
(
*
this
);
}
...
...
@@ -696,10 +841,10 @@ void program::perf_report(std::ostream& os,
parameter_map
params
,
std
::
size_t
batch
)
const
{
auto
&
ctx
=
this
->
impl
->
c
tx
;
auto
&
ctx
=
this
->
impl
->
c
ontexts
;
// Run once by itself
eval
(
params
);
ctx
.
finish
();
this
->
finish
();
// Run and time entire program
std
::
vector
<
double
>
total_vec
;
total_vec
.
reserve
(
n
);
...
...
@@ -707,28 +852,28 @@ void program::perf_report(std::ostream& os,
{
total_vec
.
push_back
(
time
<
milliseconds
>
([
&
]
{
eval
(
params
);
ctx
.
finish
();
this
->
finish
();
}));
}
std
::
sort
(
total_vec
.
begin
(),
total_vec
.
end
());
std
::
unordered_map
<
instruction_ref
,
std
::
vector
<
double
>>
ins_vec
;
// Fill the map
generic_eval
(
*
this
,
ctx
,
params
,
always
(
[
&
](
auto
ins
,
auto
)
{
generic_eval
(
*
this
,
ctx
,
params
,
[
&
](
auto
ins
,
auto
)
{
ins_vec
[
ins
].
reserve
(
n
);
return
argument
{
ins
->
get_shape
(),
nullptr
};
})
)
;
});
// Run and time each instruction
for
(
std
::
size_t
i
=
0
;
i
<
n
;
i
++
)
{
generic_eval
(
*
this
,
ctx
,
params
,
always
(
[
&
](
auto
ins
,
auto
f
)
{
generic_eval
(
*
this
,
ctx
,
params
,
[
&
](
auto
ins
,
auto
f
)
{
argument
result
;
ins_vec
[
ins
].
push_back
(
time
<
milliseconds
>
([
&
]
{
result
=
f
();
ctx
.
finish
();
this
->
impl
->
contexts
[
ins
->
get_target_id
()]
.
finish
();
}));
return
result
;
})
)
;
});
}
for
(
auto
&&
p
:
ins_vec
)
std
::
sort
(
p
.
second
.
begin
(),
p
.
second
.
end
());
...
...
@@ -861,7 +1006,9 @@ void program::print_py(std::ostream& os) const
os
<<
"p = migraphx.program()
\n
"
;
for
(
auto
&
mod
:
vec_modules
)
{
std
::
string
var_name
=
"m"
+
mod
->
name
();
std
::
string
var_name
=
"m"
;
if
(
mod
->
name
()
!=
"main"
)
var_name
+=
mod
->
name
();
os
<<
var_name
<<
" = "
;
if
(
mod
->
name
()
==
"main"
)
os
<<
"p.get_main_module()"
;
...
...
@@ -894,10 +1041,10 @@ void program::print_cpp(std::ostream& os) const
void
program
::
dry_run
(
std
::
unordered_map
<
std
::
string
,
argument
>
params
)
const
{
auto
&
ctx
=
this
->
impl
->
c
tx
;
generic_eval
(
*
this
,
ctx
,
std
::
move
(
params
),
always
(
[](
auto
ins
,
auto
&&
...)
{
auto
&
ctx
=
this
->
impl
->
c
ontexts
;
generic_eval
(
*
this
,
ctx
,
std
::
move
(
params
),
[](
auto
ins
,
auto
&&
...)
{
return
argument
{
ins
->
get_shape
(),
nullptr
};
})
)
;
});
}
void
program
::
annotate
(
std
::
ostream
&
os
,
const
std
::
function
<
void
(
instruction_ref
)
>&
a
)
const
...
...
@@ -1039,17 +1186,25 @@ void program::remove_unused_modules()
std
::
vector
<
module
*>
unused
;
generic_get_unused_modules
(
impl
->
modules
,
generic_get_modules
(
this
->
get_main_module
()),
std
::
back_inserter
(
unused
));
for
(
auto
*
m
:
unused
)
for
(
const
auto
*
m
:
unused
)
this
->
remove_module
(
m
->
name
());
}
program
&
program
::
sort
()
{
for
(
auto
&
pp
:
this
->
impl
->
modules
)
std
::
queue
<
migraphx
::
module_ref
>
mqueue
;
mqueue
.
push
(
get_main_module
());
while
(
not
mqueue
.
empty
())
{
pp
.
second
.
sort
();
module_ref
current_mod
=
mqueue
.
front
();
current_mod
->
sort
();
mqueue
.
pop
();
auto
child_mods
=
current_mod
->
get_sub_modules
(
true
);
for
(
auto
&
sub_mod
:
child_mods
)
{
mqueue
.
push
(
sub_mod
);
}
}
return
*
this
;
}
...
...
src/promote_literals.cpp
View file @
23cb7917
...
...
@@ -34,7 +34,7 @@ void promote_literals::apply(module_pass_manager& mpm) const
{
module
&
m
=
mpm
.
get_module
();
module_ref
root_module
=
mpm
.
get_root_module
();
if
(
m
.
name
()
==
"main"
)
if
(
m
==
*
root_module
)
return
;
for
(
auto
ins
:
iterator_for
(
m
))
...
...
src/py/CMakeLists.txt
View file @
23cb7917
...
...
@@ -23,14 +23,25 @@
#####################################################################################
option
(
MIGRAPHX_ENABLE_PYTHON
"Enable python bindings"
ON
)
add_library
(
migraphx_py py_loader.cpp
)
migraphx_generate_export_header
(
migraphx_py
)
target_include_directories
(
migraphx_py PRIVATE include
)
target_link_libraries
(
migraphx_py PUBLIC migraphx
)
rocm_install_targets
(
TARGETS migraphx_py INCLUDE include
)
if
(
MIGRAPHX_ENABLE_PYTHON
)
include
(
PythonModules
)
add_custom_target
(
migraphx_py
)
foreach
(
PYTHON_VERSION
${
PYTHON_VERSIONS
}
)
py_add_module
(
migraphx_py_
${
PYTHON_VERSION
}
migraphx_py.cpp PYTHON_VERSION
${
PYTHON_VERSION
}
PYTHON_MODULE migraphx
)
target_link_libraries
(
migraphx_py_
${
PYTHON_VERSION
}
PRIVATE migraphx migraphx_tf migraphx_onnx migraphx_all_targets
)
py_add_module
(
migraphx_pybind_
${
PYTHON_VERSION
}
migraphx_py.cpp PYTHON_VERSION
${
PYTHON_VERSION
}
PYTHON_MODULE migraphx
)
target_link_libraries
(
migraphx_pybind_
${
PYTHON_VERSION
}
PRIVATE migraphx migraphx_tf migraphx_onnx migraphx_all_targets
)
rocm_install_targets
(
TARGETS migraphx_pybind_
${
PYTHON_VERSION
}
)
add_dependencies
(
migraphx_py migraphx_pybind_
${
PYTHON_VERSION
}
)
add_library
(
migraphx_py_
${
PYTHON_VERSION
}
py.cpp
)
target_include_directories
(
migraphx_py_
${
PYTHON_VERSION
}
PRIVATE include
)
target_link_libraries
(
migraphx_py_
${
PYTHON_VERSION
}
PUBLIC migraphx
)
target_link_libraries
(
migraphx_py_
${
PYTHON_VERSION
}
PRIVATE pybind11::pybind11 python
${
PYTHON_VERSION
}
::runtime
)
rocm_install_targets
(
TARGETS migraphx_py_
${
PYTHON_VERSION
}
)
add_dependencies
(
migraphx_py migraphx_py_
${
PYTHON_VERSION
}
)
endforeach
()
...
...
src/py/include/migraphx/py.hpp
0 → 100644
View file @
23cb7917
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHX_PY_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_PY_HPP
#include <migraphx/config.hpp>
#include <migraphx/program.hpp>
#include <migraphx/py/export.h>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
MIGRAPHX_PY_EXPORT
program
load_py
(
const
std
::
string
&
filename
);
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_PY_HPP
src/py/migraphx_py.cpp
View file @
23cb7917
...
...
@@ -95,6 +95,10 @@ void visit_py(T x, F f)
{
f
(
x
.
template
cast
<
std
::
string
>());
}
else
if
(
py
::
isinstance
<
migraphx
::
shape
::
dynamic_dimension
>
(
x
))
{
f
(
migraphx
::
to_value
(
x
.
template
cast
<
migraphx
::
shape
::
dynamic_dimension
>()));
}
else
{
MIGRAPHX_THROW
(
"VISIT_PY: Unsupported data type!"
);
...
...
@@ -165,6 +169,9 @@ template <class T>
py
::
buffer_info
to_buffer_info
(
T
&
x
)
{
migraphx
::
shape
s
=
x
.
get_shape
();
assert
(
s
.
type
()
!=
migraphx
::
shape
::
tuple_type
);
if
(
s
.
dynamic
())
MIGRAPHX_THROW
(
"MIGRAPHX PYTHON: dynamic shape argument passed to to_buffer_info"
);
auto
strides
=
s
.
strides
();
std
::
transform
(
strides
.
begin
(),
strides
.
end
(),
strides
.
begin
(),
[
&
](
auto
i
)
{
return
i
*
s
.
type_size
();
});
...
...
@@ -177,7 +184,7 @@ py::buffer_info to_buffer_info(T& x)
b
=
py
::
buffer_info
(
x
.
data
(),
as
.
size
(),
py
::
format_descriptor
<
bool
>::
format
(),
s
.
lens
().
size
(),
s
.
ndim
(),
s
.
lens
(),
strides
);
}
...
...
@@ -186,7 +193,7 @@ py::buffer_info to_buffer_info(T& x)
b
=
py
::
buffer_info
(
x
.
data
(),
as
.
size
(),
py
::
format_descriptor
<
decltype
(
as
())
>::
format
(),
s
.
lens
().
size
(),
s
.
ndim
(),
s
.
lens
(),
strides
);
}
...
...
@@ -241,6 +248,13 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
.
def
(
py
::
init
([](
py
::
kwargs
kwargs
)
{
auto
v
=
migraphx
::
to_value
(
kwargs
);
auto
t
=
migraphx
::
shape
::
parse_type
(
v
.
get
(
"type"
,
"float"
));
if
(
v
.
contains
(
"dyn_dims"
))
{
auto
dyn_dims
=
migraphx
::
from_value
<
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>>
(
v
.
at
(
"dyn_dims"
));
return
migraphx
::
shape
(
t
,
dyn_dims
);
}
auto
lens
=
v
.
get
<
std
::
size_t
>
(
"lens"
,
{
1
});
if
(
v
.
contains
(
"strides"
))
return
migraphx
::
shape
(
t
,
lens
,
v
.
at
(
"strides"
).
to_vector
<
std
::
size_t
>
());
...
...
@@ -250,15 +264,18 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
.
def
(
"type"
,
&
migraphx
::
shape
::
type
)
.
def
(
"lens"
,
&
migraphx
::
shape
::
lens
)
.
def
(
"strides"
,
&
migraphx
::
shape
::
strides
)
.
def
(
"ndim"
,
&
migraphx
::
shape
::
ndim
)
.
def
(
"elements"
,
&
migraphx
::
shape
::
elements
)
.
def
(
"bytes"
,
&
migraphx
::
shape
::
bytes
)
.
def
(
"type_string"
,
&
migraphx
::
shape
::
type_string
)
.
def
(
"type_size"
,
&
migraphx
::
shape
::
type_size
)
.
def
(
"dyn_dims"
,
&
migraphx
::
shape
::
dyn_dims
)
.
def
(
"packed"
,
&
migraphx
::
shape
::
packed
)
.
def
(
"transposed"
,
&
migraphx
::
shape
::
transposed
)
.
def
(
"broadcasted"
,
&
migraphx
::
shape
::
broadcasted
)
.
def
(
"standard"
,
&
migraphx
::
shape
::
standard
)
.
def
(
"scalar"
,
&
migraphx
::
shape
::
scalar
)
.
def
(
"dynamic"
,
&
migraphx
::
shape
::
dynamic
)
.
def
(
"__eq__"
,
std
::
equal_to
<
migraphx
::
shape
>
{})
.
def
(
"__ne__"
,
std
::
not_equal_to
<
migraphx
::
shape
>
{})
.
def
(
"__repr__"
,
[](
const
migraphx
::
shape
&
s
)
{
return
migraphx
::
to_string
(
s
);
});
...
...
@@ -266,6 +283,15 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
py
::
enum_
<
migraphx
::
shape
::
type_t
>
(
shape_cls
,
"type_t"
)
MIGRAPHX_SHAPE_VISIT_TYPES
(
MIGRAPHX_PYTHON_GENERATE_SHAPE_ENUM
);
py
::
class_
<
migraphx
::
shape
::
dynamic_dimension
>
(
shape_cls
,
"dynamic_dimension"
)
.
def
(
py
::
init
<>
())
.
def
(
py
::
init
<
std
::
size_t
,
std
::
size_t
>
())
.
def
(
py
::
init
<
std
::
size_t
,
std
::
size_t
,
std
::
set
<
std
::
size_t
>>
())
.
def_readwrite
(
"min"
,
&
migraphx
::
shape
::
dynamic_dimension
::
min
)
.
def_readwrite
(
"max"
,
&
migraphx
::
shape
::
dynamic_dimension
::
max
)
.
def_readwrite
(
"optimals"
,
&
migraphx
::
shape
::
dynamic_dimension
::
optimals
)
.
def
(
"is_fixed"
,
&
migraphx
::
shape
::
dynamic_dimension
::
is_fixed
);
py
::
class_
<
migraphx
::
argument
>
(
m
,
"argument"
,
py
::
buffer_protocol
())
.
def_buffer
([](
migraphx
::
argument
&
x
)
->
py
::
buffer_info
{
return
to_buffer_info
(
x
);
})
.
def
(
py
::
init
([](
py
::
buffer
b
)
{
...
...
@@ -440,13 +466,18 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
"parse_onnx"
,
[](
const
std
::
string
&
filename
,
unsigned
int
default_dim_value
,
migraphx
::
shape
::
dynamic_dimension
default_dyn_dim_value
,
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
size_t
>>
map_input_dims
,
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>>
map_dyn_input_dims
,
bool
skip_unknown_operators
,
bool
print_program_on_error
,
int64_t
max_loop_iterations
)
{
migraphx
::
onnx_options
options
;
options
.
default_dim_value
=
default_dim_value
;
options
.
default_dyn_dim_value
=
default_dyn_dim_value
;
options
.
map_input_dims
=
map_input_dims
;
options
.
map_dyn_input_dims
=
map_dyn_input_dims
;
options
.
skip_unknown_operators
=
skip_unknown_operators
;
options
.
print_program_on_error
=
print_program_on_error
;
options
.
max_loop_iterations
=
max_loop_iterations
;
...
...
@@ -454,8 +485,11 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
},
"Parse onnx file"
,
py
::
arg
(
"filename"
),
py
::
arg
(
"default_dim_value"
)
=
1
,
py
::
arg
(
"default_dim_value"
)
=
0
,
py
::
arg
(
"default_dyn_dim_value"
)
=
migraphx
::
shape
::
dynamic_dimension
{
1
,
1
},
py
::
arg
(
"map_input_dims"
)
=
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
size_t
>>
(),
py
::
arg
(
"map_dyn_input_dims"
)
=
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>>
(),
py
::
arg
(
"skip_unknown_operators"
)
=
false
,
py
::
arg
(
"print_program_on_error"
)
=
false
,
py
::
arg
(
"max_loop_iterations"
)
=
10
);
...
...
@@ -464,20 +498,28 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
"parse_onnx_buffer"
,
[](
const
std
::
string
&
onnx_buffer
,
unsigned
int
default_dim_value
,
migraphx
::
shape
::
dynamic_dimension
default_dyn_dim_value
,
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
size_t
>>
map_input_dims
,
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>>
map_dyn_input_dims
,
bool
skip_unknown_operators
,
bool
print_program_on_error
)
{
migraphx
::
onnx_options
options
;
options
.
default_dim_value
=
default_dim_value
;
options
.
default_dyn_dim_value
=
default_dyn_dim_value
;
options
.
map_input_dims
=
map_input_dims
;
options
.
map_dyn_input_dims
=
map_dyn_input_dims
;
options
.
skip_unknown_operators
=
skip_unknown_operators
;
options
.
print_program_on_error
=
print_program_on_error
;
return
migraphx
::
parse_onnx_buffer
(
onnx_buffer
,
options
);
},
"Parse onnx file"
,
py
::
arg
(
"filename"
),
py
::
arg
(
"default_dim_value"
)
=
1
,
py
::
arg
(
"default_dim_value"
)
=
0
,
py
::
arg
(
"default_dyn_dim_value"
)
=
migraphx
::
shape
::
dynamic_dimension
{
1
,
1
},
py
::
arg
(
"map_input_dims"
)
=
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
size_t
>>
(),
py
::
arg
(
"map_dyn_input_dims"
)
=
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>>
(),
py
::
arg
(
"skip_unknown_operators"
)
=
false
,
py
::
arg
(
"print_program_on_error"
)
=
false
);
...
...
@@ -505,6 +547,13 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
py
::
arg
(
"format"
)
=
"msgpack"
);
m
.
def
(
"get_target"
,
&
migraphx
::
make_target
);
m
.
def
(
"create_argument"
,
[](
const
migraphx
::
shape
&
s
,
const
std
::
vector
<
double
>&
values
)
{
if
(
values
.
size
()
!=
s
.
elements
())
MIGRAPHX_THROW
(
"Values and shape elements do not match"
);
migraphx
::
argument
a
{
s
};
a
.
fill
(
values
.
begin
(),
values
.
end
());
return
a
;
});
m
.
def
(
"generate_argument"
,
&
migraphx
::
generate_argument
,
py
::
arg
(
"s"
),
py
::
arg
(
"seed"
)
=
0
);
m
.
def
(
"fill_argument"
,
&
migraphx
::
fill_argument
,
py
::
arg
(
"s"
),
py
::
arg
(
"value"
));
m
.
def
(
"quantize_fp16"
,
...
...
src/py/py.cpp
0 → 100644
View file @
23cb7917
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/config.hpp>
#include <migraphx/program.hpp>
#include <migraphx/dynamic_loader.hpp>
#include <migraphx/file_buffer.hpp>
#include <pybind11/embed.h>
namespace
py
=
pybind11
;
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
#ifdef __clang__
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wreturn-type-c-linkage"
#endif
// extern "C" is used to disable name mangling, but the function will still be called from C++
extern
"C"
program
migraphx_load_py
(
const
std
::
string
&
filename
);
#ifdef __clang__
#pragma clang diagnostic pop
#endif
const
std
::
string
&
python_path
()
{
static
const
auto
path
=
dynamic_loader
::
path
(
&
migraphx_load_py
).
parent_path
().
string
();
return
path
;
}
static
py
::
dict
run_file
(
const
std
::
string
&
file
)
{
py
::
object
scope
=
py
::
module_
::
import
(
"__main__"
).
attr
(
"__dict__"
);
std
::
string
buffer
;
buffer
.
append
(
"import sys
\n
"
);
buffer
.
append
(
"sys.path.insert(0, '"
+
python_path
()
+
"')
\n
"
);
buffer
.
append
(
"import migraphx
\n
"
);
buffer
.
append
(
read_string
(
file
));
py
::
exec
(
buffer
,
scope
);
return
scope
.
cast
<
py
::
dict
>
();
}
extern
"C"
program
migraphx_load_py
(
const
std
::
string
&
filename
)
{
py
::
scoped_interpreter
guard
{};
py
::
dict
vars
=
run_file
(
filename
);
auto
it
=
std
::
find_if
(
vars
.
begin
(),
vars
.
end
(),
[](
const
auto
&
p
)
{
return
py
::
isinstance
<
migraphx
::
program
>
(
p
.
second
);
});
if
(
it
==
vars
.
end
())
MIGRAPHX_THROW
(
"No program variable found"
);
return
it
->
second
.
cast
<
migraphx
::
program
>
();
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/py/py_loader.cpp
0 → 100644
View file @
23cb7917
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/py.hpp>
#include <migraphx/dynamic_loader.hpp>
#include <migraphx/process.hpp>
#include <migraphx/ranges.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
static
std
::
vector
<
fs
::
path
>
find_available_python_versions
()
{
std
::
vector
<
fs
::
path
>
result
;
auto
path
=
dynamic_loader
::
path
(
&
load_py
).
parent_path
();
for
(
const
auto
&
entry
:
fs
::
directory_iterator
{
path
})
{
auto
p
=
entry
.
path
();
if
(
not
fs
::
is_regular_file
(
p
))
continue
;
if
(
not
contains
(
p
.
stem
().
string
(),
"migraphx_py_"
))
continue
;
result
.
push_back
(
p
);
}
std
::
sort
(
result
.
begin
(),
result
.
end
(),
std
::
greater
<>
{});
return
result
;
}
static
dynamic_loader
load_py_lib
()
{
auto
libs
=
find_available_python_versions
();
for
(
const
auto
&
lib
:
libs
)
{
auto
result
=
dynamic_loader
::
try_load
(
lib
);
if
(
result
.
has_value
())
return
*
result
;
}
MIGRAPHX_THROW
(
"Cant find a viable version of python"
);
}
static
dynamic_loader
py_lib
()
{
static
dynamic_loader
lib
=
load_py_lib
();
return
lib
;
}
MIGRAPHX_PY_EXPORT
program
load_py
(
const
std
::
string
&
filename
)
{
static
auto
f
=
py_lib
().
get_function
<
program
(
const
std
::
string
&
)
>
(
"migraphx_load_py"
);
return
f
(
filename
);
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/quantization.cpp
View file @
23cb7917
...
...
@@ -29,6 +29,7 @@
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/simplify_qdq.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/optimize_module.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
...
...
@@ -48,19 +49,12 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_INT8_QUANTIZATION_PARAMS)
// This function is to convert any instructions specified in the input
// from double or float to float16 by inserting a convert operator.
// For the conversion, there could be cases of overflowing, but it
// is
very rare in the area of deeping learning, so we just do a
//
truncate of the input to get the fp16
.
// For the conversion, there could be cases of overflowing
or underflowing
, but it
// is
uncommon. Run optimize_module() before converting to fp16 to const eval and fold in FP32 to
//
avoid loss of precision
.
void
quantize_fp16
(
program
&
prog
,
const
std
::
vector
<
std
::
string
>&
ins_names
)
{
run_passes
(
prog
,
{
quantize_fp16_pass
{
ins_names
},
eliminate_common_subexpression
{},
dead_code_elimination
{},
simplify_reshapes
{},
dead_code_elimination
{},
simplify_qdq
{},
dead_code_elimination
{}});
run_passes
(
prog
,
{
optimize_module
{},
quantize_fp16_pass
{
ins_names
},
optimize_module
{}});
}
void
quantize_int8
(
program
&
prog
,
...
...
src/quantize_fp16.cpp
View file @
23cb7917
...
...
@@ -52,14 +52,6 @@ static void quantize_module(module& m, const std::vector<std::string>& ins_names
auto
mod_inputs
=
ins
->
module_inputs
();
auto
s
=
ins
->
get_shape
();
// Convert back to original type before quantizing the inputs
if
(
mod_inputs
.
empty
())
{
auto
r
=
m
.
insert_instruction
(
std
::
next
(
ins
),
make_op
(
"convert"
,
{{
"target_type"
,
s
.
type
()}}),
ins
);
m
.
replace_instruction
(
ins
,
r
);
}
// Convert each of the inputs that are floating point to fp16
auto
inputs
=
ins
->
inputs
();
std
::
transform
(
inputs
.
begin
(),
inputs
.
end
(),
inputs
.
begin
(),
[
&
](
auto
input
)
{
...
...
@@ -70,8 +62,17 @@ static void quantize_module(module& m, const std::vector<std::string>& ins_names
ins
,
make_op
(
"convert"
,
{{
"target_type"
,
shape
::
half_type
}}),
input
);
});
// Replace inputs
m
.
replace_instruction
(
ins
,
ins
->
get_operator
(),
inputs
,
mod_inputs
);
// Insert quantized ins
auto
converted_ins
=
m
.
insert_instruction
(
ins
,
ins
->
get_operator
(),
inputs
,
mod_inputs
);
// Convert back to original type after quantizing
if
(
mod_inputs
.
empty
())
{
converted_ins
=
m
.
insert_instruction
(
ins
,
make_op
(
"convert"
,
{{
"target_type"
,
s
.
type
()}}),
converted_ins
);
}
// Replace original instruction
m
.
replace_instruction
(
ins
,
converted_ins
);
}
}
...
...
src/replace_allocate.cpp
View file @
23cb7917
...
...
@@ -21,6 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/pass_manager.hpp>
#include <migraphx/replace_allocate.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/program.hpp>
...
...
@@ -84,10 +85,11 @@ void insert_submod_allocations(instruction_ref ins, module& mod, const allocatio
mod
.
replace_instruction
(
ins
,
ins
->
get_operator
(),
inputs
,
mod_args
);
}
void
replace_allocate
::
apply
(
module
&
m
)
const
void
replace_allocate
::
apply
(
module
_pass_manager
&
mp
m
)
const
{
module
&
m
=
mpm
.
get_module
();
auto
mod_output_names
=
create_output_names
(
m
);
bool
main
_offload_copy
=
m
.
nam
e
()
==
"main"
?
this
->
offload_copy
:
false
;
bool
root
_offload_copy
=
(
*
mpm
.
get_root_modul
e
()
==
m
)
?
this
->
offload_copy
:
false
;
for
(
auto
ins
:
iterator_for
(
m
))
{
auto
op
=
ins
->
get_operator
();
...
...
@@ -104,7 +106,7 @@ void replace_allocate::apply(module& m) const
continue
;
auto
s
=
ins
->
get_shape
();
if
(
not
main
_offload_copy
and
model
.
needs_out_params
()
and
contains
(
mod_output_names
,
ins
))
if
(
not
root
_offload_copy
and
model
.
needs_out_params
()
and
contains
(
mod_output_names
,
ins
))
{
auto
out_param
=
m
.
add_parameter
(
mod_output_names
[
ins
],
s
);
m
.
replace_instruction
(
ins
,
out_param
);
...
...
src/rewrite_quantization.cpp
View file @
23cb7917
...
...
@@ -28,6 +28,7 @@
#include <migraphx/tune_axis.hpp>
#include <migraphx/program.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/common.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -40,15 +41,18 @@ void apply_quantizelinear(module& m, instruction_ref ins)
if
(
x
->
get_shape
().
type
()
!=
y_scale
->
get_shape
().
type
())
{
x
=
m
.
insert_instruction
(
ins
,
make_op
(
"convert"
,
{{
"target_type"
,
shape
::
float_type
}}),
x
);
x
=
m
.
insert_instruction
(
ins
,
make_op
(
"convert"
,
{{
"target_type"
,
y_scale
->
get_shape
().
type
()}}),
x
);
}
auto
div
=
m
.
insert_instruction
(
ins
,
make_op
(
"div"
),
x
,
y_scale
);
auto
add_zero_point
=
m
.
insert_instruction
(
ins
,
make_op
(
"round"
),
div
);
if
(
ins
->
inputs
().
size
()
==
3
)
{
auto
zero_point
=
m
.
insert_instruction
(
ins
,
make_op
(
"convert"
,
{{
"target_type"
,
shape
::
float_type
}}),
ins
->
inputs
()[
2
]);
auto
zero_point
=
m
.
insert_instruction
(
ins
,
make_op
(
"convert"
,
{{
"target_type"
,
y_scale
->
get_shape
().
type
()}}),
ins
->
inputs
()[
2
]);
add_zero_point
=
m
.
insert_instruction
(
ins
,
make_op
(
"add"
),
add_zero_point
,
zero_point
);
}
...
...
@@ -59,12 +63,9 @@ void apply_quantizelinear(module& m, instruction_ref ins)
min_quant
=
qt
.
min
();
});
auto
s
=
add_zero_point
->
get_shape
();
std
::
vector
<
int
>
min_data
(
s
.
elements
(),
min_quant
);
std
::
vector
<
int
>
max_data
(
s
.
elements
(),
max_quant
);
auto
min_arg
=
m
.
add_literal
(
literal
(
s
,
min_data
));
auto
max_arg
=
m
.
add_literal
(
literal
(
s
,
max_data
));
auto
saturate
=
m
.
insert_instruction
(
ins
,
make_op
(
"clip"
),
add_zero_point
,
min_arg
,
max_arg
);
auto
min_arg
=
m
.
add_literal
(
literal
{
shape
{
s
.
type
()},
{
min_quant
}});
auto
max_arg
=
m
.
add_literal
(
literal
{
shape
{
s
.
type
()},
{
max_quant
}});
auto
saturate
=
insert_common_op
(
m
,
ins
,
make_op
(
"clip"
),
{
add_zero_point
,
min_arg
,
max_arg
});
m
.
replace_instruction
(
ins
,
make_op
(
"convert"
,
{{
"target_type"
,
ins
->
get_shape
().
type
()}}),
saturate
);
}
...
...
@@ -72,14 +73,16 @@ void apply_quantizelinear(module& m, instruction_ref ins)
void
apply_dequantizelinear
(
module
&
m
,
instruction_ref
ins
)
{
assert
(
ins
->
name
()
==
"dequantizelinear"
);
auto
x
=
m
.
insert_instruction
(
ins
,
make_op
(
"convert"
,
{{
"target_type"
,
shape
::
float_type
}}),
ins
->
inputs
()[
0
]);
auto
x_scale
=
ins
->
inputs
()[
1
];
auto
x
=
m
.
insert_instruction
(
ins
,
make_op
(
"convert"
,
{{
"target_type"
,
x_scale
->
get_shape
().
type
()}}),
ins
->
inputs
()[
0
]);
if
(
ins
->
inputs
().
size
()
==
3
)
{
auto
x_zero_point
=
m
.
insert_instruction
(
ins
,
make_op
(
"convert"
,
{{
"target_type"
,
shape
::
float_type
}}),
ins
->
inputs
()[
2
]);
auto
x_zero_point
=
m
.
insert_instruction
(
ins
,
make_op
(
"convert"
,
{{
"target_type"
,
x_scale
->
get_shape
().
type
()}}),
ins
->
inputs
()[
2
]);
x
=
m
.
insert_instruction
(
ins
,
make_op
(
"sub"
),
x
,
x_zero_point
);
}
...
...
src/shape.cpp
View file @
23cb7917
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
...
...
@@ -273,9 +273,23 @@ shape shape::from_permutation(type_t t,
shape
::
type_t
shape
::
type
()
const
{
return
impl
->
m_type
;
}
const
std
::
vector
<
std
::
size_t
>&
shape
::
lens
()
const
{
return
impl
->
m_lens
;
}
const
std
::
vector
<
std
::
size_t
>&
shape
::
lens
()
const
{
if
(
this
->
dynamic
())
{
MIGRAPHX_THROW
(
"SHAPE: lens() called on a dynamic shape"
);
}
return
impl
->
m_lens
;
}
const
std
::
vector
<
std
::
size_t
>&
shape
::
strides
()
const
{
return
impl
->
m_strides
;
}
const
std
::
vector
<
std
::
size_t
>&
shape
::
strides
()
const
{
if
(
this
->
dynamic
())
{
MIGRAPHX_THROW
(
"SHAPE: strides() called on a dynamic shape"
);
}
return
impl
->
m_strides
;
}
std
::
size_t
shape
::
ndim
()
const
{
...
...
@@ -535,7 +549,14 @@ bool shape::any_of_dynamic() const
});
}
const
std
::
vector
<
shape
::
dynamic_dimension
>&
shape
::
dyn_dims
()
const
{
return
impl
->
m_dyn_dims
;
}
const
std
::
vector
<
shape
::
dynamic_dimension
>&
shape
::
dyn_dims
()
const
{
if
(
not
this
->
dynamic
())
{
MIGRAPHX_THROW
(
"SHAPE: dyn_dims() called on a static shape"
);
}
return
impl
->
m_dyn_dims
;
}
std
::
vector
<
std
::
size_t
>
shape
::
min_lens
()
const
{
...
...
@@ -680,10 +701,20 @@ void migraphx_to_value(value& v, const shape& s)
{
value
result
;
result
[
"type"
]
=
migraphx
::
to_value
(
s
.
type_string
());
result
[
"lens"
]
=
migraphx
::
to_value
(
s
.
lens
());
result
[
"strides"
]
=
migraphx
::
to_value
(
s
.
strides
());
result
[
"sub_shapes"
]
=
migraphx
::
to_value
(
s
.
sub_shapes
());
// avoid calling functions that will throw
if
(
s
.
dynamic
())
{
result
[
"lens"
]
=
{};
result
[
"strides"
]
=
{};
result
[
"dynamic_dimensions"
]
=
migraphx
::
to_value
(
s
.
dyn_dims
());
}
else
{
result
[
"lens"
]
=
migraphx
::
to_value
(
s
.
lens
());
result
[
"strides"
]
=
migraphx
::
to_value
(
s
.
strides
());
result
[
"dynamic_dimensions"
]
=
{};
}
v
=
result
;
}
...
...
@@ -706,13 +737,9 @@ void migraphx_from_value(const value& v, shape& s)
{
auto
v_dd
=
v
.
at
(
"dynamic_dimensions"
);
std
::
vector
<
shape
::
dynamic_dimension
>
dyn_dims
(
v
.
at
(
"dynamic_dimensions"
).
size
());
std
::
transform
(
v_dd
.
begin
(),
v_dd
.
end
(),
dyn_dims
.
begin
(),
[](
migraphx
::
value
x
)
{
auto
x_min
=
x
.
at
(
"min"
).
template
to
<
size_t
>();
auto
x_max
=
x
.
at
(
"max"
).
template
to
<
size_t
>();
auto
v_optimals
=
x
.
at
(
"optimals"
);
std
::
set
<
size_t
>
set_x_optimals
=
from_value
<
std
::
set
<
std
::
size_t
>>
(
x
.
at
(
"optimals"
));
return
shape
::
dynamic_dimension
{
x_min
,
x_max
,
set_x_optimals
};
std
::
transform
(
v_dd
.
begin
(),
v_dd
.
end
(),
dyn_dims
.
begin
(),
[](
const
migraphx
::
value
&
x
)
{
return
from_value
<
shape
::
dynamic_dimension
>
(
x
);
});
s
=
shape
{
shape
::
parse_type
(
t
),
dyn_dims
};
...
...
src/simplify_algebra.cpp
View file @
23cb7917
...
...
@@ -204,6 +204,131 @@ struct find_mul_slice_conv
}
};
struct
find_mul_dot
{
auto
matcher
()
const
{
auto
is_dot_const_inputs
=
match
::
name
(
"dot"
)(
match
::
any_of
[
match
::
inputs
()](
match
::
is_constant
()));
return
match
::
name
(
"mul"
)(
match
::
either_arg
(
0
,
1
)(
is_dot_const_inputs
.
bind
(
"dot"
),
match
::
name
(
"broadcast"
,
"multibroadcast"
).
bind
(
"c"
)));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
dot_ins
=
r
.
instructions
[
"dot"
];
auto
a_ins
=
dot_ins
->
inputs
()[
0
];
auto
b_ins
=
dot_ins
->
inputs
()[
1
];
auto
c_ins
=
r
.
instructions
[
"c"
];
const
auto
&
c_strides
=
c_ins
->
get_shape
().
strides
();
// There should only be one stride that is not zero
if
(
std
::
count_if
(
c_strides
.
begin
(),
c_strides
.
end
(),
[](
auto
s
)
{
return
s
!=
0
;
})
>
1
)
return
;
auto
add_mul_const
=
[
&
](
instruction_ref
x_ins
)
{
if
(
not
x_ins
->
can_eval
())
return
m
.
end
();
auto
broadcast_v
=
c_ins
->
get_operator
().
to_value
();
broadcast_v
[
"out_lens"
]
=
x_ins
->
get_shape
().
lens
();
auto
cb_ins
=
m
.
insert_instruction
(
ins
,
make_op
(
c_ins
->
name
(),
broadcast_v
),
c_ins
->
inputs
());
return
m
.
insert_instruction
(
ins
,
make_op
(
"mul"
),
x_ins
,
cb_ins
);
};
if
(
c_strides
.
back
()
==
1
)
{
b_ins
=
add_mul_const
(
b_ins
);
}
else
if
(
c_strides
[
c_strides
.
size
()
-
2
]
==
1
)
{
a_ins
=
add_mul_const
(
a_ins
);
}
else
if
(
c_ins
->
get_shape
().
scalar
())
{
if
(
a_ins
->
can_eval
())
a_ins
=
add_mul_const
(
a_ins
);
else
b_ins
=
add_mul_const
(
b_ins
);
}
else
{
return
;
}
if
(
contains
({
a_ins
,
b_ins
},
m
.
end
()))
return
;
m
.
replace_instruction
(
ins
,
make_op
(
"dot"
),
a_ins
,
b_ins
);
}
};
struct
find_dot_mul
{
auto
matcher
()
const
{
auto
const_broadcast
=
match
::
name
(
"broadcast"
,
"multibroadcast"
)(
match
::
is_constant
());
auto
mul
=
match
::
name
(
"mul"
)(
match
::
used_once
(),
match
::
either_arg
(
0
,
1
)(
const_broadcast
.
bind
(
"d"
),
match
::
none_of
(
match
::
is_constant
()).
bind
(
"z"
)));
return
match
::
name
(
"dot"
)(
match
::
either_arg
(
0
,
1
)(
mul
,
match
::
is_constant
().
bind
(
"c"
)));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
a_ins
=
ins
->
inputs
()[
0
];
auto
b_ins
=
ins
->
inputs
()[
1
];
auto
d_ins
=
r
.
instructions
[
"d"
];
auto
c_ins
=
r
.
instructions
[
"c"
];
auto
z_ins
=
r
.
instructions
[
"z"
];
const
auto
&
d_strides
=
d_ins
->
get_shape
().
strides
();
// There should only be one stride that is not zero
if
(
std
::
count_if
(
d_strides
.
begin
(),
d_strides
.
end
(),
[](
auto
s
)
{
return
s
!=
0
;
})
>
1
)
return
;
if
(
not
d_ins
->
get_shape
().
scalar
())
{
if
(
d_strides
.
back
()
==
1
and
not
b_ins
->
can_eval
())
return
;
if
(
d_strides
[
d_strides
.
size
()
-
2
]
==
1
and
not
a_ins
->
can_eval
())
return
;
}
auto
broadcast_v
=
d_ins
->
get_operator
().
to_value
();
auto
c_lens
=
c_ins
->
get_shape
().
lens
();
std
::
vector
<
int64_t
>
permutation
(
c_lens
.
size
());
std
::
iota
(
permutation
.
begin
(),
permutation
.
end
(),
0
);
std
::
swap
(
permutation
.
back
(),
permutation
[
permutation
.
size
()
-
2
]);
c_lens
=
reorder_dims
(
c_lens
,
permutation
);
broadcast_v
[
"out_lens"
]
=
c_lens
;
auto
db_ins
=
m
.
insert_instruction
(
ins
,
make_op
(
d_ins
->
name
(),
broadcast_v
),
d_ins
->
inputs
());
auto
db_transpose_ins
=
m
.
insert_instruction
(
ins
,
make_op
(
"transpose"
,
{{
"permutation"
,
permutation
}}),
db_ins
);
auto
cd_ins
=
m
.
insert_instruction
(
ins
,
make_op
(
"mul"
),
c_ins
,
db_transpose_ins
);
if
(
c_ins
==
b_ins
)
{
a_ins
=
z_ins
;
b_ins
=
cd_ins
;
}
else
{
a_ins
=
cd_ins
;
b_ins
=
z_ins
;
}
m
.
replace_instruction
(
ins
,
make_op
(
"dot"
),
a_ins
,
b_ins
);
}
};
// ******************************
// a * (x + b) => a * x + a * b
// ******************************
...
...
@@ -361,30 +486,123 @@ struct find_inner_broadcast
{
auto
matcher
()
const
{
return
pointwise
(
match
::
all_of
[
match
::
inputs
()](
match
::
broadcast
()));
}
static
auto
non_scalar_op
(
const
std
::
string
&
name
)
{
return
[
=
](
instruction_ref
ins
)
{
if
(
ins
->
get_shape
().
scalar
())
return
false
;
return
ins
->
name
()
==
name
;
};
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
broadcasts
=
ins
->
inputs
();
if
(
broadcasts
.
empty
())
return
;
// Skip if different data types are used
if
(
any_of
(
broadcasts
,
[
&
](
auto
i
)
{
return
i
->
get_shape
().
type
()
!=
broadcasts
.
front
()
->
get_shape
().
type
();
}))
return
;
bool
mixed_broadcasts
=
any_of
(
broadcasts
,
non_scalar_op
(
"broadcast"
))
and
any_of
(
broadcasts
,
non_scalar_op
(
"multibroadcast"
));
// If the broadcast is not a single dimension, then dont perform inner_broadcast
if
(
mixed_broadcasts
and
any_of
(
broadcasts
,
[
&
](
instruction_ref
i
)
{
if
(
i
->
get_shape
().
scalar
())
return
false
;
if
(
i
->
name
()
==
"multibroadcast"
)
return
false
;
auto
input
=
i
->
inputs
().
at
(
0
);
const
auto
&
lens
=
input
->
get_shape
().
lens
();
return
std
::
count_if
(
lens
.
begin
(),
lens
.
end
(),
[
&
](
std
::
size_t
d
)
{
return
d
==
1
;
})
<
(
lens
.
size
()
-
1
);
}))
return
;
std
::
vector
<
instruction_ref
>
inputs
;
std
::
transform
(
broadcasts
.
begin
(),
broadcasts
.
end
(),
std
::
back_inserter
(
inputs
),
[](
auto
i
)
{
return
i
->
inputs
().
front
();
});
if
(
std
::
any_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
&
](
auto
i
)
{
return
i
->
get_shape
()
!=
inputs
.
front
()
->
get_shape
()
and
i
->
get_shape
().
elements
()
!=
1
;
}))
return
;
auto
b_it
=
std
::
find_if
(
broadcasts
.
begin
(),
broadcasts
.
end
(),
[
&
](
auto
i
)
{
return
not
i
->
get_shape
().
scalar
();
[
&
](
instruction_ref
i
)
{
auto
input
=
i
->
inputs
().
front
();
if
(
mixed_broadcasts
and
not
i
->
get_shape
().
scalar
()
and
i
->
get_shape
().
lens
().
size
()
>
1
)
return
m
.
insert_instruction
(
i
,
make_op
(
"squeeze"
),
input
);
return
input
;
});
if
(
b_it
==
broadcasts
.
end
())
b_it
=
broadcasts
.
begin
();
std
::
sort
(
broadcasts
.
begin
(),
broadcasts
.
end
(),
by
(
std
::
less
<>
{},
[](
instruction_ref
i
)
{
if
(
i
->
get_shape
().
scalar
())
return
2
;
else
if
(
i
->
name
()
==
"broadcast"
)
return
0
;
if
(
i
->
name
()
==
"multibroadcast"
)
return
1
;
return
3
;
}));
auto
op
=
insert_common_op
(
m
,
ins
,
ins
->
get_operator
(),
inputs
);
m
.
replace_instruction
(
ins
,
(
*
b_it
)
->
get_operator
(),
op
);
m
.
replace_instruction
(
ins
,
broadcasts
.
front
()
->
get_operator
(),
op
);
}
};
struct
find_dot_broadcast
{
auto
matcher
()
const
{
return
match
::
name
(
"dot"
)(
match
::
all_of
[
match
::
inputs
()](
match
::
broadcast
()));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
a
=
ins
->
inputs
()[
0
];
auto
b
=
ins
->
inputs
()[
1
];
if
(
a
->
get_operator
().
name
()
!=
b
->
get_operator
().
name
())
return
;
if
(
ins
->
get_shape
().
lens
().
size
()
<
3
)
return
;
auto
nbatch_axes
=
ins
->
get_shape
().
lens
().
size
()
-
2
;
const
auto
&
a_strides
=
a
->
get_shape
().
strides
();
const
auto
&
b_strides
=
b
->
get_shape
().
strides
();
// Find leading batch axes that are broadcasted
auto
p
=
std
::
mismatch
(
a_strides
.
begin
(),
a_strides
.
begin
()
+
nbatch_axes
,
b_strides
.
begin
(),
b_strides
.
begin
()
+
nbatch_axes
,
[](
auto
astride
,
auto
bstride
)
{
return
astride
==
0
and
bstride
==
0
;
});
auto
naxes
=
p
.
first
-
a_strides
.
begin
();
assert
(
naxes
<=
nbatch_axes
);
std
::
vector
<
std
::
size_t
>
axes
(
naxes
);
std
::
iota
(
axes
.
begin
(),
axes
.
end
(),
0
);
auto
insert_broadcast
=
[
&
](
instruction_ref
b_ins
)
->
instruction_ref
{
auto
input
=
b_ins
->
inputs
()[
0
];
std
::
vector
<
std
::
size_t
>
lens
(
b_ins
->
get_shape
().
lens
().
begin
()
+
naxes
,
b_ins
->
get_shape
().
lens
().
end
());
if
(
b_ins
->
name
()
==
"multibroadcast"
)
{
return
m
.
insert_instruction
(
ins
,
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
lens
}}),
input
);
}
else
if
(
b_ins
->
name
()
==
"broadcast"
)
{
auto
v
=
b_ins
->
get_operator
().
to_value
();
auto
axis
=
v
.
at
(
"axis"
).
to
<
std
::
size_t
>
()
-
naxes
;
return
m
.
insert_instruction
(
ins
,
make_op
(
"broadcast"
,
{{
"axis"
,
axis
},
{
"out_lens"
,
lens
}}),
input
);
}
assert
(
false
);
return
m
.
end
();
};
auto
a1
=
insert_broadcast
(
a
);
auto
b1
=
insert_broadcast
(
b
);
auto
dot
=
m
.
insert_instruction
(
ins
,
make_op
(
"dot"
),
a1
,
b1
);
auto
broadcast
=
m
.
insert_instruction
(
ins
,
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
ins
->
get_shape
().
lens
()}}),
dot
);
m
.
replace_instruction
(
ins
,
broadcast
);
}
};
...
...
@@ -393,7 +611,8 @@ struct find_concat_op
auto
matcher
()
const
{
return
match
::
name
(
"concat"
)(
match
::
any_of
[
match
::
inputs
()](
match
::
any_of
(
match
::
pointwise
(),
match
::
name
(
"broadcast"
)),
match
::
used_once
()));
match
::
any_of
(
match
::
pointwise
(),
match
::
name
(
"broadcast"
,
"multibroadcast"
)),
match
::
used_once
()));
}
template
<
class
Iterator
>
...
...
@@ -412,7 +631,8 @@ struct find_concat_op
static
bool
is_valid_op
(
const
operation
&
op
)
{
return
op
.
name
()
==
"broadcast"
or
op
.
attributes
().
contains
(
"pointwise"
);
return
contains
({
"broadcast"
,
"multibroadcast"
},
op
.
name
())
or
op
.
attributes
().
contains
(
"pointwise"
);
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
...
...
@@ -440,6 +660,16 @@ struct find_concat_op
op
=
b
;
iaxis
=
0
;
}
else
if
(
op
.
name
()
==
"multibroadcast"
)
{
shape
bshape
=
(
*
start
)
->
get_shape
();
auto
input
=
(
*
start
)
->
inputs
()[
0
];
if
(
iaxis
>=
bshape
.
strides
().
size
()
or
bshape
.
strides
()[
iaxis
]
==
0
)
return
{
start
,
last
};
op
.
from_value
({{
"out_lens"
,
get_output_lens
(
start
,
last
,
iaxis
)}});
auto
delta
=
bshape
.
lens
().
size
()
-
input
->
get_shape
().
lens
().
size
();
iaxis
-=
delta
;
}
std
::
vector
<
instruction_ref
>
concats
;
for
(
std
::
size_t
i
=
0
;
i
<
x
->
inputs
().
size
();
i
++
)
...
...
@@ -865,8 +1095,9 @@ MIGRAPHX_PRED_MATCHER(horiz_conv_dot, instruction_ref ins)
};
};
auto
dots
=
std
::
count_if
(
ins
->
outputs
().
begin
(),
ins
->
outputs
().
end
(),
pred
(
"dot"
));
auto
qdots
=
std
::
count_if
(
ins
->
outputs
().
begin
(),
ins
->
outputs
().
end
(),
pred
(
"quant_dot"
));
auto
convs
=
std
::
count_if
(
ins
->
outputs
().
begin
(),
ins
->
outputs
().
end
(),
pred
(
"convolution"
));
return
(
dots
>=
2
or
convs
>=
2
);
return
(
dots
>=
2
or
convs
>=
2
or
qdots
>=
2
);
}
struct
find_conv_dot_horiz_fusion
...
...
@@ -880,7 +1111,7 @@ struct find_conv_dot_horiz_fusion
auto
pred
=
[](
auto
i
,
auto
j
)
{
if
(
i
->
get_operator
()
!=
j
->
get_operator
())
return
false
;
if
(
not
contains
({
"dot"
,
"convolution"
},
i
->
name
()))
if
(
not
contains
({
"quant_dot"
,
"dot"
,
"convolution"
},
i
->
name
()))
return
true
;
auto
x
=
i
->
inputs
()[
1
]
->
get_shape
().
lens
();
auto
y
=
j
->
inputs
()[
1
]
->
get_shape
().
lens
();
...
...
@@ -888,7 +1119,7 @@ struct find_conv_dot_horiz_fusion
return
false
;
// Check that non-axes match
int
axis
=
1
;
if
(
i
->
name
()
==
"dot"
)
if
(
i
->
name
()
==
"dot"
or
i
->
name
()
==
"quant_dot"
)
{
axis
=
x
.
size
()
-
1
;
}
...
...
@@ -899,7 +1130,7 @@ struct find_conv_dot_horiz_fusion
if
(
std
::
distance
(
start
,
last
)
<
2
)
return
;
auto
&&
name
=
(
*
start
)
->
name
();
if
(
not
contains
({
"dot"
,
"convolution"
},
name
))
if
(
not
contains
({
"quant_dot"
,
"dot"
,
"convolution"
},
name
))
return
;
auto
op
=
(
*
start
)
->
get_operator
();
int
group
=
1
;
...
...
@@ -914,7 +1145,7 @@ struct find_conv_dot_horiz_fusion
start
,
last
,
std
::
back_inserter
(
args
),
[
&
](
auto
x
)
{
return
x
->
inputs
().
at
(
1
);
});
int
axis
=
1
;
int
concat_axis
=
0
;
if
(
name
==
"dot"
)
if
(
name
==
"dot"
or
name
==
"quant_dot"
)
{
axis
=
int
(
args
.
front
()
->
get_shape
().
lens
().
size
()
-
1
);
concat_axis
=
axis
;
...
...
@@ -1260,12 +1491,15 @@ void simplify_algebra::apply(module& m) const
{
match
::
find_matches
(
m
,
find_inner_broadcast
{},
find_dot_broadcast
{},
find_double_add_lit_broadcast
{},
find_add_lit_broadcast
{},
find_add_convs
{},
find_conv_dot_horiz_fusion
{},
find_mul_conv
{},
find_mul_slice_conv
{},
find_mul_dot
{},
find_dot_mul
{},
find_mul_add
{},
find_unit_ops
{},
find_neg_unit_ops
{},
...
...
src/simplify_reshapes.cpp
View file @
23cb7917
...
...
@@ -89,38 +89,23 @@ struct find_reshaper
{
auto
matcher
()
const
{
return
match
::
name
(
reshaper_names
())(
match
::
any_of
[
match
::
outputs
()](
match
::
name
(
reshaper_names
())));
auto
reshaper
=
match
::
name
(
reshaper_names
());
auto
contiguous
=
match
::
name
(
"contiguous"
);
auto
no_output_reshape
=
match
::
none_of
[
match
::
outputs
()](
reshaper
);
auto
input_reshape
=
match
::
arg
(
0
)(
match
::
skip
(
contiguous
)(
reshaper
));
auto
input
=
match
::
skip
(
reshaper
,
contiguous
)(
match
::
any
().
bind
(
"x"
));
return
reshaper
(
no_output_reshape
,
input_reshape
,
input
);
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
mr
)
const
{
auto
ins
=
mr
.
result
;
std
::
vector
<
instruction_ref
>
reshapes
{
ins
};
while
(
is_reshaper
(
reshapes
.
back
()))
{
assert
(
not
reshapes
.
back
()
->
inputs
().
empty
());
assert
(
m
.
has_instruction
(
reshapes
.
back
()
->
inputs
().
front
()));
auto
input
=
reshapes
.
back
()
->
inputs
().
front
();
reshapes
.
push_back
(
input
);
}
auto
input
=
mr
.
instructions
[
"x"
];
auto
dims
=
ins
->
get_shape
().
lens
();
std
::
pair
<
instruction_ref
,
instruction_ref
>
r
{
m
.
end
(),
m
.
end
()};
for
(
auto
start
:
iterator_for
(
reshapes
))
{
auto
last
=
std
::
find_if
(
reshapes
.
rbegin
(),
reshapes
.
rend
(),
[
&
](
auto
&&
i
)
{
return
i
->
get_shape
()
==
(
*
start
)
->
get_shape
()
and
i
!=
(
*
start
);
});
if
(
last
!=
reshapes
.
rend
())
{
r
=
std
::
make_pair
(
*
start
,
*
last
);
break
;
}
}
if
(
r
.
first
!=
r
.
second
)
{
m
.
replace_instruction
(
r
.
first
,
r
.
second
);
}
if
(
not
input
->
get_shape
().
standard
())
input
=
m
.
insert_instruction
(
ins
,
make_op
(
"contiguous"
),
input
);
m
.
replace_instruction
(
ins
,
make_op
(
"reshape"
,
{{
"dims"
,
dims
}}),
input
);
}
};
...
...
@@ -804,9 +789,9 @@ void simplify_reshapes::apply(module& m) const
match
::
find_matches
(
m
,
find_where_op
{},
find_resize
{},
find_reshape_cont
{},
find_nop_reshapes
{},
find_reshaper
{},
find_reshape_cont
{},
find_transpose
{},
find_concat_transpose
{},
find_concat_multibroadcasts
{},
...
...
src/split_single_dyn_dim.cpp
View file @
23cb7917
...
...
@@ -100,10 +100,10 @@ struct find_static_2in_broadcasts
}
// namespace
/**
* Makes all the shapes in the dynamic_dimension range.
*
Probably won't work for `if`
and `loop` instructions, depending on how the submodules for those
* Makes all the shapes in the dynamic_dimension range.
Probably won't work for `if`
* and `loop` instructions, depending on how the submodules for those
* work. Inserts select_module instruction to the top. Replaces return, bypassing other
* instructions.
* instructions.
Skips if the dynamic parameter outputs to a select_module operator.
*/
void
split_single_dyn_dim
::
apply
(
module_pass_manager
&
mpm
)
const
{
...
...
@@ -111,7 +111,13 @@ void split_single_dyn_dim::apply(module_pass_manager& mpm) const
auto
param_names
=
mm
->
get_parameter_names
();
auto
param_shapes
=
mm
->
get_parameter_shapes
();
optional
<
dynamic_dimensions_check
>
dd_check
=
has_one_dyn_dim
(
param_shapes
);
if
(
dd_check
.
has_value
())
auto
any_sm_next
=
[
&
](
auto
ddc
)
{
auto
p_outputs
=
mm
->
get_parameter
(
ddc
->
dyn_param_str
)
->
outputs
();
return
std
::
any_of
(
p_outputs
.
cbegin
(),
p_outputs
.
cend
(),
[](
auto
ins
)
{
return
ins
->
name
()
==
"select_module"
;
});
};
if
(
dd_check
.
has_value
()
and
not
any_sm_next
(
dd_check
))
{
const
auto
&
dyn_param
=
mm
->
get_parameter
(
dd_check
->
dyn_param_str
);
auto
dyn_param_shape
=
mm
->
get_parameter_shape
(
dd_check
->
dyn_param_str
);
...
...
Prev
1
…
7
8
9
10
11
12
13
14
15
…
23
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment