Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
30c49503
Commit
30c49503
authored
Mar 23, 2023
by
Khalique Ahmed
Browse files
manual merge
parents
870a396b
09aaa63e
Changes
202
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
690 additions
and
668 deletions
+690
-668
src/include/migraphx/shape.hpp
src/include/migraphx/shape.hpp
+3
-0
src/memory_coloring.cpp
src/memory_coloring.cpp
+405
-0
src/module.cpp
src/module.cpp
+3
-1
src/normalize_attributes.cpp
src/normalize_attributes.cpp
+17
-10
src/onnx/include/migraphx/onnx/onnx_parser.hpp
src/onnx/include/migraphx/onnx/onnx_parser.hpp
+2
-1
src/onnx/onnx_parser.cpp
src/onnx/onnx_parser.cpp
+13
-7
src/onnx/parse_gemm.cpp
src/onnx/parse_gemm.cpp
+28
-24
src/onnx/parse_if.cpp
src/onnx/parse_if.cpp
+20
-2
src/onnx/parse_loop.cpp
src/onnx/parse_loop.cpp
+1
-1
src/onnx/parse_slice.cpp
src/onnx/parse_slice.cpp
+7
-2
src/onnx/parse_trilu.cpp
src/onnx/parse_trilu.cpp
+90
-0
src/onnx/parse_where.cpp
src/onnx/parse_where.cpp
+34
-18
src/opt/memory_coloring_impl.cpp
src/opt/memory_coloring_impl.cpp
+0
-376
src/opt/memory_coloring_impl.hpp
src/opt/memory_coloring_impl.hpp
+0
-193
src/optimize_module.cpp
src/optimize_module.cpp
+15
-6
src/pass_manager.cpp
src/pass_manager.cpp
+11
-10
src/program.cpp
src/program.cpp
+4
-5
src/propagate_constant.cpp
src/propagate_constant.cpp
+16
-7
src/py/migraphx_py.cpp
src/py/migraphx_py.cpp
+11
-5
src/register_op.cpp
src/register_op.cpp
+10
-0
No files found.
src/include/migraphx/shape.hpp
View file @
30c49503
...
...
@@ -243,6 +243,9 @@ struct shape
/// Return true if the shape is dynamic
bool
dynamic
()
const
;
/// Return true if this shape or any of the sub_shapes are dynamic
bool
any_of_dynamic
()
const
;
shape
normalize_standard
()
const
;
shape
with_lens
(
type_t
t
,
const
std
::
vector
<
std
::
size_t
>&
l
)
const
;
...
...
src/memory_coloring.cpp
0 → 100644
View file @
30c49503
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/memory_coloring.hpp>
#include <migraphx/module.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/algorithm.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
#include <unordered_set>
#include <unordered_map>
#include <map>
#include <set>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_DEBUG_MEMORY_COLORING
);
using
instruction_set
=
std
::
unordered_set
<
instruction_ref
>
;
using
instruction_set_map
=
std
::
unordered_map
<
instruction_ref
,
instruction_set
>
;
// This will do liveness analysis on the module, and it will call the
// function `f` with the instruction and the set of the other instructions
// that are live
template
<
class
F
>
void
liveness
(
const
module
&
m
,
F
f
)
{
auto
implicit_deps
=
m
.
calc_implicit_deps
();
instruction_set
live_set
;
auto
rp
=
reverse
(
m
);
for
(
auto
rins
:
iterator_for
(
rp
))
// NOLINT
{
// The base iterator is one ahead, so we need to use the previous iterator
auto
ins
=
std
::
prev
(
rins
.
base
());
// Add live variables
auto
add_live_variables
=
[
&
](
const
auto
&
inputs
)
{
for
(
auto
input
:
inputs
)
{
auto
i
=
instruction
::
get_output_alias
(
input
);
// Skip if variable comes from parent
if
(
not
m
.
has_instruction
(
i
))
continue
;
live_set
.
insert
(
i
);
}
};
add_live_variables
(
ins
->
inputs
());
add_live_variables
(
implicit_deps
[
ins
]);
// Remove last usage
auto
it
=
live_set
.
find
(
ins
);
if
(
it
!=
live_set
.
end
())
{
live_set
.
erase
(
it
);
f
(
ins
,
live_set
);
}
}
}
// This will build the conflict table or interference graph. This is
// essentially a map from one instruction to a set of instruction that are
// used together. Each instruction will be the allocation instruction.
instruction_set_map
build_conflict_table
(
const
module
&
m
,
std
::
string
allocation_op
)
{
instruction_set_map
conflict_table
;
liveness
(
m
,
[
&
](
auto
ins
,
auto
live_set
)
{
// Skip variables that aren't allocations
if
(
ins
->
name
()
!=
allocation_op
)
return
;
// Skip zero allocations
if
(
ins
->
get_shape
().
bytes
()
==
0
)
return
;
conflict_table
[
ins
];
for
(
auto
i
:
live_set
)
{
if
(
i
==
ins
)
continue
;
// Skip variables that aren't allocations
if
(
i
->
name
()
!=
allocation_op
)
continue
;
// Skip zero allocations
if
(
i
->
get_shape
().
bytes
()
==
0
)
continue
;
conflict_table
[
i
].
insert
(
ins
);
conflict_table
[
ins
].
insert
(
i
);
}
});
assert
(
std
::
all_of
(
conflict_table
.
begin
(),
conflict_table
.
end
(),
[](
auto
&&
pp
)
{
return
pp
.
second
.
count
(
pp
.
first
)
==
0
;
}));
return
conflict_table
;
}
// Check if intervals overlap
bool
is_overlap
(
std
::
pair
<
std
::
size_t
,
std
::
size_t
>
x
,
std
::
pair
<
std
::
size_t
,
std
::
size_t
>
y
)
{
return
std
::
max
(
x
.
first
,
y
.
first
)
<
std
::
min
(
x
.
second
,
y
.
second
);
}
struct
allocation_segment
{
using
segment
=
std
::
pair
<
std
::
size_t
,
std
::
size_t
>
;
std
::
unordered_map
<
instruction_ref
,
segment
>
ins2segment
;
const
segment
*
add_segment
(
instruction_ref
ins
,
segment
s
)
{
return
&
(
ins2segment
[
ins
]
=
s
);
}
const
segment
*
get_segment
(
instruction_ref
ins
)
const
{
auto
it
=
ins2segment
.
find
(
ins
);
if
(
it
==
ins2segment
.
end
())
return
nullptr
;
return
&
it
->
second
;
}
// Remove segment for an instruction
void
remove
(
instruction_ref
ins
)
{
auto
it
=
ins2segment
.
find
(
ins
);
if
(
it
!=
ins2segment
.
end
())
{
ins2segment
.
erase
(
it
);
}
}
std
::
size_t
max
()
{
std
::
size_t
n
=
0
;
for
(
auto
&&
pp
:
ins2segment
)
{
auto
seg
=
pp
.
second
;
n
=
std
::
max
(
n
,
seg
.
second
);
}
return
n
;
}
template
<
class
Iterator
>
static
bool
overlaps
(
Iterator
first
,
Iterator
last
,
const
segment
&
s
)
{
return
std
::
any_of
(
first
,
last
,
[
&
](
auto
&&
t
)
{
return
is_overlap
(
s
,
t
);
});
}
static
bool
overlaps
(
const
std
::
set
<
segment
>&
segments
,
const
segment
&
s
)
{
return
overlaps
(
segments
.
begin
(),
segments
.
end
(),
s
);
}
static
auto
find_gap
(
const
std
::
set
<
segment
>&
segments
,
std
::
size_t
n
)
{
std
::
size_t
max_end
=
0
;
return
std
::
adjacent_find
(
segments
.
begin
(),
segments
.
end
(),
[
&
](
segment
x
,
segment
y
)
{
if
(
x
.
second
<
max_end
)
return
false
;
max_end
=
x
.
second
;
if
(
is_overlap
(
x
,
y
))
return
false
;
assert
(
y
.
first
>=
x
.
second
);
auto
k
=
y
.
first
-
x
.
second
;
return
(
k
>=
n
);
});
}
static
std
::
size_t
max_type_size
(
const
shape
&
s
)
{
return
std
::
accumulate
(
s
.
sub_shapes
().
begin
(),
s
.
sub_shapes
().
end
(),
s
.
type_size
(),
[](
auto
size
,
const
auto
&
sub
)
{
return
std
::
max
(
size
,
max_type_size
(
sub
));
});
}
static
std
::
size_t
compute_alignment
(
instruction_ref
ins
)
{
auto
alignment
=
max_type_size
(
ins
->
get_shape
());
// A rough estimate for the total number of elements
auto
n
=
ins
->
get_shape
().
bytes
()
/
alignment
;
// Check for vectorized alignment
if
(
n
>
4
)
{
auto
d
=
n
%
4
;
if
(
d
==
0
)
alignment
*=
4
;
if
(
d
==
2
)
alignment
*=
2
;
}
return
alignment
;
}
static
segment
next_segment
(
std
::
set
<
segment
>&
segments
,
instruction_ref
ins
,
std
::
size_t
alignment
)
{
assert
(
ins
->
get_shape
().
bytes
()
>
0
);
// Compute alignment
auto
n
=
1
+
(
ins
->
get_shape
().
bytes
()
-
1
)
/
alignment
;
assert
(
n
>
0
);
auto
start
=
0
;
// Insert at end if it cant fit at the begining
if
(
segments
.
empty
()
or
segments
.
begin
()
->
first
<=
n
)
{
auto
it
=
find_gap
(
segments
,
n
);
if
(
it
==
segments
.
end
())
it
=
std
::
max_element
(
segments
.
begin
(),
segments
.
end
(),
[
&
](
segment
x
,
segment
y
)
{
return
x
.
second
<
y
.
second
;
});
if
(
it
!=
segments
.
end
())
start
=
it
->
second
;
}
auto
s
=
segment
{
start
,
start
+
n
};
assert
(
not
overlaps
(
segments
,
s
));
segments
.
insert
(
s
);
return
s
;
}
static
std
::
unordered_map
<
instruction_ref
,
int
>
create_allocation_index
(
const
module
&
m
,
const
instruction_set_map
&
conflict_table
)
{
std
::
unordered_map
<
instruction_ref
,
int
>
result
;
int
i
=
0
;
for
(
auto
ins
:
iterator_for
(
m
))
{
if
(
not
contains
(
conflict_table
,
ins
))
continue
;
result
[
ins
]
=
i
++
;
}
return
result
;
}
// Build the allocation_color class from the conflict_table
static
allocation_segment
build
(
const
module
&
m
,
const
instruction_set_map
&
conflict_table
,
std
::
size_t
alignment
)
{
allocation_segment
as
{};
std
::
vector
<
instruction_ref
>
conflict_queue
;
// Add all allocations to the conflict_queue
std
::
transform
(
conflict_table
.
begin
(),
conflict_table
.
end
(),
std
::
back_inserter
(
conflict_queue
),
[](
auto
&&
pp
)
{
return
pp
.
first
;
});
auto
alloc_index
=
create_allocation_index
(
m
,
conflict_table
);
// Sort the conflict queue so we process the allocation with the most
// number of adjacent allocations first
std
::
sort
(
conflict_queue
.
begin
(),
conflict_queue
.
end
(),
by
(
std
::
greater
<>
{},
[
&
](
auto
x
)
{
return
std
::
make_tuple
(
conflict_table
.
at
(
x
).
size
(),
x
->
get_shape
().
bytes
(),
alloc_index
.
at
(
x
));
}));
// Process the conflict_queue, we refer to the current allocation as
// the parent and the adjacent allocations as children
for
(
auto
parent
:
conflict_queue
)
{
// Sort children by size
std
::
vector
<
instruction_ref
>
children
(
conflict_table
.
at
(
parent
).
begin
(),
conflict_table
.
at
(
parent
).
end
());
std
::
sort
(
children
.
begin
(),
children
.
end
(),
by
(
std
::
less
<>
{},
[
&
](
auto
x
)
{
return
std
::
make_tuple
(
x
->
get_shape
().
bytes
(),
alloc_index
.
at
(
x
));
}));
assert
(
not
contains
(
children
,
parent
));
// This set is to track the segments already processed
std
::
set
<
segment
>
segments
;
// Add all segments for the children to the segments already processed
transform_if
(
children
.
begin
(),
children
.
end
(),
std
::
inserter
(
segments
,
segments
.
begin
()),
[
&
](
auto
child
)
{
return
as
.
get_segment
(
child
);
},
[
&
](
auto
child
)
{
return
*
as
.
get_segment
(
child
);
});
assert
(
as
.
get_segment
(
parent
)
==
nullptr
);
as
.
add_segment
(
parent
,
next_segment
(
segments
,
parent
,
alignment
));
}
// Reduce the number of segments
for
(
std
::
size_t
n
=
0
;
n
<
3
;
n
++
)
{
for
(
auto
parent
:
conflict_queue
)
{
auto
children
=
conflict_table
.
at
(
parent
);
// This set is to track the segments already processed
std
::
set
<
segment
>
segments
;
// Add all segments for the children to the segments already processed
transform_if
(
children
.
begin
(),
children
.
end
(),
std
::
inserter
(
segments
,
segments
.
begin
()),
[
&
](
auto
child
)
{
return
as
.
get_segment
(
child
);
},
[
&
](
auto
child
)
{
return
*
as
.
get_segment
(
child
);
});
// Get the segment for the parent
const
auto
*
parent_segment
=
as
.
get_segment
(
parent
);
assert
(
parent_segment
!=
nullptr
);
auto
s
=
next_segment
(
segments
,
parent
,
alignment
);
if
(
s
!=
*
parent_segment
and
s
.
second
<=
as
.
max
())
{
as
.
add_segment
(
parent
,
s
);
}
}
}
return
as
;
}
};
static
std
::
size_t
find_max_alignment
(
const
module
&
m
,
const
std
::
string
&
allocation_op
)
{
std
::
size_t
alignment
=
1
;
for
(
auto
ins
:
iterator_for
(
m
))
{
if
(
ins
->
name
()
!=
allocation_op
)
continue
;
alignment
=
std
::
max
(
allocation_segment
::
compute_alignment
(
ins
),
alignment
);
}
return
alignment
;
}
void
memory_coloring
::
apply
(
module
&
m
)
const
{
const
std
::
size_t
alignment
=
find_max_alignment
(
m
,
allocation_op
);
auto
conflict_table
=
build_conflict_table
(
m
,
allocation_op
);
auto
as
=
allocation_segment
::
build
(
m
,
conflict_table
,
alignment
);
// All allocations should have a segment
assert
(
std
::
all_of
(
conflict_table
.
begin
(),
conflict_table
.
end
(),
[
&
](
auto
&&
pp
)
{
return
as
.
get_segment
(
pp
.
first
);
}));
// Adjacent allocations should not have overlapping segments
assert
(
std
::
none_of
(
conflict_table
.
begin
(),
conflict_table
.
end
(),
[
&
](
auto
&&
pp
)
{
auto
*
x
=
as
.
get_segment
(
pp
.
first
);
return
std
::
any_of
(
pp
.
second
.
begin
(),
pp
.
second
.
end
(),
[
&
](
auto
ins
)
{
auto
*
y
=
as
.
get_segment
(
ins
);
assert
(
x
and
y
);
return
is_overlap
(
*
x
,
*
y
);
});
}));
// Print out segments
if
(
enabled
(
MIGRAPHX_DEBUG_MEMORY_COLORING
{}))
{
for
(
auto
&&
pp
:
conflict_table
)
{
std
::
cout
<<
"------- conflict -------"
<<
std
::
endl
;
auto
s1
=
as
.
ins2segment
.
at
(
pp
.
first
);
std
::
cout
<<
s1
.
first
<<
", "
<<
s1
.
second
<<
": "
;
m
.
debug_print
(
pp
.
first
);
for
(
auto
ins
:
pp
.
second
)
{
auto
s2
=
as
.
ins2segment
.
at
(
ins
);
std
::
cout
<<
s2
.
first
<<
", "
<<
s2
.
second
<<
": "
;
m
.
debug_print
(
ins
);
}
}
}
// Total memory
std
::
size_t
n
=
as
.
max
()
*
alignment
;
// Replace allocations
auto
mem
=
m
.
add_parameter
(
"scratch"
,
shape
{
shape
::
int8_type
,
{
n
}});
for
(
auto
&&
[
ins
,
seg
]
:
as
.
ins2segment
)
{
assert
(
ins
->
name
()
==
allocation_op
);
auto
s
=
ins
->
get_shape
();
std
::
size_t
offset
=
seg
.
first
*
alignment
;
assert
(
offset
<
n
);
m
.
replace_instruction
(
ins
,
op
::
load
{
s
,
offset
},
mem
);
}
// Replace zero allocation
for
(
auto
ins
:
iterator_for
(
m
))
{
if
(
ins
->
name
()
!=
allocation_op
)
continue
;
assert
(
ins
->
get_shape
().
bytes
()
==
0
);
m
.
replace_instruction
(
ins
,
op
::
load
{
ins
->
get_shape
(),
0
},
mem
);
}
// Remove scratch parameter if its not used
if
(
mem
->
outputs
().
empty
())
{
m
.
remove_instruction
(
mem
);
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/module.cpp
View file @
30c49503
...
...
@@ -166,6 +166,7 @@ void module::assign(const module& m)
auto
s
=
ins
->
get_shape
();
copy_ins
=
impl
->
insert
(
impl
->
instructions
.
end
(),
{
builtin
::
param
{
name
,
order
},
std
::
move
(
s
),
{}});
impl
->
nparams
++
;
}
else
if
(
ins
->
name
()
==
"@outline"
)
{
...
...
@@ -822,7 +823,8 @@ static void print_make_op(std::ostream& os, const operation& op)
static
void
print_py_shape
(
std
::
ostream
&
os
,
const
migraphx
::
shape
&
s
)
{
os
<<
"migraphx.shape("
<<
s
.
type_string
()
<<
", lens="
<<
to_json_string
(
s
.
lens
());
os
<<
"migraphx.shape(type="
<<
to_json_string
(
s
.
type_string
())
<<
", lens="
<<
to_json_string
(
s
.
lens
());
if
(
not
s
.
standard
())
os
<<
", strides="
<<
to_json_string
(
s
.
strides
());
os
<<
")"
;
...
...
src/normalize_attributes.cpp
View file @
30c49503
...
...
@@ -30,13 +30,16 @@
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
// different attributes
// 1) use_input(default)/use_output
// 2) use_rank(default)/use_len
// 3) clip_min(default)/not_clip_min
// 3.1) include_min(default)/exclude_min
// 4) clip_max(default)/not_clip_max
// 4.1) exclude_max(default)/include_max
/**
* Parameters:
* vec: the vector attribute to normalize
* axes: the operator's axes attribute if it exists, empty otherwise
* val: the normalize_axes key and options. Ex: normalize["axes"] =
* value::array{normalize_attribute::include_min}; lens: shape dimensions passed when calling
* normalize_attributes(op&, lens)
*
* See normalize_attribute.hpp for explaining the options.
*/
auto
tune_attribute
(
const
std
::
vector
<
int64_t
>&
vec
,
const
std
::
vector
<
int64_t
>&
axes
,
const
value
&
val
,
...
...
@@ -151,6 +154,11 @@ auto tune_pad_attribute(const value& val)
return
result
;
}
/**
* Assumptions:
* Dimensions to pad start from the third dimension (index 2).
* Called by compute_shape_op() with the `lens` of the first input.
*/
bool
normalize_attributes
(
operation
&
op
,
const
std
::
vector
<
std
::
size_t
>&
lens
)
{
bool
tuned
=
false
;
...
...
@@ -158,9 +166,8 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens)
auto
val
=
op
.
to_value
();
if
(
attrs
.
contains
(
"normalize_padding"
))
{
auto
padding
=
val
.
at
(
attrs
.
at
(
"normalize_padding"
).
to
<
std
::
string
>
());
auto
padding_size
=
padding
.
size
();
// for now, assume the dimensions to pad start at dim 2
auto
padding
=
val
.
at
(
attrs
.
at
(
"normalize_padding"
).
to
<
std
::
string
>
());
auto
padding_size
=
padding
.
size
();
auto
padding_start
=
2
;
if
(
padding_size
==
2
*
(
lens
.
size
()
-
padding_start
))
...
...
src/onnx/include/migraphx/onnx/onnx_parser.hpp
View file @
30c49503
...
...
@@ -113,7 +113,8 @@ struct onnx_parser
void
parse_from
(
std
::
istream
&
is
,
std
::
string
name
=
""
);
void
parse_from
(
const
void
*
data
,
std
::
size_t
size
);
void
parse_graph
(
module
*
mod
,
const
onnx
::
GraphProto
&
graph
);
std
::
vector
<
instruction_ref
>
parse_graph
(
module
*
mod
,
const
onnx
::
GraphProto
&
graph
,
bool
inlining
=
false
);
literal
parse_value
(
const
onnx
::
AttributeProto
&
attr
)
const
;
literal
parse_tensor
(
const
onnx
::
TensorProto
&
t
)
const
;
shape
parse_type
(
const
onnx
::
TypeProto
&
t
,
const
std
::
vector
<
std
::
size_t
>&
input_dims
)
const
;
...
...
src/onnx/onnx_parser.cpp
View file @
30c49503
...
...
@@ -220,7 +220,7 @@ void onnx_parser::parse_from(std::istream& is, std::string name)
if
(
model
.
has_graph
())
{
this
->
parse_graph
(
mm
,
model
.
graph
());
(
void
)
this
->
parse_graph
(
mm
,
model
.
graph
());
}
}
else
...
...
@@ -240,7 +240,7 @@ void onnx_parser::parse_from(const void* data, std::size_t size)
if
(
model
.
has_graph
())
{
this
->
parse_graph
(
mm
,
model
.
graph
());
(
void
)
this
->
parse_graph
(
mm
,
model
.
graph
());
}
}
else
...
...
@@ -264,7 +264,8 @@ int64_t onnx_parser::get_opset_version(const onnx::ModelProto& model)
return
version
;
}
void
onnx_parser
::
parse_graph
(
module
*
mod
,
const
onnx
::
GraphProto
&
graph
)
std
::
vector
<
instruction_ref
>
onnx_parser
::
parse_graph
(
module
*
mod
,
const
onnx
::
GraphProto
&
graph
,
bool
inlining
)
{
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
mod_insts
;
for
(
auto
&&
f
:
graph
.
initializer
())
...
...
@@ -372,11 +373,16 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
std
::
back_inserter
(
output_ins
),
[
&
](
const
auto
&
name
)
{
return
instructions
[
name
];
});
// add the return instuction
mod
->
add_return
(
output_ins
);
if
(
not
inlining
)
{
// add the return instuction
mod
->
add_return
(
output_ins
);
// Remove instructions added in module (this is turned off for subgraph inlining)
erase_if
(
instructions
,
[
&
](
auto
&&
p
)
{
return
mod
->
has_instruction
(
p
.
second
);
});
}
// remove instructions added in this mod
erase_if
(
instructions
,
[
&
](
auto
&&
p
)
{
return
mod
->
has_instruction
(
p
.
second
);
});
return
output_ins
;
}
literal
onnx_parser
::
parse_value
(
const
onnx
::
AttributeProto
&
attr
)
const
...
...
src/onnx/parse_gemm.cpp
View file @
30c49503
...
...
@@ -90,41 +90,45 @@ struct parse_gemm : op_parser<parse_gemm>
?
info
.
add_instruction
(
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
args
[
1
])
:
args
[
1
];
auto
ret
=
info
.
add_instruction
(
make_op
(
"dot"
),
a_arg
,
b_arg
);
auto
dot_ins
=
info
.
add_instruction
(
make_op
(
"dot"
),
a_arg
,
b_arg
);
if
(
args
.
size
()
==
3
)
{
// TODO: support dynamic C input
if
(
std
::
any_of
(
args
.
cbegin
(),
args
.
cend
(),
[](
auto
in_arg
)
{
return
in_arg
->
get_shape
().
dynamic
();
}))
if
(
not
float_equal
(
beta
,
0.0
f
))
{
MIGRAPHX_THROW
(
"PARSE_GEMM: C input not handled for dynamic input shapes"
);
}
if
(
not
float_equal
(
beta
,
0.0
f
)
and
args
[
2
]
->
get_shape
().
elements
()
>
0
)
{
auto
out_lens
=
a_arg
->
get_shape
().
lens
();
out_lens
.
back
()
=
b_arg
->
get_shape
().
lens
().
back
();
auto
c_arg
=
args
[
2
];
auto
c_lens
=
c_arg
->
get_shape
().
lens
();
if
(
not
std
::
equal
(
out_lens
.
begin
(),
out_lens
.
end
(),
c_lens
.
begin
(),
c_lens
.
end
()))
auto
c_arg
=
args
[
2
];
if
(
dot_ins
->
get_shape
().
dynamic
())
{
c_arg
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
),
args
[
2
],
dot_ins
);
}
else
{
c_arg
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
out_lens
}}),
args
[
2
]);
auto
out_lens
=
a_arg
->
get_shape
().
lens
();
out_lens
.
back
()
=
b_arg
->
get_shape
().
lens
().
back
();
auto
c_lens
=
c_arg
->
get_shape
().
lens
();
if
(
not
std
::
equal
(
out_lens
.
begin
(),
out_lens
.
end
(),
c_lens
.
begin
(),
c_lens
.
end
()))
{
c_arg
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
out_lens
}}),
args
[
2
]);
}
}
auto
beta_literal
=
info
.
add_literal
(
beta
);
auto
beta_c
=
info
.
add_broadcastable_binary_op
(
"mul"
,
c_arg
,
beta_literal
);
if
(
beta_c
->
get_shape
().
type
()
!=
dot_type
)
if
(
not
float_equal
(
beta
,
1.0
f
))
{
beta_c
=
info
.
add_instruction
(
make_op
(
"convert"
,
{{
"target_type"
,
dot_type
}}),
beta_c
);
auto
beta_literal
=
info
.
add_literal
(
beta
);
c_arg
=
info
.
add_broadcastable_binary_op
(
"mul"
,
c_arg
,
beta_literal
);
if
(
c_arg
->
get_shape
().
type
()
!=
dot_type
)
{
c_arg
=
info
.
add_instruction
(
make_op
(
"convert"
,
{{
"target_type"
,
dot_type
}}),
c_arg
);
}
}
return
info
.
add_instruction
(
make_op
(
"add"
),
ret
,
beta_c
);
return
info
.
add_instruction
(
make_op
(
"add"
),
dot_ins
,
c_arg
);
}
}
return
ret
;
return
dot_ins
;
}
};
...
...
src/onnx/parse_if.cpp
View file @
30c49503
...
...
@@ -51,6 +51,24 @@ struct parse_if : op_parser<parse_if>
" condition input can have only one element!"
);
}
// Fold instruction if condition is constant thus can be evaled
// prior to inference
if
(
args
.
front
()
->
can_eval
())
{
auto
cond_arg
=
args
.
front
()
->
eval
();
auto
*
mod
=
info
.
mod
;
// then branch
if
(
cond_arg
.
at
<
bool
>
())
{
return
parser
.
parse_graph
(
mod
,
then_graph
,
true
);
}
// else branch
else
{
return
parser
.
parse_graph
(
mod
,
else_graph
,
true
);
}
}
std
::
string
then_name
=
info
.
name
+
"_if"
;
module_ref
then_mdl
=
parser
.
prog
.
create_module
(
then_name
);
...
...
@@ -58,10 +76,10 @@ struct parse_if : op_parser<parse_if>
module_ref
else_mdl
=
parser
.
prog
.
create_module
(
else_name
);
// parse the then sub_graph
parser
.
parse_graph
(
then_mdl
,
then_graph
);
(
void
)
parser
.
parse_graph
(
then_mdl
,
then_graph
);
// parse_the else sub_graph
parser
.
parse_graph
(
else_mdl
,
else_graph
);
(
void
)
parser
.
parse_graph
(
else_mdl
,
else_graph
);
auto
then_out_shapes
=
then_mdl
->
get_output_shapes
();
auto
else_out_shapes
=
else_mdl
->
get_output_shapes
();
...
...
src/onnx/parse_loop.cpp
View file @
30c49503
...
...
@@ -71,7 +71,7 @@ struct parse_loop : op_parser<parse_loop>
module_ref
sub_mod
=
parser
.
prog
.
create_module
(
mod_name
);
// parse the sub_graph
parser
.
parse_graph
(
sub_mod
,
sub_graph
);
(
void
)
parser
.
parse_graph
(
sub_mod
,
sub_graph
);
auto
ret
=
info
.
add_instruction
(
make_op
(
"loop"
,
{{
"max_iterations"
,
max_iterations
}}),
args
,
{
sub_mod
});
...
...
src/onnx/parse_slice.cpp
View file @
30c49503
...
...
@@ -46,7 +46,7 @@ struct parse_slice : op_parser<parse_slice>
std
::
vector
<
int64_t
>
steps
;
// slice can have up to 5 inputs, we first check the 5th one
// to decide whether MIGRAPHX can handle this slice
// to decide whether MIGRAPHX can handle this slice
.
if
(
args
.
size
()
==
5
)
{
migraphx
::
argument
step_arg
=
args
.
back
()
->
eval
();
...
...
@@ -90,9 +90,10 @@ struct parse_slice : op_parser<parse_slice>
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
op
.
starts
));
});
}
// If axes arg is not given, the default is all of them.
if
(
op
.
axes
.
empty
())
{
std
::
vector
<
int64_t
>
axes
(
args
[
0
]
->
get_shape
().
lens
().
size
());
std
::
vector
<
int64_t
>
axes
(
args
[
0
]
->
get_shape
().
ndim
());
std
::
iota
(
axes
.
begin
(),
axes
.
end
(),
int64_t
{
0
});
op
.
axes
=
axes
;
}
...
...
@@ -103,6 +104,7 @@ struct parse_slice : op_parser<parse_slice>
assert
(
op
.
axes
.
size
()
==
op
.
starts
.
size
());
assert
(
op
.
axes
.
size
()
==
op
.
ends
.
size
());
// If any axes have negative step, prepare to add a "reverse" op
for
(
auto
i
:
range
(
steps
.
size
()))
{
if
(
steps
[
i
]
>=
0
)
...
...
@@ -117,7 +119,10 @@ struct parse_slice : op_parser<parse_slice>
auto
ins
=
info
.
add_instruction
(
op
,
args
[
0
]);
if
(
not
raxes
.
empty
())
{
ins
=
info
.
add_instruction
(
make_op
(
"reverse"
,
{{
"axes"
,
raxes
}}),
ins
);
}
// If any steps are other than default 1, add a "steps" op
if
(
std
::
any_of
(
steps
.
begin
(),
steps
.
end
(),
[](
auto
s
)
{
return
std
::
abs
(
s
)
!=
1
;
}))
{
std
::
vector
<
int64_t
>
nsteps
;
...
...
test/context_test
.cpp
→
src/onnx/parse_trilu
.cpp
View file @
30c49503
...
...
@@ -21,20 +21,70 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/
serialize
.hpp>
#include <migraphx/
c
on
text
.hpp>
#include <migraphx/r
ef/context
.hpp>
#include <migraphx/
fun
ction
al
.hpp>
#include <
test
.hpp>
#include <migraphx/
onnx/op_parser
.hpp>
#include <migraphx/on
nx/checks
.hpp>
#include <migraphx/r
anges
.hpp>
#include <migraphx/
instru
ction.hpp>
#include <
migraphx/make_op
.hpp>
TEST_CASE
(
context
)
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
onnx
{
struct
parse_trilu
:
op_parser
<
parse_trilu
>
{
migraphx
::
context
ctx
=
migraphx
::
ref
::
context
{};
migraphx
::
value
v
=
ctx
.
to_value
();
EXPECT
(
v
.
empty
());
std
::
vector
<
op_desc
>
operators
()
const
{
return
{{
"Trilu"
}};
}
instruction_ref
parse
(
const
op_desc
&
,
const
onnx_parser
&
,
const
onnx_parser
::
node_info
&
info
,
std
::
vector
<
instruction_ref
>
args
)
const
{
auto
input_shape
=
args
[
0
]
->
get_shape
();
assert
(
input_shape
.
ndim
()
>=
2
);
auto
input_lens
=
input_shape
.
lens
();
size_t
num_rows
=
*
(
input_lens
.
rbegin
()
+
1
);
size_t
num_cols
=
input_lens
.
back
();
int
k
=
0
;
bool
upper
=
true
;
if
(
args
.
size
()
>
1
)
{
auto
arg_k
=
args
[
1
]
->
eval
();
check_arg_empty
(
arg_k
,
"PARSE_TRILU: dynamic k not supported"
);
k
=
arg_k
.
at
<
int
>
();
}
if
(
k
<
0
)
MIGRAPHX_THROW
(
"PARSE_TRILU: negative k values not supported"
);
if
(
contains
(
info
.
attributes
,
"upper"
))
{
upper
=
static_cast
<
bool
>
(
info
.
attributes
.
at
(
"upper"
).
i
());
}
shape
::
type_t
output_type
=
args
[
0
]
->
get_shape
().
type
();
// when creating the mask, if upper == 1,
// the inner triangle will have values set to 0
std
::
vector
<
bool
>
mask_mat
(
num_rows
*
num_cols
,
upper
);
for
(
size_t
i
=
0
;
i
<
num_rows
;
i
++
)
{
for
(
size_t
j
=
0
;
j
<
std
::
min
(
k
,
static_cast
<
int
>
(
num_cols
));
j
++
)
{
mask_mat
[
i
*
num_cols
+
j
]
=
not
upper
;
}
k
++
;
}
auto
mask
=
info
.
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
output_type
,
{
num_rows
,
num_cols
}},
mask_mat
});
migraphx
::
context
cpu_ctx
=
migraphx
::
ref
::
context
{}
;
cpu_ctx
.
from_value
(
v
);
}
return
info
.
add_broadcastable_binary_op
(
"mul"
,
mask
,
args
[
0
])
;
}
}
;
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
}
// namespace onnx
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/onnx/parse_where.cpp
View file @
30c49503
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
...
...
@@ -40,28 +40,44 @@ struct parse_where : op_parser<parse_where>
const
onnx_parser
::
node_info
&
info
,
std
::
vector
<
instruction_ref
>
args
)
const
{
auto
lens
=
compute_broadcasted_lens
(
args
[
0
]
->
get_shape
().
lens
(),
args
[
1
]
->
get_shape
().
lens
());
lens
=
compute_broadcasted_lens
(
lens
,
args
[
2
]
->
get_shape
().
lens
());
if
(
args
[
0
]
->
get_shape
().
lens
()
!=
lens
)
// TODO: broadcasting for dynamic shapes is only implemented
// for binary ops at time of writing, not ternary ops.
// When it becomes available, add multibroadcasting steps in the dynamic shape case.
// For now for dynamic shapes, just insert the Where op. All shapes must be the
// same for it to succeed.
if
(
std
::
all_of
(
args
.
begin
(),
args
.
end
(),
[](
auto
v
)
{
return
v
->
get_shape
().
dynamic
();
}))
{
args
[
0
]
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
lens
}}),
args
[
0
]);
return
info
.
add_instruction
(
make_op
(
"where"
),
args
[
0
],
args
[
1
],
args
[
2
]);
}
if
(
args
[
1
]
->
get_shape
().
lens
()
!=
lens
)
else
if
(
std
::
none_of
(
args
.
begin
(),
args
.
end
(),
[](
auto
v
)
{
return
v
->
get_shape
().
dynamic
();
})
)
{
args
[
1
]
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
lens
}}),
args
[
1
]);
}
// If shapes are static and any are broadcasted, insert multibroadcast ops
auto
lens
=
compute_broadcasted_lens
(
args
[
0
]
->
get_shape
().
lens
(),
args
[
1
]
->
get_shape
().
lens
());
lens
=
compute_broadcasted_lens
(
lens
,
args
[
2
]
->
get_shape
().
lens
());
if
(
args
[
0
]
->
get_shape
().
lens
()
!=
lens
)
{
args
[
0
]
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
lens
}}),
args
[
0
]);
}
if
(
args
[
2
]
->
get_shape
().
lens
()
!=
lens
)
{
args
[
2
]
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
lens
}}),
args
[
2
]);
}
if
(
args
[
1
]
->
get_shape
().
lens
()
!=
lens
)
{
args
[
1
]
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
lens
}}),
args
[
1
]);
}
if
(
args
[
2
]
->
get_shape
().
lens
()
!=
lens
)
{
args
[
2
]
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
lens
}}),
args
[
2
]);
}
return
info
.
add_instruction
(
make_op
(
"where"
),
args
[
0
],
args
[
1
],
args
[
2
]);
return
info
.
add_instruction
(
make_op
(
"where"
),
args
[
0
],
args
[
1
],
args
[
2
]);
}
else
MIGRAPHX_THROW
(
"PARSE_WHERE: doesn't support mixed static and dynamic shape inputs"
);
}
};
...
...
src/opt/memory_coloring_impl.cpp
deleted
100644 → 0
View file @
870a396b
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include "memory_coloring_impl.hpp"
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
void
memory_coloring_impl
::
run
()
{
// calc implicit depdendencies
mod_implicit_deps
=
p_mod
->
calc_implicit_deps
();
MIGRAPHX_DEBUG
(
dump
(
"---Before memory coloring---"
));
MIGRAPHX_DEBUG
(
dump_module
());
build
();
if
(
num_of_lives
!=
0
)
{
MIGRAPHX_DEBUG
(
dump_intervals
());
// Coloring
while
(
not
alloc_queue
.
empty
())
{
interval_ptr
interval
=
alloc_queue
.
top
();
allocate
(
interval
);
alloc_queue
.
pop
();
}
// rewrite happens after all modules are processed
rewrite
();
if
(
enable_verify
)
verify
();
}
}
bool
memory_coloring_impl
::
allocate
(
interval_ptr
interval
)
{
shape
s
=
interval
->
result
;
std
::
size_t
size
=
s
.
bytes
();
if
(
size
==
0
)
return
false
;
std
::
size_t
element_size
=
(
s
.
elements
()
==
0
?
4
:
(
size
/
s
.
elements
()));
live_range
&
segment
=
interval
->
segment
;
int
vn
=
segment
.
vn
;
std
::
priority_queue
<
live_range
*
,
std
::
vector
<
live_range
*>
,
ordering
>
conflict_queue
;
std
::
unordered_map
<
long
long
,
live_range
*>
offset2_live
;
offset2_live
.
clear
();
if
(
conflict_table
.
find
(
vn
)
!=
conflict_table
.
end
())
{
const
std
::
set
<
int
>&
vn_set
=
conflict_table
[
vn
];
for
(
const
auto
&
iter
:
vn_set
)
{
live_range
*
range
=
live_ranges
[
iter
];
long
long
offset
=
range
->
offset
;
if
(
offset
!=
invalid_offset
)
{
conflict_queue
.
push
(
range
);
if
(
offset2_live
.
find
(
offset
)
==
offset2_live
.
end
())
{
offset2_live
[
offset
]
=
range
;
}
else
{
live_range
*
prev
=
offset2_live
[
offset
];
assert
(
prev
->
offset
==
offset
);
if
(
prev
->
size
<
range
->
size
)
offset2_live
[
offset
]
=
range
;
}
}
}
}
std
::
size_t
offset
=
0
;
while
(
not
conflict_queue
.
empty
())
{
live_range
*
range
=
conflict_queue
.
top
();
std
::
size_t
iter_offset
=
range
->
offset
;
if
(
offset
>
iter_offset
)
{
offset
=
std
::
max
(
offset
,
iter_offset
+
range
->
size
);
}
else
if
(
offset2_live
[
iter_offset
]
==
range
)
{
if
((
iter_offset
>
offset
)
&&
(
iter_offset
-
offset
)
>=
size
)
{
break
;
}
offset
=
iter_offset
+
range
->
size
;
}
// alignment
if
((
offset
%
element_size
)
!=
0
)
offset
+=
(
element_size
-
(
offset
%
element_size
));
conflict_queue
.
pop
();
}
// when int8 type is used, the offset could be any number
// if not 4-byte aligned, miopen int8 convolution can crash
offset
=
(
offset
+
3
)
/
4
*
4
;
segment
.
offset
=
offset
;
MIGRAPHX_DEBUG
(
segment
.
dump
());
required_bytes
=
std
::
max
(
required_bytes
,
offset
+
segment
.
size
);
return
true
;
}
void
memory_coloring_impl
::
build
()
{
std
::
size_t
num_of_instrs
=
p_mod
->
size
();
if
(
num_of_instrs
==
0
)
return
;
auto
cur_points
=
num_of_instrs
*
2
;
instruction_ref
iter
=
p_mod
->
end
();
instruction_ref
begin
=
p_mod
->
begin
();
std
::
vector
<
instruction_ref
>
dead_instrs
;
std
::
set
<
int
>
live_set
;
// Build live intervals.
live_intervals
.
resize
(
num_of_instrs
);
do
{
iter
=
std
::
prev
(
iter
);
const
instruction
*
p_iter
=
&
(
*
iter
);
interval_ptr
def_interval
=
nullptr
;
bool
is_dead
=
false
;
if
(
instr2_live
.
find
(
p_iter
)
!=
instr2_live
.
end
())
{
def_interval
=
instr2_live
[
p_iter
];
bool
is_lit
=
is_literal
(
iter
);
if
(
is_allocate
(
iter
)
or
is_lit
)
{
live_range
&
range
=
def_interval
->
segment
;
def_interval
->
result
=
iter
->
get_shape
();
def_interval
->
is_literal
=
is_lit
;
range
.
begin
=
cur_points
;
def_interval
->
def_point
=
cur_points
;
range
.
size
=
(
iter
->
get_shape
()).
bytes
();
if
(
not
is_lit
or
unify_literals
)
alloc_queue
.
push
(
def_interval
);
live_set
.
erase
(
range
.
vn
);
}
}
else
if
(
not
is_param
(
iter
)
&&
not
is_outline
(
iter
)
&&
not
is_check_context
(
iter
))
{
is_dead
=
true
;
}
auto
inputs
=
iter
->
inputs
();
if
(
contains
(
mod_implicit_deps
,
iter
))
{
const
auto
&
impl_deps
=
mod_implicit_deps
.
at
(
iter
);
inputs
.
insert
(
inputs
.
end
(),
impl_deps
.
begin
(),
impl_deps
.
end
());
}
for
(
auto
&&
arg
:
inputs
)
{
if
(
not
p_mod
->
has_instruction
(
arg
))
continue
;
if
(
is_param
(
arg
)
or
is_outline
(
arg
))
{
if
(
is_output_param
(
arg
))
is_dead
=
false
;
if
(
def_interval
!=
nullptr
)
{
def_interval
->
is_live_on_entry
=
true
;
}
continue
;
}
const
instruction
*
p_arg
=
&
(
*
instruction
::
get_output_alias
(
arg
));
if
(
instr2_live
.
find
(
p_arg
)
==
instr2_live
.
end
())
{
// First time see a use, create a live interval.
int
id
=
num_of_lives
++
;
interval_ptr
interval
=
&
(
live_intervals
[
id
]);
interval
->
id
=
id
;
interval
->
segment
.
end
=
cur_points
;
interval
->
segment
.
vn
=
++
max_value_number
;
interval
->
add_use
(
cur_points
);
instr2_live
[
p_arg
]
=
interval
;
add_conflicts
(
live_set
,
max_value_number
);
live_set
.
insert
(
max_value_number
);
live_ranges
[
max_value_number
]
=
&
(
interval
->
segment
);
earliest_end_point
=
cur_points
;
if
(
latest_end_point
==
-
1
)
latest_end_point
=
cur_points
;
}
else
{
interval_ptr
interval
=
instr2_live
[
p_arg
];
interval
->
add_use
(
cur_points
);
assert
(
live_set
.
find
(
interval
->
id
)
!=
live_set
.
end
());
}
}
if
(
is_dead
)
dead_instrs
.
push_back
(
iter
);
cur_points
-=
2
;
}
while
(
iter
!=
begin
);
}
void
memory_coloring_impl
::
rewrite
()
{
std
::
vector
<
std
::
size_t
>
dims
;
dims
.
push_back
((
required_bytes
+
sizeof
(
float
)
-
1
)
/
sizeof
(
float
));
shape
s
=
{
shape
::
float_type
,
dims
};
instruction_ref
scratch_param
=
p_mod
->
add_parameter
(
"scratch"
,
s
);
for
(
auto
ins
:
iterator_for
(
*
p_mod
))
{
const
instruction
*
p_iter
=
&
(
*
ins
);
if
(
instr2_live
.
find
(
p_iter
)
!=
instr2_live
.
end
())
{
interval_ptr
interval
=
instr2_live
[
p_iter
];
if
(
interval
->
get_begin
()
==
invalid_offset
)
continue
;
if
(
not
unify_literals
&&
interval
->
is_literal
)
continue
;
std
::
size_t
offset
=
0
;
if
(
interval
->
get_offset
()
!=
invalid_offset
)
{
offset
=
interval
->
get_offset
();
}
else
{
assert
(
interval
->
result
.
bytes
()
==
0
);
}
if
(
is_allocate
(
ins
))
{
p_mod
->
replace_instruction
(
ins
,
make_op
(
"load"
,
{{
"shape"
,
to_value
(
ins
->
get_shape
())},
{
"offset"
,
offset
}}),
scratch_param
);
}
}
}
MIGRAPHX_DEBUG
(
dump
(
"---After rewrite---"
));
MIGRAPHX_DEBUG
(
dump_module
());
}
void
memory_coloring_impl
::
verify
()
{
if
(
num_of_lives
>
0
)
{
for
(
int
i
=
0
;
i
<
num_of_lives
;
++
i
)
{
const
live_interval
&
interval
=
live_intervals
[
i
];
const
live_range
&
segment
=
interval
.
segment
;
if
(
segment
.
begin
==
invalid_offset
)
{
// if(not interval.is_live_on_entry)
// MIGRAPHX_THROW("interval is not live on entry");
continue
;
}
if
(
segment
.
offset
==
invalid_offset
)
{
continue
;
}
int
vn
=
segment
.
vn
;
if
(
conflict_table
.
find
(
vn
)
!=
conflict_table
.
end
())
{
const
std
::
set
<
int
>&
vn_set
=
conflict_table
[
vn
];
for
(
const
auto
&
iter
:
vn_set
)
{
live_range
*
range
=
live_ranges
[
iter
];
if
(
range
->
offset
==
invalid_offset
)
continue
;
if
(
not
is_disjoin
(
*
range
,
segment
))
MIGRAPHX_THROW
(
"range and segment is not disjoined"
);
}
}
}
}
}
#ifdef MIGRAPHX_DEBUG_OPT
void
memory_coloring_impl
::
dump
(
const
std
::
string
&
str
)
{
std
::
cout
<<
str
<<
std
::
endl
;
}
void
memory_coloring_impl
::
dump_module
()
{
std
::
cout
<<
*
p_mod
<<
std
::
endl
;
}
void
memory_coloring_impl
::
dump_intervals
()
{
if
(
num_of_lives
>
0
)
{
std
::
cout
<<
"---live intervals ---"
<<
std
::
endl
;
for
(
int
i
=
0
;
i
<
num_of_lives
;
++
i
)
{
live_interval
&
interval
=
live_intervals
[
i
];
interval
.
dump
();
}
std
::
cout
<<
"---conflict table---"
<<
std
::
endl
;
for
(
int
i
=
0
;
i
<=
max_value_number
;
++
i
)
{
std
::
cout
<<
" segment:"
<<
i
;
std
::
cout
<<
" =>"
;
const
std
::
set
<
int
>&
table
=
conflict_table
[
i
];
for
(
const
auto
&
iter
:
table
)
{
std
::
cout
<<
(
iter
)
<<
","
;
}
}
std
::
cout
<<
std
::
endl
;
}
}
// map liveness tracking point to instruction enum.
static
int
get_ins_enum
(
int
x
)
{
if
(
x
>
0
)
{
return
(
x
/
2
)
-
1
;
}
else
return
invalid_offset
;
}
void
live_range
::
dump
()
{
std
::
cout
<<
" segment:"
<<
vn
;
std
::
cout
<<
" ["
<<
get_ins_enum
(
begin
)
<<
", "
<<
get_ins_enum
(
end
)
<<
"]"
;
if
(
offset
!=
invalid_offset
)
{
std
::
cout
<<
" mem:"
;
std
::
cout
<<
" ["
<<
offset
<<
","
<<
offset
+
size
-
1
<<
"]"
;
}
std
::
cout
<<
std
::
endl
;
}
void
live_interval
::
dump
()
{
std
::
cout
<<
"id:"
<<
id
;
segment
.
dump
();
std
::
cout
<<
" uses:"
;
for
(
const
auto
&
iter
:
use_points
)
{
std
::
cout
<<
" "
<<
get_ins_enum
(
iter
)
<<
","
;
}
std
::
cout
<<
" def:"
;
std
::
cout
<<
" "
<<
get_ins_enum
(
def_point
);
if
(
is_literal
)
std
::
cout
<<
" literal"
;
std
::
cout
<<
" "
<<
result
;
std
::
cout
<<
std
::
endl
;
}
#endif
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/opt/memory_coloring_impl.hpp
deleted
100644 → 0
View file @
870a396b
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_MEMORY_COLORING_IMPL_HPP
#define MIGRAPHX_GUARD_RTGLIB_MEMORY_COLORING_IMPL_HPP
#include <migraphx/program.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/pass_config.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/config.hpp>
#include <set>
#include <list>
#include <vector>
#include <queue>
#ifdef MIGRAPHX_DEBUG_OPT
#define MIGRAPHX_DEBUG(s) s
#else
#define MIGRAPHX_DEBUG(s)
#endif // MIGRAPHX_DEBUG_OPT
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
static
const
std
::
size_t
invalid_offset
=
std
::
numeric_limits
<
std
::
size_t
>::
max
();
struct
live_range
{
std
::
size_t
begin
;
// begin point in the instruction stream.
std
::
size_t
end
;
// end point in the instruction stream.
std
::
size_t
offset
;
// offset to base pointer of allocated memory trunk.
std
::
size_t
vn
;
// value number that identifies this live_range.
std
::
size_t
size
;
// size of required memory in bytes
#ifdef MIGRAPHX_DEBUG_OPT
void
dump
();
#endif
};
struct
live_interval
{
live_interval
()
:
segment
({
invalid_offset
,
invalid_offset
,
invalid_offset
,
invalid_offset
,
0
})
{
}
void
add_use
(
std
::
size_t
use
)
{
use_points
.
push_front
(
use
);
}
std
::
size_t
get_begin
()
const
{
return
segment
.
begin
;
}
std
::
size_t
get_end
()
const
{
return
segment
.
end
;
}
long
long
get_offset
()
const
{
return
segment
.
offset
;
}
#ifdef MIGRAPHX_DEBUG_OPT
void
dump
();
#endif
live_range
segment
;
std
::
size_t
id
=
invalid_offset
;
std
::
list
<
std
::
size_t
>
use_points
{};
std
::
size_t
def_point
=
invalid_offset
;
shape
result
{};
bool
is_literal
=
false
;
bool
is_live_on_entry
=
false
;
};
using
interval_ptr
=
live_interval
*
;
struct
memory_coloring_impl
{
memory_coloring_impl
(
module
*
p
,
std
::
string
alloc_op
,
bool
p_verify
)
:
p_mod
(
p
),
allocation_op
(
std
::
move
(
alloc_op
)),
enable_verify
(
p_verify
)
{
}
bool
allocate
(
interval_ptr
);
void
add_conflicts
(
const
std
::
set
<
int
>&
live_set
,
int
val
)
{
for
(
const
auto
&
iter
:
live_set
)
{
conflict_table
[
iter
].
insert
(
val
);
conflict_table
[
val
].
insert
(
iter
);
}
}
void
build
();
void
run
();
void
rewrite
();
private:
static
bool
is_param
(
const
instruction_ref
ins
)
{
return
ins
->
name
()
==
"@param"
;
}
static
bool
is_output_param
(
const
instruction_ref
ins
)
{
if
(
not
is_param
(
ins
))
return
false
;
auto
param_name
=
any_cast
<
builtin
::
param
>
(
ins
->
get_operator
()).
parameter
;
return
contains
(
param_name
,
"#output_"
);
}
bool
is_allocate
(
const
instruction_ref
ins
)
const
{
return
ins
->
name
()
==
allocation_op
;
}
static
bool
is_outline
(
const
instruction_ref
ins
)
{
return
ins
->
name
()
==
"@outline"
;
}
static
bool
is_literal
(
const
instruction_ref
ins
)
{
return
ins
->
name
()
==
"@literal"
;
}
static
bool
is_check_context
(
const
instruction_ref
ins
)
{
return
ins
->
name
()
==
"check_context"
;
}
static
bool
is_disjoin
(
const
live_range
&
range1
,
const
live_range
&
range2
)
{
if
((
range1
.
size
==
0
)
or
(
range2
.
size
==
0
))
return
false
;
auto
end1
=
range1
.
offset
+
range1
.
size
-
1
;
auto
end2
=
range2
.
offset
+
range2
.
size
-
1
;
return
((
end1
<
range2
.
offset
)
or
(
end2
<
range1
.
offset
));
}
void
verify
();
#ifdef MIGRAPHX_DEBUG_OPT
void
dump
(
const
std
::
string
&
);
void
dump_module
();
void
dump_intervals
();
#endif
struct
ordering
{
bool
operator
()(
const
interval_ptr
&
i1
,
const
interval_ptr
&
i2
)
const
{
auto
len1
=
i1
->
get_end
()
-
i1
->
get_begin
();
auto
len2
=
i2
->
get_end
()
-
i2
->
get_begin
();
if
(
len1
!=
len2
)
{
return
(
len1
<
len2
);
}
else
if
(
i1
->
result
.
bytes
()
!=
i2
->
result
.
bytes
())
{
return
(
i1
->
result
.
bytes
()
<
i2
->
result
.
bytes
());
}
else
{
return
i1
->
id
>
i2
->
id
;
}
}
bool
operator
()(
const
live_range
*
i1
,
const
live_range
*
i2
)
const
{
return
(
i1
->
offset
>
i2
->
offset
);
}
};
module
*
p_mod
;
std
::
unordered_map
<
const
instruction
*
,
interval_ptr
>
instr2_live
;
// universe of live intervals.
std
::
vector
<
live_interval
>
live_intervals
=
{};
// Map live range value number to live range.
std
::
unordered_map
<
int
,
live_range
*>
live_ranges
=
{};
// Map live range value number to a set of conflicting live ranges' value numbers.
std
::
unordered_map
<
int
,
std
::
set
<
int
>>
conflict_table
=
{};
// Priority queue for coloring.
std
::
priority_queue
<
interval_ptr
,
std
::
vector
<
interval_ptr
>
,
ordering
>
alloc_queue
{};
int
num_of_lives
=
0
;
int
max_value_number
=
-
1
;
std
::
size_t
required_bytes
=
0
;
// The earliest program point where an live interval ends.
int
earliest_end_point
=
-
1
;
// The latest program point where an live interval ends.
int
latest_end_point
=
-
1
;
// Whether to unify literals into coloring.
bool
unify_literals
=
false
;
std
::
string
allocation_op
{};
bool
enable_verify
;
ins_dep_map
mod_implicit_deps
;
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/opt
/memory_coloring
.cpp
→
src/opt
imize_module
.cpp
View file @
30c49503
...
...
@@ -21,18 +21,27 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/memory_coloring.hpp>
#include "memory_coloring_impl.hpp"
#include <migraphx/optimize_module.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/propagate_constant.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
void
memory_coloring
::
apply
(
module
&
m
)
const
void
optimize_module
::
apply
(
module_pass_manager
&
mp
m
)
const
{
if
(
not
enabled
(
MIGRAPHX_DISABLE_MEMORY_COLORING
{})
)
for
(
int
i
=
0
;
i
<
2
;
i
++
)
{
memory_coloring_impl
opt
(
&
m
,
allocation_op
,
verify
);
opt
.
run
();
mpm
.
run_pass
(
simplify_reshapes
{});
mpm
.
run_pass
(
simplify_algebra
{});
mpm
.
run_pass
(
eliminate_common_subexpression
{});
mpm
.
run_pass
(
dead_code_elimination
{});
mpm
.
run_pass
(
propagate_constant
{});
mpm
.
run_pass
(
dead_code_elimination
{});
}
}
...
...
src/pass_manager.cpp
View file @
30c49503
...
...
@@ -39,6 +39,7 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_TRACE_PASSES
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_TIME_PASSES
);
void
validate_pass
(
module
&
mod
,
const
pass
&
p
,
tracer
trace
)
{
...
...
@@ -94,19 +95,19 @@ struct module_pm : module_pass_manager
virtual
void
run_pass
(
const
pass
&
p
)
override
{
assert
(
mod
);
timer
ts
{};
using
seconds
=
std
::
chrono
::
duration
<
double
>
;
trace
(
"Module: "
,
mod
->
name
(),
", Pass: "
,
p
.
name
());
const
double
t1
=
ts
.
record
<
seconds
>
();
assert
(
mod
->
validate
()
==
mod
->
end
());
p
.
apply
(
*
this
);
if
(
enabled
(
MIGRAPHX_TIME_PASSES
{}))
{
using
milliseconds
=
std
::
chrono
::
duration
<
double
,
std
::
milli
>
;
auto
ms
=
time
<
milliseconds
>
([
&
]
{
p
.
apply
(
*
this
);
});
std
::
cout
<<
p
.
name
()
<<
": "
<<
ms
<<
"ms
\n
"
;
}
else
{
p
.
apply
(
*
this
);
}
trace
(
*
mod
);
validate_pass
(
*
mod
,
p
,
*
t
);
const
double
t2
=
ts
.
record
<
seconds
>
();
trace
(
"Pass: "
,
p
.
name
(),
" completed in (s): "
,
(
t2
-
t1
));
}
};
...
...
src/program.cpp
View file @
30c49503
...
...
@@ -210,17 +210,15 @@ void program::compile(const target& t, compile_options options)
assert
(
not
this
->
is_compiled
());
this
->
impl
->
target_name
=
t
.
name
();
this
->
impl
->
ctx
=
t
.
get_context
();
if
(
enabled
(
MIGRAPHX_TRACE_COMPILE
{}))
options
.
trace
=
tracer
{
std
::
cout
};
options
.
trace
(
*
this
);
options
.
trace
();
auto
&&
passes
=
t
.
get_passes
(
this
->
impl
->
ctx
,
options
);
run_passes
(
*
this
,
passes
,
options
.
trace
);
auto
mods
=
this
->
get_modules
();
// Validate and finalize
for
(
const
auto
&
mod
:
reverse
(
mods
))
{
...
...
@@ -336,7 +334,8 @@ std::vector<argument> generic_eval(const module* mod,
if
(
not
ins
->
get_shape
().
dynamic
()
and
param
.
get_shape
()
!=
ins
->
get_shape
())
{
MIGRAPHX_THROW
(
"Incorrect shape {"
+
to_string
(
param
.
get_shape
())
+
"} for parameter: "
+
param_name
);
"} for parameter: "
+
param_name
+
" should be: "
+
to_string
(
ins
->
get_shape
()));
}
return
param
;
}));
...
...
@@ -380,7 +379,7 @@ std::vector<argument> generic_eval(const module* mod,
}));
}
assert
(
results
.
find
(
ins
)
!=
results
.
end
());
if
(
not
ins
->
get_shape
().
dynamic
())
if
(
not
ins
->
get_shape
().
any_of_
dynamic
())
{
assert
(
results
.
at
(
ins
).
get_shape
()
==
ins
->
get_shape
());
}
...
...
src/propagate_constant.cpp
View file @
30c49503
...
...
@@ -44,7 +44,7 @@ bool skip_propogate(instruction_ref ins)
return
false
;
}
bool
is_const
(
instruction_ref
ins
)
{
return
ins
->
can_eval
()
and
not
skip_propogate
(
ins
);
}
bool
is_const
_ins
(
instruction_ref
ins
)
{
return
ins
->
can_eval
()
and
not
skip_propogate
(
ins
);
}
void
propagate_constant
::
apply
(
module
&
m
)
const
{
...
...
@@ -54,14 +54,23 @@ void propagate_constant::apply(module& m) const
// Find instructions that can be evaluated to a literal
for
(
auto
i
:
iterator_for
(
m
))
{
if
(
is_const
(
i
)
and
i
!=
last
)
const
bool
is_const
=
is_const_ins
(
i
);
if
(
is_const
and
i
!=
last
)
continue
;
std
::
copy_if
(
i
->
inputs
().
begin
(),
i
->
inputs
().
end
(),
std
::
inserter
(
const_instrs
,
const_instrs
.
begin
()),
[
&
](
const
instruction_ref
ins
)
{
return
is_const
(
ins
)
and
ins
->
name
()
!=
"@literal"
;
});
if
(
i
==
last
and
is_const
)
{
const_instrs
.
insert
(
i
);
}
else
{
std
::
copy_if
(
i
->
inputs
().
begin
(),
i
->
inputs
().
end
(),
std
::
inserter
(
const_instrs
,
const_instrs
.
begin
()),
[
&
](
const
instruction_ref
ins
)
{
return
is_const_ins
(
ins
)
and
ins
->
name
()
!=
"@literal"
;
});
}
}
// Compute literals in parallel
...
...
src/py/migraphx_py.cpp
View file @
30c49503
...
...
@@ -329,15 +329,21 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
.
def
(
"is_compiled"
,
&
migraphx
::
program
::
is_compiled
)
.
def
(
"compile"
,
[](
migraphx
::
program
&
p
,
const
migraphx
::
target
&
t
,
bool
offload_copy
,
bool
fast_math
)
{
[](
migraphx
::
program
&
p
,
const
migraphx
::
target
&
t
,
bool
offload_copy
,
bool
fast_math
,
bool
exhaustive_tune
)
{
migraphx
::
compile_options
options
;
options
.
offload_copy
=
offload_copy
;
options
.
fast_math
=
fast_math
;
options
.
offload_copy
=
offload_copy
;
options
.
fast_math
=
fast_math
;
options
.
exhaustive_tune
=
exhaustive_tune
;
p
.
compile
(
t
,
options
);
},
py
::
arg
(
"t"
),
py
::
arg
(
"offload_copy"
)
=
true
,
py
::
arg
(
"fast_math"
)
=
true
)
py
::
arg
(
"offload_copy"
)
=
true
,
py
::
arg
(
"fast_math"
)
=
true
,
py
::
arg
(
"exhaustive_tune"
)
=
false
)
.
def
(
"get_main_module"
,
[](
const
migraphx
::
program
&
p
)
{
return
p
.
get_main_module
();
})
.
def
(
"create_module"
,
...
...
src/register_op.cpp
View file @
30c49503
...
...
@@ -33,7 +33,17 @@ std::unordered_map<std::string, operation>& op_map()
static
std
::
unordered_map
<
std
::
string
,
operation
>
m
;
// NOLINT
return
m
;
}
void
register_op_init
()
{
(
void
)
op_map
();
}
void
register_op
(
const
operation
&
op
)
{
op_map
()[
op
.
name
()]
=
op
;
}
void
unregister_op
(
const
std
::
string
&
op_name
)
{
assert
(
op_map
().
count
(
op_name
));
op_map
().
erase
(
op_name
);
}
operation
load_op
(
const
std
::
string
&
name
)
{
return
at
(
op_map
(),
name
,
"Operator not found: "
+
name
);
...
...
Prev
1
2
3
4
5
6
7
8
…
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment