Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
9bff4331
Commit
9bff4331
authored
Mar 21, 2023
by
Paul
Browse files
Merge
parents
214b313f
94a7f6ee
Changes
274
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
440 additions
and
711 deletions
+440
-711
src/onnx/include/migraphx/onnx/onnx_parser.hpp
src/onnx/include/migraphx/onnx/onnx_parser.hpp
+2
-1
src/onnx/onnx_parser.cpp
src/onnx/onnx_parser.cpp
+44
-15
src/onnx/parse_gemm.cpp
src/onnx/parse_gemm.cpp
+56
-39
src/onnx/parse_if.cpp
src/onnx/parse_if.cpp
+20
-2
src/onnx/parse_loop.cpp
src/onnx/parse_loop.cpp
+1
-1
src/onnx/parse_matmul.cpp
src/onnx/parse_matmul.cpp
+58
-34
src/onnx/parse_pad.cpp
src/onnx/parse_pad.cpp
+6
-0
src/onnx/parse_reduce_op.cpp
src/onnx/parse_reduce_op.cpp
+1
-2
src/onnx/parse_reshape.cpp
src/onnx/parse_reshape.cpp
+1
-1
src/onnx/parse_slice.cpp
src/onnx/parse_slice.cpp
+7
-2
src/onnx/parse_trilu.cpp
src/onnx/parse_trilu.cpp
+90
-0
src/onnx/parse_where.cpp
src/onnx/parse_where.cpp
+34
-18
src/opt/memory_coloring_impl.cpp
src/opt/memory_coloring_impl.cpp
+0
-376
src/opt/memory_coloring_impl.hpp
src/opt/memory_coloring_impl.hpp
+0
-193
src/optimize_module.cpp
src/optimize_module.cpp
+49
-0
src/pass_manager.cpp
src/pass_manager.cpp
+11
-10
src/program.cpp
src/program.cpp
+23
-5
src/propagate_constant.cpp
src/propagate_constant.cpp
+16
-7
src/py/migraphx_py.cpp
src/py/migraphx_py.cpp
+11
-5
src/register_op.cpp
src/register_op.cpp
+10
-0
No files found.
src/onnx/include/migraphx/onnx/onnx_parser.hpp
View file @
9bff4331
...
@@ -113,7 +113,8 @@ struct onnx_parser
...
@@ -113,7 +113,8 @@ struct onnx_parser
void
parse_from
(
std
::
istream
&
is
,
std
::
string
name
=
""
);
void
parse_from
(
std
::
istream
&
is
,
std
::
string
name
=
""
);
void
parse_from
(
const
void
*
data
,
std
::
size_t
size
);
void
parse_from
(
const
void
*
data
,
std
::
size_t
size
);
void
parse_graph
(
module
*
mod
,
const
onnx
::
GraphProto
&
graph
);
std
::
vector
<
instruction_ref
>
parse_graph
(
module
*
mod
,
const
onnx
::
GraphProto
&
graph
,
bool
inlining
=
false
);
literal
parse_value
(
const
onnx
::
AttributeProto
&
attr
)
const
;
literal
parse_value
(
const
onnx
::
AttributeProto
&
attr
)
const
;
literal
parse_tensor
(
const
onnx
::
TensorProto
&
t
)
const
;
literal
parse_tensor
(
const
onnx
::
TensorProto
&
t
)
const
;
shape
parse_type
(
const
onnx
::
TypeProto
&
t
,
const
std
::
vector
<
std
::
size_t
>&
input_dims
)
const
;
shape
parse_type
(
const
onnx
::
TypeProto
&
t
,
const
std
::
vector
<
std
::
size_t
>&
input_dims
)
const
;
...
...
src/onnx/onnx_parser.cpp
View file @
9bff4331
...
@@ -110,9 +110,19 @@ instruction_ref onnx_parser::node_info::add_bias(const std::vector<instruction_r
...
@@ -110,9 +110,19 @@ instruction_ref onnx_parser::node_info::add_bias(const std::vector<instruction_r
{
{
if
(
args
.
size
()
==
3
)
if
(
args
.
size
()
==
3
)
{
{
auto
bias_bcast
=
mod
->
add_instruction
(
instruction_ref
bias_bcast
;
make_op
(
"broadcast"
,
{{
"axis"
,
axis
},
{
"out_lens"
,
curr_ins
->
get_shape
().
lens
()}}),
// if curr_ins has a dynamic output shape use 2 input broadcast
args
[
2
]);
if
(
curr_ins
->
get_shape
().
dynamic
())
{
bias_bcast
=
mod
->
add_instruction
(
make_op
(
"broadcast"
,
{{
"axis"
,
axis
}}),
args
[
2
],
curr_ins
);
}
else
{
bias_bcast
=
mod
->
add_instruction
(
make_op
(
"broadcast"
,
{{
"axis"
,
axis
},
{
"out_lens"
,
curr_ins
->
get_shape
().
lens
()}}),
args
[
2
]);
}
return
mod
->
add_instruction
(
make_op
(
"add"
),
curr_ins
,
bias_bcast
);
return
mod
->
add_instruction
(
make_op
(
"add"
),
curr_ins
,
bias_bcast
);
}
}
return
curr_ins
;
return
curr_ins
;
...
@@ -210,7 +220,7 @@ void onnx_parser::parse_from(std::istream& is, std::string name)
...
@@ -210,7 +220,7 @@ void onnx_parser::parse_from(std::istream& is, std::string name)
if
(
model
.
has_graph
())
if
(
model
.
has_graph
())
{
{
this
->
parse_graph
(
mm
,
model
.
graph
());
(
void
)
this
->
parse_graph
(
mm
,
model
.
graph
());
}
}
}
}
else
else
...
@@ -230,7 +240,7 @@ void onnx_parser::parse_from(const void* data, std::size_t size)
...
@@ -230,7 +240,7 @@ void onnx_parser::parse_from(const void* data, std::size_t size)
if
(
model
.
has_graph
())
if
(
model
.
has_graph
())
{
{
this
->
parse_graph
(
mm
,
model
.
graph
());
(
void
)
this
->
parse_graph
(
mm
,
model
.
graph
());
}
}
}
}
else
else
...
@@ -254,7 +264,8 @@ int64_t onnx_parser::get_opset_version(const onnx::ModelProto& model)
...
@@ -254,7 +264,8 @@ int64_t onnx_parser::get_opset_version(const onnx::ModelProto& model)
return
version
;
return
version
;
}
}
void
onnx_parser
::
parse_graph
(
module
*
mod
,
const
onnx
::
GraphProto
&
graph
)
std
::
vector
<
instruction_ref
>
onnx_parser
::
parse_graph
(
module
*
mod
,
const
onnx
::
GraphProto
&
graph
,
bool
inlining
)
{
{
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
mod_insts
;
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
mod_insts
;
for
(
auto
&&
f
:
graph
.
initializer
())
for
(
auto
&&
f
:
graph
.
initializer
())
...
@@ -362,11 +373,16 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
...
@@ -362,11 +373,16 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
std
::
back_inserter
(
output_ins
),
std
::
back_inserter
(
output_ins
),
[
&
](
const
auto
&
name
)
{
return
instructions
[
name
];
});
[
&
](
const
auto
&
name
)
{
return
instructions
[
name
];
});
// add the return instuction
if
(
not
inlining
)
mod
->
add_return
(
output_ins
);
{
// add the return instuction
mod
->
add_return
(
output_ins
);
// remove instructions added in this mod
// Remove instructions added in module (this is turned off for subgraph inlining)
erase_if
(
instructions
,
[
&
](
auto
&&
p
)
{
return
mod
->
has_instruction
(
p
.
second
);
});
erase_if
(
instructions
,
[
&
](
auto
&&
p
)
{
return
mod
->
has_instruction
(
p
.
second
);
});
}
return
output_ins
;
}
}
literal
onnx_parser
::
parse_value
(
const
onnx
::
AttributeProto
&
attr
)
const
literal
onnx_parser
::
parse_value
(
const
onnx
::
AttributeProto
&
attr
)
const
...
@@ -393,18 +409,31 @@ literal onnx_parser::parse_value(const onnx::AttributeProto& attr) const
...
@@ -393,18 +409,31 @@ literal onnx_parser::parse_value(const onnx::AttributeProto& attr) const
literal
onnx_parser
::
parse_tensor
(
const
onnx
::
TensorProto
&
t
)
const
literal
onnx_parser
::
parse_tensor
(
const
onnx
::
TensorProto
&
t
)
const
{
{
std
::
vector
<
std
::
size_t
>
dims
(
t
.
dims
().
begin
(),
t
.
dims
().
end
());
std
::
vector
<
std
::
size_t
>
dims
(
t
.
dims
().
begin
(),
t
.
dims
().
end
());
if
(
not
t
.
external_data
().
empty
())
auto
type
=
get_type
(
t
.
data_type
());
shape
tensor_shape
(
type
,
dims
);
auto
external_data
=
t
.
external_data
();
if
(
not
external_data
.
empty
())
{
{
const
std
::
string
&
data_file
=
t
.
external_data
().
at
(
0
).
value
();
const
std
::
string
&
data_file
=
external_data
.
at
(
0
).
value
();
auto
raw_buffer
=
read_buffer
(
path
+
"/"
+
data_file
);
size_t
num_data_fields
=
external_data
.
size
();
size_t
offset
=
0
;
size_t
nbytes
=
tensor_shape
.
bytes
();
if
(
num_data_fields
>
1
)
// if offset field is present
{
offset
=
std
::
stoul
(
t
.
external_data
().
at
(
1
).
value
());
}
if
(
num_data_fields
>
2
)
// if nbytes field is present
{
nbytes
=
std
::
stoul
(
t
.
external_data
().
at
(
2
).
value
());
}
auto
raw_buffer
=
read_buffer
(
path
+
"/"
+
data_file
,
offset
,
nbytes
);
std
::
string
s
(
raw_buffer
.
begin
(),
raw_buffer
.
end
());
std
::
string
s
(
raw_buffer
.
begin
(),
raw_buffer
.
end
());
auto
type
=
get_type
(
t
.
data_type
());
return
create_literal
(
type
,
dims
,
s
.
data
());
return
create_literal
(
type
,
dims
,
s
.
data
());
}
}
if
(
t
.
has_raw_data
())
if
(
t
.
has_raw_data
())
{
{
const
std
::
string
&
s
=
t
.
raw_data
();
const
std
::
string
&
s
=
t
.
raw_data
();
auto
type
=
get_type
(
t
.
data_type
());
return
create_literal
(
type
,
dims
,
s
.
data
());
return
create_literal
(
type
,
dims
,
s
.
data
());
}
}
...
...
src/onnx/parse_gemm.cpp
View file @
9bff4331
...
@@ -39,10 +39,19 @@ struct parse_gemm : op_parser<parse_gemm>
...
@@ -39,10 +39,19 @@ struct parse_gemm : op_parser<parse_gemm>
onnx_parser
::
node_info
info
,
onnx_parser
::
node_info
info
,
std
::
vector
<
instruction_ref
>
args
)
const
std
::
vector
<
instruction_ref
>
args
)
const
{
{
float
alpha
=
1.0
f
;
auto
a_arg
=
args
[
0
];
float
beta
=
1.0
f
;
auto
b_arg
=
args
[
1
];
bool
transa
=
false
;
if
(
a_arg
->
get_shape
().
ndim
()
!=
2
or
b_arg
->
get_shape
().
ndim
()
!=
2
)
bool
transb
=
false
;
{
MIGRAPHX_THROW
(
"PARSE_GEMM: A and B should be rank 2, A is rank "
+
std
::
to_string
(
a_arg
->
get_shape
().
ndim
())
+
", B is rank "
+
std
::
to_string
(
b_arg
->
get_shape
().
ndim
()));
}
float
alpha
=
1.0
f
;
float
beta
=
1.0
f
;
bool
trans_a
=
false
;
bool
trans_b
=
false
;
if
(
contains
(
info
.
attributes
,
"alpha"
))
if
(
contains
(
info
.
attributes
,
"alpha"
))
{
{
alpha
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"alpha"
)).
at
<
float
>
();
alpha
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"alpha"
)).
at
<
float
>
();
...
@@ -53,65 +62,73 @@ struct parse_gemm : op_parser<parse_gemm>
...
@@ -53,65 +62,73 @@ struct parse_gemm : op_parser<parse_gemm>
}
}
if
(
contains
(
info
.
attributes
,
"transA"
))
if
(
contains
(
info
.
attributes
,
"transA"
))
{
{
transa
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"transA"
)).
at
<
bool
>
();
trans
_
a
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"transA"
)).
at
<
bool
>
();
}
}
if
(
contains
(
info
.
attributes
,
"transB"
))
if
(
contains
(
info
.
attributes
,
"transB"
))
{
{
transb
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"transB"
)).
at
<
bool
>
();
trans
_
b
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"transB"
)).
at
<
bool
>
();
}
}
std
::
vector
<
int64_t
>
perm
(
args
[
0
]
->
get_shape
().
lens
().
size
());
std
::
vector
<
int64_t
>
perm
=
{
1
,
0
};
std
::
iota
(
perm
.
begin
(),
perm
.
end
(),
int64_t
{
0
});
auto
dot_type
=
a_arg
->
get_shape
().
type
();
// swap the last two elements
std
::
swap
(
*
perm
.
rbegin
(),
*
(
perm
.
rbegin
()
+
1
));
auto
l1
=
args
[
0
];
auto
dot_type
=
l1
->
get_shape
().
type
();
if
(
alpha
!=
1.0
f
)
if
(
alpha
!=
1.0
f
)
{
{
auto
alpha_literal
=
info
.
add_literal
(
alpha
);
auto
alpha_literal
=
info
.
add_literal
(
alpha
);
l1
=
info
.
add_broadcastable_binary_op
(
"mul"
,
alpha_literal
,
l1
);
a_arg
=
info
.
add_broadcastable_binary_op
(
"mul"
,
alpha_literal
,
a_arg
);
if
(
l1
->
get_shape
().
type
()
!=
dot_type
)
if
(
a_arg
->
get_shape
().
type
()
!=
dot_type
)
{
{
l1
=
info
.
add_instruction
(
make_op
(
"convert"
,
{{
"target_type"
,
dot_type
}}),
l1
);
a_arg
=
info
.
add_instruction
(
make_op
(
"convert"
,
{{
"target_type"
,
dot_type
}}),
a_arg
);
}
}
}
}
l1
=
a_arg
=
(
trans_a
)
(
transa
)
?
info
.
add_instruction
(
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
l1
)
:
l1
;
?
info
.
add_instruction
(
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
a_arg
)
auto
l2
=
(
transb
)
:
a_arg
;
?
info
.
add_instruction
(
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
args
[
1
])
b_arg
=
(
trans_b
)
:
args
[
1
];
?
info
.
add_instruction
(
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
args
[
1
])
:
args
[
1
];
auto
ret
=
info
.
add_instruction
(
make_op
(
"dot"
),
l1
,
l2
);
auto
dot_ins
=
info
.
add_instruction
(
make_op
(
"dot"
),
a_arg
,
b_arg
);
if
(
args
.
size
()
==
3
)
if
(
args
.
size
()
==
3
)
{
{
if
(
not
float_equal
(
beta
,
0.0
f
)
&&
args
[
2
]
->
get_shape
().
elements
()
>
0
)
if
(
not
float_equal
(
beta
,
0.0
f
))
{
{
auto
out_lens
=
l1
->
get_shape
().
lens
();
auto
c_arg
=
args
[
2
];
out_lens
.
back
()
=
l2
->
get_shape
().
lens
().
back
();
if
(
dot_ins
->
get_shape
().
dynamic
())
auto
l3
=
args
[
2
];
auto
l3_lens
=
l3
->
get_shape
().
lens
();
if
(
not
std
::
equal
(
out_lens
.
begin
(),
out_lens
.
end
(),
l3_lens
.
begin
(),
l3_lens
.
end
()))
{
{
l3
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
out_lens
}}),
c_arg
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
),
args
[
2
],
dot_ins
);
args
[
2
]);
}
}
auto
beta_literal
=
info
.
add_literal
(
beta
);
else
auto
beta_l3
=
info
.
add_broadcastable_binary_op
(
"mul"
,
l3
,
beta_literal
);
if
(
beta_l3
->
get_shape
().
type
()
!=
dot_type
)
{
{
beta_l3
=
info
.
add_instruction
(
make_op
(
"convert"
,
{{
"target_type"
,
dot_type
}}),
auto
out_lens
=
a_arg
->
get_shape
().
lens
();
beta_l3
);
out_lens
.
back
()
=
b_arg
->
get_shape
().
lens
().
back
();
auto
c_lens
=
c_arg
->
get_shape
().
lens
();
if
(
not
std
::
equal
(
out_lens
.
begin
(),
out_lens
.
end
(),
c_lens
.
begin
(),
c_lens
.
end
()))
{
c_arg
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
out_lens
}}),
args
[
2
]);
}
}
}
return
info
.
add_instruction
(
make_op
(
"add"
),
ret
,
beta_l3
);
if
(
not
float_equal
(
beta
,
1.0
f
))
{
auto
beta_literal
=
info
.
add_literal
(
beta
);
c_arg
=
info
.
add_broadcastable_binary_op
(
"mul"
,
c_arg
,
beta_literal
);
if
(
c_arg
->
get_shape
().
type
()
!=
dot_type
)
{
c_arg
=
info
.
add_instruction
(
make_op
(
"convert"
,
{{
"target_type"
,
dot_type
}}),
c_arg
);
}
}
return
info
.
add_instruction
(
make_op
(
"add"
),
dot_ins
,
c_arg
);
}
}
}
}
return
dot_ins
;
return
ret
;
}
}
};
};
...
...
src/onnx/parse_if.cpp
View file @
9bff4331
...
@@ -51,6 +51,24 @@ struct parse_if : op_parser<parse_if>
...
@@ -51,6 +51,24 @@ struct parse_if : op_parser<parse_if>
" condition input can have only one element!"
);
" condition input can have only one element!"
);
}
}
// Fold instruction if condition is constant thus can be evaled
// prior to inference
if
(
args
.
front
()
->
can_eval
())
{
auto
cond_arg
=
args
.
front
()
->
eval
();
auto
*
mod
=
info
.
mod
;
// then branch
if
(
cond_arg
.
at
<
bool
>
())
{
return
parser
.
parse_graph
(
mod
,
then_graph
,
true
);
}
// else branch
else
{
return
parser
.
parse_graph
(
mod
,
else_graph
,
true
);
}
}
std
::
string
then_name
=
info
.
name
+
"_if"
;
std
::
string
then_name
=
info
.
name
+
"_if"
;
module_ref
then_mdl
=
parser
.
prog
.
create_module
(
then_name
);
module_ref
then_mdl
=
parser
.
prog
.
create_module
(
then_name
);
...
@@ -58,10 +76,10 @@ struct parse_if : op_parser<parse_if>
...
@@ -58,10 +76,10 @@ struct parse_if : op_parser<parse_if>
module_ref
else_mdl
=
parser
.
prog
.
create_module
(
else_name
);
module_ref
else_mdl
=
parser
.
prog
.
create_module
(
else_name
);
// parse the then sub_graph
// parse the then sub_graph
parser
.
parse_graph
(
then_mdl
,
then_graph
);
(
void
)
parser
.
parse_graph
(
then_mdl
,
then_graph
);
// parse_the else sub_graph
// parse_the else sub_graph
parser
.
parse_graph
(
else_mdl
,
else_graph
);
(
void
)
parser
.
parse_graph
(
else_mdl
,
else_graph
);
auto
then_out_shapes
=
then_mdl
->
get_output_shapes
();
auto
then_out_shapes
=
then_mdl
->
get_output_shapes
();
auto
else_out_shapes
=
else_mdl
->
get_output_shapes
();
auto
else_out_shapes
=
else_mdl
->
get_output_shapes
();
...
...
src/onnx/parse_loop.cpp
View file @
9bff4331
...
@@ -71,7 +71,7 @@ struct parse_loop : op_parser<parse_loop>
...
@@ -71,7 +71,7 @@ struct parse_loop : op_parser<parse_loop>
module_ref
sub_mod
=
parser
.
prog
.
create_module
(
mod_name
);
module_ref
sub_mod
=
parser
.
prog
.
create_module
(
mod_name
);
// parse the sub_graph
// parse the sub_graph
parser
.
parse_graph
(
sub_mod
,
sub_graph
);
(
void
)
parser
.
parse_graph
(
sub_mod
,
sub_graph
);
auto
ret
=
info
.
add_instruction
(
auto
ret
=
info
.
add_instruction
(
make_op
(
"loop"
,
{{
"max_iterations"
,
max_iterations
}}),
args
,
{
sub_mod
});
make_op
(
"loop"
,
{{
"max_iterations"
,
max_iterations
}}),
args
,
{
sub_mod
});
...
...
src/onnx/parse_matmul.cpp
View file @
9bff4331
...
@@ -43,55 +43,79 @@ struct parse_matmul : op_parser<parse_matmul>
...
@@ -43,55 +43,79 @@ struct parse_matmul : op_parser<parse_matmul>
const
onnx_parser
::
node_info
&
info
,
const
onnx_parser
::
node_info
&
info
,
std
::
vector
<
instruction_ref
>
args
)
const
std
::
vector
<
instruction_ref
>
args
)
const
{
{
auto
l0
=
args
[
0
];
auto
a0
=
args
[
0
];
auto
l1
=
args
[
1
];
auto
a1
=
args
[
1
];
auto
l0_len
s
=
l
0
->
get_shape
()
.
lens
()
;
auto
s
0
=
a
0
->
get_shape
();
auto
l1_len
s
=
l
1
->
get_shape
()
.
lens
()
;
auto
s
1
=
a
1
->
get_shape
();
// args[0] is a vector, prepend 1 to the shape
instruction_ref
dot_res
;
bool
is_a_prepended
=
false
;
bool
is_a_prepended
=
false
;
if
(
l0_lens
.
size
()
==
1
)
bool
is_b_appended
=
false
;
if
(
s0
.
ndim
()
==
1
)
{
{
is_a_prepended
=
true
;
is_a_prepended
=
true
;
l0_lens
.
insert
(
l0_lens
.
begin
(),
1
);
a0
=
info
.
add_instruction
(
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
0
}}}),
args
[
0
]);
l0
=
info
.
add_instruction
(
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
0
}}}),
args
[
0
]);
}
}
if
(
s1
.
ndim
()
==
1
)
bool
is_b_appended
=
false
;
if
(
l1_lens
.
size
()
==
1
)
{
{
is_b_appended
=
true
;
is_b_appended
=
true
;
l1_lens
.
push_back
(
1
);
a1
=
info
.
add_instruction
(
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
1
}}}),
args
[
1
]);
l1
=
info
.
add_instruction
(
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
1
}}}),
args
[
1
]);
}
}
instruction_ref
bl0
=
l0
;
if
(
s0
.
dynamic
()
or
s1
.
dynamic
())
instruction_ref
bl1
=
l1
;
if
(
not
std
::
equal
(
l0_lens
.
rbegin
()
+
2
,
l0_lens
.
rend
(),
l1_lens
.
rbegin
()
+
2
,
l1_lens
.
rend
()))
{
{
auto
l0_it
=
l0_lens
.
begin
()
+
l0_lens
.
size
()
-
2
;
if
(
opd
.
op_name
==
"quant_dot"
)
std
::
vector
<
std
::
size_t
>
l0_broadcasted_lens
(
l0_lens
.
begin
(),
l0_it
);
{
auto
l1_it
=
l1_lens
.
begin
()
+
l1_lens
.
size
()
-
2
;
MIGRAPHX_THROW
(
"PARSE_MATMUL: dynamic MatMulInteger not supported"
)
;
std
::
vector
<
std
::
size_t
>
l1_broadcasted_lens
(
l1_lens
.
begin
(),
l1_it
);
}
auto
output_lens
=
compute_broadcasted_lens
(
l0_broadcasted_lens
,
l1_broadcasted_lens
);
auto
s0_dds
=
a0
->
get_shape
().
to_dynamic
().
dyn_dims
(
);
l0_broadcasted_lens
=
output_lens
;
auto
s1_dds
=
a1
->
get_shape
().
to_dynamic
().
dyn_dims
()
;
l0_broadcasted_lens
.
insert
(
l0_broadcasted_lens
.
end
(),
l0_it
,
l0_lens
.
end
());
l1_broadcasted_lens
=
output_lens
;
// TODO: handling this case requires a new multibroadcast mode
l1_broadcasted_lens
.
insert
(
l1_broadcasted_lens
.
end
(),
l1_it
,
l1_lens
.
end
());
if
(
not
std
::
equal
(
if
(
l0_lens
!=
l0_broadcasted_lens
)
s0_dds
.
rbegin
()
+
2
,
s0_dds
.
rend
(),
s1_dds
.
rbegin
()
+
2
,
s1_dds
.
rend
())
)
{
{
bl0
=
info
.
add_instruction
(
MIGRAPHX_THROW
(
"PARSE_MATMUL: dynamic shape broadcasting not supported"
);
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
l0_broadcasted_lens
}}),
l0
);
}
}
if
(
l1_lens
!=
l1_broadcasted_lens
)
dot_res
=
info
.
add_instruction
(
make_op
(
opd
.
op_name
),
a0
,
a1
);
}
else
{
auto
s0_lens
=
a0
->
get_shape
().
lens
();
auto
s1_lens
=
a1
->
get_shape
().
lens
();
instruction_ref
ba0
=
a0
;
instruction_ref
ba1
=
a1
;
// try broadcasting if dimensions other than last two do not match
if
(
not
std
::
equal
(
s0_lens
.
rbegin
()
+
2
,
s0_lens
.
rend
(),
s1_lens
.
rbegin
()
+
2
,
s1_lens
.
rend
()))
{
{
bl1
=
info
.
add_instruction
(
auto
l0_it
=
s0_lens
.
begin
()
+
s0_lens
.
size
()
-
2
;
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
l1_broadcasted_lens
}}),
l1
);
std
::
vector
<
std
::
size_t
>
l0_broadcasted_lens
(
s0_lens
.
begin
(),
l0_it
);
auto
l1_it
=
s1_lens
.
begin
()
+
s1_lens
.
size
()
-
2
;
std
::
vector
<
std
::
size_t
>
l1_broadcasted_lens
(
s1_lens
.
begin
(),
l1_it
);
auto
output_lens
=
compute_broadcasted_lens
(
l0_broadcasted_lens
,
l1_broadcasted_lens
);
l0_broadcasted_lens
=
output_lens
;
l0_broadcasted_lens
.
insert
(
l0_broadcasted_lens
.
end
(),
l0_it
,
s0_lens
.
end
());
l1_broadcasted_lens
=
output_lens
;
l1_broadcasted_lens
.
insert
(
l1_broadcasted_lens
.
end
(),
l1_it
,
s1_lens
.
end
());
if
(
s0_lens
!=
l0_broadcasted_lens
)
{
ba0
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
l0_broadcasted_lens
}}),
a0
);
}
if
(
s1_lens
!=
l1_broadcasted_lens
)
{
ba1
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
l1_broadcasted_lens
}}),
a1
);
}
}
}
dot_res
=
info
.
add_instruction
(
make_op
(
opd
.
op_name
),
ba0
,
ba1
);
}
}
instruction_ref
dot_res
=
info
.
add_instruction
(
make_op
(
opd
.
op_name
),
bl0
,
bl1
);
int64_t
num_axis
=
static_cast
<
int64_t
>
(
dot_res
->
get_shape
().
lens
().
size
());
// squeeze the appended or prepended dimensions
int64_t
num_axis
=
dot_res
->
get_shape
().
ndim
();
if
(
is_a_prepended
)
if
(
is_a_prepended
)
{
{
dot_res
=
info
.
add_instruction
(
make_op
(
"squeeze"
,
{{
"axes"
,
{
num_axis
-
2
}}}),
dot_res
);
dot_res
=
info
.
add_instruction
(
make_op
(
"squeeze"
,
{{
"axes"
,
{
num_axis
-
2
}}}),
dot_res
);
...
...
src/onnx/parse_pad.cpp
View file @
9bff4331
...
@@ -147,7 +147,13 @@ struct parse_pad : op_parser<parse_pad>
...
@@ -147,7 +147,13 @@ struct parse_pad : op_parser<parse_pad>
{
{
auto
mode
=
info
.
attributes
.
at
(
"mode"
).
s
();
auto
mode
=
info
.
attributes
.
at
(
"mode"
).
s
();
if
(
mode
==
"reflect"
)
if
(
mode
==
"reflect"
)
{
if
(
args
.
front
()
->
get_shape
().
dynamic
())
{
MIGRAPHX_THROW
(
"PARSE_PAD: reflect padding with dynamic shape not supported"
);
}
return
reflect_pad
(
info
,
pads
,
args
.
front
());
return
reflect_pad
(
info
,
pads
,
args
.
front
());
}
if
(
mode
!=
"constant"
)
if
(
mode
!=
"constant"
)
{
{
MIGRAPHX_THROW
(
MIGRAPHX_THROW
(
...
...
src/onnx/parse_reduce_op.cpp
View file @
9bff4331
...
@@ -68,8 +68,7 @@ instruction_ref parse_reduce_oper(const std::string& op_name,
...
@@ -68,8 +68,7 @@ instruction_ref parse_reduce_oper(const std::string& op_name,
}
}
else
else
{
{
std
::
size_t
n_dim
=
args
.
front
()
->
get_shape
().
lens
().
size
();
axes
.
resize
(
args
.
front
()
->
get_shape
().
ndim
());
axes
.
resize
(
n_dim
);
std
::
iota
(
axes
.
begin
(),
axes
.
end
(),
0
);
std
::
iota
(
axes
.
begin
(),
axes
.
end
(),
0
);
}
}
}
}
...
...
src/onnx/parse_reshape.cpp
View file @
9bff4331
...
@@ -49,7 +49,7 @@ struct parse_reshape : op_parser<parse_reshape>
...
@@ -49,7 +49,7 @@ struct parse_reshape : op_parser<parse_reshape>
if
(
args
.
size
()
==
2
)
if
(
args
.
size
()
==
2
)
{
{
auto
s
=
args
[
1
]
->
eval
();
auto
s
=
args
[
1
]
->
eval
();
check_arg_empty
(
s
,
"Reshape:
dynamic shape
is not supported"
);
check_arg_empty
(
s
,
"Reshape:
non-constant shape input
is not supported"
);
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
dims
));
});
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
dims
));
});
}
}
...
...
src/onnx/parse_slice.cpp
View file @
9bff4331
...
@@ -46,7 +46,7 @@ struct parse_slice : op_parser<parse_slice>
...
@@ -46,7 +46,7 @@ struct parse_slice : op_parser<parse_slice>
std
::
vector
<
int64_t
>
steps
;
std
::
vector
<
int64_t
>
steps
;
// slice can have up to 5 inputs, we first check the 5th one
// slice can have up to 5 inputs, we first check the 5th one
// to decide whether MIGRAPHX can handle this slice
// to decide whether MIGRAPHX can handle this slice
.
if
(
args
.
size
()
==
5
)
if
(
args
.
size
()
==
5
)
{
{
migraphx
::
argument
step_arg
=
args
.
back
()
->
eval
();
migraphx
::
argument
step_arg
=
args
.
back
()
->
eval
();
...
@@ -90,9 +90,10 @@ struct parse_slice : op_parser<parse_slice>
...
@@ -90,9 +90,10 @@ struct parse_slice : op_parser<parse_slice>
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
op
.
starts
));
});
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
op
.
starts
));
});
}
}
// If axes arg is not given, the default is all of them.
if
(
op
.
axes
.
empty
())
if
(
op
.
axes
.
empty
())
{
{
std
::
vector
<
int64_t
>
axes
(
args
[
0
]
->
get_shape
().
lens
().
size
());
std
::
vector
<
int64_t
>
axes
(
args
[
0
]
->
get_shape
().
ndim
());
std
::
iota
(
axes
.
begin
(),
axes
.
end
(),
int64_t
{
0
});
std
::
iota
(
axes
.
begin
(),
axes
.
end
(),
int64_t
{
0
});
op
.
axes
=
axes
;
op
.
axes
=
axes
;
}
}
...
@@ -103,6 +104,7 @@ struct parse_slice : op_parser<parse_slice>
...
@@ -103,6 +104,7 @@ struct parse_slice : op_parser<parse_slice>
assert
(
op
.
axes
.
size
()
==
op
.
starts
.
size
());
assert
(
op
.
axes
.
size
()
==
op
.
starts
.
size
());
assert
(
op
.
axes
.
size
()
==
op
.
ends
.
size
());
assert
(
op
.
axes
.
size
()
==
op
.
ends
.
size
());
// If any axes have negative step, prepare to add a "reverse" op
for
(
auto
i
:
range
(
steps
.
size
()))
for
(
auto
i
:
range
(
steps
.
size
()))
{
{
if
(
steps
[
i
]
>=
0
)
if
(
steps
[
i
]
>=
0
)
...
@@ -117,7 +119,10 @@ struct parse_slice : op_parser<parse_slice>
...
@@ -117,7 +119,10 @@ struct parse_slice : op_parser<parse_slice>
auto
ins
=
info
.
add_instruction
(
op
,
args
[
0
]);
auto
ins
=
info
.
add_instruction
(
op
,
args
[
0
]);
if
(
not
raxes
.
empty
())
if
(
not
raxes
.
empty
())
{
ins
=
info
.
add_instruction
(
make_op
(
"reverse"
,
{{
"axes"
,
raxes
}}),
ins
);
ins
=
info
.
add_instruction
(
make_op
(
"reverse"
,
{{
"axes"
,
raxes
}}),
ins
);
}
// If any steps are other than default 1, add a "steps" op
if
(
std
::
any_of
(
steps
.
begin
(),
steps
.
end
(),
[](
auto
s
)
{
return
std
::
abs
(
s
)
!=
1
;
}))
if
(
std
::
any_of
(
steps
.
begin
(),
steps
.
end
(),
[](
auto
s
)
{
return
std
::
abs
(
s
)
!=
1
;
}))
{
{
std
::
vector
<
int64_t
>
nsteps
;
std
::
vector
<
int64_t
>
nsteps
;
...
...
src/onnx/parse_trilu.cpp
0 → 100644
View file @
9bff4331
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
onnx
{
struct
parse_trilu
:
op_parser
<
parse_trilu
>
{
std
::
vector
<
op_desc
>
operators
()
const
{
return
{{
"Trilu"
}};
}
instruction_ref
parse
(
const
op_desc
&
,
const
onnx_parser
&
,
const
onnx_parser
::
node_info
&
info
,
std
::
vector
<
instruction_ref
>
args
)
const
{
auto
input_shape
=
args
[
0
]
->
get_shape
();
assert
(
input_shape
.
ndim
()
>=
2
);
auto
input_lens
=
input_shape
.
lens
();
size_t
num_rows
=
*
(
input_lens
.
rbegin
()
+
1
);
size_t
num_cols
=
input_lens
.
back
();
int
k
=
0
;
bool
upper
=
true
;
if
(
args
.
size
()
>
1
)
{
auto
arg_k
=
args
[
1
]
->
eval
();
check_arg_empty
(
arg_k
,
"PARSE_TRILU: dynamic k not supported"
);
k
=
arg_k
.
at
<
int
>
();
}
if
(
k
<
0
)
MIGRAPHX_THROW
(
"PARSE_TRILU: negative k values not supported"
);
if
(
contains
(
info
.
attributes
,
"upper"
))
{
upper
=
static_cast
<
bool
>
(
info
.
attributes
.
at
(
"upper"
).
i
());
}
shape
::
type_t
output_type
=
args
[
0
]
->
get_shape
().
type
();
// when creating the mask, if upper == 1,
// the inner triangle will have values set to 0
std
::
vector
<
bool
>
mask_mat
(
num_rows
*
num_cols
,
upper
);
for
(
size_t
i
=
0
;
i
<
num_rows
;
i
++
)
{
for
(
size_t
j
=
0
;
j
<
std
::
min
(
k
,
static_cast
<
int
>
(
num_cols
));
j
++
)
{
mask_mat
[
i
*
num_cols
+
j
]
=
not
upper
;
}
k
++
;
}
auto
mask
=
info
.
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
output_type
,
{
num_rows
,
num_cols
}},
mask_mat
});
return
info
.
add_broadcastable_binary_op
(
"mul"
,
mask
,
args
[
0
]);
}
};
}
// namespace onnx
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/onnx/parse_where.cpp
View file @
9bff4331
/*
/*
* The MIT License (MIT)
* The MIT License (MIT)
*
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* of this software and associated documentation files (the "Software"), to deal
...
@@ -40,28 +40,44 @@ struct parse_where : op_parser<parse_where>
...
@@ -40,28 +40,44 @@ struct parse_where : op_parser<parse_where>
const
onnx_parser
::
node_info
&
info
,
const
onnx_parser
::
node_info
&
info
,
std
::
vector
<
instruction_ref
>
args
)
const
std
::
vector
<
instruction_ref
>
args
)
const
{
{
auto
lens
=
// TODO: broadcasting for dynamic shapes is only implemented
compute_broadcasted_lens
(
args
[
0
]
->
get_shape
().
lens
(),
args
[
1
]
->
get_shape
().
lens
());
// for binary ops at time of writing, not ternary ops.
lens
=
compute_broadcasted_lens
(
lens
,
args
[
2
]
->
get_shape
().
lens
());
// When it becomes available, add multibroadcasting steps in the dynamic shape case.
if
(
args
[
0
]
->
get_shape
().
lens
()
!=
lens
)
// For now for dynamic shapes, just insert the Where op. All shapes must be the
// same for it to succeed.
if
(
std
::
all_of
(
args
.
begin
(),
args
.
end
(),
[](
auto
v
)
{
return
v
->
get_shape
().
dynamic
();
}))
{
{
args
[
0
]
=
return
info
.
add_instruction
(
make_op
(
"where"
),
args
[
0
],
args
[
1
],
args
[
2
]);
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
lens
}}),
args
[
0
]);
}
}
else
if
(
std
::
none_of
(
if
(
args
[
1
]
->
get_shape
().
lens
()
!=
lens
)
args
.
begin
(),
args
.
end
(),
[](
auto
v
)
{
return
v
->
get_shape
().
dynamic
();
})
)
{
{
args
[
1
]
=
// If shapes are static and any are broadcasted, insert multibroadcast ops
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
lens
}}),
args
[
1
]);
auto
lens
=
}
compute_broadcasted_lens
(
args
[
0
]
->
get_shape
().
lens
(),
args
[
1
]
->
get_shape
().
lens
());
lens
=
compute_broadcasted_lens
(
lens
,
args
[
2
]
->
get_shape
().
lens
());
if
(
args
[
0
]
->
get_shape
().
lens
()
!=
lens
)
{
args
[
0
]
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
lens
}}),
args
[
0
]);
}
if
(
args
[
2
]
->
get_shape
().
lens
()
!=
lens
)
if
(
args
[
1
]
->
get_shape
().
lens
()
!=
lens
)
{
{
args
[
2
]
=
args
[
1
]
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
lens
}}),
args
[
2
]);
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
lens
}}),
args
[
1
]);
}
}
if
(
args
[
2
]
->
get_shape
().
lens
()
!=
lens
)
{
args
[
2
]
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
lens
}}),
args
[
2
]);
}
return
info
.
add_instruction
(
make_op
(
"where"
),
args
[
0
],
args
[
1
],
args
[
2
]);
return
info
.
add_instruction
(
make_op
(
"where"
),
args
[
0
],
args
[
1
],
args
[
2
]);
}
else
MIGRAPHX_THROW
(
"PARSE_WHERE: doesn't support mixed static and dynamic shape inputs"
);
}
}
};
};
...
...
src/opt/memory_coloring_impl.cpp
deleted
100644 → 0
View file @
214b313f
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include "memory_coloring_impl.hpp"
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
void
memory_coloring_impl
::
run
()
{
// calc implicit depdendencies
mod_implicit_deps
=
p_mod
->
calc_implicit_deps
();
MIGRAPHX_DEBUG
(
dump
(
"---Before memory coloring---"
));
MIGRAPHX_DEBUG
(
dump_module
());
build
();
if
(
num_of_lives
!=
0
)
{
MIGRAPHX_DEBUG
(
dump_intervals
());
// Coloring
while
(
not
alloc_queue
.
empty
())
{
interval_ptr
interval
=
alloc_queue
.
top
();
allocate
(
interval
);
alloc_queue
.
pop
();
}
// rewrite happens after all modules are processed
rewrite
();
if
(
enable_verify
)
verify
();
}
}
bool
memory_coloring_impl
::
allocate
(
interval_ptr
interval
)
{
shape
s
=
interval
->
result
;
std
::
size_t
size
=
s
.
bytes
();
if
(
size
==
0
)
return
false
;
std
::
size_t
element_size
=
(
s
.
elements
()
==
0
?
4
:
(
size
/
s
.
elements
()));
live_range
&
segment
=
interval
->
segment
;
int
vn
=
segment
.
vn
;
std
::
priority_queue
<
live_range
*
,
std
::
vector
<
live_range
*>
,
ordering
>
conflict_queue
;
std
::
unordered_map
<
long
long
,
live_range
*>
offset2_live
;
offset2_live
.
clear
();
if
(
conflict_table
.
find
(
vn
)
!=
conflict_table
.
end
())
{
const
std
::
set
<
int
>&
vn_set
=
conflict_table
[
vn
];
for
(
const
auto
&
iter
:
vn_set
)
{
live_range
*
range
=
live_ranges
[
iter
];
long
long
offset
=
range
->
offset
;
if
(
offset
!=
invalid_offset
)
{
conflict_queue
.
push
(
range
);
if
(
offset2_live
.
find
(
offset
)
==
offset2_live
.
end
())
{
offset2_live
[
offset
]
=
range
;
}
else
{
live_range
*
prev
=
offset2_live
[
offset
];
assert
(
prev
->
offset
==
offset
);
if
(
prev
->
size
<
range
->
size
)
offset2_live
[
offset
]
=
range
;
}
}
}
}
std
::
size_t
offset
=
0
;
while
(
not
conflict_queue
.
empty
())
{
live_range
*
range
=
conflict_queue
.
top
();
std
::
size_t
iter_offset
=
range
->
offset
;
if
(
offset
>
iter_offset
)
{
offset
=
std
::
max
(
offset
,
iter_offset
+
range
->
size
);
}
else
if
(
offset2_live
[
iter_offset
]
==
range
)
{
if
((
iter_offset
>
offset
)
&&
(
iter_offset
-
offset
)
>=
size
)
{
break
;
}
offset
=
iter_offset
+
range
->
size
;
}
// alignment
if
((
offset
%
element_size
)
!=
0
)
offset
+=
(
element_size
-
(
offset
%
element_size
));
conflict_queue
.
pop
();
}
// when int8 type is used, the offset could be any number
// if not 4-byte aligned, miopen int8 convolution can crash
offset
=
(
offset
+
3
)
/
4
*
4
;
segment
.
offset
=
offset
;
MIGRAPHX_DEBUG
(
segment
.
dump
());
required_bytes
=
std
::
max
(
required_bytes
,
offset
+
segment
.
size
);
return
true
;
}
void
memory_coloring_impl
::
build
()
{
std
::
size_t
num_of_instrs
=
p_mod
->
size
();
if
(
num_of_instrs
==
0
)
return
;
auto
cur_points
=
num_of_instrs
*
2
;
instruction_ref
iter
=
p_mod
->
end
();
instruction_ref
begin
=
p_mod
->
begin
();
std
::
vector
<
instruction_ref
>
dead_instrs
;
std
::
set
<
int
>
live_set
;
// Build live intervals.
live_intervals
.
resize
(
num_of_instrs
);
do
{
iter
=
std
::
prev
(
iter
);
const
instruction
*
p_iter
=
&
(
*
iter
);
interval_ptr
def_interval
=
nullptr
;
bool
is_dead
=
false
;
if
(
instr2_live
.
find
(
p_iter
)
!=
instr2_live
.
end
())
{
def_interval
=
instr2_live
[
p_iter
];
bool
is_lit
=
is_literal
(
iter
);
if
(
is_allocate
(
iter
)
or
is_lit
)
{
live_range
&
range
=
def_interval
->
segment
;
def_interval
->
result
=
iter
->
get_shape
();
def_interval
->
is_literal
=
is_lit
;
range
.
begin
=
cur_points
;
def_interval
->
def_point
=
cur_points
;
range
.
size
=
(
iter
->
get_shape
()).
bytes
();
if
(
not
is_lit
or
unify_literals
)
alloc_queue
.
push
(
def_interval
);
live_set
.
erase
(
range
.
vn
);
}
}
else
if
(
not
is_param
(
iter
)
&&
not
is_outline
(
iter
)
&&
not
is_check_context
(
iter
))
{
is_dead
=
true
;
}
auto
inputs
=
iter
->
inputs
();
if
(
contains
(
mod_implicit_deps
,
iter
))
{
const
auto
&
impl_deps
=
mod_implicit_deps
.
at
(
iter
);
inputs
.
insert
(
inputs
.
end
(),
impl_deps
.
begin
(),
impl_deps
.
end
());
}
for
(
auto
&&
arg
:
inputs
)
{
if
(
not
p_mod
->
has_instruction
(
arg
))
continue
;
if
(
is_param
(
arg
)
or
is_outline
(
arg
))
{
if
(
is_output_param
(
arg
))
is_dead
=
false
;
if
(
def_interval
!=
nullptr
)
{
def_interval
->
is_live_on_entry
=
true
;
}
continue
;
}
const
instruction
*
p_arg
=
&
(
*
instruction
::
get_output_alias
(
arg
));
if
(
instr2_live
.
find
(
p_arg
)
==
instr2_live
.
end
())
{
// First time see a use, create a live interval.
int
id
=
num_of_lives
++
;
interval_ptr
interval
=
&
(
live_intervals
[
id
]);
interval
->
id
=
id
;
interval
->
segment
.
end
=
cur_points
;
interval
->
segment
.
vn
=
++
max_value_number
;
interval
->
add_use
(
cur_points
);
instr2_live
[
p_arg
]
=
interval
;
add_conflicts
(
live_set
,
max_value_number
);
live_set
.
insert
(
max_value_number
);
live_ranges
[
max_value_number
]
=
&
(
interval
->
segment
);
earliest_end_point
=
cur_points
;
if
(
latest_end_point
==
-
1
)
latest_end_point
=
cur_points
;
}
else
{
interval_ptr
interval
=
instr2_live
[
p_arg
];
interval
->
add_use
(
cur_points
);
assert
(
live_set
.
find
(
interval
->
id
)
!=
live_set
.
end
());
}
}
if
(
is_dead
)
dead_instrs
.
push_back
(
iter
);
cur_points
-=
2
;
}
while
(
iter
!=
begin
);
}
void
memory_coloring_impl
::
rewrite
()
{
std
::
vector
<
std
::
size_t
>
dims
;
dims
.
push_back
((
required_bytes
+
sizeof
(
float
)
-
1
)
/
sizeof
(
float
));
shape
s
=
{
shape
::
float_type
,
dims
};
instruction_ref
scratch_param
=
p_mod
->
add_parameter
(
"scratch"
,
s
);
for
(
auto
ins
:
iterator_for
(
*
p_mod
))
{
const
instruction
*
p_iter
=
&
(
*
ins
);
if
(
instr2_live
.
find
(
p_iter
)
!=
instr2_live
.
end
())
{
interval_ptr
interval
=
instr2_live
[
p_iter
];
if
(
interval
->
get_begin
()
==
invalid_offset
)
continue
;
if
(
not
unify_literals
&&
interval
->
is_literal
)
continue
;
std
::
size_t
offset
=
0
;
if
(
interval
->
get_offset
()
!=
invalid_offset
)
{
offset
=
interval
->
get_offset
();
}
else
{
assert
(
interval
->
result
.
bytes
()
==
0
);
}
if
(
is_allocate
(
ins
))
{
p_mod
->
replace_instruction
(
ins
,
make_op
(
"load"
,
{{
"shape"
,
to_value
(
ins
->
get_shape
())},
{
"offset"
,
offset
}}),
scratch_param
);
}
}
}
MIGRAPHX_DEBUG
(
dump
(
"---After rewrite---"
));
MIGRAPHX_DEBUG
(
dump_module
());
}
void
memory_coloring_impl
::
verify
()
{
if
(
num_of_lives
>
0
)
{
for
(
int
i
=
0
;
i
<
num_of_lives
;
++
i
)
{
const
live_interval
&
interval
=
live_intervals
[
i
];
const
live_range
&
segment
=
interval
.
segment
;
if
(
segment
.
begin
==
invalid_offset
)
{
// if(not interval.is_live_on_entry)
// MIGRAPHX_THROW("interval is not live on entry");
continue
;
}
if
(
segment
.
offset
==
invalid_offset
)
{
continue
;
}
int
vn
=
segment
.
vn
;
if
(
conflict_table
.
find
(
vn
)
!=
conflict_table
.
end
())
{
const
std
::
set
<
int
>&
vn_set
=
conflict_table
[
vn
];
for
(
const
auto
&
iter
:
vn_set
)
{
live_range
*
range
=
live_ranges
[
iter
];
if
(
range
->
offset
==
invalid_offset
)
continue
;
if
(
not
is_disjoin
(
*
range
,
segment
))
MIGRAPHX_THROW
(
"range and segment is not disjoined"
);
}
}
}
}
}
#ifdef MIGRAPHX_DEBUG_OPT
void
memory_coloring_impl
::
dump
(
const
std
::
string
&
str
)
{
std
::
cout
<<
str
<<
std
::
endl
;
}
void
memory_coloring_impl
::
dump_module
()
{
std
::
cout
<<
*
p_mod
<<
std
::
endl
;
}
void
memory_coloring_impl
::
dump_intervals
()
{
if
(
num_of_lives
>
0
)
{
std
::
cout
<<
"---live intervals ---"
<<
std
::
endl
;
for
(
int
i
=
0
;
i
<
num_of_lives
;
++
i
)
{
live_interval
&
interval
=
live_intervals
[
i
];
interval
.
dump
();
}
std
::
cout
<<
"---conflict table---"
<<
std
::
endl
;
for
(
int
i
=
0
;
i
<=
max_value_number
;
++
i
)
{
std
::
cout
<<
" segment:"
<<
i
;
std
::
cout
<<
" =>"
;
const
std
::
set
<
int
>&
table
=
conflict_table
[
i
];
for
(
const
auto
&
iter
:
table
)
{
std
::
cout
<<
(
iter
)
<<
","
;
}
}
std
::
cout
<<
std
::
endl
;
}
}
// map liveness tracking point to instruction enum.
static
int
get_ins_enum
(
int
x
)
{
if
(
x
>
0
)
{
return
(
x
/
2
)
-
1
;
}
else
return
invalid_offset
;
}
void
live_range
::
dump
()
{
std
::
cout
<<
" segment:"
<<
vn
;
std
::
cout
<<
" ["
<<
get_ins_enum
(
begin
)
<<
", "
<<
get_ins_enum
(
end
)
<<
"]"
;
if
(
offset
!=
invalid_offset
)
{
std
::
cout
<<
" mem:"
;
std
::
cout
<<
" ["
<<
offset
<<
","
<<
offset
+
size
-
1
<<
"]"
;
}
std
::
cout
<<
std
::
endl
;
}
void
live_interval
::
dump
()
{
std
::
cout
<<
"id:"
<<
id
;
segment
.
dump
();
std
::
cout
<<
" uses:"
;
for
(
const
auto
&
iter
:
use_points
)
{
std
::
cout
<<
" "
<<
get_ins_enum
(
iter
)
<<
","
;
}
std
::
cout
<<
" def:"
;
std
::
cout
<<
" "
<<
get_ins_enum
(
def_point
);
if
(
is_literal
)
std
::
cout
<<
" literal"
;
std
::
cout
<<
" "
<<
result
;
std
::
cout
<<
std
::
endl
;
}
#endif
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/opt/memory_coloring_impl.hpp
deleted
100644 → 0
View file @
214b313f
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_MEMORY_COLORING_IMPL_HPP
#define MIGRAPHX_GUARD_RTGLIB_MEMORY_COLORING_IMPL_HPP
#include <migraphx/program.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/pass_config.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/config.hpp>
#include <set>
#include <list>
#include <vector>
#include <queue>
#ifdef MIGRAPHX_DEBUG_OPT
#define MIGRAPHX_DEBUG(s) s
#else
#define MIGRAPHX_DEBUG(s)
#endif // MIGRAPHX_DEBUG_OPT
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
static
const
std
::
size_t
invalid_offset
=
std
::
numeric_limits
<
std
::
size_t
>::
max
();
struct
live_range
{
std
::
size_t
begin
;
// begin point in the instruction stream.
std
::
size_t
end
;
// end point in the instruction stream.
std
::
size_t
offset
;
// offset to base pointer of allocated memory trunk.
std
::
size_t
vn
;
// value number that identifies this live_range.
std
::
size_t
size
;
// size of required memory in bytes
#ifdef MIGRAPHX_DEBUG_OPT
void
dump
();
#endif
};
struct
live_interval
{
live_interval
()
:
segment
({
invalid_offset
,
invalid_offset
,
invalid_offset
,
invalid_offset
,
0
})
{
}
void
add_use
(
std
::
size_t
use
)
{
use_points
.
push_front
(
use
);
}
std
::
size_t
get_begin
()
const
{
return
segment
.
begin
;
}
std
::
size_t
get_end
()
const
{
return
segment
.
end
;
}
long
long
get_offset
()
const
{
return
segment
.
offset
;
}
#ifdef MIGRAPHX_DEBUG_OPT
void
dump
();
#endif
live_range
segment
;
std
::
size_t
id
=
invalid_offset
;
std
::
list
<
std
::
size_t
>
use_points
{};
std
::
size_t
def_point
=
invalid_offset
;
shape
result
{};
bool
is_literal
=
false
;
bool
is_live_on_entry
=
false
;
};
using
interval_ptr
=
live_interval
*
;
struct
memory_coloring_impl
{
memory_coloring_impl
(
module
*
p
,
std
::
string
alloc_op
,
bool
p_verify
)
:
p_mod
(
p
),
allocation_op
(
std
::
move
(
alloc_op
)),
enable_verify
(
p_verify
)
{
}
bool
allocate
(
interval_ptr
);
void
add_conflicts
(
const
std
::
set
<
int
>&
live_set
,
int
val
)
{
for
(
const
auto
&
iter
:
live_set
)
{
conflict_table
[
iter
].
insert
(
val
);
conflict_table
[
val
].
insert
(
iter
);
}
}
void
build
();
void
run
();
void
rewrite
();
private:
static
bool
is_param
(
const
instruction_ref
ins
)
{
return
ins
->
name
()
==
"@param"
;
}
static
bool
is_output_param
(
const
instruction_ref
ins
)
{
if
(
not
is_param
(
ins
))
return
false
;
auto
param_name
=
any_cast
<
builtin
::
param
>
(
ins
->
get_operator
()).
parameter
;
return
contains
(
param_name
,
"#output_"
);
}
bool
is_allocate
(
const
instruction_ref
ins
)
const
{
return
ins
->
name
()
==
allocation_op
;
}
static
bool
is_outline
(
const
instruction_ref
ins
)
{
return
ins
->
name
()
==
"@outline"
;
}
static
bool
is_literal
(
const
instruction_ref
ins
)
{
return
ins
->
name
()
==
"@literal"
;
}
static
bool
is_check_context
(
const
instruction_ref
ins
)
{
return
ins
->
name
()
==
"check_context"
;
}
static
bool
is_disjoin
(
const
live_range
&
range1
,
const
live_range
&
range2
)
{
if
((
range1
.
size
==
0
)
or
(
range2
.
size
==
0
))
return
false
;
auto
end1
=
range1
.
offset
+
range1
.
size
-
1
;
auto
end2
=
range2
.
offset
+
range2
.
size
-
1
;
return
((
end1
<
range2
.
offset
)
or
(
end2
<
range1
.
offset
));
}
void
verify
();
#ifdef MIGRAPHX_DEBUG_OPT
void
dump
(
const
std
::
string
&
);
void
dump_module
();
void
dump_intervals
();
#endif
struct
ordering
{
bool
operator
()(
const
interval_ptr
&
i1
,
const
interval_ptr
&
i2
)
const
{
auto
len1
=
i1
->
get_end
()
-
i1
->
get_begin
();
auto
len2
=
i2
->
get_end
()
-
i2
->
get_begin
();
if
(
len1
!=
len2
)
{
return
(
len1
<
len2
);
}
else
if
(
i1
->
result
.
bytes
()
!=
i2
->
result
.
bytes
())
{
return
(
i1
->
result
.
bytes
()
<
i2
->
result
.
bytes
());
}
else
{
return
i1
->
id
>
i2
->
id
;
}
}
bool
operator
()(
const
live_range
*
i1
,
const
live_range
*
i2
)
const
{
return
(
i1
->
offset
>
i2
->
offset
);
}
};
module
*
p_mod
;
std
::
unordered_map
<
const
instruction
*
,
interval_ptr
>
instr2_live
;
// universe of live intervals.
std
::
vector
<
live_interval
>
live_intervals
=
{};
// Map live range value number to live range.
std
::
unordered_map
<
int
,
live_range
*>
live_ranges
=
{};
// Map live range value number to a set of conflicting live ranges' value numbers.
std
::
unordered_map
<
int
,
std
::
set
<
int
>>
conflict_table
=
{};
// Priority queue for coloring.
std
::
priority_queue
<
interval_ptr
,
std
::
vector
<
interval_ptr
>
,
ordering
>
alloc_queue
{};
int
num_of_lives
=
0
;
int
max_value_number
=
-
1
;
std
::
size_t
required_bytes
=
0
;
// The earliest program point where an live interval ends.
int
earliest_end_point
=
-
1
;
// The latest program point where an live interval ends.
int
latest_end_point
=
-
1
;
// Whether to unify literals into coloring.
bool
unify_literals
=
false
;
std
::
string
allocation_op
{};
bool
enable_verify
;
ins_dep_map
mod_implicit_deps
;
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/optimize_module.cpp
0 → 100644
View file @
9bff4331
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/optimize_module.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/propagate_constant.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
void
optimize_module
::
apply
(
module_pass_manager
&
mpm
)
const
{
for
(
int
i
=
0
;
i
<
2
;
i
++
)
{
mpm
.
run_pass
(
simplify_reshapes
{});
mpm
.
run_pass
(
simplify_algebra
{});
mpm
.
run_pass
(
eliminate_common_subexpression
{});
mpm
.
run_pass
(
dead_code_elimination
{});
mpm
.
run_pass
(
propagate_constant
{});
mpm
.
run_pass
(
dead_code_elimination
{});
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/pass_manager.cpp
View file @
9bff4331
...
@@ -39,6 +39,7 @@ namespace migraphx {
...
@@ -39,6 +39,7 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_TRACE_PASSES
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_TRACE_PASSES
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_TIME_PASSES
);
void
validate_pass
(
module
&
mod
,
const
pass
&
p
,
tracer
trace
)
void
validate_pass
(
module
&
mod
,
const
pass
&
p
,
tracer
trace
)
{
{
...
@@ -94,19 +95,19 @@ struct module_pm : module_pass_manager
...
@@ -94,19 +95,19 @@ struct module_pm : module_pass_manager
virtual
void
run_pass
(
const
pass
&
p
)
override
virtual
void
run_pass
(
const
pass
&
p
)
override
{
{
assert
(
mod
);
assert
(
mod
);
timer
ts
{};
using
seconds
=
std
::
chrono
::
duration
<
double
>
;
trace
(
"Module: "
,
mod
->
name
(),
", Pass: "
,
p
.
name
());
const
double
t1
=
ts
.
record
<
seconds
>
();
assert
(
mod
->
validate
()
==
mod
->
end
());
assert
(
mod
->
validate
()
==
mod
->
end
());
p
.
apply
(
*
this
);
if
(
enabled
(
MIGRAPHX_TIME_PASSES
{}))
{
using
milliseconds
=
std
::
chrono
::
duration
<
double
,
std
::
milli
>
;
auto
ms
=
time
<
milliseconds
>
([
&
]
{
p
.
apply
(
*
this
);
});
std
::
cout
<<
p
.
name
()
<<
": "
<<
ms
<<
"ms
\n
"
;
}
else
{
p
.
apply
(
*
this
);
}
trace
(
*
mod
);
trace
(
*
mod
);
validate_pass
(
*
mod
,
p
,
*
t
);
validate_pass
(
*
mod
,
p
,
*
t
);
const
double
t2
=
ts
.
record
<
seconds
>
();
trace
(
"Pass: "
,
p
.
name
(),
" completed in (s): "
,
(
t2
-
t1
));
}
}
};
};
...
...
src/program.cpp
View file @
9bff4331
...
@@ -210,17 +210,15 @@ void program::compile(const target& t, compile_options options)
...
@@ -210,17 +210,15 @@ void program::compile(const target& t, compile_options options)
assert
(
not
this
->
is_compiled
());
assert
(
not
this
->
is_compiled
());
this
->
impl
->
target_name
=
t
.
name
();
this
->
impl
->
target_name
=
t
.
name
();
this
->
impl
->
ctx
=
t
.
get_context
();
this
->
impl
->
ctx
=
t
.
get_context
();
if
(
enabled
(
MIGRAPHX_TRACE_COMPILE
{}))
if
(
enabled
(
MIGRAPHX_TRACE_COMPILE
{}))
options
.
trace
=
tracer
{
std
::
cout
};
options
.
trace
=
tracer
{
std
::
cout
};
options
.
trace
(
*
this
);
options
.
trace
(
*
this
);
options
.
trace
();
options
.
trace
();
auto
&&
passes
=
t
.
get_passes
(
this
->
impl
->
ctx
,
options
);
auto
&&
passes
=
t
.
get_passes
(
this
->
impl
->
ctx
,
options
);
run_passes
(
*
this
,
passes
,
options
.
trace
);
run_passes
(
*
this
,
passes
,
options
.
trace
);
auto
mods
=
this
->
get_modules
();
auto
mods
=
this
->
get_modules
();
// Validate and finalize
// Validate and finalize
for
(
const
auto
&
mod
:
reverse
(
mods
))
for
(
const
auto
&
mod
:
reverse
(
mods
))
{
{
...
@@ -336,7 +334,8 @@ std::vector<argument> generic_eval(const module* mod,
...
@@ -336,7 +334,8 @@ std::vector<argument> generic_eval(const module* mod,
if
(
not
ins
->
get_shape
().
dynamic
()
and
param
.
get_shape
()
!=
ins
->
get_shape
())
if
(
not
ins
->
get_shape
().
dynamic
()
and
param
.
get_shape
()
!=
ins
->
get_shape
())
{
{
MIGRAPHX_THROW
(
"Incorrect shape {"
+
to_string
(
param
.
get_shape
())
+
MIGRAPHX_THROW
(
"Incorrect shape {"
+
to_string
(
param
.
get_shape
())
+
"} for parameter: "
+
param_name
);
"} for parameter: "
+
param_name
+
" should be: "
+
to_string
(
ins
->
get_shape
()));
}
}
return
param
;
return
param
;
}));
}));
...
@@ -380,7 +379,7 @@ std::vector<argument> generic_eval(const module* mod,
...
@@ -380,7 +379,7 @@ std::vector<argument> generic_eval(const module* mod,
}));
}));
}
}
assert
(
results
.
find
(
ins
)
!=
results
.
end
());
assert
(
results
.
find
(
ins
)
!=
results
.
end
());
if
(
not
ins
->
get_shape
().
dynamic
())
if
(
not
ins
->
get_shape
().
any_of_
dynamic
())
{
{
assert
(
results
.
at
(
ins
).
get_shape
()
==
ins
->
get_shape
());
assert
(
results
.
at
(
ins
).
get_shape
()
==
ins
->
get_shape
());
}
}
...
@@ -854,6 +853,25 @@ void program::print_graph(std::ostream& os, bool brief) const
...
@@ -854,6 +853,25 @@ void program::print_graph(std::ostream& os, bool brief) const
mm
->
print_graph
(
os
,
brief
);
mm
->
print_graph
(
os
,
brief
);
}
}
void
program
::
print_py
(
std
::
ostream
&
os
)
const
{
auto
vec_modules
=
this
->
get_modules
();
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
names
;
os
<<
"p = migraphx.program()
\n
"
;
for
(
auto
&
mod
:
vec_modules
)
{
std
::
string
var_name
=
"m"
+
mod
->
name
();
os
<<
var_name
<<
" = "
;
if
(
mod
->
name
()
==
"main"
)
os
<<
"p.get_main_module()"
;
else
os
<<
"p.create_module(
\"
"
<<
mod
->
name
()
<<
"
\"
);"
;
os
<<
std
::
endl
;
names
=
mod
->
print_py
(
os
,
var_name
,
names
);
os
<<
std
::
endl
;
}
}
void
program
::
print_cpp
(
std
::
ostream
&
os
)
const
void
program
::
print_cpp
(
std
::
ostream
&
os
)
const
{
{
auto
vec_modules
=
this
->
get_modules
();
auto
vec_modules
=
this
->
get_modules
();
...
...
src/propagate_constant.cpp
View file @
9bff4331
...
@@ -44,7 +44,7 @@ bool skip_propogate(instruction_ref ins)
...
@@ -44,7 +44,7 @@ bool skip_propogate(instruction_ref ins)
return
false
;
return
false
;
}
}
bool
is_const
(
instruction_ref
ins
)
{
return
ins
->
can_eval
()
and
not
skip_propogate
(
ins
);
}
bool
is_const
_ins
(
instruction_ref
ins
)
{
return
ins
->
can_eval
()
and
not
skip_propogate
(
ins
);
}
void
propagate_constant
::
apply
(
module
&
m
)
const
void
propagate_constant
::
apply
(
module
&
m
)
const
{
{
...
@@ -54,14 +54,23 @@ void propagate_constant::apply(module& m) const
...
@@ -54,14 +54,23 @@ void propagate_constant::apply(module& m) const
// Find instructions that can be evaluated to a literal
// Find instructions that can be evaluated to a literal
for
(
auto
i
:
iterator_for
(
m
))
for
(
auto
i
:
iterator_for
(
m
))
{
{
if
(
is_const
(
i
)
and
i
!=
last
)
const
bool
is_const
=
is_const_ins
(
i
);
if
(
is_const
and
i
!=
last
)
continue
;
continue
;
std
::
copy_if
(
if
(
i
==
last
and
is_const
)
i
->
inputs
().
begin
(),
{
i
->
inputs
().
end
(),
const_instrs
.
insert
(
i
);
std
::
inserter
(
const_instrs
,
const_instrs
.
begin
()),
}
[
&
](
const
instruction_ref
ins
)
{
return
is_const
(
ins
)
and
ins
->
name
()
!=
"@literal"
;
});
else
{
std
::
copy_if
(
i
->
inputs
().
begin
(),
i
->
inputs
().
end
(),
std
::
inserter
(
const_instrs
,
const_instrs
.
begin
()),
[
&
](
const
instruction_ref
ins
)
{
return
is_const_ins
(
ins
)
and
ins
->
name
()
!=
"@literal"
;
});
}
}
}
// Compute literals in parallel
// Compute literals in parallel
...
...
src/py/migraphx_py.cpp
View file @
9bff4331
...
@@ -329,15 +329,21 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
...
@@ -329,15 +329,21 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
.
def
(
"is_compiled"
,
&
migraphx
::
program
::
is_compiled
)
.
def
(
"is_compiled"
,
&
migraphx
::
program
::
is_compiled
)
.
def
(
.
def
(
"compile"
,
"compile"
,
[](
migraphx
::
program
&
p
,
const
migraphx
::
target
&
t
,
bool
offload_copy
,
bool
fast_math
)
{
[](
migraphx
::
program
&
p
,
const
migraphx
::
target
&
t
,
bool
offload_copy
,
bool
fast_math
,
bool
exhaustive_tune
)
{
migraphx
::
compile_options
options
;
migraphx
::
compile_options
options
;
options
.
offload_copy
=
offload_copy
;
options
.
offload_copy
=
offload_copy
;
options
.
fast_math
=
fast_math
;
options
.
fast_math
=
fast_math
;
options
.
exhaustive_tune
=
exhaustive_tune
;
p
.
compile
(
t
,
options
);
p
.
compile
(
t
,
options
);
},
},
py
::
arg
(
"t"
),
py
::
arg
(
"t"
),
py
::
arg
(
"offload_copy"
)
=
true
,
py
::
arg
(
"offload_copy"
)
=
true
,
py
::
arg
(
"fast_math"
)
=
true
)
py
::
arg
(
"fast_math"
)
=
true
,
py
::
arg
(
"exhaustive_tune"
)
=
false
)
.
def
(
"get_main_module"
,
[](
const
migraphx
::
program
&
p
)
{
return
p
.
get_main_module
();
})
.
def
(
"get_main_module"
,
[](
const
migraphx
::
program
&
p
)
{
return
p
.
get_main_module
();
})
.
def
(
.
def
(
"create_module"
,
"create_module"
,
...
...
src/register_op.cpp
View file @
9bff4331
...
@@ -33,7 +33,17 @@ std::unordered_map<std::string, operation>& op_map()
...
@@ -33,7 +33,17 @@ std::unordered_map<std::string, operation>& op_map()
static
std
::
unordered_map
<
std
::
string
,
operation
>
m
;
// NOLINT
static
std
::
unordered_map
<
std
::
string
,
operation
>
m
;
// NOLINT
return
m
;
return
m
;
}
}
void
register_op_init
()
{
(
void
)
op_map
();
}
void
register_op
(
const
operation
&
op
)
{
op_map
()[
op
.
name
()]
=
op
;
}
void
register_op
(
const
operation
&
op
)
{
op_map
()[
op
.
name
()]
=
op
;
}
void
unregister_op
(
const
std
::
string
&
op_name
)
{
assert
(
op_map
().
count
(
op_name
));
op_map
().
erase
(
op_name
);
}
operation
load_op
(
const
std
::
string
&
name
)
operation
load_op
(
const
std
::
string
&
name
)
{
{
return
at
(
op_map
(),
name
,
"Operator not found: "
+
name
);
return
at
(
op_map
(),
name
,
"Operator not found: "
+
name
);
...
...
Prev
1
2
3
4
5
6
7
8
9
…
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment