Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
fe493c28
Commit
fe493c28
authored
Apr 10, 2023
by
Alan Turner
Browse files
Merge remote-tracking branch 'origin/develop' into ck-gsg
parents
ba0b3794
cce35871
Changes
121
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
513 additions
and
49 deletions
+513
-49
src/fuse_reduce.cpp
src/fuse_reduce.cpp
+350
-0
src/include/migraphx/check_context.hpp
src/include/migraphx/check_context.hpp
+23
-1
src/include/migraphx/common.hpp
src/include/migraphx/common.hpp
+5
-0
src/include/migraphx/compile_options.hpp
src/include/migraphx/compile_options.hpp
+9
-1
src/include/migraphx/cpp_generator.hpp
src/include/migraphx/cpp_generator.hpp
+5
-0
src/include/migraphx/dynamic_loader.hpp
src/include/migraphx/dynamic_loader.hpp
+6
-0
src/include/migraphx/fuse_reduce.hpp
src/include/migraphx/fuse_reduce.hpp
+43
-0
src/include/migraphx/matcher.hpp
src/include/migraphx/matcher.hpp
+21
-1
src/include/migraphx/module.hpp
src/include/migraphx/module.hpp
+2
-0
src/include/migraphx/msgpack.hpp
src/include/migraphx/msgpack.hpp
+2
-0
src/include/migraphx/onnx.hpp
src/include/migraphx/onnx.hpp
+1
-1
src/include/migraphx/op/argmax.hpp
src/include/migraphx/op/argmax.hpp
+1
-1
src/include/migraphx/op/concat.hpp
src/include/migraphx/op/concat.hpp
+1
-1
src/include/migraphx/op/contiguous.hpp
src/include/migraphx/op/contiguous.hpp
+1
-1
src/include/migraphx/op/convolution.hpp
src/include/migraphx/op/convolution.hpp
+16
-7
src/include/migraphx/op/dequantizelinear.hpp
src/include/migraphx/op/dequantizelinear.hpp
+5
-1
src/include/migraphx/op/flatten.hpp
src/include/migraphx/op/flatten.hpp
+4
-8
src/include/migraphx/op/gathernd.hpp
src/include/migraphx/op/gathernd.hpp
+1
-1
src/include/migraphx/op/nonmaxsuppression.hpp
src/include/migraphx/op/nonmaxsuppression.hpp
+2
-2
src/include/migraphx/op/pooling.hpp
src/include/migraphx/op/pooling.hpp
+15
-23
No files found.
src/fuse_reduce.cpp
0 → 100644
View file @
fe493c28
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/fuse_reduce.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/program.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/register_op.hpp>
#include <iterator>
#include <map>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
fused_reduce
{
std
::
vector
<
std
::
int64_t
>
axes
{};
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
axes
,
"axes"
));
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
,
std
::
vector
<
module_ref
>
mods
)
const
{
if
(
mods
.
size
()
!=
1
)
MIGRAPHX_THROW
(
"should have one submodule."
);
auto
*
sm
=
mods
.
front
();
if
(
sm
->
get_output_shapes
().
size
()
!=
1
)
MIGRAPHX_THROW
(
"Only one output supported"
);
auto
names
=
sm
->
get_parameter_names
();
check_shapes
{
inputs
,
*
this
}.
has
(
names
.
size
()).
same_ndims
();
std
::
sort
(
names
.
begin
(),
names
.
end
());
auto
shapes
=
sm
->
get_parameter_shapes
();
// Check dimension matches for each input
if
(
not
equal
(
names
,
inputs
,
[
&
](
const
auto
&
name
,
const
auto
&
input
)
{
return
shapes
.
at
(
name
).
lens
()
==
input
.
lens
();
}))
MIGRAPHX_THROW
(
"Dimenstion does not match the submodule."
);
const
auto
&
s
=
inputs
.
at
(
0
);
auto
lens
=
s
.
lens
();
if
(
lens
!=
sm
->
get_output_shapes
().
front
().
lens
())
{
for
(
const
auto
&
axis
:
axes
)
{
lens
[
axis
]
=
1
;
}
}
return
shape
::
from_permutation
(
sm
->
get_output_shapes
().
front
().
type
(),
lens
,
find_permutation
(
inputs
));
}
std
::
string
name
()
const
{
return
"fused_reduce"
;
}
};
MIGRAPHX_REGISTER_OP
(
fused_reduce
);
static
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
get_ins_param_map
(
const
std
::
vector
<
instruction_ref
>&
inputs
,
const_module_ref
sm
)
{
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
result
;
auto
names
=
sm
->
get_parameter_names
();
std
::
sort
(
names
.
begin
(),
names
.
end
());
assert
(
names
.
size
()
==
inputs
.
size
());
std
::
transform
(
names
.
begin
(),
names
.
end
(),
inputs
.
begin
(),
std
::
inserter
(
result
,
result
.
end
()),
[
&
](
const
auto
&
name
,
auto
input
)
{
return
std
::
make_pair
(
input
,
sm
->
get_parameter
(
name
));
});
return
result
;
}
static
void
insert_params
(
module_ref
sm
,
instruction_ref
ins
,
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>&
map_ins
)
{
auto
n
=
sm
->
get_parameter_shapes
().
size
();
for
(
auto
input
:
ins
->
inputs
())
{
if
(
contains
(
map_ins
,
input
))
continue
;
auto
s
=
shape
{
input
->
get_shape
().
type
(),
input
->
get_shape
().
lens
()};
map_ins
[
input
]
=
sm
->
add_parameter
(
"x"
+
std
::
to_string
(
n
++
),
s
);
}
}
static
auto
insert_ins_in_submodule
(
module_ref
sm
,
instruction_ref
ins
,
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>&
map_ins
)
{
insert_params
(
sm
,
ins
,
map_ins
);
return
sm
->
add_instructions
({
ins
},
map_ins
);
}
static
auto
insert_ins_in_submodule
(
module_ref
sm
,
instruction_ref
ins
)
{
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
map_ins
;
return
insert_ins_in_submodule
(
sm
,
ins
,
map_ins
);
}
static
auto
insert_module_in_submodule
(
module_ref
sm
,
instruction_ref
ins
,
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>&
map_ins
)
{
insert_params
(
sm
,
ins
,
map_ins
);
auto
*
m
=
ins
->
module_inputs
().
front
();
auto
param_map
=
get_ins_param_map
(
ins
->
inputs
(),
m
);
for
(
auto
&&
[
input
,
param
]
:
param_map
)
{
map_ins
[
param
]
=
map_ins
.
at
(
input
);
}
return
sm
->
add_instructions
(
m
,
map_ins
);
}
static
std
::
vector
<
instruction_ref
>
find_inputs
(
module_ref
sm
,
const
module
&
parent
,
const
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>&
map_ins
)
{
std
::
vector
<
instruction_ref
>
result
;
std
::
map
<
std
::
string
,
instruction_ref
>
names
;
for
(
auto
&&
[
input
,
param
]
:
map_ins
)
{
if
(
not
sm
->
has_instruction
(
param
))
continue
;
if
(
param
->
name
()
!=
"@param"
)
continue
;
if
(
not
parent
.
has_instruction
(
input
))
continue
;
auto
v
=
param
->
get_operator
().
to_value
();
auto
name
=
v
.
at
(
"parameter"
).
to
<
std
::
string
>
();
names
[
name
]
=
input
;
}
std
::
transform
(
names
.
begin
(),
names
.
end
(),
std
::
back_inserter
(
result
),
[](
const
auto
&
p
)
{
return
p
.
second
;
});
assert
(
result
.
size
()
==
sm
->
get_parameter_shapes
().
size
());
return
result
;
}
static
void
create_reduce_modules
(
module_pass_manager
&
mpm
)
{
std
::
size_t
n
=
0
;
for
(
auto
ins
:
iterator_for
(
mpm
.
get_module
()))
{
if
(
not
ins
->
get_operator
().
attributes
().
get
(
"reduce"
,
false
))
continue
;
if
(
ins
->
inputs
().
size
()
!=
1
)
continue
;
auto
*
rm
=
mpm
.
create_module
(
mpm
.
get_module
().
name
()
+
":"
+
ins
->
name
()
+
std
::
to_string
(
n
++
));
rm
->
set_bypass
();
rm
->
add_return
(
insert_ins_in_submodule
(
rm
,
ins
));
auto
v
=
ins
->
get_operator
().
to_value
();
mpm
.
get_module
().
replace_instruction
(
ins
,
make_op
(
"fused_reduce"
,
{{
"axes"
,
v
[
"axes"
]}}),
ins
->
inputs
(),
{
rm
});
}
}
template
<
class
...
Ms
>
static
auto
match_broadcast
(
Ms
...
ms
)
{
return
match
::
skip
(
match
::
name
(
"contiguous"
))(
match
::
name
(
"multibroadcast"
)(
match
::
arg
(
0
)(
ms
...),
match
::
used_once
()).
bind
(
"broadcast"
));
}
template
<
class
...
Ms
>
static
auto
any_input
(
Ms
...
ms
)
{
return
match
::
any_of
[
match
::
inputs
()](
match
::
any
(
ms
...).
bind
(
"input"
));
}
static
auto
match_broadcastable_input
(
const
std
::
string
&
op
,
const
std
::
string
&
name
)
{
auto
match_op
=
match
::
name
(
op
)(
match
::
used_once
()).
bind
(
name
);
auto
match_op_input
=
any_input
(
match_op
,
match
::
used_once
());
auto
broadcast_match_op_input
=
any_input
(
match_broadcast
(
match_op
),
match
::
used_once
());
return
match
::
any_of
(
match_op_input
,
broadcast_match_op_input
);
}
namespace
{
struct
find_pointwise_reduce
{
auto
matcher
()
const
{
return
match
::
name
(
"fused_reduce"
)(
match_broadcastable_input
(
"pointwise"
,
"pointwise"
));
}
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
auto
reduce
=
r
.
result
;
auto
input
=
r
.
instructions
[
"pointwise"
];
const
auto
*
pm
=
input
->
module_inputs
().
front
();
const
auto
*
old_rm
=
reduce
->
module_inputs
().
front
();
auto
*
rm
=
mpm
.
create_module
(
pm
->
name
()
+
":"
+
old_rm
->
name
());
rm
->
set_bypass
();
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
map_ins
;
// Insert pointwise
auto
rins
=
insert_ins_in_submodule
(
rm
,
input
,
map_ins
).
front
();
map_ins
[
input
]
=
rins
;
if
(
contains
(
r
.
instructions
,
"broadcast"
))
{
auto
broadcast
=
r
.
instructions
[
"broadcast"
];
map_ins
[
broadcast
]
=
insert_ins_in_submodule
(
rm
,
broadcast
,
map_ins
).
front
();
}
// Insert fused_reduce
rm
->
add_return
(
insert_module_in_submodule
(
rm
,
reduce
,
map_ins
));
auto
new_inputs
=
find_inputs
(
rm
,
mpm
.
get_module
(),
map_ins
);
mpm
.
get_module
().
replace_instruction
(
reduce
,
reduce
->
get_operator
(),
new_inputs
,
{
rm
});
}
};
struct
find_reduce_pointwise
{
auto
matcher
()
const
{
return
match
::
name
(
"pointwise"
)(
match_broadcastable_input
(
"fused_reduce"
,
"reduce"
));
}
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
auto
pw
=
r
.
result
;
auto
reduce
=
r
.
instructions
[
"reduce"
];
auto
input
=
r
.
instructions
[
"input"
];
const
auto
*
pm
=
pw
->
module_inputs
().
front
();
const
auto
*
old_rm
=
reduce
->
module_inputs
().
front
();
auto
*
rm
=
mpm
.
create_module
(
old_rm
->
name
()
+
":"
+
pm
->
name
());
rm
->
set_bypass
();
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
map_ins
;
// Copy module instructions
insert_module_in_submodule
(
rm
,
reduce
,
map_ins
);
if
(
contains
(
r
.
instructions
,
"broadcast"
))
{
auto
broadcast
=
r
.
instructions
[
"broadcast"
];
map_ins
[
broadcast
->
inputs
().
front
()]
=
rm
->
get_returns
().
front
();
auto
bout
=
insert_ins_in_submodule
(
rm
,
broadcast
,
map_ins
);
map_ins
[
input
]
=
bout
.
front
();
}
else
{
map_ins
[
input
]
=
rm
->
get_returns
().
front
();
}
auto
out
=
insert_ins_in_submodule
(
rm
,
pw
,
map_ins
);
rm
->
replace_return
(
out
);
auto
new_inputs
=
find_inputs
(
rm
,
mpm
.
get_module
(),
map_ins
);
mpm
.
get_module
().
replace_instruction
(
pw
,
reduce
->
get_operator
(),
new_inputs
,
{
rm
});
}
};
struct
find_reduce_reduce
{
auto
matcher
()
const
{
return
match
::
name
(
"fused_reduce"
)(
match_broadcastable_input
(
"fused_reduce"
,
"reduce"
));
}
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
auto
reduce1
=
r
.
result
;
auto
reduce2
=
r
.
instructions
[
"reduce"
];
auto
input
=
r
.
instructions
[
"input"
];
if
(
reduce1
->
get_operator
()
!=
reduce2
->
get_operator
())
return
;
const
auto
*
rm1
=
reduce1
->
module_inputs
().
front
();
const
auto
*
rm2
=
reduce2
->
module_inputs
().
front
();
auto
*
rm
=
mpm
.
create_module
(
rm1
->
name
()
+
":"
+
rm2
->
name
());
rm
->
set_bypass
();
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
map_ins
;
// Copy reduce1 instructions
insert_module_in_submodule
(
rm
,
reduce2
,
map_ins
);
if
(
contains
(
r
.
instructions
,
"broadcast"
))
{
auto
broadcast
=
r
.
instructions
[
"broadcast"
];
map_ins
[
broadcast
->
inputs
().
front
()]
=
rm
->
get_returns
().
front
();
auto
bout
=
insert_ins_in_submodule
(
rm
,
broadcast
,
map_ins
);
map_ins
[
input
]
=
bout
.
front
();
}
else
{
map_ins
[
input
]
=
rm
->
get_returns
().
front
();
}
auto
out
=
insert_module_in_submodule
(
rm
,
reduce1
,
map_ins
);
rm
->
replace_return
(
out
);
auto
new_inputs
=
find_inputs
(
rm
,
mpm
.
get_module
(),
map_ins
);
mpm
.
get_module
().
replace_instruction
(
reduce1
,
reduce1
->
get_operator
(),
new_inputs
,
{
rm
});
}
};
}
// namespace
void
fuse_reduce
::
apply
(
module_pass_manager
&
mpm
)
const
{
create_reduce_modules
(
mpm
);
mpm
.
run_pass
(
dead_code_elimination
{});
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
match
::
find_matches
(
mpm
,
find_reduce_pointwise
{},
find_pointwise_reduce
{},
find_reduce_reduce
{});
mpm
.
run_pass
(
dead_code_elimination
{});
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/include/migraphx/check_context.hpp
View file @
fe493c28
...
...
@@ -27,6 +27,8 @@
#include <migraphx/program.hpp>
#include <migraphx/config.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/ranges.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -36,7 +38,27 @@ struct check_context
{
struct
op
:
auto_register_op
<
op
>
{
std
::
string
name
()
const
{
return
"check_context::"
+
get_type_name
<
T
>
();
}
static
std
::
string
compute_op_name
()
{
const
auto
&
op_type_name
=
get_type_name
<
T
>
();
const
auto
&
split_name
=
split_string
(
op_type_name
,
':'
);
std
::
vector
<
std
::
string
>
name_without_version
=
{
"check_context"
};
// op_type_name would contain internal namespace name with version_x_y_z
// remove version and construct op_name such as check_context::migraphx::gpu::context
std
::
copy_if
(
split_name
.
begin
(),
split_name
.
end
(),
std
::
back_inserter
(
name_without_version
),
[
&
](
const
auto
&
i
)
{
return
not
i
.
empty
()
and
not
contains
(
i
,
"version"
);
});
return
join_strings
(
name_without_version
,
"::"
);
}
std
::
string
name
()
const
{
static
auto
op_name
=
compute_op_name
();
return
op_name
;
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
)
const
{
return
{};
}
argument
compute
(
context
&
ctx
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
const
{
...
...
src/include/migraphx/common.hpp
View file @
fe493c28
...
...
@@ -41,6 +41,11 @@ std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, sha
shape
common_shape
(
const
std
::
vector
<
shape
>&
shapes
);
std
::
vector
<
instruction_ref
>
insert_common_args
(
module
&
m
,
instruction_ref
ins
,
std
::
vector
<
instruction_ref
>
inputs
);
std
::
vector
<
instruction_ref
>
add_common_args
(
module
&
m
,
std
::
vector
<
instruction_ref
>
inputs
);
instruction_ref
insert_common_op
(
module
&
m
,
instruction_ref
ins
,
const
operation
&
op
,
...
...
src/include/migraphx/compile_options.hpp
View file @
fe493c28
...
...
@@ -32,9 +32,17 @@ inline namespace MIGRAPHX_INLINE_NS {
struct
compile_options
{
bool
offload_copy
=
false
;
/**
* Have MIGX allocate memory for parameters and add instructions
* to copy parameters and output to/from an offload device like a GPU.
*/
bool
offload_copy
=
false
;
bool
fast_math
=
true
;
bool
exhaustive_tune
=
false
;
/// Use the split_single_dyn_dim pass
bool
split_single_dyn_dim
=
false
;
tracer
trace
{};
};
...
...
src/include/migraphx/cpp_generator.hpp
View file @
fe493c28
...
...
@@ -77,6 +77,7 @@ struct cpp_generator
function
&
set_types
(
const
module
&
m
);
function
&
set_types
(
const
module
&
m
,
const
std
::
function
<
std
::
string
(
shape
)
>&
parse
);
function
&
set_generic_types
(
const
module
&
m
);
function
&
add_generic_param
(
const
std
::
string
&
pname
);
};
cpp_generator
();
...
...
@@ -105,6 +106,10 @@ struct cpp_generator
std
::
string
create_function
(
const
function
&
f
);
static
std
::
vector
<
std
::
string
>
to_args
(
const
std
::
vector
<
instruction_ref
>&
inputs
,
const
std
::
unordered_map
<
instruction_ref
,
std
::
string
>&
names
);
private:
std
::
unique_ptr
<
cpp_generator_impl
>
impl
;
};
...
...
src/include/migraphx/dynamic_loader.hpp
View file @
fe493c28
...
...
@@ -37,6 +37,12 @@ struct dynamic_loader_impl;
struct
dynamic_loader
{
template
<
class
T
>
static
fs
::
path
path
(
T
*
address
)
{
return
path
(
reinterpret_cast
<
void
*>
(
address
));
}
static
fs
::
path
path
(
void
*
address
);
dynamic_loader
()
=
default
;
dynamic_loader
(
const
fs
::
path
&
p
);
...
...
src/include/migraphx/fuse_reduce.hpp
0 → 100644
View file @
fe493c28
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHX_FUSE_REDUCE_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_FUSE_REDUCE_HPP
#include <migraphx/config.hpp>
#include <string>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
module_pass_manager
;
struct
fuse_reduce
{
std
::
string
name
()
const
{
return
"fuse_reduce"
;
}
void
apply
(
module_pass_manager
&
mpm
)
const
;
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_FUSE_POINTWISE_HPP
src/include/migraphx/matcher.hpp
View file @
fe493c28
...
...
@@ -347,6 +347,7 @@ match::matcher_result find_match(module& modl, M&& m)
}
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_TRACE_MATCHES
)
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_VALIDATE_MATCHES
)
/// Find matches for an instruction in the module
template
<
class
Mod
,
class
...
Ms
>
...
...
@@ -356,7 +357,11 @@ void find_matches(Mod& mod, instruction_ref ins, Ms&&... ms)
const
#endif
int
trace
=
value_of
(
MIGRAPHX_TRACE_MATCHES
{});
bool
match
=
false
;
#if !defined(__GNUC__) || defined(__clang__) || __GNUC__ > 5
const
#endif
bool
validate
=
enabled
(
MIGRAPHX_VALIDATE_MATCHES
{});
bool
match
=
false
;
each_args
(
[
&
](
auto
&&
m
)
{
if
(
match
)
...
...
@@ -371,7 +376,20 @@ void find_matches(Mod& mod, instruction_ref ins, Ms&&... ms)
std
::
cout
<<
"Matched by "
<<
get_type_name
(
m
)
<<
std
::
endl
;
get_module
(
mod
).
debug_print
(
ins
);
}
// If its already invalid dont validate it again
bool
invalidated
=
validate
and
get_module
(
mod
).
validate
()
!=
get_module
(
mod
).
end
();
m
.
apply
(
mod
,
r
);
if
(
validate
and
not
invalidated
)
{
auto
invalid
=
get_module
(
mod
).
validate
();
if
(
invalid
!=
get_module
(
mod
).
end
())
{
std
::
cout
<<
"Invalid program from match: "
<<
get_type_name
(
m
)
<<
std
::
endl
;
std
::
cout
<<
"Invalid instructions: "
<<
std
::
endl
;
get_module
(
mod
).
debug_print
(
invalid
->
inputs
());
get_module
(
mod
).
debug_print
(
invalid
);
}
}
match
=
true
;
},
ms
...);
...
...
@@ -520,6 +538,8 @@ MIGRAPHX_PRED_MATCHER(not_standard_shape, instruction_ref ins)
{
return
not
ins
->
get_shape
().
standard
();
}
MIGRAPHX_PRED_MATCHER
(
dynamic_shape
,
instruction_ref
ins
)
{
return
ins
->
get_shape
().
dynamic
();
}
MIGRAPHX_PRED_MATCHER
(
static_shape
,
instruction_ref
ins
)
{
return
not
ins
->
get_shape
().
dynamic
();
}
MIGRAPHX_PRED_MATCHER
(
broadcast_shape
,
instruction_ref
ins
)
{
return
ins
->
get_shape
().
broadcasted
();
...
...
src/include/migraphx/module.hpp
View file @
fe493c28
...
...
@@ -178,6 +178,8 @@ struct module
bool
has_instruction
(
instruction_ref
ins
)
const
;
std
::
vector
<
instruction_ref
>
get_returns
()
const
;
std
::
size_t
size
()
const
;
instruction_ref
begin
()
const
;
instruction_ref
end
()
const
;
...
...
src/include/migraphx/msgpack.hpp
View file @
fe493c28
...
...
@@ -26,10 +26,12 @@
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <functional>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
void
to_msgpack
(
const
value
&
v
,
std
::
function
<
void
(
const
char
*
,
std
::
size_t
)
>
writer
);
std
::
vector
<
char
>
to_msgpack
(
const
value
&
v
);
value
from_msgpack
(
const
std
::
vector
<
char
>&
buffer
);
value
from_msgpack
(
const
char
*
buffer
,
std
::
size_t
size
);
...
...
src/include/migraphx/onnx.hpp
View file @
fe493c28
...
...
@@ -37,7 +37,7 @@ struct onnx_options
std
::
size_t
default_dim_value
=
0
;
/// Default dynamic dimension size (if both default_dim_value and default_dyn_dim_value set
/// parser throws)
shape
::
dynamic_dimension
default_dyn_dim_value
=
{
1
,
1
,
0
};
shape
::
dynamic_dimension
default_dyn_dim_value
=
{
1
,
1
};
/// Explicitly specify the dims of an input
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
size_t
>>
map_input_dims
=
{};
/// Explicitly specify dynamic dims of an input (if both map_input_dims and map_dyn_input_dims
...
...
src/include/migraphx/op/argmax.hpp
View file @
fe493c28
...
...
@@ -62,7 +62,7 @@ struct argmax
if
(
s0
.
dynamic
())
{
auto
dyn_dims
=
s0
.
dyn_dims
();
dyn_dims
[
axis
]
=
{
1
,
1
,
0
};
dyn_dims
[
axis
]
=
{
1
,
1
};
return
{
shape
::
int64_type
,
dyn_dims
};
}
else
...
...
src/include/migraphx/op/concat.hpp
View file @
fe493c28
...
...
@@ -134,7 +134,7 @@ struct concat
}
auto
new_dims
=
inputs
[
0
].
dyn_dims
();
new_dims
[
axis
]
=
migraphx
::
shape
::
dynamic_dimension
{
new_min
,
new_max
,
0
};
new_dims
[
axis
]
=
migraphx
::
shape
::
dynamic_dimension
{
new_min
,
new_max
};
return
{
inputs
[
0
].
type
(),
new_dims
};
}
else
...
...
src/include/migraphx/op/contiguous.hpp
View file @
fe493c28
...
...
@@ -48,7 +48,7 @@ struct contiguous
{
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
);
auto
s0
=
inputs
.
front
();
if
(
s0
.
dynamic
()
or
s0
.
standard
()
)
if
(
s0
.
dynamic
())
{
return
s0
;
}
...
...
src/include/migraphx/op/convolution.hpp
View file @
fe493c28
...
...
@@ -38,6 +38,10 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
/**
* Convolution operator. Does not support optimal dimensions for spatial dimensions. Returns empty
* optimals.
*/
struct
convolution
{
std
::
vector
<
std
::
size_t
>
padding
=
{
0
,
0
};
...
...
@@ -148,7 +152,7 @@ struct convolution
else
{
auto
l
=
input_shape
.
lens
().
at
(
0
);
output_dyn_dims
.
push_back
({
l
,
l
,
0
});
output_dyn_dims
.
push_back
({
l
,
l
});
}
};
...
...
@@ -165,25 +169,30 @@ struct convolution
if
(
x_shape
.
dynamic
())
{
auto
x
=
x_shape
.
dyn_dims
()[
i
+
2
];
output_dyn_dims
.
push_back
(
shape
::
dynamic_dimension
{
ceil_div
(
x
.
min
,
s
),
ceil_div
(
x
.
max
,
s
),
ceil_div
(
x
.
opt
,
s
)});
std
::
set
<
std
::
size_t
>
optimals
{};
std
::
transform
(
x
.
optimals
.
begin
(),
x
.
optimals
.
end
(),
std
::
inserter
(
optimals
,
optimals
.
begin
()),
[
&
](
auto
o
)
{
return
ceil_div
(
o
,
s
);
});
output_dyn_dims
.
push_back
(
shape
::
dynamic_dimension
{
ceil_div
(
x
.
min
,
s
),
ceil_div
(
x
.
max
,
s
),
optimals
});
}
else
{
auto
od
=
ceil_div
(
x_shape
.
lens
()[
i
+
2
],
s
);
output_dyn_dims
.
push_back
(
shape
::
dynamic_dimension
{
od
,
od
,
0
});
output_dyn_dims
.
push_back
(
shape
::
dynamic_dimension
{
od
,
od
});
}
}
}
else
{
// Does not compute for optimals
auto
min_spatial_dims
=
calc_conv_lens
(
x_shape
.
min_lens
(),
w_shape
.
max_lens
());
auto
max_spatial_dims
=
calc_conv_lens
(
x_shape
.
max_lens
(),
w_shape
.
min_lens
());
auto
opt_spatial_dims
=
calc_conv_lens
(
x_shape
.
opt_lens
(),
w_shape
.
opt_lens
());
for
(
size_t
i
=
0
;
i
<
num_spatial_dims
;
++
i
)
{
output_dyn_dims
.
push_back
(
shape
::
dynamic_dimension
{
min_spatial_dims
[
i
],
max_spatial_dims
[
i
],
opt_spatial_dims
[
i
]
});
output_dyn_dims
.
push_back
(
shape
::
dynamic_dimension
{
min_spatial_dims
[
i
],
max_spatial_dims
[
i
],
{}
});
}
}
return
shape
{
x_shape
.
type
(),
output_dyn_dims
};
...
...
src/include/migraphx/op/dequantizelinear.hpp
View file @
fe493c28
...
...
@@ -40,7 +40,11 @@ struct dequantizelinear
std
::
string
name
()
const
{
return
"dequantizelinear"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
same_dims
();
check_shapes
{
inputs
,
*
this
}.
same_dims
().
has
(
2
,
3
);
if
(
inputs
.
size
()
==
3
and
inputs
[
0
].
type
()
!=
inputs
[
2
].
type
())
{
MIGRAPHX_THROW
(
"DEQUANTIZELINEAR: Zero point and input should be the same type."
);
}
return
{
inputs
[
1
].
type
(),
inputs
[
0
].
lens
(),
inputs
[
0
].
strides
()};
}
...
...
src/include/migraphx/op/flatten.hpp
View file @
fe493c28
...
...
@@ -29,6 +29,7 @@
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/dyn_output.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -59,27 +60,22 @@ struct flatten
auto
s
=
inputs
[
0
];
if
(
s
.
dynamic
())
{
// Doesn't handle optimals
auto
min_lens
=
s
.
min_lens
();
auto
max_lens
=
s
.
max_lens
();
auto
opt_lens
=
s
.
opt_lens
();
// If any of the opt values is 0, output opt will be 0
shape
::
dynamic_dimension
x
=
{
std
::
accumulate
(
min_lens
.
begin
(),
min_lens
.
begin
()
+
axis
,
std
::
size_t
{
1
},
std
::
multiplies
<>
{}),
std
::
accumulate
(
max_lens
.
begin
(),
max_lens
.
begin
()
+
axis
,
std
::
size_t
{
1
},
std
::
multiplies
<>
{}),
std
::
accumulate
(
opt_lens
.
begin
(),
opt_lens
.
begin
()
+
axis
,
std
::
size_t
{
1
},
std
::
multiplies
<>
{})};
{}};
shape
::
dynamic_dimension
y
=
{
std
::
accumulate
(
min_lens
.
begin
()
+
axis
,
min_lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<>
{}),
std
::
accumulate
(
max_lens
.
begin
()
+
axis
,
max_lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<>
{}),
std
::
accumulate
(
opt_lens
.
begin
()
+
axis
,
opt_lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<>
{}),
};
{}};
return
{
s
.
type
(),
{
x
,
y
}};
}
else
...
...
src/include/migraphx/op/gathernd.hpp
View file @
fe493c28
...
...
@@ -121,7 +121,7 @@ struct gathernd
// A rank 0 output is a scalar
if
(
output_ndim
==
0
)
return
shape
(
data_shape
.
type
(),
{
shape
::
dynamic_dimension
({
1
,
1
,
0
})});
return
shape
(
data_shape
.
type
(),
{
shape
::
dynamic_dimension
({
1
,
1
})});
// Part of the output shape comes from indices tensor, part from data tensor
std
::
vector
<
shape
::
dynamic_dimension
>
output_dims
(
output_ndim
);
...
...
src/include/migraphx/op/nonmaxsuppression.hpp
View file @
fe493c28
...
...
@@ -119,8 +119,8 @@ struct nonmaxsuppression
fixed_shape_error_check
();
}
std
::
vector
<
shape
::
dynamic_dimension
>
out_lens
=
{};
out_lens
.
push_back
({
0
,
max_num_boxes
,
0
});
out_lens
.
push_back
({
3
,
3
,
0
});
out_lens
.
push_back
({
0
,
max_num_boxes
});
out_lens
.
push_back
({
3
,
3
});
return
{
shape
::
int64_type
,
out_lens
};
}
else
...
...
src/include/migraphx/op/pooling.hpp
View file @
fe493c28
...
...
@@ -89,25 +89,17 @@ struct pooling
std
::
vector
<
std
::
size_t
>
output_lens
{};
for
(
size_t
i
=
0
;
i
<
kdims
;
++
i
)
{
if
(
input_lens
[
i
+
2
]
==
0
)
{
// handle opt = 0
output_lens
.
push_back
(
0
);
}
else
{
std
::
size_t
padding_factor
=
2
*
padding
[
i
];
if
(
padding
.
size
()
==
2
*
kdims
)
padding_factor
=
padding
[
i
]
+
padding
[
i
+
kdims
];
assert
(
input_lens
[
i
+
2
]
+
padding_factor
>=
lengths
[
i
]);
std
::
size_t
dim_size
=
input_lens
[
i
+
2
]
+
padding_factor
-
lengths
[
i
];
std
::
size_t
len
=
(
ceil_mode
)
?
dim_size
/
stride
[
i
]
+
static_cast
<
std
::
size_t
>
((
dim_size
%
stride
[
i
]
!=
0
))
// ceil uint divide
:
dim_size
/
stride
[
i
];
// floor divide
output_lens
.
push_back
(
len
+
1
);
}
std
::
size_t
padding_factor
=
2
*
padding
[
i
];
if
(
padding
.
size
()
==
2
*
kdims
)
padding_factor
=
padding
[
i
]
+
padding
[
i
+
kdims
];
assert
(
input_lens
[
i
+
2
]
+
padding_factor
>=
lengths
[
i
]);
std
::
size_t
dim_size
=
input_lens
[
i
+
2
]
+
padding_factor
-
lengths
[
i
];
std
::
size_t
len
=
(
ceil_mode
)
?
dim_size
/
stride
[
i
]
+
static_cast
<
std
::
size_t
>
((
dim_size
%
stride
[
i
]
!=
0
))
// ceil uint divide
:
dim_size
/
stride
[
i
];
// floor divide
output_lens
.
push_back
(
len
+
1
);
}
return
output_lens
;
}
...
...
@@ -134,19 +126,19 @@ struct pooling
{
for
(
size_t
i
=
0
;
i
<
kdims
;
++
i
)
{
output_dyn_dims
.
push_back
(
shape
::
dynamic_dimension
{
1
,
1
,
1
});
output_dyn_dims
.
push_back
(
shape
::
dynamic_dimension
{
1
,
1
});
}
return
{
input
.
type
(),
output_dyn_dims
};
}
else
{
// does not compute for optimals
auto
min_spatial_dims
=
calc_spatial_dim_out
(
input
.
min_lens
(),
kdims
);
auto
max_spatial_dims
=
calc_spatial_dim_out
(
input
.
max_lens
(),
kdims
);
auto
opt_spatial_dims
=
calc_spatial_dim_out
(
input
.
opt_lens
(),
kdims
);
for
(
size_t
i
=
0
;
i
<
kdims
;
++
i
)
{
output_dyn_dims
.
push_back
(
shape
::
dynamic_dimension
{
min_spatial_dims
[
i
],
max_spatial_dims
[
i
],
opt_spatial_dims
[
i
]
});
output_dyn_dims
.
push_back
(
shape
::
dynamic_dimension
{
min_spatial_dims
[
i
],
max_spatial_dims
[
i
],
{}
});
}
return
{
input
.
type
(),
output_dyn_dims
};
}
...
...
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment