Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
3f322644
Commit
3f322644
authored
Mar 23, 2023
by
Alan Turner
Browse files
Merge remote-tracking branch 'origin/develop' into ck-gsg
parents
53aee707
09aaa63e
Changes
150
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
459 additions
and
712 deletions
+459
-712
src/include/migraphx/op/allocate.hpp
src/include/migraphx/op/allocate.hpp
+1
-1
src/include/migraphx/op/concat.hpp
src/include/migraphx/op/concat.hpp
+64
-25
src/include/migraphx/op/nonmaxsuppression.hpp
src/include/migraphx/op/nonmaxsuppression.hpp
+33
-31
src/include/migraphx/op/normalize_attribute.hpp
src/include/migraphx/op/normalize_attribute.hpp
+21
-9
src/include/migraphx/op/reverse.hpp
src/include/migraphx/op/reverse.hpp
+2
-0
src/include/migraphx/op/select_module.hpp
src/include/migraphx/op/select_module.hpp
+146
-0
src/include/migraphx/op/slice.hpp
src/include/migraphx/op/slice.hpp
+69
-28
src/include/migraphx/op/where.hpp
src/include/migraphx/op/where.hpp
+13
-5
src/include/migraphx/operation.hpp
src/include/migraphx/operation.hpp
+2
-0
src/include/migraphx/register_op.hpp
src/include/migraphx/register_op.hpp
+22
-1
src/include/migraphx/register_target.hpp
src/include/migraphx/register_target.hpp
+15
-1
src/include/migraphx/replace_allocate.hpp
src/include/migraphx/replace_allocate.hpp
+3
-0
src/include/migraphx/serialize.hpp
src/include/migraphx/serialize.hpp
+6
-12
src/include/migraphx/shape.hpp
src/include/migraphx/shape.hpp
+3
-0
src/module.cpp
src/module.cpp
+1
-0
src/normalize_attributes.cpp
src/normalize_attributes.cpp
+17
-10
src/onnx/parse_slice.cpp
src/onnx/parse_slice.cpp
+7
-2
src/onnx/parse_where.cpp
src/onnx/parse_where.cpp
+34
-18
src/opt/memory_coloring_impl.cpp
src/opt/memory_coloring_impl.cpp
+0
-376
src/opt/memory_coloring_impl.hpp
src/opt/memory_coloring_impl.hpp
+0
-193
No files found.
src/include/migraphx/op/allocate.hpp
View file @
3f322644
...
...
@@ -44,7 +44,7 @@ struct allocate
std
::
string
name
()
const
{
return
"allocate"
;
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
migraphx
::
check_shapes
{
inputs
,
*
this
}.
has
(
0
);
migraphx
::
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
0
);
return
s
;
}
argument
compute
(
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
)
const
...
...
src/include/migraphx/op/concat.hpp
View file @
3f322644
...
...
@@ -26,6 +26,7 @@
#include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
...
...
@@ -73,49 +74,87 @@ struct concat
}
return
offsets
;
}
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
if
(
inputs
.
empty
())
// inputs can contain 1 or more shapes (variadic). compute_shape_op ensures there must
// be at least 1.
check_shapes
{
inputs
,
*
this
,
true
}.
same_ndims
().
same_type
();
if
(
std
::
none_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
&
](
const
shape
&
s
)
{
return
s
.
dynamic
();
}))
{
MIGRAPHX_THROW
(
"CONCAT: Number of input tensors should exceed 0"
);
// Static input shapes
const
auto
&
first_shape_lens
=
inputs
.
front
().
lens
();
const
auto
&
type
=
inputs
.
front
().
type
();
for
(
std
::
size_t
ll
=
0
;
ll
<
first_shape_lens
.
size
();
ll
++
)
{
if
(
ll
!=
axis
)
{
if
(
not
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
&
](
auto
s
)
{
return
s
.
lens
()[
ll
]
==
first_shape_lens
[
ll
];
}))
{
MIGRAPHX_THROW
(
"CONCAT: all input dimensions should match along axis "
+
std
::
to_string
(
ll
));
}
}
}
std
::
size_t
new_dim_axis
=
0
;
for
(
const
auto
&
input
:
inputs
)
{
const
auto
&
lens
=
input
.
lens
();
new_dim_axis
+=
lens
[
axis
];
}
std
::
vector
<
std
::
size_t
>
new_lens
=
first_shape_lens
;
new_lens
[
axis
]
=
new_dim_axis
;
return
shape
::
from_permutation
(
type
,
new_lens
,
find_permutation
(
inputs
));
}
const
auto
&
first_shape_lens
=
inputs
.
front
().
lens
();
const
auto
&
type
=
inputs
.
front
().
type
();
for
(
std
::
size_t
l
=
0
;
l
<
first_shape_lens
.
size
();
l
++
)
else
if
(
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
&
](
const
shape
&
s
)
{
return
s
.
dynamic
();
}))
{
if
(
l
!=
axis
)
// Dynamic input shapes
for
(
std
::
size_t
index
=
0
;
index
<
inputs
[
0
].
ndim
();
index
++
)
{
if
(
not
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
&
](
auto
s
)
{
return
s
.
lens
()[
l
]
==
first_shape_lens
[
l
];
}))
if
(
index
!=
axis
)
{
MIGRAPHX_THROW
(
"CONCAT: Non-axis dimensions should match"
);
if
(
not
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
&
](
const
shape
&
s
)
{
return
s
.
dyn_dims
()[
index
]
==
inputs
[
0
].
dyn_dims
()[
index
];
}))
MIGRAPHX_THROW
(
"CONCAT: all input dimensions should match in axis "
+
std
::
to_string
(
index
));
}
}
std
::
size_t
new_min
=
0
;
std
::
size_t
new_max
=
0
;
for
(
const
auto
&
input
:
inputs
)
{
auto
ddim
=
input
.
dyn_dims
()[
axis
];
new_min
+=
ddim
.
min
;
new_max
+=
ddim
.
max
;
}
auto
new_dims
=
inputs
[
0
].
dyn_dims
();
new_dims
[
axis
]
=
migraphx
::
shape
::
dynamic_dimension
{
new_min
,
new_max
,
0
};
return
{
inputs
[
0
].
type
(),
new_dims
};
}
std
::
size_t
new_dim_axis
=
0
;
for
(
const
auto
&
input
:
inputs
)
else
{
const
auto
&
lens
=
input
.
lens
();
new_dim_axis
+=
lens
[
axis
];
MIGRAPHX_THROW
(
"CONCAT: Cannot mix static and dynamic input shapes."
);
}
std
::
vector
<
std
::
size_t
>
new_lens
;
std
::
copy
(
first_shape_lens
.
begin
(),
first_shape_lens
.
end
(),
std
::
back_inserter
(
new_lens
));
new_lens
[
axis
]
=
new_dim_axis
;
return
shape
::
from_permutation
(
type
,
new_lens
,
find_permutation
(
inputs
));
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
out
put_shape
};
std
::
vector
<
std
::
size_t
>
coffsets
=
compute_offsets
(
out
put_shape
,
args
);
argument
result
{
dyn_out
.
com
put
ed
_shape
};
std
::
vector
<
std
::
size_t
>
coffsets
=
compute_offsets
(
dyn_out
.
com
put
ed
_shape
,
args
);
for
(
std
::
size_t
l
=
0
;
l
<
args
.
size
();
l
++
)
{
auto
argl
=
args
[
l
];
visit_all
(
result
,
argl
)([
&
](
auto
output
,
auto
input
)
{
auto
slice_shape
=
shape
{
output_shape
.
type
(),
input
.
get_shape
().
lens
(),
output_shape
.
strides
()};
auto
slice
=
make_view
(
slice_shape
,
output
.
data
()
+
coffsets
[
l
]);
auto
slice_shape
=
shape
{
dyn_out
.
computed_shape
.
type
(),
input
.
get_shape
().
lens
(),
dyn_out
.
computed_shape
.
strides
()};
auto
slice
=
make_view
(
slice_shape
,
output
.
data
()
+
coffsets
[
l
]);
std
::
copy
(
input
.
begin
(),
input
.
end
(),
slice
.
begin
());
});
}
...
...
src/include/migraphx/op/nonmaxsuppression.hpp
View file @
3f322644
...
...
@@ -143,16 +143,22 @@ struct nonmaxsuppression
void
sort
()
{
std
::
sort
(
x
.
begin
(),
x
.
end
());
std
::
sort
(
y
.
begin
(),
y
.
end
());
if
(
x
[
0
]
>
x
[
1
])
{
std
::
swap
(
x
[
0
],
x
[
1
]);
}
if
(
y
[
0
]
>
y
[
1
])
{
std
::
swap
(
y
[
0
],
y
[
1
]);
}
}
std
::
array
<
double
,
2
>&
operator
[](
std
::
size_t
i
)
{
return
i
==
0
?
x
:
y
;
}
double
area
()
const
{
assert
(
std
::
is_sorted
(
x
.
begin
(),
x
.
end
())
);
assert
(
std
::
is_sorted
(
y
.
begin
(),
y
.
end
())
);
assert
(
x
[
0
]
<=
x
[
1
]
);
assert
(
y
[
0
]
<=
y
[
1
]
);
return
(
x
[
1
]
-
x
[
0
])
*
(
y
[
1
]
-
y
[
0
]);
}
};
...
...
@@ -190,14 +196,10 @@ struct nonmaxsuppression
{
intersection
[
i
][
0
]
=
std
::
max
(
b1
[
i
][
0
],
b2
[
i
][
0
]);
intersection
[
i
][
1
]
=
std
::
min
(
b1
[
i
][
1
],
b2
[
i
][
1
]);
}
std
::
vector
<
std
::
array
<
double
,
2
>>
bbox
=
{
intersection
.
x
,
intersection
.
y
};
if
(
std
::
any_of
(
bbox
.
begin
(),
bbox
.
end
(),
[](
auto
bx
)
{
return
not
std
::
is_sorted
(
bx
.
begin
(),
bx
.
end
());
}))
{
return
false
;
if
(
intersection
[
i
][
0
]
>
intersection
[
i
][
1
])
{
return
false
;
}
}
const
double
area1
=
b1
.
area
();
...
...
@@ -265,31 +267,31 @@ struct nonmaxsuppression
auto
batch_boxes_start
=
boxes
.
begin
()
+
batch_idx
*
num_boxes
*
4
;
auto
boxes_heap
=
filter_boxes_by_score
(
scores_start
,
num_boxes
,
score_threshold
);
selected_boxes_inside_class
.
clear
();
// Get the next box with top score, filter by iou_threshold
while
(
not
boxes_heap
.
empty
()
&&
selected_boxes_inside_class
.
size
()
<
max_output_boxes_per_class
)
{
//
Check with existing selected boxes for this class, remove box if it
//
exceeds the IOU (Intersection Over Union) threshold
//
select next top scorer box and remove any boxes from boxes_heap that exceeds IOU
//
threshold with the selected box
const
auto
next_top_score
=
boxes_heap
.
top
();
bool
not_selected
=
std
::
any_of
(
selected_boxes_inside_class
.
begin
(),
selected_boxes_inside_class
.
end
(),
[
&
](
auto
selected_index
)
{
return
this
->
suppress_by_iou
(
batch_box
(
batch_boxes_start
,
next_top_score
.
second
),
batch_box
(
batch_boxes_start
,
selected_index
.
second
),
iou_threshold
);
});
if
(
not
not_selected
)
boxes_heap
.
pop
();
selected_boxes_inside_class
.
push_back
(
next_top_score
);
selected_indices
.
push_back
(
batch_idx
);
selected_indices
.
push_back
(
class_idx
);
selected_indices
.
push_back
(
next_top_score
.
second
);
std
::
priority_queue
<
std
::
pair
<
double
,
int64_t
>>
remainder_boxes
;
while
(
not
boxes_heap
.
empty
())
{
selected_boxes_inside_class
.
push_back
(
next_top_score
);
selected_indices
.
push_back
(
batch_idx
);
selected_indices
.
push_back
(
class_idx
);
selected_indices
.
push_back
(
next_top_score
.
second
);
auto
iou_candidate_box
=
boxes_heap
.
top
();
if
(
not
this
->
suppress_by_iou
(
batch_box
(
batch_boxes_start
,
iou_candidate_box
.
second
),
batch_box
(
batch_boxes_start
,
next_top_score
.
second
),
iou_threshold
))
{
remainder_boxes
.
push
(
iou_candidate_box
);
}
boxes_heap
.
pop
();
}
boxes_heap
.
pop
()
;
boxes_heap
=
remainder_boxes
;
}
});
std
::
copy
(
selected_indices
.
begin
(),
selected_indices
.
end
(),
output
.
begin
());
...
...
src/include/migraphx/op/normalize_attribute.hpp
View file @
3f322644
...
...
@@ -31,18 +31,30 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
// different attributes
// 1) use_input(default)/use_output
// 2) use_rank(default)/use_len
// 3) clip_min(default)/not_clip_min
// 3.1) include_min(default)/exclude_min
// 4) clip_max(default)/not_clip_max
// 4.1) exclude_max(default)/include_max
// 5) normalize padding
/**
* `normalize_attribute` settings:
* Note that default options are not included as enums.
* 1. `use_input` (default) vs. `use_output`:
* Affects the rank of the attribute.
* `use_input -> lens.size()`, `use_output -> lens.size() + vec.size()`.
* 2. use_rank (default) vs use_len:
* `use_rank` sets the max value/index of the attribute as the rank of lens.
* `use_lens` sets the max value/index as the corresponding value in lens at the axes index.
* 3. `clip_min` vs. `not_clip_min` (default):
* Clip values less than the minimum to the minimum or not.
* 4. `include_min` vs. `exclude_min` (default):
* Include or exclude the minimum value/index for range checking and clipping.
* 5. `clip_max` vs. `not_clip_max` (default):
* Clip values greater than the maximum or not.
* 6. `include_max` vs. `exclude_max` (default):
* Include or exclude the maximum value/index for range checking and clipping.
* 7. `normalize_padding`:
* To normalize the padding to `2*(pad ndim)` dimensions.
*/
enum
class
normalize_attribute
{
use_len
,
use_output
,
use_len
,
clip_max
,
clip_min
,
include_max
,
...
...
src/include/migraphx/op/reverse.hpp
View file @
3f322644
...
...
@@ -28,6 +28,7 @@
#include <vector>
#include <cmath>
#include <utility>
#include <migraphx/check_shapes.hpp>
#include <migraphx/config.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/op/normalize_attribute.hpp>
...
...
@@ -60,6 +61,7 @@ struct reverse
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
return
inputs
[
0
].
with_lens
(
inputs
[
0
].
lens
());
}
...
...
src/include/migraphx/op/select_module.hpp
0 → 100644
View file @
3f322644
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_OPERATORS_SELECT_MODULE_HPP
#define MIGRAPHX_GUARD_OPERATORS_SELECT_MODULE_HPP
#include <migraphx/check_shapes.hpp>
#include <migraphx/module.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
select_module
{
shape
output_dyn_shapes
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
output_dyn_shapes
,
"output_dyn_shapes"
));
}
std
::
string
name
()
const
{
return
"select_module"
;
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
,
const
std
::
vector
<
module_ref
>&
)
const
{
check_shapes
{
inputs
,
*
this
,
true
}.
has_at_least
(
1
);
return
shape
{
output_dyn_shapes
};
}
std
::
vector
<
std
::
string
>
get_input_parameter_names
(
module_ref
mod
)
const
{
auto
param_names
=
mod
->
get_parameter_names
();
std
::
vector
<
std
::
string
>
ret
;
std
::
copy_if
(
param_names
.
cbegin
(),
param_names
.
cend
(),
std
::
back_inserter
(
ret
),
[](
auto
pn
)
{
return
not
contains
(
pn
,
"#output_"
);
});
return
ret
;
}
std
::
vector
<
std
::
string
>
get_output_parameter_names
(
module_ref
mod
)
const
{
auto
param_names
=
mod
->
get_parameter_names
();
std
::
vector
<
std
::
string
>
ret
;
std
::
copy_if
(
param_names
.
cbegin
(),
param_names
.
cend
(),
std
::
back_inserter
(
ret
),
[](
auto
pn
)
{
return
contains
(
pn
,
"#output_"
);
});
return
ret
;
}
argument
compute
(
const
shape
&
,
const
std
::
vector
<
argument
>&
args
,
const
std
::
vector
<
module_ref
>&
submodule_list
,
const
std
::
function
<
std
::
vector
<
argument
>
(
module_ref
&
,
const
std
::
unordered_map
<
std
::
string
,
argument
>&
)
>&
run
)
const
{
// Find submodule with input parameter shapes exactly the same as the input instruction
// arguments. Assuming instruction arguments are in the same order as the instruction
// parameters.
auto
module_iter
=
std
::
find_if
(
submodule_list
.
cbegin
(),
submodule_list
.
cend
(),
[
&
](
module_ref
mr
)
{
auto
in_param_names
=
get_input_parameter_names
(
mr
);
auto
param_shapes
=
mr
->
get_parameter_shapes
();
assert
(
in_param_names
.
size
()
<=
args
.
size
());
return
std
::
equal
(
in_param_names
.
cbegin
(),
in_param_names
.
cend
(),
args
.
cbegin
(),
[
&
](
auto
p_name
,
auto
a
)
{
return
a
.
get_shape
()
==
param_shapes
[
p_name
];
});
});
if
(
module_iter
==
submodule_list
.
end
())
{
MIGRAPHX_THROW
(
"SELECT_MODULE: no compatible submodules found for given input shapes"
);
}
auto
*
module_to_run
=
*
module_iter
;
std
::
unordered_map
<
std
::
string
,
argument
>
p_map
;
// add input parameters to parameter_map
auto
in_param_names
=
get_input_parameter_names
(
module_to_run
);
assert
(
in_param_names
.
size
()
<=
args
.
size
());
std
::
transform
(
in_param_names
.
begin
(),
in_param_names
.
end
(),
args
.
begin
(),
std
::
inserter
(
p_map
,
p_map
.
end
()),
[
&
](
auto
&&
name
,
auto
&&
a
)
{
return
std
::
make_pair
(
name
,
a
);
});
// One tuple output parameter in main module to multiple output parameters in submodule
auto
out_param_names
=
get_output_parameter_names
(
module_to_run
);
auto
output_sub_objects
=
args
.
back
().
get_sub_objects
();
assert
(
out_param_names
.
size
()
==
output_sub_objects
.
size
());
std
::
transform
(
out_param_names
.
begin
(),
out_param_names
.
end
(),
output_sub_objects
.
begin
(),
std
::
inserter
(
p_map
,
p_map
.
end
()),
[
&
](
auto
&&
name
,
auto
&&
a
)
{
auto
ps
=
module_to_run
->
get_parameter_shape
(
name
);
if
(
a
.
get_shape
()
!=
ps
)
{
assert
(
ps
.
bytes
()
==
a
.
get_shape
().
bytes
());
return
std
::
make_pair
(
name
,
a
.
reshape
(
ps
));
}
else
{
return
std
::
make_pair
(
name
,
a
);
}
});
auto
results
=
run
(
module_to_run
,
p_map
);
return
argument
{
results
};
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
shapes
)
const
{
return
shapes
.
size
()
-
1
;
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/slice.hpp
View file @
3f322644
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
...
...
@@ -27,6 +27,7 @@
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
...
...
@@ -46,6 +47,10 @@ struct slice
return
pack
(
f
(
self
.
axes
,
"axes"
),
f
(
self
.
starts
,
"starts"
),
f
(
self
.
ends
,
"ends"
));
}
/**
* Ensure that attribute vectors axes, starts, and ends are all the same size and values are in
* limits.
*/
value
attributes
()
const
{
value
normalize
=
value
::
object
{};
...
...
@@ -65,14 +70,6 @@ struct slice
std
::
string
name
()
const
{
return
"slice"
;
}
auto
fix_index
(
const
std
::
vector
<
std
::
size_t
>&
lens
,
std
::
size_t
axis
,
int64_t
index
)
const
{
int64_t
r
=
std
::
min
(
index
,
static_cast
<
int64_t
>
(
lens
[
axis
]));
if
(
r
<
0
)
r
+=
lens
[
axis
];
return
std
::
size_t
(
r
);
}
auto
compute_offset
(
const
shape
&
s
)
const
{
const
std
::
vector
<
std
::
size_t
>&
lens
=
s
.
lens
();
...
...
@@ -83,14 +80,14 @@ struct slice
for
(
std
::
size_t
i
=
0
;
i
<
axes
.
size
();
i
++
)
{
auto
axis
=
axes
[
i
];
offset
+=
fix_index
(
lens
,
axis
,
starts
[
i
]
)
*
strides
[
axis
];
offset
+=
starts
[
i
]
*
strides
[
axis
];
}
}
else
{
for
(
std
::
size_t
axis
=
0
;
axis
<
lens
.
size
();
axis
++
)
{
offset
+=
fix_index
(
lens
,
axis
,
starts
[
axis
]
)
*
strides
[
axis
];
offset
+=
starts
[
axis
]
*
strides
[
axis
];
}
}
return
offset
;
...
...
@@ -98,37 +95,81 @@ struct slice
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
auto
input_shape
=
inputs
[
0
];
auto
t
=
input_shape
.
type
();
const
auto
&
old_lens
=
input_shape
.
lens
();
const
auto
&
old_strides
=
input_shape
.
strides
();
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
);
auto
input_shape
=
inputs
[
0
];
auto
t
=
input_shape
.
type
();
if
(
std
::
any_of
(
axes
.
begin
(),
axes
.
end
(),
[
&
](
auto
i
)
{
return
(
i
>=
old_lens
.
size
()
and
i
<
0
);
}))
// TODO: When support for dynamic shapes is added to normalize_attributes,
// remove this restriction.
if
(
input_shape
.
dynamic
()
and
std
::
any_of
(
axes
.
begin
(),
axes
.
end
(),
[
&
](
auto
axis
)
{
return
not
input_shape
.
dyn_dims
()[
axis
].
is_fixed
();
}))
{
MIGRAPHX_THROW
(
"SLICE:
input axis "
+
to_string_range
(
axes
)
+
" out of range
"
);
MIGRAPHX_THROW
(
"SLICE:
slicing is not allowed on non-fixed dynamic input axis
"
);
}
if
(
starts
.
size
()
!=
axes
.
size
()
or
axes
.
size
()
!=
ends
.
size
())
// For a static shape, old_lens will be adjusted to a new size
// for those axes that are sliced.
// For dynamic shape, the adjusted old_lens become the new max values,
// while updating the old mins and opts if possible.
std
::
vector
<
std
::
size_t
>
new_mins
;
std
::
vector
<
std
::
size_t
>
new_opts
;
std
::
vector
<
std
::
size_t
>
old_lens
;
std
::
vector
<
std
::
size_t
>
old_strides
;
if
(
input_shape
.
dynamic
())
{
old_lens
=
input_shape
.
max_lens
();
new_mins
=
input_shape
.
min_lens
();
new_opts
=
input_shape
.
opt_lens
();
}
else
{
MIGRAPHX_THROW
(
"SLICE: inconsistent sizes"
);
old_lens
=
input_shape
.
lens
();
// For static shape (including during eval step after a dynamic input) the strides are
// indexed into the pre-slice array, so they are larger than the apparent size of the
// resulting shape.
old_strides
=
input_shape
.
strides
();
}
std
::
vector
<
std
::
size_t
>
new_lens
=
old_lens
;
for
(
std
::
size_t
i
=
0
;
i
<
axes
.
size
();
i
++
)
{
auto
axis
=
axes
[
i
];
new_lens
[
axis
]
=
fix_index
(
old_lens
,
axis
,
ends
[
i
])
-
fix_index
(
old_lens
,
axis
,
starts
[
i
]);
auto
axis
=
axes
[
i
];
size_t
sliced_length
=
ends
[
i
]
-
starts
[
i
];
// A Numpy indexing convention: a slice size larger than the actual dimension
// is legal and the "ends" value is clipped to the axis size
new_lens
[
axis
]
=
std
::
min
(
new_lens
[
axis
],
sliced_length
);
if
(
input_shape
.
dynamic
())
{
// TODO: when non-fixed shape slicing is allowed, this will be different than
// sliced_length, making use of TBD start/end values.
std
::
size_t
sliced_min_length
=
ends
[
i
]
-
starts
[
i
];
// if the slice size is smaller than maxes but larger than mins
new_mins
[
axis
]
=
std
::
min
(
sliced_min_length
,
new_mins
[
axis
]);
auto
sliced_opt_length
=
ends
[
i
]
-
starts
[
i
];
if
(
new_opts
[
axis
]
!=
0
)
new_opts
[
axis
]
=
sliced_opt_length
;
if
(
new_opts
[
axis
]
<
new_mins
[
axis
]
or
new_opts
[
axis
]
>
new_lens
[
axis
])
new_opts
[
axis
]
=
0
;
}
}
if
(
input_shape
.
dynamic
())
{
return
shape
{
t
,
new_mins
,
new_lens
,
new_opts
};
}
else
{
return
shape
{
t
,
new_lens
,
old_strides
};
}
return
shape
{
t
,
new_lens
,
old_strides
};
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
auto
input
=
args
[
0
];
auto
offset
=
compute_offset
(
input
.
get_shape
())
*
output_shape
.
type_size
();
return
{
std
::
move
(
output_shape
),
[
=
]
{
return
input
.
data
()
+
offset
;
}};
auto
input
=
args
[
0
];
auto
offset
=
compute_offset
(
input
.
get_shape
())
*
dyn_out
.
computed_shape
.
type_size
();
return
{
dyn_out
.
computed_shape
,
[
=
]
{
return
input
.
data
()
+
offset
;
}};
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
...
...
src/include/migraphx/op/where.hpp
View file @
3f322644
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
...
...
@@ -42,9 +42,17 @@ struct where
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
3
).
same_dims
();
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
3
).
same_dims
();
auto
s1
=
inputs
.
at
(
1
);
auto
s2
=
inputs
.
at
(
2
);
if
(
s1
.
dynamic
()
or
s2
.
dynamic
())
{
if
(
s1
==
s2
)
return
s1
;
MIGRAPHX_THROW
(
"WHERE: dynamic input shapes must be the same"
);
}
// Compare two static shapes, returning a standard shape
if
(
s1
==
s2
and
s1
.
packed
())
{
return
s1
;
...
...
@@ -63,12 +71,12 @@ struct where
}
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
out
put_shape
};
argument
result
{
dyn_out
.
com
put
ed
_shape
};
visit_all
(
result
,
args
[
1
],
args
[
2
])([
&
](
auto
output
,
const
auto
x
,
const
auto
y
)
{
args
[
0
].
visit
([
&
](
const
auto
condition
)
{
par_for
(
out
put_shape
.
elements
(),
par_for
(
dyn_out
.
com
put
ed
_shape
.
elements
(),
[
&
](
auto
i
)
{
output
[
i
]
=
condition
[
i
]
?
x
[
i
]
:
y
[
i
];
});
});
});
...
...
src/include/migraphx/operation.hpp
View file @
3f322644
...
...
@@ -140,6 +140,8 @@ template <class T>
auto
compute_shape_op
(
rank
<
2
>
,
const
T
&
x
,
const
std
::
vector
<
shape
>&
inputs
)
->
decltype
(
x
.
normalize_compute_shape
(
inputs
))
{
if
(
inputs
.
empty
())
MIGRAPHX_THROW
(
"At least one input is required for "
+
x
.
name
());
dependent_type
<
operation
,
T
>
y
=
x
;
normalize_attributes
(
y
,
inputs
[
0
].
max_lens
());
return
any_cast
<
T
>
(
y
).
normalize_compute_shape
(
inputs
);
...
...
src/include/migraphx/register_op.hpp
View file @
3f322644
...
...
@@ -33,15 +33,36 @@
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
// unregister all ops for specified target, useful when unloading dynamically plugged-in target lib
void
unregister_op
(
const
std
::
string
&
op_name
);
namespace
detail
{
struct
op_handler
{
operation
op
;
std
::
string
name
;
op_handler
(
const
operation
&
op_r
)
:
op
(
op_r
),
name
(
op
.
name
()){};
~
op_handler
()
{
unregister_op
(
name
);
}
};
}
// namespace detail
void
register_op_init
();
void
register_op
(
const
operation
&
op
);
operation
load_op
(
const
std
::
string
&
name
);
bool
has_op
(
const
std
::
string
&
name
);
std
::
vector
<
std
::
string
>
get_operators
();
template
<
class
T
>
void
register_op
()
{
register_op
(
T
{});
register_op_init
();
// instantiate static op_map;
static
auto
op_h
=
detail
::
op_handler
(
T
{});
register_op
(
op_h
.
op
);
}
struct
register_op_action
...
...
src/include/migraphx/register_target.hpp
View file @
3f322644
...
...
@@ -33,14 +33,28 @@
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
void
register_target_init
();
void
register_target
(
const
target
&
t
);
void
unregister_target
(
const
std
::
string
&
name
);
target
make_target
(
const
std
::
string
&
name
);
std
::
vector
<
std
::
string
>
get_targets
();
namespace
detail
{
struct
target_handler
{
target
t
;
std
::
string
target_name
;
target_handler
(
const
target
&
t_r
)
:
t
(
t_r
),
target_name
(
t
.
name
())
{}
~
target_handler
()
{
unregister_target
(
target_name
);
}
};
}
// namespace detail
template
<
class
T
>
void
register_target
()
{
register_target
(
T
{});
register_target_init
();
static
auto
t_h
=
detail
::
target_handler
(
T
{});
register_target
(
t_h
.
t
);
}
struct
register_target_action
...
...
src/include/migraphx/replace_allocate.hpp
View file @
3f322644
...
...
@@ -32,6 +32,9 @@ inline namespace MIGRAPHX_INLINE_NS {
struct
module
;
/**
* Replace `allocate` instructions with target allocations or output parameters.
*/
struct
replace_allocate
{
allocation_model
model
;
...
...
src/include/migraphx/serialize.hpp
View file @
3f322644
...
...
@@ -93,7 +93,7 @@ auto to_value_impl(rank<4>, const optional<T>& x)
{
value
result
{};
if
(
x
.
has_value
())
to_value
(
*
x
);
return
to_value
(
*
x
);
return
result
;
}
...
...
@@ -212,28 +212,22 @@ void from_value_impl(rank<6>, const value& v, optional<T>& x)
x
=
from_value
<
T
>
(
v
);
}
template
<
class
T
,
MIGRAPHX_REQUIRES
(
std
::
is_arithmetic
<
T
>{})
>
template
<
class
T
,
MIGRAPHX_REQUIRES
(
std
::
is_arithmetic
<
T
>{}
or
std
::
is_enum
<
T
>
{}
)
>
void
from_value_impl
(
rank
<
7
>
,
const
value
&
v
,
T
&
x
)
{
x
=
v
.
to
<
T
>
();
}
template
<
class
T
,
MIGRAPHX_REQUIRES
(
std
::
is_enum
<
T
>{})
>
void
from_value_impl
(
rank
<
8
>
,
const
value
&
v
,
T
&
x
)
{
x
=
v
.
to
<
T
>
();
}
inline
void
from_value_impl
(
rank
<
9
>
,
const
value
&
v
,
std
::
string
&
x
)
{
x
=
v
.
to
<
std
::
string
>
();
}
inline
void
from_value_impl
(
rank
<
8
>
,
const
value
&
v
,
std
::
string
&
x
)
{
x
=
v
.
to
<
std
::
string
>
();
}
template
<
class
T
>
auto
from_value_impl
(
rank
<
10
>
,
const
value
&
v
,
T
&
x
)
->
decltype
(
x
.
from_value
(
v
),
void
())
auto
from_value_impl
(
rank
<
9
>
,
const
value
&
v
,
T
&
x
)
->
decltype
(
x
.
from_value
(
v
),
void
())
{
x
.
from_value
(
v
);
}
template
<
class
T
>
auto
from_value_impl
(
rank
<
1
1
>
,
const
value
&
v
,
T
&
x
)
->
decltype
(
migraphx_from_value
(
v
,
x
),
void
())
auto
from_value_impl
(
rank
<
1
0
>
,
const
value
&
v
,
T
&
x
)
->
decltype
(
migraphx_from_value
(
v
,
x
),
void
())
{
migraphx_from_value
(
v
,
x
);
}
...
...
@@ -249,7 +243,7 @@ value to_value(const T& x)
template
<
class
T
>
void
from_value
(
const
value
&
v
,
T
&
x
)
{
detail
::
from_value_impl
(
rank
<
1
1
>
{},
v
,
x
);
detail
::
from_value_impl
(
rank
<
1
0
>
{},
v
,
x
);
}
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/shape.hpp
View file @
3f322644
...
...
@@ -243,6 +243,9 @@ struct shape
/// Return true if the shape is dynamic
bool
dynamic
()
const
;
/// Return true if this shape or any of the sub_shapes are dynamic
bool
any_of_dynamic
()
const
;
shape
normalize_standard
()
const
;
shape
with_lens
(
type_t
t
,
const
std
::
vector
<
std
::
size_t
>&
l
)
const
;
...
...
src/module.cpp
View file @
3f322644
...
...
@@ -166,6 +166,7 @@ void module::assign(const module& m)
auto
s
=
ins
->
get_shape
();
copy_ins
=
impl
->
insert
(
impl
->
instructions
.
end
(),
{
builtin
::
param
{
name
,
order
},
std
::
move
(
s
),
{}});
impl
->
nparams
++
;
}
else
if
(
ins
->
name
()
==
"@outline"
)
{
...
...
src/normalize_attributes.cpp
View file @
3f322644
...
...
@@ -30,13 +30,16 @@
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
// different attributes
// 1) use_input(default)/use_output
// 2) use_rank(default)/use_len
// 3) clip_min(default)/not_clip_min
// 3.1) include_min(default)/exclude_min
// 4) clip_max(default)/not_clip_max
// 4.1) exclude_max(default)/include_max
/**
* Parameters:
* vec: the vector attribute to normalize
* axes: the operator's axes attribute if it exists, empty otherwise
* val: the normalize_axes key and options. Ex: normalize["axes"] =
* value::array{normalize_attribute::include_min}; lens: shape dimensions passed when calling
* normalize_attributes(op&, lens)
*
* See normalize_attribute.hpp for explaining the options.
*/
auto
tune_attribute
(
const
std
::
vector
<
int64_t
>&
vec
,
const
std
::
vector
<
int64_t
>&
axes
,
const
value
&
val
,
...
...
@@ -151,6 +154,11 @@ auto tune_pad_attribute(const value& val)
return
result
;
}
/**
* Assumptions:
* Dimensions to pad start from the third dimension (index 2).
* Called by compute_shape_op() with the `lens` of the first input.
*/
bool
normalize_attributes
(
operation
&
op
,
const
std
::
vector
<
std
::
size_t
>&
lens
)
{
bool
tuned
=
false
;
...
...
@@ -158,9 +166,8 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens)
auto
val
=
op
.
to_value
();
if
(
attrs
.
contains
(
"normalize_padding"
))
{
auto
padding
=
val
.
at
(
attrs
.
at
(
"normalize_padding"
).
to
<
std
::
string
>
());
auto
padding_size
=
padding
.
size
();
// for now, assume the dimensions to pad start at dim 2
auto
padding
=
val
.
at
(
attrs
.
at
(
"normalize_padding"
).
to
<
std
::
string
>
());
auto
padding_size
=
padding
.
size
();
auto
padding_start
=
2
;
if
(
padding_size
==
2
*
(
lens
.
size
()
-
padding_start
))
...
...
src/onnx/parse_slice.cpp
View file @
3f322644
...
...
@@ -46,7 +46,7 @@ struct parse_slice : op_parser<parse_slice>
std
::
vector
<
int64_t
>
steps
;
// slice can have up to 5 inputs, we first check the 5th one
// to decide whether MIGRAPHX can handle this slice
// to decide whether MIGRAPHX can handle this slice
.
if
(
args
.
size
()
==
5
)
{
migraphx
::
argument
step_arg
=
args
.
back
()
->
eval
();
...
...
@@ -90,9 +90,10 @@ struct parse_slice : op_parser<parse_slice>
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
op
.
starts
));
});
}
// If axes arg is not given, the default is all of them.
if
(
op
.
axes
.
empty
())
{
std
::
vector
<
int64_t
>
axes
(
args
[
0
]
->
get_shape
().
lens
().
size
());
std
::
vector
<
int64_t
>
axes
(
args
[
0
]
->
get_shape
().
ndim
());
std
::
iota
(
axes
.
begin
(),
axes
.
end
(),
int64_t
{
0
});
op
.
axes
=
axes
;
}
...
...
@@ -103,6 +104,7 @@ struct parse_slice : op_parser<parse_slice>
assert
(
op
.
axes
.
size
()
==
op
.
starts
.
size
());
assert
(
op
.
axes
.
size
()
==
op
.
ends
.
size
());
// If any axes have negative step, prepare to add a "reverse" op
for
(
auto
i
:
range
(
steps
.
size
()))
{
if
(
steps
[
i
]
>=
0
)
...
...
@@ -117,7 +119,10 @@ struct parse_slice : op_parser<parse_slice>
auto
ins
=
info
.
add_instruction
(
op
,
args
[
0
]);
if
(
not
raxes
.
empty
())
{
ins
=
info
.
add_instruction
(
make_op
(
"reverse"
,
{{
"axes"
,
raxes
}}),
ins
);
}
// If any steps are other than default 1, add a "steps" op
if
(
std
::
any_of
(
steps
.
begin
(),
steps
.
end
(),
[](
auto
s
)
{
return
std
::
abs
(
s
)
!=
1
;
}))
{
std
::
vector
<
int64_t
>
nsteps
;
...
...
src/onnx/parse_where.cpp
View file @
3f322644
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
...
...
@@ -40,28 +40,44 @@ struct parse_where : op_parser<parse_where>
const
onnx_parser
::
node_info
&
info
,
std
::
vector
<
instruction_ref
>
args
)
const
{
auto
lens
=
compute_broadcasted_lens
(
args
[
0
]
->
get_shape
().
lens
(),
args
[
1
]
->
get_shape
().
lens
());
lens
=
compute_broadcasted_lens
(
lens
,
args
[
2
]
->
get_shape
().
lens
());
if
(
args
[
0
]
->
get_shape
().
lens
()
!=
lens
)
// TODO: broadcasting for dynamic shapes is only implemented
// for binary ops at time of writing, not ternary ops.
// When it becomes available, add multibroadcasting steps in the dynamic shape case.
// For now for dynamic shapes, just insert the Where op. All shapes must be the
// same for it to succeed.
if
(
std
::
all_of
(
args
.
begin
(),
args
.
end
(),
[](
auto
v
)
{
return
v
->
get_shape
().
dynamic
();
}))
{
args
[
0
]
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
lens
}}),
args
[
0
]);
return
info
.
add_instruction
(
make_op
(
"where"
),
args
[
0
],
args
[
1
],
args
[
2
]);
}
if
(
args
[
1
]
->
get_shape
().
lens
()
!=
lens
)
else
if
(
std
::
none_of
(
args
.
begin
(),
args
.
end
(),
[](
auto
v
)
{
return
v
->
get_shape
().
dynamic
();
})
)
{
args
[
1
]
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
lens
}}),
args
[
1
]);
}
// If shapes are static and any are broadcasted, insert multibroadcast ops
auto
lens
=
compute_broadcasted_lens
(
args
[
0
]
->
get_shape
().
lens
(),
args
[
1
]
->
get_shape
().
lens
());
lens
=
compute_broadcasted_lens
(
lens
,
args
[
2
]
->
get_shape
().
lens
());
if
(
args
[
0
]
->
get_shape
().
lens
()
!=
lens
)
{
args
[
0
]
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
lens
}}),
args
[
0
]);
}
if
(
args
[
2
]
->
get_shape
().
lens
()
!=
lens
)
{
args
[
2
]
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
lens
}}),
args
[
2
]);
}
if
(
args
[
1
]
->
get_shape
().
lens
()
!=
lens
)
{
args
[
1
]
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
lens
}}),
args
[
1
]);
}
if
(
args
[
2
]
->
get_shape
().
lens
()
!=
lens
)
{
args
[
2
]
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
lens
}}),
args
[
2
]);
}
return
info
.
add_instruction
(
make_op
(
"where"
),
args
[
0
],
args
[
1
],
args
[
2
]);
return
info
.
add_instruction
(
make_op
(
"where"
),
args
[
0
],
args
[
1
],
args
[
2
]);
}
else
MIGRAPHX_THROW
(
"PARSE_WHERE: doesn't support mixed static and dynamic shape inputs"
);
}
};
...
...
src/opt/memory_coloring_impl.cpp
deleted
100644 → 0
View file @
53aee707
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include "memory_coloring_impl.hpp"
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
void
memory_coloring_impl
::
run
()
{
// calc implicit depdendencies
mod_implicit_deps
=
p_mod
->
calc_implicit_deps
();
MIGRAPHX_DEBUG
(
dump
(
"---Before memory coloring---"
));
MIGRAPHX_DEBUG
(
dump_module
());
build
();
if
(
num_of_lives
!=
0
)
{
MIGRAPHX_DEBUG
(
dump_intervals
());
// Coloring
while
(
not
alloc_queue
.
empty
())
{
interval_ptr
interval
=
alloc_queue
.
top
();
allocate
(
interval
);
alloc_queue
.
pop
();
}
// rewrite happens after all modules are processed
rewrite
();
if
(
enable_verify
)
verify
();
}
}
bool
memory_coloring_impl
::
allocate
(
interval_ptr
interval
)
{
shape
s
=
interval
->
result
;
std
::
size_t
size
=
s
.
bytes
();
if
(
size
==
0
)
return
false
;
std
::
size_t
element_size
=
(
s
.
elements
()
==
0
?
4
:
(
size
/
s
.
elements
()));
live_range
&
segment
=
interval
->
segment
;
int
vn
=
segment
.
vn
;
std
::
priority_queue
<
live_range
*
,
std
::
vector
<
live_range
*>
,
ordering
>
conflict_queue
;
std
::
unordered_map
<
long
long
,
live_range
*>
offset2_live
;
offset2_live
.
clear
();
if
(
conflict_table
.
find
(
vn
)
!=
conflict_table
.
end
())
{
const
std
::
set
<
int
>&
vn_set
=
conflict_table
[
vn
];
for
(
const
auto
&
iter
:
vn_set
)
{
live_range
*
range
=
live_ranges
[
iter
];
long
long
offset
=
range
->
offset
;
if
(
offset
!=
invalid_offset
)
{
conflict_queue
.
push
(
range
);
if
(
offset2_live
.
find
(
offset
)
==
offset2_live
.
end
())
{
offset2_live
[
offset
]
=
range
;
}
else
{
live_range
*
prev
=
offset2_live
[
offset
];
assert
(
prev
->
offset
==
offset
);
if
(
prev
->
size
<
range
->
size
)
offset2_live
[
offset
]
=
range
;
}
}
}
}
std
::
size_t
offset
=
0
;
while
(
not
conflict_queue
.
empty
())
{
live_range
*
range
=
conflict_queue
.
top
();
std
::
size_t
iter_offset
=
range
->
offset
;
if
(
offset
>
iter_offset
)
{
offset
=
std
::
max
(
offset
,
iter_offset
+
range
->
size
);
}
else
if
(
offset2_live
[
iter_offset
]
==
range
)
{
if
((
iter_offset
>
offset
)
&&
(
iter_offset
-
offset
)
>=
size
)
{
break
;
}
offset
=
iter_offset
+
range
->
size
;
}
// alignment
if
((
offset
%
element_size
)
!=
0
)
offset
+=
(
element_size
-
(
offset
%
element_size
));
conflict_queue
.
pop
();
}
// when int8 type is used, the offset could be any number
// if not 4-byte aligned, miopen int8 convolution can crash
offset
=
(
offset
+
3
)
/
4
*
4
;
segment
.
offset
=
offset
;
MIGRAPHX_DEBUG
(
segment
.
dump
());
required_bytes
=
std
::
max
(
required_bytes
,
offset
+
segment
.
size
);
return
true
;
}
void
memory_coloring_impl
::
build
()
{
std
::
size_t
num_of_instrs
=
p_mod
->
size
();
if
(
num_of_instrs
==
0
)
return
;
auto
cur_points
=
num_of_instrs
*
2
;
instruction_ref
iter
=
p_mod
->
end
();
instruction_ref
begin
=
p_mod
->
begin
();
std
::
vector
<
instruction_ref
>
dead_instrs
;
std
::
set
<
int
>
live_set
;
// Build live intervals.
live_intervals
.
resize
(
num_of_instrs
);
do
{
iter
=
std
::
prev
(
iter
);
const
instruction
*
p_iter
=
&
(
*
iter
);
interval_ptr
def_interval
=
nullptr
;
bool
is_dead
=
false
;
if
(
instr2_live
.
find
(
p_iter
)
!=
instr2_live
.
end
())
{
def_interval
=
instr2_live
[
p_iter
];
bool
is_lit
=
is_literal
(
iter
);
if
(
is_allocate
(
iter
)
or
is_lit
)
{
live_range
&
range
=
def_interval
->
segment
;
def_interval
->
result
=
iter
->
get_shape
();
def_interval
->
is_literal
=
is_lit
;
range
.
begin
=
cur_points
;
def_interval
->
def_point
=
cur_points
;
range
.
size
=
(
iter
->
get_shape
()).
bytes
();
if
(
not
is_lit
or
unify_literals
)
alloc_queue
.
push
(
def_interval
);
live_set
.
erase
(
range
.
vn
);
}
}
else
if
(
not
is_param
(
iter
)
&&
not
is_outline
(
iter
)
&&
not
is_check_context
(
iter
))
{
is_dead
=
true
;
}
auto
inputs
=
iter
->
inputs
();
if
(
contains
(
mod_implicit_deps
,
iter
))
{
const
auto
&
impl_deps
=
mod_implicit_deps
.
at
(
iter
);
inputs
.
insert
(
inputs
.
end
(),
impl_deps
.
begin
(),
impl_deps
.
end
());
}
for
(
auto
&&
arg
:
inputs
)
{
if
(
not
p_mod
->
has_instruction
(
arg
))
continue
;
if
(
is_param
(
arg
)
or
is_outline
(
arg
))
{
if
(
is_output_param
(
arg
))
is_dead
=
false
;
if
(
def_interval
!=
nullptr
)
{
def_interval
->
is_live_on_entry
=
true
;
}
continue
;
}
const
instruction
*
p_arg
=
&
(
*
instruction
::
get_output_alias
(
arg
));
if
(
instr2_live
.
find
(
p_arg
)
==
instr2_live
.
end
())
{
// First time see a use, create a live interval.
int
id
=
num_of_lives
++
;
interval_ptr
interval
=
&
(
live_intervals
[
id
]);
interval
->
id
=
id
;
interval
->
segment
.
end
=
cur_points
;
interval
->
segment
.
vn
=
++
max_value_number
;
interval
->
add_use
(
cur_points
);
instr2_live
[
p_arg
]
=
interval
;
add_conflicts
(
live_set
,
max_value_number
);
live_set
.
insert
(
max_value_number
);
live_ranges
[
max_value_number
]
=
&
(
interval
->
segment
);
earliest_end_point
=
cur_points
;
if
(
latest_end_point
==
-
1
)
latest_end_point
=
cur_points
;
}
else
{
interval_ptr
interval
=
instr2_live
[
p_arg
];
interval
->
add_use
(
cur_points
);
assert
(
live_set
.
find
(
interval
->
id
)
!=
live_set
.
end
());
}
}
if
(
is_dead
)
dead_instrs
.
push_back
(
iter
);
cur_points
-=
2
;
}
while
(
iter
!=
begin
);
}
void
memory_coloring_impl
::
rewrite
()
{
std
::
vector
<
std
::
size_t
>
dims
;
dims
.
push_back
((
required_bytes
+
sizeof
(
float
)
-
1
)
/
sizeof
(
float
));
shape
s
=
{
shape
::
float_type
,
dims
};
instruction_ref
scratch_param
=
p_mod
->
add_parameter
(
"scratch"
,
s
);
for
(
auto
ins
:
iterator_for
(
*
p_mod
))
{
const
instruction
*
p_iter
=
&
(
*
ins
);
if
(
instr2_live
.
find
(
p_iter
)
!=
instr2_live
.
end
())
{
interval_ptr
interval
=
instr2_live
[
p_iter
];
if
(
interval
->
get_begin
()
==
invalid_offset
)
continue
;
if
(
not
unify_literals
&&
interval
->
is_literal
)
continue
;
std
::
size_t
offset
=
0
;
if
(
interval
->
get_offset
()
!=
invalid_offset
)
{
offset
=
interval
->
get_offset
();
}
else
{
assert
(
interval
->
result
.
bytes
()
==
0
);
}
if
(
is_allocate
(
ins
))
{
p_mod
->
replace_instruction
(
ins
,
make_op
(
"load"
,
{{
"shape"
,
to_value
(
ins
->
get_shape
())},
{
"offset"
,
offset
}}),
scratch_param
);
}
}
}
MIGRAPHX_DEBUG
(
dump
(
"---After rewrite---"
));
MIGRAPHX_DEBUG
(
dump_module
());
}
void
memory_coloring_impl
::
verify
()
{
if
(
num_of_lives
>
0
)
{
for
(
int
i
=
0
;
i
<
num_of_lives
;
++
i
)
{
const
live_interval
&
interval
=
live_intervals
[
i
];
const
live_range
&
segment
=
interval
.
segment
;
if
(
segment
.
begin
==
invalid_offset
)
{
// if(not interval.is_live_on_entry)
// MIGRAPHX_THROW("interval is not live on entry");
continue
;
}
if
(
segment
.
offset
==
invalid_offset
)
{
continue
;
}
int
vn
=
segment
.
vn
;
if
(
conflict_table
.
find
(
vn
)
!=
conflict_table
.
end
())
{
const
std
::
set
<
int
>&
vn_set
=
conflict_table
[
vn
];
for
(
const
auto
&
iter
:
vn_set
)
{
live_range
*
range
=
live_ranges
[
iter
];
if
(
range
->
offset
==
invalid_offset
)
continue
;
if
(
not
is_disjoin
(
*
range
,
segment
))
MIGRAPHX_THROW
(
"range and segment is not disjoined"
);
}
}
}
}
}
#ifdef MIGRAPHX_DEBUG_OPT
void
memory_coloring_impl
::
dump
(
const
std
::
string
&
str
)
{
std
::
cout
<<
str
<<
std
::
endl
;
}
void
memory_coloring_impl
::
dump_module
()
{
std
::
cout
<<
*
p_mod
<<
std
::
endl
;
}
void
memory_coloring_impl
::
dump_intervals
()
{
if
(
num_of_lives
>
0
)
{
std
::
cout
<<
"---live intervals ---"
<<
std
::
endl
;
for
(
int
i
=
0
;
i
<
num_of_lives
;
++
i
)
{
live_interval
&
interval
=
live_intervals
[
i
];
interval
.
dump
();
}
std
::
cout
<<
"---conflict table---"
<<
std
::
endl
;
for
(
int
i
=
0
;
i
<=
max_value_number
;
++
i
)
{
std
::
cout
<<
" segment:"
<<
i
;
std
::
cout
<<
" =>"
;
const
std
::
set
<
int
>&
table
=
conflict_table
[
i
];
for
(
const
auto
&
iter
:
table
)
{
std
::
cout
<<
(
iter
)
<<
","
;
}
}
std
::
cout
<<
std
::
endl
;
}
}
// map liveness tracking point to instruction enum.
static
int
get_ins_enum
(
int
x
)
{
if
(
x
>
0
)
{
return
(
x
/
2
)
-
1
;
}
else
return
invalid_offset
;
}
void
live_range
::
dump
()
{
std
::
cout
<<
" segment:"
<<
vn
;
std
::
cout
<<
" ["
<<
get_ins_enum
(
begin
)
<<
", "
<<
get_ins_enum
(
end
)
<<
"]"
;
if
(
offset
!=
invalid_offset
)
{
std
::
cout
<<
" mem:"
;
std
::
cout
<<
" ["
<<
offset
<<
","
<<
offset
+
size
-
1
<<
"]"
;
}
std
::
cout
<<
std
::
endl
;
}
void
live_interval
::
dump
()
{
std
::
cout
<<
"id:"
<<
id
;
segment
.
dump
();
std
::
cout
<<
" uses:"
;
for
(
const
auto
&
iter
:
use_points
)
{
std
::
cout
<<
" "
<<
get_ins_enum
(
iter
)
<<
","
;
}
std
::
cout
<<
" def:"
;
std
::
cout
<<
" "
<<
get_ins_enum
(
def_point
);
if
(
is_literal
)
std
::
cout
<<
" literal"
;
std
::
cout
<<
" "
<<
result
;
std
::
cout
<<
std
::
endl
;
}
#endif
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/opt/memory_coloring_impl.hpp
deleted
100644 → 0
View file @
53aee707
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_MEMORY_COLORING_IMPL_HPP
#define MIGRAPHX_GUARD_RTGLIB_MEMORY_COLORING_IMPL_HPP
#include <migraphx/program.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/pass_config.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/config.hpp>
#include <set>
#include <list>
#include <vector>
#include <queue>
#ifdef MIGRAPHX_DEBUG_OPT
#define MIGRAPHX_DEBUG(s) s
#else
#define MIGRAPHX_DEBUG(s)
#endif // MIGRAPHX_DEBUG_OPT
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
static
const
std
::
size_t
invalid_offset
=
std
::
numeric_limits
<
std
::
size_t
>::
max
();
struct
live_range
{
std
::
size_t
begin
;
// begin point in the instruction stream.
std
::
size_t
end
;
// end point in the instruction stream.
std
::
size_t
offset
;
// offset to base pointer of allocated memory trunk.
std
::
size_t
vn
;
// value number that identifies this live_range.
std
::
size_t
size
;
// size of required memory in bytes
#ifdef MIGRAPHX_DEBUG_OPT
void
dump
();
#endif
};
struct
live_interval
{
live_interval
()
:
segment
({
invalid_offset
,
invalid_offset
,
invalid_offset
,
invalid_offset
,
0
})
{
}
void
add_use
(
std
::
size_t
use
)
{
use_points
.
push_front
(
use
);
}
std
::
size_t
get_begin
()
const
{
return
segment
.
begin
;
}
std
::
size_t
get_end
()
const
{
return
segment
.
end
;
}
long
long
get_offset
()
const
{
return
segment
.
offset
;
}
#ifdef MIGRAPHX_DEBUG_OPT
void
dump
();
#endif
live_range
segment
;
std
::
size_t
id
=
invalid_offset
;
std
::
list
<
std
::
size_t
>
use_points
{};
std
::
size_t
def_point
=
invalid_offset
;
shape
result
{};
bool
is_literal
=
false
;
bool
is_live_on_entry
=
false
;
};
using
interval_ptr
=
live_interval
*
;
struct
memory_coloring_impl
{
memory_coloring_impl
(
module
*
p
,
std
::
string
alloc_op
,
bool
p_verify
)
:
p_mod
(
p
),
allocation_op
(
std
::
move
(
alloc_op
)),
enable_verify
(
p_verify
)
{
}
bool
allocate
(
interval_ptr
);
void
add_conflicts
(
const
std
::
set
<
int
>&
live_set
,
int
val
)
{
for
(
const
auto
&
iter
:
live_set
)
{
conflict_table
[
iter
].
insert
(
val
);
conflict_table
[
val
].
insert
(
iter
);
}
}
void
build
();
void
run
();
void
rewrite
();
private:
static
bool
is_param
(
const
instruction_ref
ins
)
{
return
ins
->
name
()
==
"@param"
;
}
static
bool
is_output_param
(
const
instruction_ref
ins
)
{
if
(
not
is_param
(
ins
))
return
false
;
auto
param_name
=
any_cast
<
builtin
::
param
>
(
ins
->
get_operator
()).
parameter
;
return
contains
(
param_name
,
"#output_"
);
}
bool
is_allocate
(
const
instruction_ref
ins
)
const
{
return
ins
->
name
()
==
allocation_op
;
}
static
bool
is_outline
(
const
instruction_ref
ins
)
{
return
ins
->
name
()
==
"@outline"
;
}
static
bool
is_literal
(
const
instruction_ref
ins
)
{
return
ins
->
name
()
==
"@literal"
;
}
static
bool
is_check_context
(
const
instruction_ref
ins
)
{
return
ins
->
name
()
==
"check_context"
;
}
static
bool
is_disjoin
(
const
live_range
&
range1
,
const
live_range
&
range2
)
{
if
((
range1
.
size
==
0
)
or
(
range2
.
size
==
0
))
return
false
;
auto
end1
=
range1
.
offset
+
range1
.
size
-
1
;
auto
end2
=
range2
.
offset
+
range2
.
size
-
1
;
return
((
end1
<
range2
.
offset
)
or
(
end2
<
range1
.
offset
));
}
void
verify
();
#ifdef MIGRAPHX_DEBUG_OPT
void
dump
(
const
std
::
string
&
);
void
dump_module
();
void
dump_intervals
();
#endif
struct
ordering
{
bool
operator
()(
const
interval_ptr
&
i1
,
const
interval_ptr
&
i2
)
const
{
auto
len1
=
i1
->
get_end
()
-
i1
->
get_begin
();
auto
len2
=
i2
->
get_end
()
-
i2
->
get_begin
();
if
(
len1
!=
len2
)
{
return
(
len1
<
len2
);
}
else
if
(
i1
->
result
.
bytes
()
!=
i2
->
result
.
bytes
())
{
return
(
i1
->
result
.
bytes
()
<
i2
->
result
.
bytes
());
}
else
{
return
i1
->
id
>
i2
->
id
;
}
}
bool
operator
()(
const
live_range
*
i1
,
const
live_range
*
i2
)
const
{
return
(
i1
->
offset
>
i2
->
offset
);
}
};
module
*
p_mod
;
std
::
unordered_map
<
const
instruction
*
,
interval_ptr
>
instr2_live
;
// universe of live intervals.
std
::
vector
<
live_interval
>
live_intervals
=
{};
// Map live range value number to live range.
std
::
unordered_map
<
int
,
live_range
*>
live_ranges
=
{};
// Map live range value number to a set of conflicting live ranges' value numbers.
std
::
unordered_map
<
int
,
std
::
set
<
int
>>
conflict_table
=
{};
// Priority queue for coloring.
std
::
priority_queue
<
interval_ptr
,
std
::
vector
<
interval_ptr
>
,
ordering
>
alloc_queue
{};
int
num_of_lives
=
0
;
int
max_value_number
=
-
1
;
std
::
size_t
required_bytes
=
0
;
// The earliest program point where an live interval ends.
int
earliest_end_point
=
-
1
;
// The latest program point where an live interval ends.
int
latest_end_point
=
-
1
;
// Whether to unify literals into coloring.
bool
unify_literals
=
false
;
std
::
string
allocation_op
{};
bool
enable_verify
;
ins_dep_map
mod_implicit_deps
;
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
Prev
1
2
3
4
5
6
7
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment