Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
421ecad6
Commit
421ecad6
authored
May 08, 2023
by
Alan Turner
Browse files
Merge remote-tracking branch 'origin/develop' into ck-gsg
parents
3d0426e9
7cf05301
Changes
84
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
494 additions
and
131 deletions
+494
-131
src/api/migraphx.py
src/api/migraphx.py
+56
-44
src/cpp_generator.cpp
src/cpp_generator.cpp
+9
-0
src/driver/main.cpp
src/driver/main.cpp
+20
-5
src/fuse_pointwise.cpp
src/fuse_pointwise.cpp
+11
-1
src/include/migraphx/compile_options.hpp
src/include/migraphx/compile_options.hpp
+0
-3
src/include/migraphx/cpp_generator.hpp
src/include/migraphx/cpp_generator.hpp
+1
-0
src/include/migraphx/op/dequantizelinear.hpp
src/include/migraphx/op/dequantizelinear.hpp
+9
-0
src/include/migraphx/op/pointwise.hpp
src/include/migraphx/op/pointwise.hpp
+5
-4
src/include/migraphx/op/quantizelinear.hpp
src/include/migraphx/op/quantizelinear.hpp
+9
-0
src/include/migraphx/op/select_module.hpp
src/include/migraphx/op/select_module.hpp
+1
-1
src/include/migraphx/shape.hpp
src/include/migraphx/shape.hpp
+8
-4
src/include/migraphx/tf.hpp
src/include/migraphx/tf.hpp
+2
-0
src/normalize_attributes.cpp
src/normalize_attributes.cpp
+11
-8
src/onnx/onnx_parser.cpp
src/onnx/onnx_parser.cpp
+15
-12
src/onnx/op_parser.cpp
src/onnx/op_parser.cpp
+1
-0
src/pass_manager.cpp
src/pass_manager.cpp
+1
-0
src/propagate_constant.cpp
src/propagate_constant.cpp
+16
-0
src/py/migraphx_py.cpp
src/py/migraphx_py.cpp
+60
-11
src/shape.cpp
src/shape.cpp
+16
-23
src/simplify_algebra.cpp
src/simplify_algebra.cpp
+243
-15
No files found.
src/api/migraphx.py
View file @
421ecad6
...
@@ -45,57 +45,48 @@ def shape_type_wrap(p):
...
@@ -45,57 +45,48 @@ def shape_type_wrap(p):
p
.
read
=
'migraphx::to_shape_type(${name})'
p
.
read
=
'migraphx::to_shape_type(${name})'
@
api
.
cwrap
(
'migraphx::compile_options'
)
def
auto_handle
(
*
args
,
**
kwargs
):
def
compile_options_type_wrap
(
p
):
def
with_handle
(
f
):
if
p
.
returns
:
return
api
.
handle
(
'migraphx_'
+
f
.
__name__
,
'migraphx::'
+
f
.
__name__
,
p
.
add_param
(
'migraphx_compile_options *'
)
*
args
,
**
kwargs
)(
f
)
p
.
bad_param
(
'${name} == nullptr'
,
'Null pointer'
)
p
.
write
=
[
'*${name} = migraphx::to_compile_options(${result})'
]
else
:
p
.
add_param
(
'migraphx_compile_options *'
)
p
.
read
=
'${name} == nullptr ? migraphx::compile_options{} : migraphx::to_compile_options(*${name})'
@
api
.
cwrap
(
'migraphx::file_options'
)
def
file_options_type_wrap
(
p
):
if
p
.
returns
:
p
.
add_param
(
'migraphx_file_options *'
)
p
.
bad_param
(
'${name} == nullptr'
,
'Null pointer'
)
p
.
write
=
[
'*${name} = migraphx::to_file_options(${result})'
]
else
:
p
.
add_param
(
'migraphx_file_options *'
)
p
.
read
=
'${name} == nullptr ? migraphx::file_options{} : migraphx::to_file_options(*${name})'
@
api
.
cwrap
(
'migraphx::onnx_options'
)
return
with_handle
def
onnx_options_type_wrap
(
p
):
if
p
.
returns
:
p
.
add_param
(
'migraphx_onnx_options *'
)
p
.
bad_param
(
'${name} == nullptr'
,
'Null pointer'
)
p
.
write
=
[
'*${name} = migraphx::to_onnx_options(${result})'
]
else
:
p
.
add_param
(
'migraphx_onnx_options *'
)
p
.
read
=
'${name} == nullptr ? migraphx::onnx_options{} : migraphx::to_onnx_options(*${name})'
@
api
.
cwrap
(
'migraphx::tf_options'
)
@
api
.
handle
(
'migraphx_optimals'
,
'std::set<size_t>'
)
def
tf_options_type_wrap
(
p
):
def
optimals
(
h
):
if
p
.
returns
:
h
.
constructor
(
'create'
,
p
.
add_param
(
'migraphx_tf_options *'
)
api
.
params
(
ptr
=
'const size_t*'
,
size
=
'size_t'
),
p
.
bad_param
(
'${name} == nullptr'
,
'Null pointer'
)
fname
=
'migraphx::make_set<size_t>'
)
p
.
write
=
[
'*${name} = migraphx::to_tf_options(${result})'
]
else
:
p
.
add_param
(
'migraphx_tf_options *'
)
p
.
read
=
'${name} == nullptr ? migraphx::tf_options{} : migraphx::to_tf_options(*${name})'
def
auto_handle
(
*
args
,
**
kwargs
):
@
api
.
handle
(
'migraphx_dynamic_dimension'
,
'migraphx::shape::dynamic_dimension'
)
def
dynamic_dimension
(
h
):
h
.
constructor
(
'create_min_max'
,
api
.
params
(
min
=
'size_t'
,
max
=
'size_t'
))
h
.
constructor
(
'create_min_max_optimals'
,
api
.
params
(
min
=
'size_t'
,
max
=
'size_t'
,
optimals
=
'std::set<size_t>'
))
h
.
method
(
'is_fixed'
,
returns
=
'bool'
,
const
=
True
)
h
.
method
(
'equal'
,
api
.
params
(
x
=
'const migraphx::shape::dynamic_dimension&'
),
invoke
=
'migraphx::equal($@)'
,
returns
=
'bool'
,
const
=
True
)
def
with_handle
(
f
):
return
api
.
handle
(
'migraphx_'
+
f
.
__name__
,
'migraphx::'
+
f
.
__name__
,
*
args
,
**
kwargs
)(
f
)
return
with_handle
@
api
.
handle
(
'migraphx_dynamic_dimensions'
,
'std::vector<migraphx::shape::dynamic_dimension>'
)
def
dynamic_dimensions
(
h
):
h
.
constructor
(
'create'
,
api
.
params
(
ptr
=
'const_migraphx_dynamic_dimension_t*'
,
size
=
'size_t'
),
fname
=
'migraphx::to_obj_vector<const_migraphx_dynamic_dimension_t>'
)
h
.
method
(
'size'
,
returns
=
'size_t'
)
h
.
method
(
'get'
,
api
.
params
(
idx
=
'size_t'
),
fname
=
'at'
,
cpp_name
=
'operator[]'
,
returns
=
'const migraphx::shape::dynamic_dimension&'
)
@
auto_handle
()
@
auto_handle
()
...
@@ -110,20 +101,29 @@ def shape(h):
...
@@ -110,20 +101,29 @@ def shape(h):
lengths
=
'std::vector<size_t>'
,
lengths
=
'std::vector<size_t>'
,
strides
=
'std::vector<size_t>'
))
strides
=
'std::vector<size_t>'
))
h
.
constructor
(
'create_scalar'
,
api
.
params
(
type
=
'migraphx::shape::type_t'
))
h
.
constructor
(
'create_scalar'
,
api
.
params
(
type
=
'migraphx::shape::type_t'
))
h
.
constructor
(
'create_dynamic'
,
api
.
params
(
type
=
'migraphx::shape::type_t'
,
dims
=
'std::vector<migraphx::shape::dynamic_dimension>'
))
h
.
method
(
'lengths'
,
h
.
method
(
'lengths'
,
fname
=
'lens'
,
fname
=
'lens'
,
returns
=
'const std::vector<size_t>&'
,
returns
=
'const std::vector<size_t>&'
,
const
=
True
)
const
=
True
)
h
.
method
(
'strides'
,
returns
=
'const std::vector<size_t>&'
,
const
=
True
)
h
.
method
(
'strides'
,
returns
=
'const std::vector<size_t>&'
,
const
=
True
)
h
.
method
(
'dyn_dims'
,
returns
=
'std::vector<migraphx::shape::dynamic_dimension>'
,
const
=
True
)
h
.
method
(
'type'
,
returns
=
'migraphx::shape::type_t'
,
const
=
True
)
h
.
method
(
'type'
,
returns
=
'migraphx::shape::type_t'
,
const
=
True
)
h
.
method
(
'elements'
,
returns
=
'size_t'
,
const
=
True
)
h
.
method
(
'elements'
,
returns
=
'size_t'
,
const
=
True
)
h
.
method
(
'bytes'
,
returns
=
'size_t'
,
const
=
True
)
h
.
method
(
'bytes'
,
returns
=
'size_t'
,
const
=
True
)
h
.
method
(
'ndim'
,
returns
=
'size_t'
,
const
=
True
)
h
.
method
(
'equal'
,
h
.
method
(
'equal'
,
api
.
params
(
x
=
'const migraphx::shape&'
),
api
.
params
(
x
=
'const migraphx::shape&'
),
invoke
=
'migraphx::equal($@)'
,
invoke
=
'migraphx::equal($@)'
,
returns
=
'bool'
,
returns
=
'bool'
,
const
=
True
)
const
=
True
)
h
.
method
(
'standard'
,
returns
=
'bool'
,
const
=
True
)
h
.
method
(
'standard'
,
returns
=
'bool'
,
const
=
True
)
h
.
method
(
'dynamic'
,
returns
=
'bool'
,
const
=
True
)
h
.
method
(
'index'
,
api
.
params
(
i
=
'size_t'
),
returns
=
'size_t'
,
const
=
True
)
h
.
method
(
'index'
,
api
.
params
(
i
=
'size_t'
),
returns
=
'size_t'
,
const
=
True
)
...
@@ -131,6 +131,7 @@ def shape(h):
...
@@ -131,6 +131,7 @@ def shape(h):
def
argument
(
h
):
def
argument
(
h
):
h
.
constructor
(
'create'
,
h
.
constructor
(
'create'
,
api
.
params
(
shape
=
'const migraphx::shape&'
,
buffer
=
'void*'
))
api
.
params
(
shape
=
'const migraphx::shape&'
,
buffer
=
'void*'
))
h
.
constructor
(
'create_empty'
,
api
.
params
(
shape
=
'const migraphx::shape&'
))
h
.
method
(
'shape'
,
h
.
method
(
'shape'
,
fname
=
'get_shape'
,
fname
=
'get_shape'
,
cpp_name
=
'get_shape'
,
cpp_name
=
'get_shape'
,
...
@@ -326,11 +327,22 @@ def onnx_options(h):
...
@@ -326,11 +327,22 @@ def onnx_options(h):
api
.
params
(
name
=
'const char*'
,
dims
=
'std::vector<size_t>'
),
api
.
params
(
name
=
'const char*'
,
dims
=
'std::vector<size_t>'
),
invoke
=
'migraphx::set_input_parameter_shape($@)'
,
invoke
=
'migraphx::set_input_parameter_shape($@)'
,
)
)
h
.
method
(
'set_dyn_input_parameter_shape'
,
api
.
params
(
name
=
'const char*'
,
dims
=
'std::vector<migraphx::shape::dynamic_dimension>'
),
invoke
=
'migraphx::set_dyn_input_parameter_shape($@)'
,
)
h
.
method
(
h
.
method
(
'set_default_dim_value'
,
'set_default_dim_value'
,
api
.
params
(
value
=
'size_t'
),
api
.
params
(
value
=
'size_t'
),
invoke
=
'migraphx::set_default_dim_value($@)'
,
invoke
=
'migraphx::set_default_dim_value($@)'
,
)
)
h
.
method
(
'set_default_dyn_dim_value'
,
api
.
params
(
dd
=
'const migraphx::shape::dynamic_dimension&'
),
invoke
=
'migraphx::set_default_dyn_dim_value($@)'
,
)
h
.
method
(
h
.
method
(
'set_default_loop_iterations'
,
'set_default_loop_iterations'
,
api
.
params
(
value
=
'int64_t'
),
api
.
params
(
value
=
'int64_t'
),
...
...
src/cpp_generator.cpp
View file @
421ecad6
...
@@ -106,6 +106,11 @@ cpp_generator::function& cpp_generator::function::set_generic_types(const module
...
@@ -106,6 +106,11 @@ cpp_generator::function& cpp_generator::function::set_generic_types(const module
return
*
this
;
return
*
this
;
}
}
cpp_generator
::
function
&
cpp_generator
::
function
::
unused_param
(
const
std
::
string
&
pname
)
{
body
.
insert
(
0
,
"(void)"
+
pname
+
";
\n
"
);
return
*
this
;
}
cpp_generator
::
function
&
cpp_generator
::
function
::
add_generic_param
(
const
std
::
string
&
pname
)
cpp_generator
::
function
&
cpp_generator
::
function
::
add_generic_param
(
const
std
::
string
&
pname
)
{
{
params
.
push_back
({
pname
,
"T"
+
pname
});
params
.
push_back
({
pname
,
"T"
+
pname
});
...
@@ -174,6 +179,8 @@ std::string cpp_generator::generate_point_op(const operation& op,
...
@@ -174,6 +179,8 @@ std::string cpp_generator::generate_point_op(const operation& op,
else
if
(
with_char
(
::
isdigit
)(
key
[
0
]))
else
if
(
with_char
(
::
isdigit
)(
key
[
0
]))
{
{
auto
i
=
std
::
stoul
(
key
);
auto
i
=
std
::
stoul
(
key
);
if
(
i
>=
args
.
size
())
MIGRAPHX_THROW
(
"Invalid argument index: "
+
key
);
return
args
.
at
(
i
);
return
args
.
at
(
i
);
}
}
else
if
(
v
.
contains
(
key
))
else
if
(
v
.
contains
(
key
))
...
@@ -238,6 +245,8 @@ std::string cpp_generator::create_function(const cpp_generator::function& f)
...
@@ -238,6 +245,8 @@ std::string cpp_generator::create_function(const cpp_generator::function& f)
std
::
string
name
=
f
.
name
.
empty
()
?
"f"
+
std
::
to_string
(
impl
->
function_count
)
:
f
.
name
;
std
::
string
name
=
f
.
name
.
empty
()
?
"f"
+
std
::
to_string
(
impl
->
function_count
)
:
f
.
name
;
impl
->
fs
<<
join_strings
(
f
.
attributes
,
" "
)
<<
" "
<<
f
.
return_type
<<
" "
<<
name
;
impl
->
fs
<<
join_strings
(
f
.
attributes
,
" "
)
<<
" "
<<
f
.
return_type
<<
" "
<<
name
;
char
delim
=
'('
;
char
delim
=
'('
;
if
(
f
.
params
.
empty
())
impl
->
fs
<<
delim
;
for
(
auto
&&
p
:
f
.
params
)
for
(
auto
&&
p
:
f
.
params
)
{
{
impl
->
fs
<<
delim
<<
p
.
type
<<
" "
<<
p
.
name
;
impl
->
fs
<<
delim
<<
p
.
type
<<
" "
<<
p
.
name
;
...
...
src/driver/main.cpp
View file @
421ecad6
...
@@ -436,11 +436,6 @@ struct compiler
...
@@ -436,11 +436,6 @@ struct compiler
{
"--exhaustive-tune"
},
{
"--exhaustive-tune"
},
ap
.
help
(
"Exhastively search for best tuning parameters for kernels"
),
ap
.
help
(
"Exhastively search for best tuning parameters for kernels"
),
ap
.
set_value
(
true
));
ap
.
set_value
(
true
));
ap
(
co
.
split_single_dyn_dim
,
{
"--split-single-dyn-dim"
},
ap
.
help
(
"If there is a single non-fixed dynamic dimension in the model, then split to "
"static submodules"
),
ap
.
set_value
(
true
));
ap
(
quantize
,
{
"--fp16"
},
ap
.
help
(
"Quantize for fp16"
),
ap
.
set_value
(
precision
::
fp16
));
ap
(
quantize
,
{
"--fp16"
},
ap
.
help
(
"Quantize for fp16"
),
ap
.
set_value
(
precision
::
fp16
));
ap
(
quantize
,
{
"--int8"
},
ap
.
help
(
"Quantize for int8"
),
ap
.
set_value
(
precision
::
int8
));
ap
(
quantize
,
{
"--int8"
},
ap
.
help
(
"Quantize for int8"
),
ap
.
set_value
(
precision
::
int8
));
}
}
...
@@ -662,6 +657,26 @@ struct onnx : command<onnx>
...
@@ -662,6 +657,26 @@ struct onnx : command<onnx>
}
}
};
};
struct
tf
:
command
<
tf
>
{
bool
show_ops
=
false
;
void
parse
(
argument_parser
&
ap
)
{
ap
(
show_ops
,
{
"--list"
,
"-l"
},
ap
.
help
(
"List all tf operators supported by MIGraphX"
),
ap
.
set_value
(
true
));
}
void
run
()
const
{
if
(
show_ops
)
{
for
(
const
auto
&
name
:
get_tf_operators
())
std
::
cout
<<
name
<<
std
::
endl
;
}
}
};
struct
main_command
struct
main_command
{
{
static
std
::
string
get_command_help
(
const
std
::
string
&
title
=
colorize
(
color
::
fg_yellow
,
static
std
::
string
get_command_help
(
const
std
::
string
&
title
=
colorize
(
color
::
fg_yellow
,
...
...
src/fuse_pointwise.cpp
View file @
421ecad6
...
@@ -31,6 +31,8 @@
...
@@ -31,6 +31,8 @@
#include <migraphx/ranges.hpp>
#include <migraphx/ranges.hpp>
#include <iterator>
#include <iterator>
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_DISABLE_POINTWISE_FUSION
)
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -67,13 +69,13 @@ static void create_pointwise_modules(module_pass_manager& mpm)
...
@@ -67,13 +69,13 @@ static void create_pointwise_modules(module_pass_manager& mpm)
continue
;
continue
;
if
(
ins
->
get_operator
().
name
()
==
"layout"
)
if
(
ins
->
get_operator
().
name
()
==
"layout"
)
continue
;
continue
;
assert
(
ins
->
get_operator
().
attributes
().
contains
(
"point_op"
));
auto
*
pm
=
mpm
.
create_module
(
mpm
.
get_module
().
name
()
+
":pointwise"
+
std
::
to_string
(
n
++
));
auto
*
pm
=
mpm
.
create_module
(
mpm
.
get_module
().
name
()
+
":pointwise"
+
std
::
to_string
(
n
++
));
pm
->
set_bypass
();
pm
->
set_bypass
();
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
param_map
;
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
param_map
;
std
::
vector
<
instruction_ref
>
pointwise_inputs
;
std
::
vector
<
instruction_ref
>
pointwise_inputs
;
std
::
size_t
i
=
0
;
std
::
size_t
i
=
0
;
for
(
auto
input
:
ins
->
inputs
())
for
(
auto
input
:
ins
->
inputs
())
{
{
if
(
contains
(
param_map
,
input
))
if
(
contains
(
param_map
,
input
))
...
@@ -92,6 +94,10 @@ static void create_pointwise_modules(module_pass_manager& mpm)
...
@@ -92,6 +94,10 @@ static void create_pointwise_modules(module_pass_manager& mpm)
}
}
}
}
// Don't create pointwise module if no inputs are detected
if
(
pointwise_inputs
.
empty
())
continue
;
std
::
vector
<
instruction_ref
>
inputs
;
std
::
vector
<
instruction_ref
>
inputs
;
std
::
transform
(
ins
->
inputs
().
begin
(),
std
::
transform
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
ins
->
inputs
().
end
(),
...
@@ -188,6 +194,10 @@ void fuse_pointwise::apply(module_pass_manager& mpm) const
...
@@ -188,6 +194,10 @@ void fuse_pointwise::apply(module_pass_manager& mpm) const
{
{
create_pointwise_modules
(
mpm
);
create_pointwise_modules
(
mpm
);
mpm
.
run_pass
(
dead_code_elimination
{});
mpm
.
run_pass
(
dead_code_elimination
{});
if
(
enabled
(
MIGRAPHX_DISABLE_POINTWISE_FUSION
{}))
{
return
;
}
for
(
int
i
=
0
;
i
<
8
;
i
++
)
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
{
if
(
not
find_pointwise_modules
(
mpm
.
get_module
()))
if
(
not
find_pointwise_modules
(
mpm
.
get_module
()))
...
...
src/include/migraphx/compile_options.hpp
View file @
421ecad6
...
@@ -40,9 +40,6 @@ struct compile_options
...
@@ -40,9 +40,6 @@ struct compile_options
bool
fast_math
=
true
;
bool
fast_math
=
true
;
bool
exhaustive_tune
=
false
;
bool
exhaustive_tune
=
false
;
/// Use the split_single_dyn_dim pass
bool
split_single_dyn_dim
=
false
;
tracer
trace
{};
tracer
trace
{};
};
};
...
...
src/include/migraphx/cpp_generator.hpp
View file @
421ecad6
...
@@ -78,6 +78,7 @@ struct cpp_generator
...
@@ -78,6 +78,7 @@ struct cpp_generator
function
&
set_types
(
const
module
&
m
,
const
std
::
function
<
std
::
string
(
shape
)
>&
parse
);
function
&
set_types
(
const
module
&
m
,
const
std
::
function
<
std
::
string
(
shape
)
>&
parse
);
function
&
set_generic_types
(
const
module
&
m
);
function
&
set_generic_types
(
const
module
&
m
);
function
&
add_generic_param
(
const
std
::
string
&
pname
);
function
&
add_generic_param
(
const
std
::
string
&
pname
);
function
&
unused_param
(
const
std
::
string
&
pname
);
};
};
cpp_generator
();
cpp_generator
();
...
...
src/include/migraphx/op/dequantizelinear.hpp
View file @
421ecad6
...
@@ -37,6 +37,15 @@ namespace op {
...
@@ -37,6 +37,15 @@ namespace op {
struct
dequantizelinear
struct
dequantizelinear
{
{
value
attributes
()
const
{
// Note: point_op attribute is not used in this op. Instead, in
// gpu compilation pipeline, rewrite_quantization will be invoked
// from generate_pointwise() to rewrite this op.
return
{{
"pointwise"
,
true
}};
}
std
::
string
name
()
const
{
return
"dequantizelinear"
;
}
std
::
string
name
()
const
{
return
"dequantizelinear"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
...
...
src/include/migraphx/op/pointwise.hpp
View file @
421ecad6
...
@@ -45,14 +45,15 @@ struct pointwise
...
@@ -45,14 +45,15 @@ struct pointwise
{
{
MIGRAPHX_THROW
(
"should have one submodule."
);
MIGRAPHX_THROW
(
"should have one submodule."
);
}
}
auto
*
pm
=
mods
.
front
();
auto
*
pm
=
mods
.
front
();
if
(
pm
->
get_output_shapes
().
size
()
!=
1
)
MIGRAPHX_THROW
(
"pointwise should have only one output."
);
if
(
inputs
.
empty
())
MIGRAPHX_THROW
(
"pointwise should have at least one input"
);
auto
pnames
=
pm
->
get_parameter_names
();
auto
pnames
=
pm
->
get_parameter_names
();
std
::
sort
(
pnames
.
begin
(),
pnames
.
end
());
std
::
sort
(
pnames
.
begin
(),
pnames
.
end
());
check_shapes
{
inputs
,
*
this
}.
has
(
pnames
.
size
()).
same_dims
();
check_shapes
{
inputs
,
*
this
}.
has
(
pnames
.
size
()).
same_dims
();
if
(
pm
->
get_output_shapes
().
size
()
!=
1
)
MIGRAPHX_THROW
(
"submodule should have only one output."
);
auto
type
=
pm
->
get_output_shapes
().
front
().
type
();
auto
type
=
pm
->
get_output_shapes
().
front
().
type
();
// Scalar output if all inputs are scalar
// Scalar output if all inputs are scalar
...
...
src/include/migraphx/op/quantizelinear.hpp
View file @
421ecad6
...
@@ -38,6 +38,15 @@ namespace op {
...
@@ -38,6 +38,15 @@ namespace op {
struct
quantizelinear
struct
quantizelinear
{
{
std
::
string
name
()
const
{
return
"quantizelinear"
;
}
std
::
string
name
()
const
{
return
"quantizelinear"
;
}
value
attributes
()
const
{
// Note: point_op attribute is not used in this op. Instead, in
// gpu compilation pipeline, rewrite_quantization will be invoked
// from generate_pointwise() to rewrite this op.
return
{{
"pointwise"
,
true
}};
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
same_dims
().
has
(
2
,
3
);
check_shapes
{
inputs
,
*
this
}.
same_dims
().
has
(
2
,
3
);
...
...
src/include/migraphx/op/select_module.hpp
View file @
421ecad6
...
@@ -125,7 +125,7 @@ struct select_module
...
@@ -125,7 +125,7 @@ struct select_module
auto
ps
=
param_shapes
.
at
(
name
);
auto
ps
=
param_shapes
.
at
(
name
);
if
(
a
.
get_shape
()
!=
ps
)
if
(
a
.
get_shape
()
!=
ps
)
{
{
assert
(
ps
.
bytes
()
=
=
a
.
get_shape
().
bytes
());
assert
(
ps
.
bytes
()
<
=
a
.
get_shape
().
bytes
());
return
std
::
make_pair
(
name
,
a
.
reshape
(
ps
));
return
std
::
make_pair
(
name
,
a
.
reshape
(
ps
));
}
}
else
else
...
...
src/include/migraphx/shape.hpp
View file @
421ecad6
...
@@ -222,11 +222,15 @@ struct shape
...
@@ -222,11 +222,15 @@ struct shape
/// Map element index to space index
/// Map element index to space index
std
::
size_t
index
(
std
::
size_t
i
)
const
;
std
::
size_t
index
(
std
::
size_t
i
)
const
;
std
::
vector
<
std
::
size_t
>
multi
(
std
::
size_t
i
)
const
;
/// Map element index to multi-dimensional index
void
multi_copy
(
std
::
size_t
i
,
std
::
size_t
*
start
,
const
std
::
size_t
*
end
)
const
;
std
::
vector
<
std
::
size_t
>
multi
(
std
::
size_t
idx
)
const
;
/// Returns true if the shape is packed (number of elements and buffer size the same) with no
/// Map element index to multi-dimensional index and put them them into location provided by
/// padding
/// pointers
void
multi_copy
(
std
::
size_t
idx
,
std
::
size_t
*
start
,
const
std
::
size_t
*
end
)
const
;
/// Returns true if the shape is packed (number of elements and buffer size the same) with
/// no padding
bool
packed
()
const
;
bool
packed
()
const
;
/// Returns true is the shape has been transposed. That is the strides are not in descending
/// Returns true is the shape has been transposed. That is the strides are not in descending
...
...
src/include/migraphx/tf.hpp
View file @
421ecad6
...
@@ -43,6 +43,8 @@ struct tf_options
...
@@ -43,6 +43,8 @@ struct tf_options
/// Create a program from a tf pb file (default is nhwc format)
/// Create a program from a tf pb file (default is nhwc format)
program
parse_tf
(
const
std
::
string
&
name
,
const
tf_options
&
options
=
tf_options
{});
program
parse_tf
(
const
std
::
string
&
name
,
const
tf_options
&
options
=
tf_options
{});
std
::
vector
<
std
::
string
>
get_tf_operators
();
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
...
...
src/normalize_attributes.cpp
View file @
421ecad6
...
@@ -40,10 +40,12 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -40,10 +40,12 @@ inline namespace MIGRAPHX_INLINE_NS {
*
*
* See normalize_attribute.hpp for explaining the options.
* See normalize_attribute.hpp for explaining the options.
*/
*/
template
<
class
Message
>
auto
tune_attribute
(
const
std
::
vector
<
int64_t
>&
vec
,
auto
tune_attribute
(
const
std
::
vector
<
int64_t
>&
vec
,
const
std
::
vector
<
int64_t
>&
axes
,
const
std
::
vector
<
int64_t
>&
axes
,
const
value
&
val
,
const
value
&
val
,
const
std
::
vector
<
std
::
size_t
>&
lens
)
const
std
::
vector
<
std
::
size_t
>&
lens
,
Message
m
)
{
{
std
::
vector
<
int64_t
>
result
(
vec
);
std
::
vector
<
int64_t
>
result
(
vec
);
int64_t
n_rank
=
lens
.
size
();
int64_t
n_rank
=
lens
.
size
();
...
@@ -84,14 +86,14 @@ auto tune_attribute(const std::vector<int64_t>& vec,
...
@@ -84,14 +86,14 @@ auto tune_attribute(const std::vector<int64_t>& vec,
{
{
if
(
not
std
::
equal
(
result
.
begin
(),
result
.
end
(),
max_vals
.
begin
(),
std
::
less_equal
<>
{}))
if
(
not
std
::
equal
(
result
.
begin
(),
result
.
end
(),
max_vals
.
begin
(),
std
::
less_equal
<>
{}))
{
{
MIGRAPHX_THROW
(
"TUNE_VECTOR:
value out of range!"
);
MIGRAPHX_THROW
(
m
()
+
"
value out of range!"
);
}
}
}
}
else
else
{
{
if
(
not
std
::
equal
(
result
.
begin
(),
result
.
end
(),
max_vals
.
begin
(),
std
::
less
<>
{}))
if
(
not
std
::
equal
(
result
.
begin
(),
result
.
end
(),
max_vals
.
begin
(),
std
::
less
<>
{}))
{
{
MIGRAPHX_THROW
(
"TUNE_VECTOR:
value out of range!"
);
MIGRAPHX_THROW
(
m
()
+
"
value out of range!"
);
}
}
}
}
}
}
...
@@ -124,14 +126,14 @@ auto tune_attribute(const std::vector<int64_t>& vec,
...
@@ -124,14 +126,14 @@ auto tune_attribute(const std::vector<int64_t>& vec,
if
(
not
std
::
equal
(
if
(
not
std
::
equal
(
min_vals
.
begin
(),
min_vals
.
end
(),
result
.
begin
(),
std
::
less_equal
<>
{}))
min_vals
.
begin
(),
min_vals
.
end
(),
result
.
begin
(),
std
::
less_equal
<>
{}))
{
{
MIGRAPHX_THROW
(
"TUNE_VECTOR:
attribute out of range!"
);
MIGRAPHX_THROW
(
m
()
+
"
attribute out of range!"
);
}
}
}
}
else
else
{
{
if
(
not
std
::
equal
(
result
.
begin
(),
result
.
end
(),
min_vals
.
begin
(),
std
::
less
<>
{}))
if
(
not
std
::
equal
(
result
.
begin
(),
result
.
end
(),
min_vals
.
begin
(),
std
::
less
<>
{}))
{
{
MIGRAPHX_THROW
(
"TUNE_VECTOR:
attribute out of range!"
);
MIGRAPHX_THROW
(
m
()
+
"
attribute out of range!"
);
}
}
}
}
}
}
...
@@ -193,7 +195,8 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens)
...
@@ -193,7 +195,8 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens)
const
auto
&
key
=
rv
.
get_key
();
const
auto
&
key
=
rv
.
get_key
();
if
(
val
.
contains
(
key
))
if
(
val
.
contains
(
key
))
{
{
auto
vv
=
val
.
at
(
key
).
without_key
();
auto
message
=
[
&
]
{
return
op
.
name
()
+
": "
+
key
+
": "
;
};
auto
vv
=
val
.
at
(
key
).
without_key
();
if
(
vv
.
is_array
())
if
(
vv
.
is_array
())
{
{
std
::
vector
<
int64_t
>
axes
;
std
::
vector
<
int64_t
>
axes
;
...
@@ -202,7 +205,7 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens)
...
@@ -202,7 +205,7 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens)
axes
=
val
.
at
(
"axes"
).
without_key
().
to_vector
<
int64_t
>
();
axes
=
val
.
at
(
"axes"
).
without_key
().
to_vector
<
int64_t
>
();
}
}
auto
vec
=
vv
.
to_vector
<
int64_t
>
();
auto
vec
=
vv
.
to_vector
<
int64_t
>
();
auto
result
=
tune_attribute
(
vec
,
axes
,
rv
.
without_key
(),
lens
);
auto
result
=
tune_attribute
(
vec
,
axes
,
rv
.
without_key
(),
lens
,
message
);
val
[
key
]
=
result
;
val
[
key
]
=
result
;
op
.
from_value
(
val
);
op
.
from_value
(
val
);
val
=
op
.
to_value
();
val
=
op
.
to_value
();
...
@@ -211,7 +214,7 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens)
...
@@ -211,7 +214,7 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens)
else
else
{
{
auto
num
=
vv
.
to
<
int64_t
>
();
auto
num
=
vv
.
to
<
int64_t
>
();
auto
result
=
tune_attribute
({
num
},
{
num
},
rv
.
without_key
(),
lens
);
auto
result
=
tune_attribute
({
num
},
{
num
},
rv
.
without_key
(),
lens
,
message
);
val
[
key
]
=
result
.
front
();
val
[
key
]
=
result
.
front
();
op
.
from_value
(
val
);
op
.
from_value
(
val
);
val
=
op
.
to_value
();
val
=
op
.
to_value
();
...
...
src/onnx/onnx_parser.cpp
View file @
421ecad6
...
@@ -39,7 +39,19 @@
...
@@ -39,7 +39,19 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_REMOVE_LAST_OUTPUT
);
static
shape
shape_from_dyn_dims
(
shape
::
type_t
shape_type
,
const
std
::
vector
<
shape
::
dynamic_dimension
>&
dyn_dims
)
{
if
(
std
::
all_of
(
dyn_dims
.
begin
(),
dyn_dims
.
end
(),
[](
auto
dd
)
{
return
dd
.
is_fixed
();
}))
{
std
::
vector
<
std
::
size_t
>
dims
;
std
::
transform
(
dyn_dims
.
cbegin
(),
dyn_dims
.
cend
(),
std
::
back_inserter
(
dims
),
[](
auto
d
)
{
return
d
.
max
;
});
return
{
shape_type
,
dims
};
}
return
{
shape_type
,
dyn_dims
};
}
namespace
onnx
{
namespace
onnx
{
...
@@ -302,7 +314,7 @@ onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph, bool inlini
...
@@ -302,7 +314,7 @@ onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph, bool inlini
else
if
(
map_dyn_input_dims
.
count
(
name
)
>
0
)
else
if
(
map_dyn_input_dims
.
count
(
name
)
>
0
)
{
{
shape
::
type_t
shape_type
=
get_type
(
input
.
type
().
tensor_type
().
elem_type
());
shape
::
type_t
shape_type
=
get_type
(
input
.
type
().
tensor_type
().
elem_type
());
s
=
{
shape_type
,
map_dyn_input_dims
.
at
(
name
)
}
;
s
=
shape_from_dyn_dims
(
shape_type
,
map_dyn_input_dims
.
at
(
name
)
)
;
}
}
else
else
{
{
...
@@ -508,16 +520,7 @@ shape onnx_parser::parse_type(const onnx::TypeProto& t,
...
@@ -508,16 +520,7 @@ shape onnx_parser::parse_type(const onnx::TypeProto& t,
{
{
return
{
shape_type
};
return
{
shape_type
};
}
}
if
(
std
::
all_of
(
dynamic_dims
.
begin
(),
dynamic_dims
.
end
(),
[](
auto
dd
)
{
return
dd
.
is_fixed
();
}))
return
shape_from_dyn_dims
(
shape_type
,
dynamic_dims
);
{
std
::
vector
<
std
::
size_t
>
dims
;
std
::
transform
(
dynamic_dims
.
begin
(),
dynamic_dims
.
end
(),
std
::
back_inserter
(
dims
),
[](
auto
d
)
{
return
d
.
max
;
});
return
{
shape_type
,
dims
};
}
return
{
shape_type
,
dynamic_dims
};
}
}
shape
::
type_t
get_type
(
int
dtype
)
shape
::
type_t
get_type
(
int
dtype
)
...
...
src/onnx/op_parser.cpp
View file @
421ecad6
...
@@ -46,6 +46,7 @@ std::vector<std::string> get_op_parsers()
...
@@ -46,6 +46,7 @@ std::vector<std::string> get_op_parsers()
op_parser_map
().
end
(),
op_parser_map
().
end
(),
std
::
back_inserter
(
result
),
std
::
back_inserter
(
result
),
[
&
](
auto
&&
p
)
{
return
p
.
first
;
});
[
&
](
auto
&&
p
)
{
return
p
.
first
;
});
std
::
sort
(
result
.
begin
(),
result
.
end
());
return
result
;
return
result
;
}
}
...
...
src/pass_manager.cpp
View file @
421ecad6
...
@@ -103,6 +103,7 @@ struct module_pm : module_pass_manager
...
@@ -103,6 +103,7 @@ struct module_pm : module_pass_manager
virtual
void
run_pass
(
const
pass
&
p
)
override
virtual
void
run_pass
(
const
pass
&
p
)
override
{
{
trace
(
"Pass: "
,
p
.
name
());
assert
(
mod
);
assert
(
mod
);
assert
(
mod
->
validate
()
==
mod
->
end
());
assert
(
mod
->
validate
()
==
mod
->
end
());
if
(
enabled
(
MIGRAPHX_TIME_PASSES
{}))
if
(
enabled
(
MIGRAPHX_TIME_PASSES
{}))
...
...
src/propagate_constant.cpp
View file @
421ecad6
...
@@ -27,11 +27,14 @@
...
@@ -27,11 +27,14 @@
#include <migraphx/literal.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/env.hpp>
#include <unordered_set>
#include <unordered_set>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_TRACE_PROPAGATE_CONSTANT
)
bool
skip_propogate
(
instruction_ref
ins
)
bool
skip_propogate
(
instruction_ref
ins
)
{
{
if
(
ins
->
name
()
==
"contiguous"
)
if
(
ins
->
name
()
==
"contiguous"
)
...
@@ -85,6 +88,19 @@ void propagate_constant::apply(module& m) const
...
@@ -85,6 +88,19 @@ void propagate_constant::apply(module& m) const
{
{
if
(
not
literals
[
i
].
empty
())
if
(
not
literals
[
i
].
empty
())
{
{
if
(
enabled
(
MIGRAPHX_TRACE_PROPAGATE_CONSTANT
{}))
{
std
::
cout
<<
"Constant replace: "
<<
std
::
endl
;
std
::
vector
<
instruction_ref
>
inss
;
fix
([
&
](
auto
self
,
auto
ins
)
{
if
(
contains
(
inss
,
ins
))
return
;
for
(
auto
input
:
ins
->
inputs
())
self
(
input
);
inss
.
push_back
(
ins
);
})(
const_instrs_vec
[
i
]);
m
.
debug_print
(
inss
);
}
assert
(
literals
[
i
].
get_shape
()
==
const_instrs_vec
[
i
]
->
get_shape
());
assert
(
literals
[
i
].
get_shape
()
==
const_instrs_vec
[
i
]
->
get_shape
());
auto
l
=
m
.
add_literal
(
literals
[
i
].
get_shape
(),
literals
[
i
].
data
());
auto
l
=
m
.
add_literal
(
literals
[
i
].
get_shape
(),
literals
[
i
].
data
());
m
.
replace_instruction
(
const_instrs_vec
[
i
],
l
);
m
.
replace_instruction
(
const_instrs_vec
[
i
],
l
);
...
...
src/py/migraphx_py.cpp
View file @
421ecad6
...
@@ -62,6 +62,7 @@ namespace py = pybind11;
...
@@ -62,6 +62,7 @@ namespace py = pybind11;
PYBIND11_MODULE(__VA_ARGS__) \
PYBIND11_MODULE(__VA_ARGS__) \
MIGRAPHX_POP_WARNING
MIGRAPHX_POP_WARNING
#define MIGRAPHX_PYTHON_GENERATE_SHAPE_ENUM(x, t) .value(#x, migraphx::shape::type_t::x)
namespace
migraphx
{
namespace
migraphx
{
migraphx
::
value
to_value
(
py
::
kwargs
kwargs
);
migraphx
::
value
to_value
(
py
::
kwargs
kwargs
);
...
@@ -94,6 +95,10 @@ void visit_py(T x, F f)
...
@@ -94,6 +95,10 @@ void visit_py(T x, F f)
{
{
f
(
x
.
template
cast
<
std
::
string
>());
f
(
x
.
template
cast
<
std
::
string
>());
}
}
else
if
(
py
::
isinstance
<
migraphx
::
shape
::
dynamic_dimension
>
(
x
))
{
f
(
migraphx
::
to_value
(
x
.
template
cast
<
migraphx
::
shape
::
dynamic_dimension
>()));
}
else
else
{
{
MIGRAPHX_THROW
(
"VISIT_PY: Unsupported data type!"
);
MIGRAPHX_THROW
(
"VISIT_PY: Unsupported data type!"
);
...
@@ -164,7 +169,10 @@ template <class T>
...
@@ -164,7 +169,10 @@ template <class T>
py
::
buffer_info
to_buffer_info
(
T
&
x
)
py
::
buffer_info
to_buffer_info
(
T
&
x
)
{
{
migraphx
::
shape
s
=
x
.
get_shape
();
migraphx
::
shape
s
=
x
.
get_shape
();
auto
strides
=
s
.
strides
();
assert
(
s
.
type
()
!=
migraphx
::
shape
::
tuple_type
);
if
(
s
.
dynamic
())
MIGRAPHX_THROW
(
"MIGRAPHX PYTHON: dynamic shape argument passed to to_buffer_info"
);
auto
strides
=
s
.
strides
();
std
::
transform
(
std
::
transform
(
strides
.
begin
(),
strides
.
end
(),
strides
.
begin
(),
[
&
](
auto
i
)
{
return
i
*
s
.
type_size
();
});
strides
.
begin
(),
strides
.
end
(),
strides
.
begin
(),
[
&
](
auto
i
)
{
return
i
*
s
.
type_size
();
});
py
::
buffer_info
b
;
py
::
buffer_info
b
;
...
@@ -176,7 +184,7 @@ py::buffer_info to_buffer_info(T& x)
...
@@ -176,7 +184,7 @@ py::buffer_info to_buffer_info(T& x)
b
=
py
::
buffer_info
(
x
.
data
(),
b
=
py
::
buffer_info
(
x
.
data
(),
as
.
size
(),
as
.
size
(),
py
::
format_descriptor
<
bool
>::
format
(),
py
::
format_descriptor
<
bool
>::
format
(),
s
.
lens
().
size
(),
s
.
ndim
(),
s
.
lens
(),
s
.
lens
(),
strides
);
strides
);
}
}
...
@@ -185,7 +193,7 @@ py::buffer_info to_buffer_info(T& x)
...
@@ -185,7 +193,7 @@ py::buffer_info to_buffer_info(T& x)
b
=
py
::
buffer_info
(
x
.
data
(),
b
=
py
::
buffer_info
(
x
.
data
(),
as
.
size
(),
as
.
size
(),
py
::
format_descriptor
<
decltype
(
as
())
>::
format
(),
py
::
format_descriptor
<
decltype
(
as
())
>::
format
(),
s
.
lens
().
size
(),
s
.
ndim
(),
s
.
lens
(),
s
.
lens
(),
strides
);
strides
);
}
}
...
@@ -235,10 +243,18 @@ migraphx::shape to_shape(const py::buffer_info& info)
...
@@ -235,10 +243,18 @@ migraphx::shape to_shape(const py::buffer_info& info)
MIGRAPHX_PYBIND11_MODULE
(
migraphx
,
m
)
MIGRAPHX_PYBIND11_MODULE
(
migraphx
,
m
)
{
{
py
::
class_
<
migraphx
::
shape
>
(
m
,
"shape"
)
py
::
class_
<
migraphx
::
shape
>
shape_cls
(
m
,
"shape"
);
shape_cls
.
def
(
py
::
init
([](
py
::
kwargs
kwargs
)
{
.
def
(
py
::
init
([](
py
::
kwargs
kwargs
)
{
auto
v
=
migraphx
::
to_value
(
kwargs
);
auto
v
=
migraphx
::
to_value
(
kwargs
);
auto
t
=
migraphx
::
shape
::
parse_type
(
v
.
get
(
"type"
,
"float"
));
auto
t
=
migraphx
::
shape
::
parse_type
(
v
.
get
(
"type"
,
"float"
));
if
(
v
.
contains
(
"dyn_dims"
))
{
auto
dyn_dims
=
migraphx
::
from_value
<
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>>
(
v
.
at
(
"dyn_dims"
));
return
migraphx
::
shape
(
t
,
dyn_dims
);
}
auto
lens
=
v
.
get
<
std
::
size_t
>
(
"lens"
,
{
1
});
auto
lens
=
v
.
get
<
std
::
size_t
>
(
"lens"
,
{
1
});
if
(
v
.
contains
(
"strides"
))
if
(
v
.
contains
(
"strides"
))
return
migraphx
::
shape
(
t
,
lens
,
v
.
at
(
"strides"
).
to_vector
<
std
::
size_t
>
());
return
migraphx
::
shape
(
t
,
lens
,
v
.
at
(
"strides"
).
to_vector
<
std
::
size_t
>
());
...
@@ -248,19 +264,34 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
...
@@ -248,19 +264,34 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
.
def
(
"type"
,
&
migraphx
::
shape
::
type
)
.
def
(
"type"
,
&
migraphx
::
shape
::
type
)
.
def
(
"lens"
,
&
migraphx
::
shape
::
lens
)
.
def
(
"lens"
,
&
migraphx
::
shape
::
lens
)
.
def
(
"strides"
,
&
migraphx
::
shape
::
strides
)
.
def
(
"strides"
,
&
migraphx
::
shape
::
strides
)
.
def
(
"ndim"
,
&
migraphx
::
shape
::
ndim
)
.
def
(
"elements"
,
&
migraphx
::
shape
::
elements
)
.
def
(
"elements"
,
&
migraphx
::
shape
::
elements
)
.
def
(
"bytes"
,
&
migraphx
::
shape
::
bytes
)
.
def
(
"bytes"
,
&
migraphx
::
shape
::
bytes
)
.
def
(
"type_string"
,
&
migraphx
::
shape
::
type_string
)
.
def
(
"type_string"
,
&
migraphx
::
shape
::
type_string
)
.
def
(
"type_size"
,
&
migraphx
::
shape
::
type_size
)
.
def
(
"type_size"
,
&
migraphx
::
shape
::
type_size
)
.
def
(
"dyn_dims"
,
&
migraphx
::
shape
::
dyn_dims
)
.
def
(
"packed"
,
&
migraphx
::
shape
::
packed
)
.
def
(
"packed"
,
&
migraphx
::
shape
::
packed
)
.
def
(
"transposed"
,
&
migraphx
::
shape
::
transposed
)
.
def
(
"transposed"
,
&
migraphx
::
shape
::
transposed
)
.
def
(
"broadcasted"
,
&
migraphx
::
shape
::
broadcasted
)
.
def
(
"broadcasted"
,
&
migraphx
::
shape
::
broadcasted
)
.
def
(
"standard"
,
&
migraphx
::
shape
::
standard
)
.
def
(
"standard"
,
&
migraphx
::
shape
::
standard
)
.
def
(
"scalar"
,
&
migraphx
::
shape
::
scalar
)
.
def
(
"scalar"
,
&
migraphx
::
shape
::
scalar
)
.
def
(
"dynamic"
,
&
migraphx
::
shape
::
dynamic
)
.
def
(
"__eq__"
,
std
::
equal_to
<
migraphx
::
shape
>
{})
.
def
(
"__eq__"
,
std
::
equal_to
<
migraphx
::
shape
>
{})
.
def
(
"__ne__"
,
std
::
not_equal_to
<
migraphx
::
shape
>
{})
.
def
(
"__ne__"
,
std
::
not_equal_to
<
migraphx
::
shape
>
{})
.
def
(
"__repr__"
,
[](
const
migraphx
::
shape
&
s
)
{
return
migraphx
::
to_string
(
s
);
});
.
def
(
"__repr__"
,
[](
const
migraphx
::
shape
&
s
)
{
return
migraphx
::
to_string
(
s
);
});
py
::
enum_
<
migraphx
::
shape
::
type_t
>
(
shape_cls
,
"type_t"
)
MIGRAPHX_SHAPE_VISIT_TYPES
(
MIGRAPHX_PYTHON_GENERATE_SHAPE_ENUM
);
py
::
class_
<
migraphx
::
shape
::
dynamic_dimension
>
(
shape_cls
,
"dynamic_dimension"
)
.
def
(
py
::
init
<>
())
.
def
(
py
::
init
<
std
::
size_t
,
std
::
size_t
>
())
.
def
(
py
::
init
<
std
::
size_t
,
std
::
size_t
,
std
::
set
<
std
::
size_t
>>
())
.
def_readwrite
(
"min"
,
&
migraphx
::
shape
::
dynamic_dimension
::
min
)
.
def_readwrite
(
"max"
,
&
migraphx
::
shape
::
dynamic_dimension
::
max
)
.
def_readwrite
(
"optimals"
,
&
migraphx
::
shape
::
dynamic_dimension
::
optimals
)
.
def
(
"is_fixed"
,
&
migraphx
::
shape
::
dynamic_dimension
::
is_fixed
);
py
::
class_
<
migraphx
::
argument
>
(
m
,
"argument"
,
py
::
buffer_protocol
())
py
::
class_
<
migraphx
::
argument
>
(
m
,
"argument"
,
py
::
buffer_protocol
())
.
def_buffer
([](
migraphx
::
argument
&
x
)
->
py
::
buffer_info
{
return
to_buffer_info
(
x
);
})
.
def_buffer
([](
migraphx
::
argument
&
x
)
->
py
::
buffer_info
{
return
to_buffer_info
(
x
);
})
.
def
(
py
::
init
([](
py
::
buffer
b
)
{
.
def
(
py
::
init
([](
py
::
buffer
b
)
{
...
@@ -282,7 +313,9 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
...
@@ -282,7 +313,9 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
py
::
class_
<
migraphx
::
target
>
(
m
,
"target"
);
py
::
class_
<
migraphx
::
target
>
(
m
,
"target"
);
py
::
class_
<
migraphx
::
instruction_ref
>
(
m
,
"instruction_ref"
);
py
::
class_
<
migraphx
::
instruction_ref
>
(
m
,
"instruction_ref"
)
.
def
(
"shape"
,
[](
migraphx
::
instruction_ref
i
)
{
return
i
->
get_shape
();
})
.
def
(
"op"
,
[](
migraphx
::
instruction_ref
i
)
{
return
i
->
get_operator
();
});
py
::
class_
<
migraphx
::
module
,
std
::
unique_ptr
<
migraphx
::
module
,
py
::
nodelete
>>
(
m
,
"module"
)
py
::
class_
<
migraphx
::
module
,
std
::
unique_ptr
<
migraphx
::
module
,
py
::
nodelete
>>
(
m
,
"module"
)
.
def
(
"print"
,
[](
const
migraphx
::
module
&
mm
)
{
std
::
cout
<<
mm
<<
std
::
endl
;
})
.
def
(
"print"
,
[](
const
migraphx
::
module
&
mm
)
{
std
::
cout
<<
mm
<<
std
::
endl
;
})
...
@@ -433,13 +466,18 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
...
@@ -433,13 +466,18 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
"parse_onnx"
,
"parse_onnx"
,
[](
const
std
::
string
&
filename
,
[](
const
std
::
string
&
filename
,
unsigned
int
default_dim_value
,
unsigned
int
default_dim_value
,
migraphx
::
shape
::
dynamic_dimension
default_dyn_dim_value
,
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
size_t
>>
map_input_dims
,
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
size_t
>>
map_input_dims
,
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>>
map_dyn_input_dims
,
bool
skip_unknown_operators
,
bool
skip_unknown_operators
,
bool
print_program_on_error
,
bool
print_program_on_error
,
int64_t
max_loop_iterations
)
{
int64_t
max_loop_iterations
)
{
migraphx
::
onnx_options
options
;
migraphx
::
onnx_options
options
;
options
.
default_dim_value
=
default_dim_value
;
options
.
default_dim_value
=
default_dim_value
;
options
.
default_dyn_dim_value
=
default_dyn_dim_value
;
options
.
map_input_dims
=
map_input_dims
;
options
.
map_input_dims
=
map_input_dims
;
options
.
map_dyn_input_dims
=
map_dyn_input_dims
;
options
.
skip_unknown_operators
=
skip_unknown_operators
;
options
.
skip_unknown_operators
=
skip_unknown_operators
;
options
.
print_program_on_error
=
print_program_on_error
;
options
.
print_program_on_error
=
print_program_on_error
;
options
.
max_loop_iterations
=
max_loop_iterations
;
options
.
max_loop_iterations
=
max_loop_iterations
;
...
@@ -447,8 +485,11 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
...
@@ -447,8 +485,11 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
},
},
"Parse onnx file"
,
"Parse onnx file"
,
py
::
arg
(
"filename"
),
py
::
arg
(
"filename"
),
py
::
arg
(
"default_dim_value"
)
=
1
,
py
::
arg
(
"default_dim_value"
)
=
0
,
py
::
arg
(
"map_input_dims"
)
=
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
size_t
>>
(),
py
::
arg
(
"default_dyn_dim_value"
)
=
migraphx
::
shape
::
dynamic_dimension
{
1
,
1
},
py
::
arg
(
"map_input_dims"
)
=
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
size_t
>>
(),
py
::
arg
(
"map_dyn_input_dims"
)
=
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>>
(),
py
::
arg
(
"skip_unknown_operators"
)
=
false
,
py
::
arg
(
"skip_unknown_operators"
)
=
false
,
py
::
arg
(
"print_program_on_error"
)
=
false
,
py
::
arg
(
"print_program_on_error"
)
=
false
,
py
::
arg
(
"max_loop_iterations"
)
=
10
);
py
::
arg
(
"max_loop_iterations"
)
=
10
);
...
@@ -457,20 +498,28 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
...
@@ -457,20 +498,28 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
"parse_onnx_buffer"
,
"parse_onnx_buffer"
,
[](
const
std
::
string
&
onnx_buffer
,
[](
const
std
::
string
&
onnx_buffer
,
unsigned
int
default_dim_value
,
unsigned
int
default_dim_value
,
migraphx
::
shape
::
dynamic_dimension
default_dyn_dim_value
,
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
size_t
>>
map_input_dims
,
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
size_t
>>
map_input_dims
,
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>>
map_dyn_input_dims
,
bool
skip_unknown_operators
,
bool
skip_unknown_operators
,
bool
print_program_on_error
)
{
bool
print_program_on_error
)
{
migraphx
::
onnx_options
options
;
migraphx
::
onnx_options
options
;
options
.
default_dim_value
=
default_dim_value
;
options
.
default_dim_value
=
default_dim_value
;
options
.
default_dyn_dim_value
=
default_dyn_dim_value
;
options
.
map_input_dims
=
map_input_dims
;
options
.
map_input_dims
=
map_input_dims
;
options
.
map_dyn_input_dims
=
map_dyn_input_dims
;
options
.
skip_unknown_operators
=
skip_unknown_operators
;
options
.
skip_unknown_operators
=
skip_unknown_operators
;
options
.
print_program_on_error
=
print_program_on_error
;
options
.
print_program_on_error
=
print_program_on_error
;
return
migraphx
::
parse_onnx_buffer
(
onnx_buffer
,
options
);
return
migraphx
::
parse_onnx_buffer
(
onnx_buffer
,
options
);
},
},
"Parse onnx file"
,
"Parse onnx file"
,
py
::
arg
(
"filename"
),
py
::
arg
(
"filename"
),
py
::
arg
(
"default_dim_value"
)
=
1
,
py
::
arg
(
"default_dim_value"
)
=
0
,
py
::
arg
(
"map_input_dims"
)
=
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
size_t
>>
(),
py
::
arg
(
"default_dyn_dim_value"
)
=
migraphx
::
shape
::
dynamic_dimension
{
1
,
1
},
py
::
arg
(
"map_input_dims"
)
=
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
size_t
>>
(),
py
::
arg
(
"map_dyn_input_dims"
)
=
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>>
(),
py
::
arg
(
"skip_unknown_operators"
)
=
false
,
py
::
arg
(
"skip_unknown_operators"
)
=
false
,
py
::
arg
(
"print_program_on_error"
)
=
false
);
py
::
arg
(
"print_program_on_error"
)
=
false
);
...
...
src/shape.cpp
View file @
421ecad6
...
@@ -361,29 +361,26 @@ std::size_t shape::index(std::size_t i) const
...
@@ -361,29 +361,26 @@ std::size_t shape::index(std::size_t i) const
}
}
}
}
std
::
vector
<
std
::
size_t
>
shape
::
multi
(
std
::
size_t
i
)
const
std
::
vector
<
std
::
size_t
>
shape
::
multi
(
std
::
size_t
i
dx
)
const
{
{
assert
(
this
->
standard
());
assert
(
idx
<
elements
());
std
::
vector
<
std
::
size_t
>
indices
(
lens
().
size
());
std
::
vector
<
std
::
size_t
>
indices
(
lens
().
size
());
multi_copy
(
i
,
indices
.
data
(),
indices
.
data
()
+
lens
().
size
());
multi_copy
(
idx
,
indices
.
data
(),
indices
.
data
()
+
lens
().
size
());
return
indices
;
return
indices
;
}
}
void
shape
::
multi_copy
(
std
::
size_t
i
,
std
::
size_t
*
start
,
const
std
::
size_t
*
end
)
const
void
shape
::
multi_copy
(
std
::
size_t
i
dx
,
std
::
size_t
*
start
,
const
std
::
size_t
*
end
)
const
{
{
assert
(
this
->
standard
())
;
size_t
tidx
=
idx
;
(
void
)
end
;
(
void
)
end
;
assert
(
idx
<
elements
());
assert
(
lens
().
size
()
<=
(
end
-
start
));
assert
(
lens
().
size
()
<=
(
end
-
start
));
std
::
transform
(
strides
().
begin
(),
for
(
size_t
ii
=
lens
().
size
()
-
1
;
ii
>
0
;
ii
--
)
strides
().
end
(),
{
lens
().
begin
(),
*
(
start
+
ii
)
=
tidx
%
lens
()[
ii
];
start
,
tidx
=
tidx
/
lens
()[
ii
];
[
&
](
std
::
size_t
stride
,
std
::
size_t
len
)
{
}
assert
(
len
>
0
and
stride
>
0
);
*
start
=
tidx
;
return
(
i
/
stride
)
%
len
;
});
}
}
bool
shape
::
packed
()
const
bool
shape
::
packed
()
const
...
@@ -709,14 +706,10 @@ void migraphx_from_value(const value& v, shape& s)
...
@@ -709,14 +706,10 @@ void migraphx_from_value(const value& v, shape& s)
{
{
auto
v_dd
=
v
.
at
(
"dynamic_dimensions"
);
auto
v_dd
=
v
.
at
(
"dynamic_dimensions"
);
std
::
vector
<
shape
::
dynamic_dimension
>
dyn_dims
(
v
.
at
(
"dynamic_dimensions"
).
size
());
std
::
vector
<
shape
::
dynamic_dimension
>
dyn_dims
(
v
.
at
(
"dynamic_dimensions"
).
size
());
std
::
transform
(
v_dd
.
begin
(),
v_dd
.
end
(),
dyn_dims
.
begin
(),
[](
migraphx
::
value
x
)
{
std
::
transform
(
auto
x_min
=
x
.
at
(
"min"
).
template
to
<
size_t
>();
v_dd
.
begin
(),
v_dd
.
end
(),
dyn_dims
.
begin
(),
[](
const
migraphx
::
value
&
x
)
{
auto
x_max
=
x
.
at
(
"max"
).
template
to
<
size_t
>();
return
from_value
<
shape
::
dynamic_dimension
>
(
x
);
auto
v_optimals
=
x
.
at
(
"optimals"
);
});
std
::
set
<
size_t
>
set_x_optimals
=
from_value
<
std
::
set
<
std
::
size_t
>>
(
x
.
at
(
"optimals"
));
return
shape
::
dynamic_dimension
{
x_min
,
x_max
,
set_x_optimals
};
});
s
=
shape
{
shape
::
parse_type
(
t
),
dyn_dims
};
s
=
shape
{
shape
::
parse_type
(
t
),
dyn_dims
};
}
}
...
...
src/simplify_algebra.cpp
View file @
421ecad6
...
@@ -204,6 +204,131 @@ struct find_mul_slice_conv
...
@@ -204,6 +204,131 @@ struct find_mul_slice_conv
}
}
};
};
struct
find_mul_dot
{
auto
matcher
()
const
{
auto
is_dot_const_inputs
=
match
::
name
(
"dot"
)(
match
::
any_of
[
match
::
inputs
()](
match
::
is_constant
()));
return
match
::
name
(
"mul"
)(
match
::
either_arg
(
0
,
1
)(
is_dot_const_inputs
.
bind
(
"dot"
),
match
::
name
(
"broadcast"
,
"multibroadcast"
).
bind
(
"c"
)));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
dot_ins
=
r
.
instructions
[
"dot"
];
auto
a_ins
=
dot_ins
->
inputs
()[
0
];
auto
b_ins
=
dot_ins
->
inputs
()[
1
];
auto
c_ins
=
r
.
instructions
[
"c"
];
const
auto
&
c_strides
=
c_ins
->
get_shape
().
strides
();
// There should only be one stride that is not zero
if
(
std
::
count_if
(
c_strides
.
begin
(),
c_strides
.
end
(),
[](
auto
s
)
{
return
s
!=
0
;
})
>
1
)
return
;
auto
add_mul_const
=
[
&
](
instruction_ref
x_ins
)
{
if
(
not
x_ins
->
can_eval
())
return
m
.
end
();
auto
broadcast_v
=
c_ins
->
get_operator
().
to_value
();
broadcast_v
[
"out_lens"
]
=
x_ins
->
get_shape
().
lens
();
auto
cb_ins
=
m
.
insert_instruction
(
ins
,
make_op
(
c_ins
->
name
(),
broadcast_v
),
c_ins
->
inputs
());
return
m
.
insert_instruction
(
ins
,
make_op
(
"mul"
),
x_ins
,
cb_ins
);
};
if
(
c_strides
.
back
()
==
1
)
{
b_ins
=
add_mul_const
(
b_ins
);
}
else
if
(
c_strides
[
c_strides
.
size
()
-
2
]
==
1
)
{
a_ins
=
add_mul_const
(
a_ins
);
}
else
if
(
c_ins
->
get_shape
().
scalar
())
{
if
(
a_ins
->
can_eval
())
a_ins
=
add_mul_const
(
a_ins
);
else
b_ins
=
add_mul_const
(
b_ins
);
}
else
{
return
;
}
if
(
contains
({
a_ins
,
b_ins
},
m
.
end
()))
return
;
m
.
replace_instruction
(
ins
,
make_op
(
"dot"
),
a_ins
,
b_ins
);
}
};
struct
find_dot_mul
{
auto
matcher
()
const
{
auto
const_broadcast
=
match
::
name
(
"broadcast"
,
"multibroadcast"
)(
match
::
is_constant
());
auto
mul
=
match
::
name
(
"mul"
)(
match
::
used_once
(),
match
::
either_arg
(
0
,
1
)(
const_broadcast
.
bind
(
"d"
),
match
::
none_of
(
match
::
is_constant
()).
bind
(
"z"
)));
return
match
::
name
(
"dot"
)(
match
::
either_arg
(
0
,
1
)(
mul
,
match
::
is_constant
().
bind
(
"c"
)));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
a_ins
=
ins
->
inputs
()[
0
];
auto
b_ins
=
ins
->
inputs
()[
1
];
auto
d_ins
=
r
.
instructions
[
"d"
];
auto
c_ins
=
r
.
instructions
[
"c"
];
auto
z_ins
=
r
.
instructions
[
"z"
];
const
auto
&
d_strides
=
d_ins
->
get_shape
().
strides
();
// There should only be one stride that is not zero
if
(
std
::
count_if
(
d_strides
.
begin
(),
d_strides
.
end
(),
[](
auto
s
)
{
return
s
!=
0
;
})
>
1
)
return
;
if
(
not
d_ins
->
get_shape
().
scalar
())
{
if
(
d_strides
.
back
()
==
1
and
not
b_ins
->
can_eval
())
return
;
if
(
d_strides
[
d_strides
.
size
()
-
2
]
==
1
and
not
a_ins
->
can_eval
())
return
;
}
auto
broadcast_v
=
d_ins
->
get_operator
().
to_value
();
auto
c_lens
=
c_ins
->
get_shape
().
lens
();
std
::
vector
<
int64_t
>
permutation
(
c_lens
.
size
());
std
::
iota
(
permutation
.
begin
(),
permutation
.
end
(),
0
);
std
::
swap
(
permutation
.
back
(),
permutation
[
permutation
.
size
()
-
2
]);
c_lens
=
reorder_dims
(
c_lens
,
permutation
);
broadcast_v
[
"out_lens"
]
=
c_lens
;
auto
db_ins
=
m
.
insert_instruction
(
ins
,
make_op
(
d_ins
->
name
(),
broadcast_v
),
d_ins
->
inputs
());
auto
db_transpose_ins
=
m
.
insert_instruction
(
ins
,
make_op
(
"transpose"
,
{{
"permutation"
,
permutation
}}),
db_ins
);
auto
cd_ins
=
m
.
insert_instruction
(
ins
,
make_op
(
"mul"
),
c_ins
,
db_transpose_ins
);
if
(
c_ins
==
b_ins
)
{
a_ins
=
z_ins
;
b_ins
=
cd_ins
;
}
else
{
a_ins
=
cd_ins
;
b_ins
=
z_ins
;
}
m
.
replace_instruction
(
ins
,
make_op
(
"dot"
),
a_ins
,
b_ins
);
}
};
// ******************************
// ******************************
// a * (x + b) => a * x + a * b
// a * (x + b) => a * x + a * b
// ******************************
// ******************************
...
@@ -361,30 +486,118 @@ struct find_inner_broadcast
...
@@ -361,30 +486,118 @@ struct find_inner_broadcast
{
{
auto
matcher
()
const
{
return
pointwise
(
match
::
all_of
[
match
::
inputs
()](
match
::
broadcast
()));
}
auto
matcher
()
const
{
return
pointwise
(
match
::
all_of
[
match
::
inputs
()](
match
::
broadcast
()));
}
static
auto
non_scalar_op
(
const
std
::
string
&
name
)
{
return
[
=
](
instruction_ref
ins
)
{
if
(
ins
->
get_shape
().
scalar
())
return
false
;
return
ins
->
name
()
==
name
;
};
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
{
auto
ins
=
r
.
result
;
auto
ins
=
r
.
result
;
auto
broadcasts
=
ins
->
inputs
();
auto
broadcasts
=
ins
->
inputs
();
if
(
broadcasts
.
empty
())
if
(
broadcasts
.
empty
())
return
;
return
;
bool
mixed_broadcasts
=
any_of
(
broadcasts
,
non_scalar_op
(
"broadcast"
))
and
any_of
(
broadcasts
,
non_scalar_op
(
"multibroadcast"
));
// If the broadcast is not a single dimension, then dont perform inner_broadcast
if
(
mixed_broadcasts
and
any_of
(
broadcasts
,
[
&
](
instruction_ref
i
)
{
if
(
i
->
get_shape
().
scalar
())
return
false
;
if
(
i
->
name
()
==
"multibroadcast"
)
return
false
;
auto
input
=
i
->
inputs
().
at
(
0
);
const
auto
&
lens
=
input
->
get_shape
().
lens
();
return
std
::
count_if
(
lens
.
begin
(),
lens
.
end
(),
[
&
](
std
::
size_t
d
)
{
return
d
==
1
;
})
<
(
lens
.
size
()
-
1
);
}))
return
;
std
::
vector
<
instruction_ref
>
inputs
;
std
::
vector
<
instruction_ref
>
inputs
;
std
::
transform
(
broadcasts
.
begin
(),
std
::
transform
(
broadcasts
.
begin
(),
broadcasts
.
end
(),
broadcasts
.
end
(),
std
::
back_inserter
(
inputs
),
std
::
back_inserter
(
inputs
),
[](
auto
i
)
{
return
i
->
inputs
().
front
();
});
[
&
](
instruction_ref
i
)
{
if
(
std
::
any_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
&
](
auto
i
)
{
auto
input
=
i
->
inputs
().
front
();
return
i
->
get_shape
()
!=
inputs
.
front
()
->
get_shape
()
and
if
(
mixed_broadcasts
and
not
i
->
get_shape
().
scalar
()
and
i
->
get_shape
().
elements
()
!=
1
;
i
->
get_shape
().
lens
().
size
()
>
1
)
}))
return
m
.
insert_instruction
(
i
,
make_op
(
"squeeze"
),
input
);
return
;
return
input
;
});
auto
b_it
=
std
::
find_if
(
broadcasts
.
begin
(),
broadcasts
.
end
(),
[
&
](
auto
i
)
{
return
not
i
->
get_shape
().
scalar
();
std
::
sort
(
broadcasts
.
begin
(),
broadcasts
.
end
(),
by
(
std
::
less
<>
{},
[](
instruction_ref
i
)
{
});
if
(
i
->
get_shape
().
scalar
())
if
(
b_it
==
broadcasts
.
end
())
return
2
;
b_it
=
broadcasts
.
begin
();
else
if
(
i
->
name
()
==
"broadcast"
)
return
0
;
if
(
i
->
name
()
==
"multibroadcast"
)
return
1
;
return
3
;
}));
auto
op
=
insert_common_op
(
m
,
ins
,
ins
->
get_operator
(),
inputs
);
auto
op
=
insert_common_op
(
m
,
ins
,
ins
->
get_operator
(),
inputs
);
m
.
replace_instruction
(
ins
,
(
*
b_it
)
->
get_operator
(),
op
);
m
.
replace_instruction
(
ins
,
broadcasts
.
front
()
->
get_operator
(),
op
);
}
};
struct
find_dot_broadcast
{
auto
matcher
()
const
{
return
match
::
name
(
"dot"
)(
match
::
all_of
[
match
::
inputs
()](
match
::
broadcast
()));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
a
=
ins
->
inputs
()[
0
];
auto
b
=
ins
->
inputs
()[
1
];
if
(
a
->
get_operator
().
name
()
!=
b
->
get_operator
().
name
())
return
;
if
(
ins
->
get_shape
().
lens
().
size
()
<
3
)
return
;
auto
nbatch_axes
=
ins
->
get_shape
().
lens
().
size
()
-
2
;
const
auto
&
a_strides
=
a
->
get_shape
().
strides
();
const
auto
&
b_strides
=
b
->
get_shape
().
strides
();
// Find leading batch axes that are broadcasted
auto
p
=
std
::
mismatch
(
a_strides
.
begin
(),
a_strides
.
begin
()
+
nbatch_axes
,
b_strides
.
begin
(),
b_strides
.
begin
()
+
nbatch_axes
,
[](
auto
astride
,
auto
bstride
)
{
return
astride
==
0
and
bstride
==
0
;
});
auto
naxes
=
p
.
first
-
a_strides
.
begin
();
assert
(
naxes
<=
nbatch_axes
);
std
::
vector
<
std
::
size_t
>
axes
(
naxes
);
std
::
iota
(
axes
.
begin
(),
axes
.
end
(),
0
);
auto
insert_broadcast
=
[
&
](
instruction_ref
b_ins
)
->
instruction_ref
{
auto
input
=
b_ins
->
inputs
()[
0
];
std
::
vector
<
std
::
size_t
>
lens
(
b_ins
->
get_shape
().
lens
().
begin
()
+
naxes
,
b_ins
->
get_shape
().
lens
().
end
());
if
(
b_ins
->
name
()
==
"multibroadcast"
)
{
return
m
.
insert_instruction
(
ins
,
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
lens
}}),
input
);
}
else
if
(
b_ins
->
name
()
==
"broadcast"
)
{
auto
v
=
b_ins
->
get_operator
().
to_value
();
auto
axis
=
v
.
at
(
"axis"
).
to
<
std
::
size_t
>
()
-
naxes
;
return
m
.
insert_instruction
(
ins
,
make_op
(
"broadcast"
,
{{
"axis"
,
axis
},
{
"out_lens"
,
lens
}}),
input
);
}
assert
(
false
);
return
m
.
end
();
};
auto
a1
=
insert_broadcast
(
a
);
auto
b1
=
insert_broadcast
(
b
);
auto
dot
=
m
.
insert_instruction
(
ins
,
make_op
(
"dot"
),
a1
,
b1
);
auto
broadcast
=
m
.
insert_instruction
(
ins
,
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
ins
->
get_shape
().
lens
()}}),
dot
);
m
.
replace_instruction
(
ins
,
broadcast
);
}
}
};
};
...
@@ -393,7 +606,8 @@ struct find_concat_op
...
@@ -393,7 +606,8 @@ struct find_concat_op
auto
matcher
()
const
auto
matcher
()
const
{
{
return
match
::
name
(
"concat"
)(
match
::
any_of
[
match
::
inputs
()](
return
match
::
name
(
"concat"
)(
match
::
any_of
[
match
::
inputs
()](
match
::
any_of
(
match
::
pointwise
(),
match
::
name
(
"broadcast"
)),
match
::
used_once
()));
match
::
any_of
(
match
::
pointwise
(),
match
::
name
(
"broadcast"
,
"multibroadcast"
)),
match
::
used_once
()));
}
}
template
<
class
Iterator
>
template
<
class
Iterator
>
...
@@ -412,7 +626,8 @@ struct find_concat_op
...
@@ -412,7 +626,8 @@ struct find_concat_op
static
bool
is_valid_op
(
const
operation
&
op
)
static
bool
is_valid_op
(
const
operation
&
op
)
{
{
return
op
.
name
()
==
"broadcast"
or
op
.
attributes
().
contains
(
"pointwise"
);
return
contains
({
"broadcast"
,
"multibroadcast"
},
op
.
name
())
or
op
.
attributes
().
contains
(
"pointwise"
);
}
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
...
@@ -440,6 +655,16 @@ struct find_concat_op
...
@@ -440,6 +655,16 @@ struct find_concat_op
op
=
b
;
op
=
b
;
iaxis
=
0
;
iaxis
=
0
;
}
}
else
if
(
op
.
name
()
==
"multibroadcast"
)
{
shape
bshape
=
(
*
start
)
->
get_shape
();
auto
input
=
(
*
start
)
->
inputs
()[
0
];
if
(
iaxis
>=
bshape
.
strides
().
size
()
or
bshape
.
strides
()[
iaxis
]
==
0
)
return
{
start
,
last
};
op
.
from_value
({{
"out_lens"
,
get_output_lens
(
start
,
last
,
iaxis
)}});
auto
delta
=
bshape
.
lens
().
size
()
-
input
->
get_shape
().
lens
().
size
();
iaxis
-=
delta
;
}
std
::
vector
<
instruction_ref
>
concats
;
std
::
vector
<
instruction_ref
>
concats
;
for
(
std
::
size_t
i
=
0
;
i
<
x
->
inputs
().
size
();
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
x
->
inputs
().
size
();
i
++
)
...
@@ -1260,12 +1485,15 @@ void simplify_algebra::apply(module& m) const
...
@@ -1260,12 +1485,15 @@ void simplify_algebra::apply(module& m) const
{
{
match
::
find_matches
(
m
,
match
::
find_matches
(
m
,
find_inner_broadcast
{},
find_inner_broadcast
{},
find_dot_broadcast
{},
find_double_add_lit_broadcast
{},
find_double_add_lit_broadcast
{},
find_add_lit_broadcast
{},
find_add_lit_broadcast
{},
find_add_convs
{},
find_add_convs
{},
find_conv_dot_horiz_fusion
{},
find_conv_dot_horiz_fusion
{},
find_mul_conv
{},
find_mul_conv
{},
find_mul_slice_conv
{},
find_mul_slice_conv
{},
find_mul_dot
{},
find_dot_mul
{},
find_mul_add
{},
find_mul_add
{},
find_unit_ops
{},
find_unit_ops
{},
find_neg_unit_ops
{},
find_neg_unit_ops
{},
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment