Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
30c49503
Commit
30c49503
authored
Mar 23, 2023
by
Khalique Ahmed
Browse files
manual merge
parents
870a396b
09aaa63e
Changes
202
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
681 additions
and
196 deletions
+681
-196
src/include/migraphx/instruction_ref.hpp
src/include/migraphx/instruction_ref.hpp
+2
-2
src/include/migraphx/match/layernorm.hpp
src/include/migraphx/match/layernorm.hpp
+30
-6
src/include/migraphx/memory_coloring.hpp
src/include/migraphx/memory_coloring.hpp
+3
-2
src/include/migraphx/op/allocate.hpp
src/include/migraphx/op/allocate.hpp
+1
-1
src/include/migraphx/op/concat.hpp
src/include/migraphx/op/concat.hpp
+64
-25
src/include/migraphx/op/gather.hpp
src/include/migraphx/op/gather.hpp
+40
-15
src/include/migraphx/op/gathernd.hpp
src/include/migraphx/op/gathernd.hpp
+88
-17
src/include/migraphx/op/nonmaxsuppression.hpp
src/include/migraphx/op/nonmaxsuppression.hpp
+33
-31
src/include/migraphx/op/normalize_attribute.hpp
src/include/migraphx/op/normalize_attribute.hpp
+21
-9
src/include/migraphx/op/reverse.hpp
src/include/migraphx/op/reverse.hpp
+2
-0
src/include/migraphx/op/scatternd_op.hpp
src/include/migraphx/op/scatternd_op.hpp
+66
-21
src/include/migraphx/op/select_module.hpp
src/include/migraphx/op/select_module.hpp
+146
-0
src/include/migraphx/op/slice.hpp
src/include/migraphx/op/slice.hpp
+69
-28
src/include/migraphx/op/where.hpp
src/include/migraphx/op/where.hpp
+13
-5
src/include/migraphx/operation.hpp
src/include/migraphx/operation.hpp
+2
-0
src/include/migraphx/optimize_module.hpp
src/include/migraphx/optimize_module.hpp
+16
-6
src/include/migraphx/register_op.hpp
src/include/migraphx/register_op.hpp
+22
-1
src/include/migraphx/register_target.hpp
src/include/migraphx/register_target.hpp
+15
-1
src/include/migraphx/replace_allocate.hpp
src/include/migraphx/replace_allocate.hpp
+3
-0
src/include/migraphx/serialize.hpp
src/include/migraphx/serialize.hpp
+45
-26
No files found.
src/include/migraphx/instruction_ref.hpp
View file @
30c49503
...
@@ -41,7 +41,7 @@ migraphx::instruction* as_address(const instruction_ref& ins) noexcept;
...
@@ -41,7 +41,7 @@ migraphx::instruction* as_address(const instruction_ref& ins) noexcept;
namespace
std
{
namespace
std
{
template
<
>
template
<
>
struct
hash
<
migraphx
::
instruction_ref
>
struct
hash
<
migraphx
::
instruction_ref
>
// NOLINT
{
{
using
argument_type
=
migraphx
::
instruction_ref
;
using
argument_type
=
migraphx
::
instruction_ref
;
using
result_type
=
std
::
size_t
;
using
result_type
=
std
::
size_t
;
...
@@ -52,7 +52,7 @@ struct hash<migraphx::instruction_ref>
...
@@ -52,7 +52,7 @@ struct hash<migraphx::instruction_ref>
};
};
template
<
>
template
<
>
struct
equal_to
<
migraphx
::
instruction_ref
>
struct
equal_to
<
migraphx
::
instruction_ref
>
// NOLINT
{
{
using
argument_type
=
migraphx
::
instruction_ref
;
using
argument_type
=
migraphx
::
instruction_ref
;
using
result_type
=
bool
;
using
result_type
=
bool
;
...
...
src/include/migraphx/match/layernorm.hpp
View file @
30c49503
...
@@ -36,22 +36,46 @@ template <class F>
...
@@ -36,22 +36,46 @@ template <class F>
struct
layernorm_matcher
struct
layernorm_matcher
{
{
F
f
;
F
f
;
auto
last_axis
()
const
{
return
make_basic_pred_matcher
([](
instruction_ref
ins
)
{
auto
v
=
ins
->
get_operator
().
to_value
();
if
(
not
v
.
contains
(
"axes"
))
return
false
;
auto
axes
=
v
[
"axes"
].
to_vector
<
std
::
size_t
>
();
if
(
axes
.
size
()
!=
1
)
return
false
;
return
axes
.
front
()
==
ins
->
inputs
().
front
()
->
get_shape
().
lens
().
size
()
-
1
;
});
}
auto
reduce_mean
()
const
{
return
f
(
"reduce_mean"
)(
last_axis
());
}
auto
x_minus_mean
()
const
auto
x_minus_mean
()
const
{
{
return
f
(
"sub"
)(
arg
(
0
)(
any
().
bind
(
"x"
)),
arg
(
1
)(
skip_broadcasts
(
f
(
"
reduce_mean
"
))));
return
f
(
"sub"
)(
arg
(
0
)(
any
().
bind
(
"x"
)),
arg
(
1
)(
skip_broadcasts
(
reduce_mean
(
))));
}
}
auto
variance
()
const
auto
variance
()
const
{
{
return
f
(
"reduce_mean"
)(
arg
(
0
)(
f
(
"pow"
)(
arg
(
0
)(
x_minus_mean
()),
arg
(
1
)(
has_value
(
2.0
f
)))));
return
reduce_mean
()(
arg
(
0
)(
any_of
(
f
(
"pow"
)(
arg
(
0
)(
x_minus_mean
()),
arg
(
1
)(
has_value
(
2.0
f
))),
f
(
"mul"
)(
arg
(
0
)(
x_minus_mean
()),
arg
(
1
)(
x_minus_mean
())),
f
(
"sqdiff"
)(
either_arg
(
0
,
1
)(
any
().
bind
(
"x"
),
skip_broadcasts
(
reduce_mean
()))))));
}
}
auto
layernorm_onnx
(
)
const
auto
sqrt_add_eps
(
const
std
::
string
&
name
)
const
{
{
return
f
(
"div"
)(
arg
(
0
)(
x_minus_mean
()),
auto
add_eps
=
f
(
"add"
)(
either_arg
(
0
,
1
)(
variance
(),
is_constant
().
bind
(
"eps"
)));
return
skip_broadcasts
(
f
(
name
)(
arg
(
0
)(
any_of
(
add_eps
,
variance
()))));
}
arg
(
1
)(
skip_broadcasts
(
f
(
"sqrt"
)(
arg
(
0
)(
auto
layernorm_onnx
()
const
f
(
"add"
)(
either_arg
(
0
,
1
)(
variance
(),
is_constant
().
bind
(
"eps"
))))))));
{
auto
div_sqrt
=
f
(
"div"
)(
arg
(
0
)(
x_minus_mean
()),
arg
(
1
)(
sqrt_add_eps
(
"sqrt"
)));
auto
mul_rsqrt
=
f
(
"mul"
)(
either_arg
(
0
,
1
)(
x_minus_mean
(),
sqrt_add_eps
(
"rsqrt"
)));
return
any
(
any_of
(
div_sqrt
,
mul_rsqrt
));
}
}
auto
matcher
()
const
{
return
layernorm_onnx
();
}
auto
matcher
()
const
{
return
layernorm_onnx
();
}
...
...
src/include/migraphx/memory_coloring.hpp
View file @
30c49503
...
@@ -33,13 +33,14 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -33,13 +33,14 @@ inline namespace MIGRAPHX_INLINE_NS {
struct
module
;
struct
module
;
/**
/**
* Remove memory allocations. It uses graph coloring to find memory allocations that can be reused.
* Remove multiple memory allocations using graph coloring to find memory allocations that can be
* reused.
*/
*/
struct
memory_coloring
struct
memory_coloring
{
{
std
::
string
allocation_op
{};
std
::
string
allocation_op
{};
bool
verify
=
false
;
bool
verify
=
false
;
std
::
string
name
()
const
{
return
"memory
coloring"
;
}
std
::
string
name
()
const
{
return
"memory
_
coloring"
;
}
void
apply
(
module
&
m
)
const
;
void
apply
(
module
&
m
)
const
;
};
};
...
...
src/include/migraphx/op/allocate.hpp
View file @
30c49503
...
@@ -44,7 +44,7 @@ struct allocate
...
@@ -44,7 +44,7 @@ struct allocate
std
::
string
name
()
const
{
return
"allocate"
;
}
std
::
string
name
()
const
{
return
"allocate"
;
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
{
migraphx
::
check_shapes
{
inputs
,
*
this
}.
has
(
0
);
migraphx
::
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
0
);
return
s
;
return
s
;
}
}
argument
compute
(
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
)
const
argument
compute
(
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
)
const
...
...
src/include/migraphx/op/concat.hpp
View file @
30c49503
...
@@ -26,6 +26,7 @@
...
@@ -26,6 +26,7 @@
#include <array>
#include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/literal.hpp>
...
@@ -73,24 +74,28 @@ struct concat
...
@@ -73,24 +74,28 @@ struct concat
}
}
return
offsets
;
return
offsets
;
}
}
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
if
(
inputs
.
empty
())
// inputs can contain 1 or more shapes (variadic). compute_shape_op ensures there must
{
// be at least 1.
MIGRAPHX_THROW
(
"CONCAT: Number of input tensors should exceed 0"
);
check_shapes
{
inputs
,
*
this
,
true
}.
same_ndims
().
same_type
();
}
if
(
std
::
none_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
&
](
const
shape
&
s
)
{
return
s
.
dynamic
();
}))
{
// Static input shapes
const
auto
&
first_shape_lens
=
inputs
.
front
().
lens
();
const
auto
&
first_shape_lens
=
inputs
.
front
().
lens
();
const
auto
&
type
=
inputs
.
front
().
type
();
const
auto
&
type
=
inputs
.
front
().
type
();
for
(
std
::
size_t
l
=
0
;
l
<
first_shape_lens
.
size
();
l
++
)
for
(
std
::
size_t
l
l
=
0
;
l
l
<
first_shape_lens
.
size
();
l
l
++
)
{
{
if
(
l
!=
axis
)
if
(
l
l
!=
axis
)
{
{
if
(
not
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
&
](
auto
s
)
{
if
(
not
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
&
](
auto
s
)
{
return
s
.
lens
()[
l
]
==
first_shape_lens
[
l
];
return
s
.
lens
()[
l
l
]
==
first_shape_lens
[
l
l
];
}))
}))
{
{
MIGRAPHX_THROW
(
"CONCAT: Non-axis dimensions should match"
);
MIGRAPHX_THROW
(
"CONCAT: all input dimensions should match along axis "
+
std
::
to_string
(
ll
));
}
}
}
}
}
}
...
@@ -100,21 +105,55 @@ struct concat
...
@@ -100,21 +105,55 @@ struct concat
const
auto
&
lens
=
input
.
lens
();
const
auto
&
lens
=
input
.
lens
();
new_dim_axis
+=
lens
[
axis
];
new_dim_axis
+=
lens
[
axis
];
}
}
std
::
vector
<
std
::
size_t
>
new_lens
;
std
::
vector
<
std
::
size_t
>
new_lens
=
first_shape_lens
;
std
::
copy
(
first_shape_lens
.
begin
(),
first_shape_lens
.
end
(),
std
::
back_inserter
(
new_lens
));
new_lens
[
axis
]
=
new_dim_axis
;
new_lens
[
axis
]
=
new_dim_axis
;
return
shape
::
from_permutation
(
type
,
new_lens
,
find_permutation
(
inputs
));
return
shape
::
from_permutation
(
type
,
new_lens
,
find_permutation
(
inputs
));
}
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
else
if
(
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
&
](
const
shape
&
s
)
{
return
s
.
dynamic
();
}))
{
// Dynamic input shapes
for
(
std
::
size_t
index
=
0
;
index
<
inputs
[
0
].
ndim
();
index
++
)
{
if
(
index
!=
axis
)
{
if
(
not
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
&
](
const
shape
&
s
)
{
return
s
.
dyn_dims
()[
index
]
==
inputs
[
0
].
dyn_dims
()[
index
];
}))
MIGRAPHX_THROW
(
"CONCAT: all input dimensions should match in axis "
+
std
::
to_string
(
index
));
}
}
std
::
size_t
new_min
=
0
;
std
::
size_t
new_max
=
0
;
for
(
const
auto
&
input
:
inputs
)
{
auto
ddim
=
input
.
dyn_dims
()[
axis
];
new_min
+=
ddim
.
min
;
new_max
+=
ddim
.
max
;
}
auto
new_dims
=
inputs
[
0
].
dyn_dims
();
new_dims
[
axis
]
=
migraphx
::
shape
::
dynamic_dimension
{
new_min
,
new_max
,
0
};
return
{
inputs
[
0
].
type
(),
new_dims
};
}
else
{
MIGRAPHX_THROW
(
"CONCAT: Cannot mix static and dynamic input shapes."
);
}
}
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
{
argument
result
{
out
put_shape
};
argument
result
{
dyn_out
.
com
put
ed
_shape
};
std
::
vector
<
std
::
size_t
>
coffsets
=
compute_offsets
(
out
put_shape
,
args
);
std
::
vector
<
std
::
size_t
>
coffsets
=
compute_offsets
(
dyn_out
.
com
put
ed
_shape
,
args
);
for
(
std
::
size_t
l
=
0
;
l
<
args
.
size
();
l
++
)
for
(
std
::
size_t
l
=
0
;
l
<
args
.
size
();
l
++
)
{
{
auto
argl
=
args
[
l
];
auto
argl
=
args
[
l
];
visit_all
(
result
,
argl
)([
&
](
auto
output
,
auto
input
)
{
visit_all
(
result
,
argl
)([
&
](
auto
output
,
auto
input
)
{
auto
slice_shape
=
auto
slice_shape
=
shape
{
dyn_out
.
computed_shape
.
type
(),
shape
{
output_shape
.
type
(),
input
.
get_shape
().
lens
(),
output_shape
.
strides
()};
input
.
get_shape
().
lens
(),
dyn_out
.
computed_shape
.
strides
()};
auto
slice
=
make_view
(
slice_shape
,
output
.
data
()
+
coffsets
[
l
]);
auto
slice
=
make_view
(
slice_shape
,
output
.
data
()
+
coffsets
[
l
]);
std
::
copy
(
input
.
begin
(),
input
.
end
(),
slice
.
begin
());
std
::
copy
(
input
.
begin
(),
input
.
end
(),
slice
.
begin
());
});
});
...
...
src/include/migraphx/op/gather.hpp
View file @
30c49503
...
@@ -26,6 +26,7 @@
...
@@ -26,6 +26,7 @@
#include <array>
#include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/literal.hpp>
...
@@ -61,13 +62,36 @@ struct gather
...
@@ -61,13 +62,36 @@ struct gather
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
2
);
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
2
);
auto
lens
=
inputs
[
0
].
lens
();
shape
data
=
inputs
[
0
];
auto
type
=
inputs
[
0
].
type
();
shape
indices
=
inputs
[
1
];
auto
type
=
data
.
type
();
// If index_dims is dynamic, convert the data to dynamic too.
if
(
indices
.
dynamic
())
{
data
=
data
.
to_dynamic
();
}
if
(
data
.
dynamic
())
{
auto
dims
=
data
.
dyn_dims
();
dims
.
erase
(
dims
.
begin
()
+
axis
);
if
(
not
indices
.
scalar
())
{
auto
index_dims
=
indices
.
to_dynamic
().
dyn_dims
();
dims
.
insert
(
dims
.
begin
()
+
axis
,
index_dims
.
begin
(),
index_dims
.
end
());
}
return
{
type
,
dims
};
}
else
{
// Both data and indices are static. indices may be scalar
auto
lens
=
data
.
lens
();
lens
.
erase
(
lens
.
begin
()
+
axis
);
lens
.
erase
(
lens
.
begin
()
+
axis
);
if
(
not
inputs
[
1
].
scalar
())
if
(
not
indices
.
scalar
())
{
{
auto
ind_lens
=
in
puts
[
1
]
.
lens
();
auto
ind_lens
=
in
dices
.
lens
();
lens
.
insert
(
lens
.
begin
()
+
axis
,
ind_lens
.
begin
(),
ind_lens
.
end
());
lens
.
insert
(
lens
.
begin
()
+
axis
,
ind_lens
.
begin
(),
ind_lens
.
end
());
}
}
...
@@ -79,17 +103,18 @@ struct gather
...
@@ -79,17 +103,18 @@ struct gather
return
{
type
,
lens
};
return
{
type
,
lens
};
}
}
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
{
argument
result
{
out
put_shape
};
argument
result
{
dyn_out
.
com
put
ed
_shape
};
// negative axis means counting dimensions from back
// negative axis means counting dimensions from back
auto
lens
=
args
[
0
].
get_shape
().
lens
();
auto
lens
=
args
[
0
].
get_shape
().
lens
();
std
::
size_t
axis_dim_size
=
lens
[
axis
];
std
::
size_t
axis_dim_size
=
lens
[
axis
];
// max dimension in axis
// max dimension in axis
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
data
)
{
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
data
)
{
args
[
1
].
visit
([
&
](
auto
indices
)
{
args
[
1
].
visit
([
&
](
auto
indices
)
{
if
(
out
put_shape
.
scalar
())
if
(
dyn_out
.
com
put
ed
_shape
.
scalar
())
{
{
auto
in_index
=
indices
.
front
();
auto
in_index
=
indices
.
front
();
in_index
=
(
in_index
<
0
)
?
in_index
+
axis_dim_size
:
in_index
;
in_index
=
(
in_index
<
0
)
?
in_index
+
axis_dim_size
:
in_index
;
...
...
src/include/migraphx/op/gathernd.hpp
View file @
30c49503
/*
/*
* The MIT License (MIT)
* The MIT License (MIT)
*
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* of this software and associated documentation files (the "Software"), to deal
...
@@ -25,6 +25,7 @@
...
@@ -25,6 +25,7 @@
#define MIGRAPHX_GUARD_OPERATORS_GATHERND_HPP
#define MIGRAPHX_GUARD_OPERATORS_GATHERND_HPP
#include <migraphx/check_shapes.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/argument.hpp>
...
@@ -47,33 +48,103 @@ struct gathernd
...
@@ -47,33 +48,103 @@ struct gathernd
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
2
);
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
2
);
auto
r
=
inputs
.
front
().
lens
().
size
();
auto
i_shape
=
inputs
.
back
();
auto
q
=
inputs
.
back
().
lens
().
size
();
auto
data_shape
=
inputs
.
front
();
auto
k
=
inputs
.
back
().
lens
().
back
();
auto
r
=
data_shape
.
ndim
();
auto
q
=
i_shape
.
ndim
();
size_t
k
;
if
(
i_shape
.
dynamic
())
{
// the rank of the output is a function of k, so it must be fixed.
if
(
not
i_shape
.
dyn_dims
().
back
().
is_fixed
())
{
MIGRAPHX_THROW
(
"GATHERND: last dimension of indices tensor must be fixed (min=max)"
);
}
k
=
i_shape
.
dyn_dims
().
back
().
min
;
}
else
k
=
i_shape
.
lens
().
back
();
// Begin input validation checks.
int
output_ndim
=
int
(
q
)
+
r
-
k
-
batch_dims
-
1
;
if
(
k
>
r
-
batch_dims
)
if
(
k
>
r
-
batch_dims
)
{
{
MIGRAPHX_THROW
(
"GATHERND: Indices of length "
+
std
::
to_string
(
k
)
+
MIGRAPHX_THROW
(
"GATHERND: Indices of length "
+
std
::
to_string
(
k
)
+
" cannot be used to access data of rank "
+
" cannot be used to access data of rank "
+
std
::
to_string
(
r
-
batch_dims
));
std
::
to_string
(
r
-
batch_dims
));
}
}
auto
indices_lens_iter
=
inputs
.
back
().
lens
().
begin
();
auto
output_lens_size
=
q
+
r
-
k
-
batch_dims
-
1
;
if
(
batch_dims
>=
q
or
batch_dims
>=
r
)
std
::
vector
<
std
::
size_t
>
output_lens
(
output_lens_size
);
{
MIGRAPHX_THROW
(
"GATHERND: rank of an input cannot be less than batch_dims="
+
std
::
to_string
(
batch_dims
));
}
if
(
output_ndim
<
0
)
{
MIGRAPHX_THROW
(
"GATHERND: Indices too large for static data input: k="
+
std
::
to_string
(
k
));
}
if
(
migraphx
::
none_of
(
inputs
,
[](
auto
v
)
{
return
v
.
dynamic
();
}))
{
auto
indices_lens_iter
=
i_shape
.
lens
().
begin
();
// A rank 0 output is a scalar
if
(
output_ndim
==
0
)
return
shape
{
data_shape
.
type
(),
{
1
}};
// Part of the output shape comes from indices tensor, part from data tensor
std
::
vector
<
std
::
size_t
>
output_lens
(
output_ndim
);
std
::
copy
(
indices_lens_iter
,
indices_lens_iter
+
(
q
-
1
),
output_lens
.
begin
());
std
::
copy
(
indices_lens_iter
,
indices_lens_iter
+
(
q
-
1
),
output_lens
.
begin
());
if
(
k
<
r
-
batch_dims
)
// fill the rest of output shape from data tensor
if
(
k
+
batch_dims
<
r
)
{
{
auto
data_lens
=
inputs
.
front
().
lens
();
auto
data_lens
=
data_shape
.
lens
();
std
::
copy
(
std
::
copy
(
data_lens
.
begin
()
+
batch_dims
+
k
,
data_lens
.
begin
()
+
batch_dims
+
k
,
data_lens
.
end
(),
output_lens
.
begin
()
+
q
-
1
);
data_lens
.
end
(),
output_lens
.
begin
()
+
q
-
1
);
}
}
shape
output_shape
{
inputs
.
front
()
.
type
(),
output_lens
};
shape
output_shape
{
data_shape
.
type
(),
output_lens
};
return
output_shape
;
return
output_shape
;
}
}
else
{
// If one or both inputs are dynamic shapes, the output is dynamic.
// Make both inputs dynamic to simplify computations.
data_shape
=
data_shape
.
to_dynamic
();
i_shape
=
i_shape
.
to_dynamic
();
// A rank 0 output is a scalar
if
(
output_ndim
==
0
)
return
shape
(
data_shape
.
type
(),
{
shape
::
dynamic_dimension
({
1
,
1
,
0
})});
// Part of the output shape comes from indices tensor, part from data tensor
std
::
vector
<
shape
::
dynamic_dimension
>
output_dims
(
output_ndim
);
std
::
copy
(
i_shape
.
dyn_dims
().
begin
(),
i_shape
.
dyn_dims
().
begin
()
+
q
-
1
,
output_dims
.
begin
());
// fill the rest of output shape from data tensor
if
(
k
+
batch_dims
<
r
)
{
auto
data_dims
=
data_shape
.
dyn_dims
();
std
::
copy
(
data_dims
.
begin
()
+
batch_dims
+
k
,
data_dims
.
begin
()
+
r
,
output_dims
.
begin
()
+
q
-
1
);
}
shape
output_shape
(
data_shape
.
type
(),
output_dims
);
return
output_shape
;
}
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
{
argument
result
{
out
put_shape
};
argument
result
{
dyn_out
.
com
put
ed
_shape
};
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
data
)
{
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
data
)
{
args
[
1
].
visit
([
&
](
auto
indices
)
{
args
[
1
].
visit
([
&
](
auto
indices
)
{
auto
indices_shape
=
indices
.
get_shape
();
auto
indices_shape
=
indices
.
get_shape
();
...
...
src/include/migraphx/op/nonmaxsuppression.hpp
View file @
30c49503
...
@@ -143,16 +143,22 @@ struct nonmaxsuppression
...
@@ -143,16 +143,22 @@ struct nonmaxsuppression
void
sort
()
void
sort
()
{
{
std
::
sort
(
x
.
begin
(),
x
.
end
());
if
(
x
[
0
]
>
x
[
1
])
std
::
sort
(
y
.
begin
(),
y
.
end
());
{
std
::
swap
(
x
[
0
],
x
[
1
]);
}
if
(
y
[
0
]
>
y
[
1
])
{
std
::
swap
(
y
[
0
],
y
[
1
]);
}
}
}
std
::
array
<
double
,
2
>&
operator
[](
std
::
size_t
i
)
{
return
i
==
0
?
x
:
y
;
}
std
::
array
<
double
,
2
>&
operator
[](
std
::
size_t
i
)
{
return
i
==
0
?
x
:
y
;
}
double
area
()
const
double
area
()
const
{
{
assert
(
std
::
is_sorted
(
x
.
begin
(),
x
.
end
())
);
assert
(
x
[
0
]
<=
x
[
1
]
);
assert
(
std
::
is_sorted
(
y
.
begin
(),
y
.
end
())
);
assert
(
y
[
0
]
<=
y
[
1
]
);
return
(
x
[
1
]
-
x
[
0
])
*
(
y
[
1
]
-
y
[
0
]);
return
(
x
[
1
]
-
x
[
0
])
*
(
y
[
1
]
-
y
[
0
]);
}
}
};
};
...
@@ -190,15 +196,11 @@ struct nonmaxsuppression
...
@@ -190,15 +196,11 @@ struct nonmaxsuppression
{
{
intersection
[
i
][
0
]
=
std
::
max
(
b1
[
i
][
0
],
b2
[
i
][
0
]);
intersection
[
i
][
0
]
=
std
::
max
(
b1
[
i
][
0
],
b2
[
i
][
0
]);
intersection
[
i
][
1
]
=
std
::
min
(
b1
[
i
][
1
],
b2
[
i
][
1
]);
intersection
[
i
][
1
]
=
std
::
min
(
b1
[
i
][
1
],
b2
[
i
][
1
]);
}
if
(
intersection
[
i
][
0
]
>
intersection
[
i
][
1
])
std
::
vector
<
std
::
array
<
double
,
2
>>
bbox
=
{
intersection
.
x
,
intersection
.
y
};
if
(
std
::
any_of
(
bbox
.
begin
(),
bbox
.
end
(),
[](
auto
bx
)
{
return
not
std
::
is_sorted
(
bx
.
begin
(),
bx
.
end
());
}))
{
{
return
false
;
return
false
;
}
}
}
const
double
area1
=
b1
.
area
();
const
double
area1
=
b1
.
area
();
const
double
area2
=
b2
.
area
();
const
double
area2
=
b2
.
area
();
...
@@ -265,32 +267,32 @@ struct nonmaxsuppression
...
@@ -265,32 +267,32 @@ struct nonmaxsuppression
auto
batch_boxes_start
=
boxes
.
begin
()
+
batch_idx
*
num_boxes
*
4
;
auto
batch_boxes_start
=
boxes
.
begin
()
+
batch_idx
*
num_boxes
*
4
;
auto
boxes_heap
=
filter_boxes_by_score
(
scores_start
,
num_boxes
,
score_threshold
);
auto
boxes_heap
=
filter_boxes_by_score
(
scores_start
,
num_boxes
,
score_threshold
);
selected_boxes_inside_class
.
clear
();
selected_boxes_inside_class
.
clear
();
// Get the next box with top score, filter by iou_threshold
while
(
not
boxes_heap
.
empty
()
&&
while
(
not
boxes_heap
.
empty
()
&&
selected_boxes_inside_class
.
size
()
<
max_output_boxes_per_class
)
selected_boxes_inside_class
.
size
()
<
max_output_boxes_per_class
)
{
{
//
Check with existing selected boxes for this class, remove box if it
//
select next top scorer box and remove any boxes from boxes_heap that exceeds IOU
//
exceeds the IOU (Intersection Over Union) threshold
//
threshold with the selected box
const
auto
next_top_score
=
boxes_heap
.
top
();
const
auto
next_top_score
=
boxes_heap
.
top
();
bool
not_selected
=
boxes_heap
.
pop
();
std
::
any_of
(
selected_boxes_inside_class
.
begin
(),
selected_boxes_inside_class
.
end
(),
[
&
](
auto
selected_index
)
{
return
this
->
suppress_by_iou
(
batch_box
(
batch_boxes_start
,
next_top_score
.
second
),
batch_box
(
batch_boxes_start
,
selected_index
.
second
),
iou_threshold
);
});
if
(
not
not_selected
)
{
selected_boxes_inside_class
.
push_back
(
next_top_score
);
selected_boxes_inside_class
.
push_back
(
next_top_score
);
selected_indices
.
push_back
(
batch_idx
);
selected_indices
.
push_back
(
batch_idx
);
selected_indices
.
push_back
(
class_idx
);
selected_indices
.
push_back
(
class_idx
);
selected_indices
.
push_back
(
next_top_score
.
second
);
selected_indices
.
push_back
(
next_top_score
.
second
);
std
::
priority_queue
<
std
::
pair
<
double
,
int64_t
>>
remainder_boxes
;
while
(
not
boxes_heap
.
empty
())
{
auto
iou_candidate_box
=
boxes_heap
.
top
();
if
(
not
this
->
suppress_by_iou
(
batch_box
(
batch_boxes_start
,
iou_candidate_box
.
second
),
batch_box
(
batch_boxes_start
,
next_top_score
.
second
),
iou_threshold
))
{
remainder_boxes
.
push
(
iou_candidate_box
);
}
}
boxes_heap
.
pop
();
boxes_heap
.
pop
();
}
}
boxes_heap
=
remainder_boxes
;
}
});
});
std
::
copy
(
selected_indices
.
begin
(),
selected_indices
.
end
(),
output
.
begin
());
std
::
copy
(
selected_indices
.
begin
(),
selected_indices
.
end
(),
output
.
begin
());
return
selected_indices
.
size
()
/
3
;
return
selected_indices
.
size
()
/
3
;
...
...
src/include/migraphx/op/normalize_attribute.hpp
View file @
30c49503
...
@@ -31,18 +31,30 @@ namespace migraphx {
...
@@ -31,18 +31,30 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
namespace
op
{
// different attributes
/**
// 1) use_input(default)/use_output
* `normalize_attribute` settings:
// 2) use_rank(default)/use_len
* Note that default options are not included as enums.
// 3) clip_min(default)/not_clip_min
* 1. `use_input` (default) vs. `use_output`:
// 3.1) include_min(default)/exclude_min
* Affects the rank of the attribute.
// 4) clip_max(default)/not_clip_max
* `use_input -> lens.size()`, `use_output -> lens.size() + vec.size()`.
// 4.1) exclude_max(default)/include_max
* 2. use_rank (default) vs use_len:
// 5) normalize padding
* `use_rank` sets the max value/index of the attribute as the rank of lens.
* `use_lens` sets the max value/index as the corresponding value in lens at the axes index.
* 3. `clip_min` vs. `not_clip_min` (default):
* Clip values less than the minimum to the minimum or not.
* 4. `include_min` vs. `exclude_min` (default):
* Include or exclude the minimum value/index for range checking and clipping.
* 5. `clip_max` vs. `not_clip_max` (default):
* Clip values greater than the maximum or not.
* 6. `include_max` vs. `exclude_max` (default):
* Include or exclude the maximum value/index for range checking and clipping.
* 7. `normalize_padding`:
* To normalize the padding to `2*(pad ndim)` dimensions.
*/
enum
class
normalize_attribute
enum
class
normalize_attribute
{
{
use_len
,
use_output
,
use_output
,
use_len
,
clip_max
,
clip_max
,
clip_min
,
clip_min
,
include_max
,
include_max
,
...
...
src/include/migraphx/op/reverse.hpp
View file @
30c49503
...
@@ -28,6 +28,7 @@
...
@@ -28,6 +28,7 @@
#include <vector>
#include <vector>
#include <cmath>
#include <cmath>
#include <utility>
#include <utility>
#include <migraphx/check_shapes.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/op/normalize_attribute.hpp>
...
@@ -60,6 +61,7 @@ struct reverse
...
@@ -60,6 +61,7 @@ struct reverse
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
return
inputs
[
0
].
with_lens
(
inputs
[
0
].
lens
());
return
inputs
[
0
].
with_lens
(
inputs
[
0
].
lens
());
}
}
...
...
src/include/migraphx/op/scatternd_op.hpp
View file @
30c49503
...
@@ -28,44 +28,89 @@
...
@@ -28,44 +28,89 @@
#include <migraphx/check_shapes.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/ranges.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
namespace
op
{
/**
* @brief
* N-dimensional Scatter operations. This struct is parent class to ops which differ in what formula
* is used to reduce (combine old and new values of) the scattered value. It was originally based
* on Onnx ScatterND operation (see
* https://github.com/onnx/onnx/blob/main/docs/Operators.md#ScatterND) and is also similar to Numpy
* numpy.add.at().
*
* @tparam Derived a template parameter in the CRTP inheritance idiom, represents one of the child
* operations.
*/
template
<
class
Derived
>
template
<
class
Derived
>
struct
scatternd_op
:
op_name
<
Derived
>
struct
scatternd_op
:
op_name
<
Derived
>
{
{
/** Validate input shapes and return the correct output shape. For Scatter ops, the output
* is the same shape as the data tensor (first input), but cast to a standard shape.
*
*/
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
3
);
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
3
);
auto
r
=
inputs
.
front
().
lens
().
size
();
auto
data_shape
=
inputs
.
front
();
auto
q
=
inputs
.
at
(
1
).
lens
().
size
();
auto
index_shape
=
inputs
.
at
(
1
);
auto
k
=
inputs
.
at
(
1
).
lens
().
back
();
auto
upd_shape
=
inputs
.
back
();
auto
ind_lens
=
inputs
.
at
(
1
).
lens
();
auto
upd_lens
=
inputs
.
back
().
lens
();
auto
r
=
data_shape
.
ndim
();
auto
data_lens
=
inputs
.
front
().
lens
();
auto
q
=
index_shape
.
ndim
();
size_t
k
;
if
(
index_shape
.
dynamic
())
{
// the rank of the output is a function of k, so k must be fixed.
if
(
not
index_shape
.
dyn_dims
().
back
().
is_fixed
())
{
MIGRAPHX_THROW
(
"GATHERND: last dimension of indices tensor must be fixed (min=max)"
);
}
k
=
index_shape
.
dyn_dims
().
back
().
min
;
}
else
k
=
index_shape
.
lens
().
back
();
// Checks on the sizes of input tensors
if
(
q
+
r
!=
upd_shape
.
ndim
()
+
k
+
1
)
MIGRAPHX_THROW
(
"ScatterND: ranks of inputs don't match. "
+
std
::
to_string
(
q
)
+
" + "
+
std
::
to_string
(
r
)
+
" - "
+
std
::
to_string
(
k
)
+
" - 1 != "
+
std
::
to_string
(
upd_shape
.
ndim
()));
if
(
k
>
r
)
if
(
k
>
r
)
MIGRAPHX_THROW
(
"ScatterND: index of size "
+
std
::
to_string
(
k
)
+
MIGRAPHX_THROW
(
"ScatterND: index of size "
+
std
::
to_string
(
k
)
+
" is too large for tensor of rank "
+
std
::
to_string
(
r
));
" is too large for tensor of rank "
+
std
::
to_string
(
r
));
if
(
not
(
std
::
equal
(
ind_lens
.
begin
(),
ind_lens
.
begin
()
+
q
-
1
,
upd_lens
.
begin
())
and
std
::
equal
(
data_lens
.
begin
()
+
k
,
data_lens
.
end
(),
upd_lens
.
begin
()
+
q
-
1
)))
// Convert all static shape dimensions to dynamic so they can be compared.
MIGRAPHX_THROW
(
"ScatterND: incorrect update shape. update.lens != indices.lens[0:q-1] "
// It's possible for some of the 3 inputs to be dynamic shapes and some static,
"++ data.lens[k:r-1]"
);
// but any dynamic dimension that's compared to a static dimension must be fixed.
auto
s
=
inputs
.
front
();
auto
ind_dims
=
index_shape
.
to_dynamic
().
dyn_dims
();
if
(
s
.
broadcasted
())
auto
upd_dims
=
upd_shape
.
to_dynamic
().
dyn_dims
();
auto
data_dims
=
data_shape
.
to_dynamic
().
dyn_dims
();
// Check that corresponding portions of tensor shapes match.
if
(
not
(
std
::
equal
(
ind_dims
.
begin
(),
ind_dims
.
begin
()
+
q
-
1
,
upd_dims
.
begin
())
and
std
::
equal
(
data_dims
.
begin
()
+
k
,
data_dims
.
end
(),
upd_dims
.
begin
()
+
q
-
1
)))
MIGRAPHX_THROW
(
"ScatterND: incorrect update shape. Update dimensions must match "
"indices and data."
);
if
(
data_shape
.
dynamic
())
return
data_shape
;
else
if
(
data_shape
.
broadcasted
())
{
{
return
{
s
.
type
(),
s
.
lens
()};
return
{
data_shape
.
type
(),
data_shape
.
lens
()};
}
}
else
else
{
{
return
s
.
with_lens
(
s
.
lens
());
return
data_shape
.
with_lens
(
data_shape
.
lens
());
}
}
}
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
{
argument
result
{
out
put_shape
};
argument
result
{
dyn_out
.
com
put
ed
_shape
};
auto
&
self
=
static_cast
<
const
Derived
&>
(
*
this
);
auto
&
self
=
static_cast
<
const
Derived
&>
(
*
this
);
visit_all
(
result
,
args
[
0
],
args
[
2
])([
&
](
auto
output
,
auto
data
,
auto
updates
)
{
visit_all
(
result
,
args
[
0
],
args
[
2
])([
&
](
auto
output
,
auto
data
,
auto
updates
)
{
std
::
copy
(
data
.
begin
(),
data
.
end
(),
output
.
begin
());
std
::
copy
(
data
.
begin
(),
data
.
end
(),
output
.
begin
());
...
@@ -74,8 +119,8 @@ struct scatternd_op : op_name<Derived>
...
@@ -74,8 +119,8 @@ struct scatternd_op : op_name<Derived>
auto
updates_std
=
shape
{
updates_shape
.
type
(),
updates_shape
.
lens
()};
auto
updates_std
=
shape
{
updates_shape
.
type
(),
updates_shape
.
lens
()};
auto
indices_shape
=
indices
.
get_shape
();
auto
indices_shape
=
indices
.
get_shape
();
auto
k
=
indices_shape
.
lens
().
back
();
auto
k
=
indices_shape
.
lens
().
back
();
auto
q
=
indices_shape
.
lens
().
size
();
auto
q
=
indices_shape
.
ndim
();
auto
r
=
out
put_shape
.
lens
().
size
();
auto
r
=
dyn_out
.
com
put
ed
_shape
.
ndim
();
par_for
(
updates_shape
.
elements
(),
[
&
](
const
auto
i
)
{
par_for
(
updates_shape
.
elements
(),
[
&
](
const
auto
i
)
{
auto
updates_idx
=
updates_std
.
multi
(
i
);
auto
updates_idx
=
updates_std
.
multi
(
i
);
std
::
vector
<
std
::
size_t
>
indices_idx
(
q
,
0
);
std
::
vector
<
std
::
size_t
>
indices_idx
(
q
,
0
);
...
@@ -89,7 +134,7 @@ struct scatternd_op : op_name<Derived>
...
@@ -89,7 +134,7 @@ struct scatternd_op : op_name<Derived>
std
::
copy
(
index_start
,
index_end
,
out_idx
.
begin
());
std
::
copy
(
index_start
,
index_end
,
out_idx
.
begin
());
std
::
copy
(
updates_idx
.
begin
()
+
q
-
1
,
updates_idx
.
end
(),
out_idx
.
begin
()
+
k
);
std
::
copy
(
updates_idx
.
begin
()
+
q
-
1
,
updates_idx
.
end
(),
out_idx
.
begin
()
+
k
);
self
.
reduction
()(
output
[
out
put_shape
.
index
(
out_idx
)],
updates
[
i
]);
self
.
reduction
()(
output
[
dyn_out
.
com
put
ed
_shape
.
index
(
out_idx
)],
updates
[
i
]);
});
});
});
});
});
});
...
...
src/include/migraphx/op/select_module.hpp
0 → 100644
View file @
30c49503
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_OPERATORS_SELECT_MODULE_HPP
#define MIGRAPHX_GUARD_OPERATORS_SELECT_MODULE_HPP
#include <migraphx/check_shapes.hpp>
#include <migraphx/module.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
select_module
{
shape
output_dyn_shapes
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
output_dyn_shapes
,
"output_dyn_shapes"
));
}
std
::
string
name
()
const
{
return
"select_module"
;
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
,
const
std
::
vector
<
module_ref
>&
)
const
{
check_shapes
{
inputs
,
*
this
,
true
}.
has_at_least
(
1
);
return
shape
{
output_dyn_shapes
};
}
std
::
vector
<
std
::
string
>
get_input_parameter_names
(
module_ref
mod
)
const
{
auto
param_names
=
mod
->
get_parameter_names
();
std
::
vector
<
std
::
string
>
ret
;
std
::
copy_if
(
param_names
.
cbegin
(),
param_names
.
cend
(),
std
::
back_inserter
(
ret
),
[](
auto
pn
)
{
return
not
contains
(
pn
,
"#output_"
);
});
return
ret
;
}
std
::
vector
<
std
::
string
>
get_output_parameter_names
(
module_ref
mod
)
const
{
auto
param_names
=
mod
->
get_parameter_names
();
std
::
vector
<
std
::
string
>
ret
;
std
::
copy_if
(
param_names
.
cbegin
(),
param_names
.
cend
(),
std
::
back_inserter
(
ret
),
[](
auto
pn
)
{
return
contains
(
pn
,
"#output_"
);
});
return
ret
;
}
argument
compute
(
const
shape
&
,
const
std
::
vector
<
argument
>&
args
,
const
std
::
vector
<
module_ref
>&
submodule_list
,
const
std
::
function
<
std
::
vector
<
argument
>
(
module_ref
&
,
const
std
::
unordered_map
<
std
::
string
,
argument
>&
)
>&
run
)
const
{
// Find submodule with input parameter shapes exactly the same as the input instruction
// arguments. Assuming instruction arguments are in the same order as the instruction
// parameters.
auto
module_iter
=
std
::
find_if
(
submodule_list
.
cbegin
(),
submodule_list
.
cend
(),
[
&
](
module_ref
mr
)
{
auto
in_param_names
=
get_input_parameter_names
(
mr
);
auto
param_shapes
=
mr
->
get_parameter_shapes
();
assert
(
in_param_names
.
size
()
<=
args
.
size
());
return
std
::
equal
(
in_param_names
.
cbegin
(),
in_param_names
.
cend
(),
args
.
cbegin
(),
[
&
](
auto
p_name
,
auto
a
)
{
return
a
.
get_shape
()
==
param_shapes
[
p_name
];
});
});
if
(
module_iter
==
submodule_list
.
end
())
{
MIGRAPHX_THROW
(
"SELECT_MODULE: no compatible submodules found for given input shapes"
);
}
auto
*
module_to_run
=
*
module_iter
;
std
::
unordered_map
<
std
::
string
,
argument
>
p_map
;
// add input parameters to parameter_map
auto
in_param_names
=
get_input_parameter_names
(
module_to_run
);
assert
(
in_param_names
.
size
()
<=
args
.
size
());
std
::
transform
(
in_param_names
.
begin
(),
in_param_names
.
end
(),
args
.
begin
(),
std
::
inserter
(
p_map
,
p_map
.
end
()),
[
&
](
auto
&&
name
,
auto
&&
a
)
{
return
std
::
make_pair
(
name
,
a
);
});
// One tuple output parameter in main module to multiple output parameters in submodule
auto
out_param_names
=
get_output_parameter_names
(
module_to_run
);
auto
output_sub_objects
=
args
.
back
().
get_sub_objects
();
assert
(
out_param_names
.
size
()
==
output_sub_objects
.
size
());
std
::
transform
(
out_param_names
.
begin
(),
out_param_names
.
end
(),
output_sub_objects
.
begin
(),
std
::
inserter
(
p_map
,
p_map
.
end
()),
[
&
](
auto
&&
name
,
auto
&&
a
)
{
auto
ps
=
module_to_run
->
get_parameter_shape
(
name
);
if
(
a
.
get_shape
()
!=
ps
)
{
assert
(
ps
.
bytes
()
==
a
.
get_shape
().
bytes
());
return
std
::
make_pair
(
name
,
a
.
reshape
(
ps
));
}
else
{
return
std
::
make_pair
(
name
,
a
);
}
});
auto
results
=
run
(
module_to_run
,
p_map
);
return
argument
{
results
};
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
shapes
)
const
{
return
shapes
.
size
()
-
1
;
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/slice.hpp
View file @
30c49503
/*
/*
* The MIT License (MIT)
* The MIT License (MIT)
*
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* of this software and associated documentation files (the "Software"), to deal
...
@@ -27,6 +27,7 @@
...
@@ -27,6 +27,7 @@
#include <migraphx/check_shapes.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/value.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/op/normalize_attribute.hpp>
...
@@ -46,6 +47,10 @@ struct slice
...
@@ -46,6 +47,10 @@ struct slice
return
pack
(
f
(
self
.
axes
,
"axes"
),
f
(
self
.
starts
,
"starts"
),
f
(
self
.
ends
,
"ends"
));
return
pack
(
f
(
self
.
axes
,
"axes"
),
f
(
self
.
starts
,
"starts"
),
f
(
self
.
ends
,
"ends"
));
}
}
/**
* Ensure that attribute vectors axes, starts, and ends are all the same size and values are in
* limits.
*/
value
attributes
()
const
value
attributes
()
const
{
{
value
normalize
=
value
::
object
{};
value
normalize
=
value
::
object
{};
...
@@ -65,14 +70,6 @@ struct slice
...
@@ -65,14 +70,6 @@ struct slice
std
::
string
name
()
const
{
return
"slice"
;
}
std
::
string
name
()
const
{
return
"slice"
;
}
auto
fix_index
(
const
std
::
vector
<
std
::
size_t
>&
lens
,
std
::
size_t
axis
,
int64_t
index
)
const
{
int64_t
r
=
std
::
min
(
index
,
static_cast
<
int64_t
>
(
lens
[
axis
]));
if
(
r
<
0
)
r
+=
lens
[
axis
];
return
std
::
size_t
(
r
);
}
auto
compute_offset
(
const
shape
&
s
)
const
auto
compute_offset
(
const
shape
&
s
)
const
{
{
const
std
::
vector
<
std
::
size_t
>&
lens
=
s
.
lens
();
const
std
::
vector
<
std
::
size_t
>&
lens
=
s
.
lens
();
...
@@ -83,14 +80,14 @@ struct slice
...
@@ -83,14 +80,14 @@ struct slice
for
(
std
::
size_t
i
=
0
;
i
<
axes
.
size
();
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
axes
.
size
();
i
++
)
{
{
auto
axis
=
axes
[
i
];
auto
axis
=
axes
[
i
];
offset
+=
fix_index
(
lens
,
axis
,
starts
[
i
]
)
*
strides
[
axis
];
offset
+=
starts
[
i
]
*
strides
[
axis
];
}
}
}
}
else
else
{
{
for
(
std
::
size_t
axis
=
0
;
axis
<
lens
.
size
();
axis
++
)
for
(
std
::
size_t
axis
=
0
;
axis
<
lens
.
size
();
axis
++
)
{
{
offset
+=
fix_index
(
lens
,
axis
,
starts
[
axis
]
)
*
strides
[
axis
];
offset
+=
starts
[
axis
]
*
strides
[
axis
];
}
}
}
}
return
offset
;
return
offset
;
...
@@ -98,37 +95,81 @@ struct slice
...
@@ -98,37 +95,81 @@ struct slice
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
);
auto
input_shape
=
inputs
[
0
];
auto
input_shape
=
inputs
[
0
];
auto
t
=
input_shape
.
type
();
auto
t
=
input_shape
.
type
();
const
auto
&
old_lens
=
input_shape
.
lens
();
const
auto
&
old_strides
=
input_shape
.
strides
();
if
(
std
::
any_of
(
// TODO: When support for dynamic shapes is added to normalize_attributes,
axes
.
begin
(),
axes
.
end
(),
[
&
](
auto
i
)
{
return
(
i
>=
old_lens
.
size
()
and
i
<
0
);
}))
// remove this restriction.
if
(
input_shape
.
dynamic
()
and
std
::
any_of
(
axes
.
begin
(),
axes
.
end
(),
[
&
](
auto
axis
)
{
return
not
input_shape
.
dyn_dims
()[
axis
].
is_fixed
();
}))
{
{
MIGRAPHX_THROW
(
"SLICE:
input axis "
+
to_string_range
(
axes
)
+
" out of range
"
);
MIGRAPHX_THROW
(
"SLICE:
slicing is not allowed on non-fixed dynamic input axis
"
);
}
}
if
(
starts
.
size
()
!=
axes
.
size
()
or
axes
.
size
()
!=
ends
.
size
())
// For a static shape, old_lens will be adjusted to a new size
// for those axes that are sliced.
// For dynamic shape, the adjusted old_lens become the new max values,
// while updating the old mins and opts if possible.
std
::
vector
<
std
::
size_t
>
new_mins
;
std
::
vector
<
std
::
size_t
>
new_opts
;
std
::
vector
<
std
::
size_t
>
old_lens
;
std
::
vector
<
std
::
size_t
>
old_strides
;
if
(
input_shape
.
dynamic
())
{
{
MIGRAPHX_THROW
(
"SLICE: inconsistent sizes"
);
old_lens
=
input_shape
.
max_lens
();
new_mins
=
input_shape
.
min_lens
();
new_opts
=
input_shape
.
opt_lens
();
}
else
{
old_lens
=
input_shape
.
lens
();
// For static shape (including during eval step after a dynamic input) the strides are
// indexed into the pre-slice array, so they are larger than the apparent size of the
// resulting shape.
old_strides
=
input_shape
.
strides
();
}
}
std
::
vector
<
std
::
size_t
>
new_lens
=
old_lens
;
std
::
vector
<
std
::
size_t
>
new_lens
=
old_lens
;
for
(
std
::
size_t
i
=
0
;
i
<
axes
.
size
();
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
axes
.
size
();
i
++
)
{
{
auto
axis
=
axes
[
i
];
auto
axis
=
axes
[
i
];
new_lens
[
axis
]
=
size_t
sliced_length
=
ends
[
i
]
-
starts
[
i
];
fix_index
(
old_lens
,
axis
,
ends
[
i
])
-
fix_index
(
old_lens
,
axis
,
starts
[
i
]);
// A Numpy indexing convention: a slice size larger than the actual dimension
// is legal and the "ends" value is clipped to the axis size
new_lens
[
axis
]
=
std
::
min
(
new_lens
[
axis
],
sliced_length
);
if
(
input_shape
.
dynamic
())
{
// TODO: when non-fixed shape slicing is allowed, this will be different than
// sliced_length, making use of TBD start/end values.
std
::
size_t
sliced_min_length
=
ends
[
i
]
-
starts
[
i
];
// if the slice size is smaller than maxes but larger than mins
new_mins
[
axis
]
=
std
::
min
(
sliced_min_length
,
new_mins
[
axis
]);
auto
sliced_opt_length
=
ends
[
i
]
-
starts
[
i
];
if
(
new_opts
[
axis
]
!=
0
)
new_opts
[
axis
]
=
sliced_opt_length
;
if
(
new_opts
[
axis
]
<
new_mins
[
axis
]
or
new_opts
[
axis
]
>
new_lens
[
axis
])
new_opts
[
axis
]
=
0
;
}
}
if
(
input_shape
.
dynamic
())
{
return
shape
{
t
,
new_mins
,
new_lens
,
new_opts
};
}
}
else
{
return
shape
{
t
,
new_lens
,
old_strides
};
return
shape
{
t
,
new_lens
,
old_strides
};
}
}
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
{
auto
input
=
args
[
0
];
auto
input
=
args
[
0
];
auto
offset
=
compute_offset
(
input
.
get_shape
())
*
output_shape
.
type_size
();
return
{
std
::
move
(
output_shape
),
[
=
]
{
return
input
.
data
()
+
offset
;
}};
auto
offset
=
compute_offset
(
input
.
get_shape
())
*
dyn_out
.
computed_shape
.
type_size
();
return
{
dyn_out
.
computed_shape
,
[
=
]
{
return
input
.
data
()
+
offset
;
}};
}
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
};
...
...
src/include/migraphx/op/where.hpp
View file @
30c49503
/*
/*
* The MIT License (MIT)
* The MIT License (MIT)
*
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* of this software and associated documentation files (the "Software"), to deal
...
@@ -42,9 +42,17 @@ struct where
...
@@ -42,9 +42,17 @@ struct where
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
3
).
same_dims
();
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
3
).
same_dims
();
auto
s1
=
inputs
.
at
(
1
);
auto
s1
=
inputs
.
at
(
1
);
auto
s2
=
inputs
.
at
(
2
);
auto
s2
=
inputs
.
at
(
2
);
if
(
s1
.
dynamic
()
or
s2
.
dynamic
())
{
if
(
s1
==
s2
)
return
s1
;
MIGRAPHX_THROW
(
"WHERE: dynamic input shapes must be the same"
);
}
// Compare two static shapes, returning a standard shape
if
(
s1
==
s2
and
s1
.
packed
())
if
(
s1
==
s2
and
s1
.
packed
())
{
{
return
s1
;
return
s1
;
...
@@ -63,12 +71,12 @@ struct where
...
@@ -63,12 +71,12 @@ struct where
}
}
}
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
{
argument
result
{
out
put_shape
};
argument
result
{
dyn_out
.
com
put
ed
_shape
};
visit_all
(
result
,
args
[
1
],
args
[
2
])([
&
](
auto
output
,
const
auto
x
,
const
auto
y
)
{
visit_all
(
result
,
args
[
1
],
args
[
2
])([
&
](
auto
output
,
const
auto
x
,
const
auto
y
)
{
args
[
0
].
visit
([
&
](
const
auto
condition
)
{
args
[
0
].
visit
([
&
](
const
auto
condition
)
{
par_for
(
out
put_shape
.
elements
(),
par_for
(
dyn_out
.
com
put
ed
_shape
.
elements
(),
[
&
](
auto
i
)
{
output
[
i
]
=
condition
[
i
]
?
x
[
i
]
:
y
[
i
];
});
[
&
](
auto
i
)
{
output
[
i
]
=
condition
[
i
]
?
x
[
i
]
:
y
[
i
];
});
});
});
});
});
...
...
src/include/migraphx/operation.hpp
View file @
30c49503
...
@@ -140,6 +140,8 @@ template <class T>
...
@@ -140,6 +140,8 @@ template <class T>
auto
compute_shape_op
(
rank
<
2
>
,
const
T
&
x
,
const
std
::
vector
<
shape
>&
inputs
)
auto
compute_shape_op
(
rank
<
2
>
,
const
T
&
x
,
const
std
::
vector
<
shape
>&
inputs
)
->
decltype
(
x
.
normalize_compute_shape
(
inputs
))
->
decltype
(
x
.
normalize_compute_shape
(
inputs
))
{
{
if
(
inputs
.
empty
())
MIGRAPHX_THROW
(
"At least one input is required for "
+
x
.
name
());
dependent_type
<
operation
,
T
>
y
=
x
;
dependent_type
<
operation
,
T
>
y
=
x
;
normalize_attributes
(
y
,
inputs
[
0
].
max_lens
());
normalize_attributes
(
y
,
inputs
[
0
].
max_lens
());
return
any_cast
<
T
>
(
y
).
normalize_compute_shape
(
inputs
);
return
any_cast
<
T
>
(
y
).
normalize_compute_shape
(
inputs
);
...
...
src/include/migraphx/
pass_config
.hpp
→
src/include/migraphx/
optimize_module
.hpp
View file @
30c49503
...
@@ -21,18 +21,28 @@
...
@@ -21,18 +21,28 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
* THE SOFTWARE.
*/
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_OPTIMIZE_MODULE_HPP
#define MIGRAPHX_GUARD_RTGLIB_OPTIMIZE_MODULE_HPP
#ifndef MIGRAPHX_GUARD_PASS_CONFIG_HPP
#include <string>
#define MIGRAPHX_GUARD_PASS_CONFIG_HPP
#include <migraphx/instruction_ref.hpp>
#include <migraphx/env.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_DISABLE_MEMORY_COLORING
)
struct
module_pass_manager
;
/**
* Runs several passes in a loop
*/
struct
optimize_module
{
std
::
string
name
()
const
{
return
"optimize_module"
;
}
void
apply
(
module_pass_manager
&
mpm
)
const
;
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_PASS_CONFIG_HPP
#endif
src/include/migraphx/register_op.hpp
View file @
30c49503
...
@@ -33,15 +33,36 @@
...
@@ -33,15 +33,36 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
// unregister all ops for specified target, useful when unloading dynamically plugged-in target lib
void
unregister_op
(
const
std
::
string
&
op_name
);
namespace
detail
{
struct
op_handler
{
operation
op
;
std
::
string
name
;
op_handler
(
const
operation
&
op_r
)
:
op
(
op_r
),
name
(
op
.
name
()){};
~
op_handler
()
{
unregister_op
(
name
);
}
};
}
// namespace detail
void
register_op_init
();
void
register_op
(
const
operation
&
op
);
void
register_op
(
const
operation
&
op
);
operation
load_op
(
const
std
::
string
&
name
);
operation
load_op
(
const
std
::
string
&
name
);
bool
has_op
(
const
std
::
string
&
name
);
bool
has_op
(
const
std
::
string
&
name
);
std
::
vector
<
std
::
string
>
get_operators
();
std
::
vector
<
std
::
string
>
get_operators
();
template
<
class
T
>
template
<
class
T
>
void
register_op
()
void
register_op
()
{
{
register_op
(
T
{});
register_op_init
();
// instantiate static op_map;
static
auto
op_h
=
detail
::
op_handler
(
T
{});
register_op
(
op_h
.
op
);
}
}
struct
register_op_action
struct
register_op_action
...
...
src/include/migraphx/register_target.hpp
View file @
30c49503
...
@@ -33,14 +33,28 @@
...
@@ -33,14 +33,28 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
void
register_target_init
();
void
register_target
(
const
target
&
t
);
void
register_target
(
const
target
&
t
);
void
unregister_target
(
const
std
::
string
&
name
);
target
make_target
(
const
std
::
string
&
name
);
target
make_target
(
const
std
::
string
&
name
);
std
::
vector
<
std
::
string
>
get_targets
();
std
::
vector
<
std
::
string
>
get_targets
();
namespace
detail
{
struct
target_handler
{
target
t
;
std
::
string
target_name
;
target_handler
(
const
target
&
t_r
)
:
t
(
t_r
),
target_name
(
t
.
name
())
{}
~
target_handler
()
{
unregister_target
(
target_name
);
}
};
}
// namespace detail
template
<
class
T
>
template
<
class
T
>
void
register_target
()
void
register_target
()
{
{
register_target
(
T
{});
register_target_init
();
static
auto
t_h
=
detail
::
target_handler
(
T
{});
register_target
(
t_h
.
t
);
}
}
struct
register_target_action
struct
register_target_action
...
...
src/include/migraphx/replace_allocate.hpp
View file @
30c49503
...
@@ -32,6 +32,9 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -32,6 +32,9 @@ inline namespace MIGRAPHX_INLINE_NS {
struct
module
;
struct
module
;
/**
* Replace `allocate` instructions with target allocations or output parameters.
*/
struct
replace_allocate
struct
replace_allocate
{
{
allocation_model
model
;
allocation_model
model
;
...
...
src/include/migraphx/serialize.hpp
View file @
30c49503
...
@@ -28,6 +28,7 @@
...
@@ -28,6 +28,7 @@
#include <migraphx/value.hpp>
#include <migraphx/value.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/requires.hpp>
#include <migraphx/requires.hpp>
#include <migraphx/optional.hpp>
#include <migraphx/rank.hpp>
#include <migraphx/rank.hpp>
#include <type_traits>
#include <type_traits>
...
@@ -60,11 +61,12 @@ value to_value_impl(rank<0>, const T&)
...
@@ -60,11 +61,12 @@ value to_value_impl(rank<0>, const T&)
return
value
::
object
{};
return
value
::
object
{};
}
}
template
<
class
T
,
class
U
>
template
<
class
T
>
value
to_value_impl
(
rank
<
1
>
,
const
std
::
pair
<
T
,
U
>&
x
)
auto
to_value_impl
(
rank
<
1
>
,
const
T
&
x
)
->
decltype
(
std
::
tuple_size
<
T
>
{},
value
{}
)
{
{
value
result
=
value
::
array
{};
return
{
x
.
first
,
x
.
second
};
repeat_c
<
std
::
tuple_size
<
T
>
{}
>
([
&
](
auto
i
)
{
result
.
push_back
(
to_value
(
std
::
get
<
i
>
(
x
)));
});
return
result
;
}
}
template
<
class
T
>
template
<
class
T
>
...
@@ -86,46 +88,55 @@ value to_value_impl(rank<3>, const T& x)
...
@@ -86,46 +88,55 @@ value to_value_impl(rank<3>, const T& x)
return
result
;
return
result
;
}
}
template
<
class
T
>
auto
to_value_impl
(
rank
<
4
>
,
const
optional
<
T
>&
x
)
{
value
result
{};
if
(
x
.
has_value
())
return
to_value
(
*
x
);
return
result
;
}
template
<
class
T
,
MIGRAPHX_REQUIRES
(
std
::
is_signed
<
T
>{})
>
template
<
class
T
,
MIGRAPHX_REQUIRES
(
std
::
is_signed
<
T
>{})
>
value
to_value_impl
(
rank
<
4
>
,
const
T
&
x
)
value
to_value_impl
(
rank
<
5
>
,
const
T
&
x
)
{
{
return
std
::
int64_t
{
x
};
return
std
::
int64_t
{
x
};
}
}
template
<
class
T
,
MIGRAPHX_REQUIRES
(
std
::
is_unsigned
<
T
>{})
>
template
<
class
T
,
MIGRAPHX_REQUIRES
(
std
::
is_unsigned
<
T
>{})
>
value
to_value_impl
(
rank
<
5
>
,
const
T
&
x
)
value
to_value_impl
(
rank
<
6
>
,
const
T
&
x
)
{
{
return
std
::
uint64_t
{
x
};
return
std
::
uint64_t
{
x
};
}
}
template
<
class
T
,
MIGRAPHX_REQUIRES
(
std
::
is_floating_point
<
T
>{})
>
template
<
class
T
,
MIGRAPHX_REQUIRES
(
std
::
is_floating_point
<
T
>{})
>
value
to_value_impl
(
rank
<
6
>
,
const
T
&
x
)
value
to_value_impl
(
rank
<
7
>
,
const
T
&
x
)
{
{
return
double
{
x
};
return
double
{
x
};
}
}
template
<
class
T
,
MIGRAPHX_REQUIRES
(
std
::
is_enum
<
T
>{})
>
template
<
class
T
,
MIGRAPHX_REQUIRES
(
std
::
is_enum
<
T
>{})
>
value
to_value_impl
(
rank
<
7
>
,
const
T
&
x
)
value
to_value_impl
(
rank
<
8
>
,
const
T
&
x
)
{
{
return
x
;
return
x
;
}
}
inline
value
to_value_impl
(
rank
<
8
>
,
const
std
::
string
&
x
)
{
return
x
;
}
inline
value
to_value_impl
(
rank
<
9
>
,
const
std
::
string
&
x
)
{
return
x
;
}
template
<
class
T
>
template
<
class
T
>
auto
to_value_impl
(
rank
<
9
>
,
const
T
&
x
)
->
decltype
(
migraphx_to_value
(
x
))
auto
to_value_impl
(
rank
<
10
>
,
const
T
&
x
)
->
decltype
(
migraphx_to_value
(
x
))
{
{
return
migraphx_to_value
(
x
);
return
migraphx_to_value
(
x
);
}
}
template
<
class
T
>
template
<
class
T
>
auto
to_value_impl
(
rank
<
1
0
>
,
const
T
&
x
)
->
decltype
(
x
.
to_value
())
auto
to_value_impl
(
rank
<
1
1
>
,
const
T
&
x
)
->
decltype
(
x
.
to_value
())
{
{
return
x
.
to_value
();
return
x
.
to_value
();
}
}
template
<
class
T
>
template
<
class
T
>
auto
to_value_impl
(
rank
<
1
1
>
,
const
T
&
x
)
auto
to_value_impl
(
rank
<
1
2
>
,
const
T
&
x
)
->
decltype
(
migraphx_to_value
(
std
::
declval
<
value
&>
(),
x
),
value
{})
->
decltype
(
migraphx_to_value
(
std
::
declval
<
value
&>
(),
x
),
value
{})
{
{
value
v
;
value
v
;
...
@@ -144,7 +155,14 @@ void from_value_impl(rank<0>, const value& v, T& x)
...
@@ -144,7 +155,14 @@ void from_value_impl(rank<0>, const value& v, T& x)
}
}
template
<
class
T
>
template
<
class
T
>
auto
from_value_impl
(
rank
<
1
>
,
const
value
&
v
,
T
&
x
)
auto
from_value_impl
(
rank
<
1
>
,
const
value
&
v
,
T
&
x
)
->
decltype
(
std
::
tuple_size
<
T
>
{},
void
())
{
repeat_c
<
std
::
tuple_size
<
T
>
{}
>
(
[
&
](
auto
i
)
{
std
::
get
<
i
>
(
x
)
=
from_value
<
std
::
tuple_element_t
<
i
,
T
>>
(
v
[
i
]);
});
}
template
<
class
T
>
auto
from_value_impl
(
rank
<
2
>
,
const
value
&
v
,
T
&
x
)
->
decltype
(
x
.
insert
(
x
.
end
(),
*
x
.
begin
()),
void
())
->
decltype
(
x
.
insert
(
x
.
end
(),
*
x
.
begin
()),
void
())
{
{
x
.
clear
();
x
.
clear
();
...
@@ -153,7 +171,7 @@ auto from_value_impl(rank<1>, const value& v, T& x)
...
@@ -153,7 +171,7 @@ auto from_value_impl(rank<1>, const value& v, T& x)
}
}
template
<
class
T
,
MIGRAPHX_REQUIRES
(
std
::
is_arithmetic
<
typename
T
::
value_type
>{})
>
template
<
class
T
,
MIGRAPHX_REQUIRES
(
std
::
is_arithmetic
<
typename
T
::
value_type
>{})
>
auto
from_value_impl
(
rank
<
2
>
,
const
value
&
v
,
T
&
x
)
auto
from_value_impl
(
rank
<
3
>
,
const
value
&
v
,
T
&
x
)
->
decltype
(
x
.
insert
(
x
.
end
(),
*
x
.
begin
()),
void
())
->
decltype
(
x
.
insert
(
x
.
end
(),
*
x
.
begin
()),
void
())
{
{
x
.
clear
();
x
.
clear
();
...
@@ -170,7 +188,7 @@ auto from_value_impl(rank<2>, const value& v, T& x)
...
@@ -170,7 +188,7 @@ auto from_value_impl(rank<2>, const value& v, T& x)
}
}
template
<
class
T
>
template
<
class
T
>
auto
from_value_impl
(
rank
<
3
>
,
const
value
&
v
,
T
&
x
)
->
decltype
(
x
.
insert
(
*
x
.
begin
()),
void
())
auto
from_value_impl
(
rank
<
4
>
,
const
value
&
v
,
T
&
x
)
->
decltype
(
x
.
insert
(
*
x
.
begin
()),
void
())
{
{
x
.
clear
();
x
.
clear
();
for
(
auto
&&
e
:
v
)
for
(
auto
&&
e
:
v
)
...
@@ -178,7 +196,7 @@ auto from_value_impl(rank<3>, const value& v, T& x) -> decltype(x.insert(*x.begi
...
@@ -178,7 +196,7 @@ auto from_value_impl(rank<3>, const value& v, T& x) -> decltype(x.insert(*x.begi
}
}
template
<
class
T
,
MIGRAPHX_REQUIRES
(
is_reflectable
<
T
>{})
>
template
<
class
T
,
MIGRAPHX_REQUIRES
(
is_reflectable
<
T
>{})
>
void
from_value_impl
(
rank
<
4
>
,
const
value
&
v
,
T
&
x
)
void
from_value_impl
(
rank
<
5
>
,
const
value
&
v
,
T
&
x
)
{
{
reflect_each
(
x
,
[
&
](
auto
&
y
,
const
std
::
string
&
name
)
{
reflect_each
(
x
,
[
&
](
auto
&
y
,
const
std
::
string
&
name
)
{
using
type
=
std
::
decay_t
<
decltype
(
y
)
>
;
using
type
=
std
::
decay_t
<
decltype
(
y
)
>
;
...
@@ -187,28 +205,29 @@ void from_value_impl(rank<4>, const value& v, T& x)
...
@@ -187,28 +205,29 @@ void from_value_impl(rank<4>, const value& v, T& x)
});
});
}
}
template
<
class
T
,
MIGRAPHX_REQUIRES
(
std
::
is_arithmetic
<
T
>{})
>
template
<
class
T
>
void
from_value_impl
(
rank
<
5
>
,
const
value
&
v
,
T
&
x
)
void
from_value_impl
(
rank
<
6
>
,
const
value
&
v
,
optional
<
T
>
&
x
)
{
{
x
=
v
.
to
<
T
>
();
if
(
not
v
.
is_null
())
x
=
from_value
<
T
>
(
v
);
}
}
template
<
class
T
,
MIGRAPHX_REQUIRES
(
std
::
is_enum
<
T
>{})
>
template
<
class
T
,
MIGRAPHX_REQUIRES
(
std
::
is_arithmetic
<
T
>{}
or
std
::
is_enum
<
T
>
{})
>
void
from_value_impl
(
rank
<
6
>
,
const
value
&
v
,
T
&
x
)
void
from_value_impl
(
rank
<
7
>
,
const
value
&
v
,
T
&
x
)
{
{
x
=
v
.
to
<
T
>
();
x
=
v
.
to
<
T
>
();
}
}
inline
void
from_value_impl
(
rank
<
7
>
,
const
value
&
v
,
std
::
string
&
x
)
{
x
=
v
.
to
<
std
::
string
>
();
}
inline
void
from_value_impl
(
rank
<
8
>
,
const
value
&
v
,
std
::
string
&
x
)
{
x
=
v
.
to
<
std
::
string
>
();
}
template
<
class
T
>
template
<
class
T
>
auto
from_value_impl
(
rank
<
8
>
,
const
value
&
v
,
T
&
x
)
->
decltype
(
x
.
from_value
(
v
),
void
())
auto
from_value_impl
(
rank
<
9
>
,
const
value
&
v
,
T
&
x
)
->
decltype
(
x
.
from_value
(
v
),
void
())
{
{
x
.
from_value
(
v
);
x
.
from_value
(
v
);
}
}
template
<
class
T
>
template
<
class
T
>
auto
from_value_impl
(
rank
<
9
>
,
const
value
&
v
,
T
&
x
)
->
decltype
(
migraphx_from_value
(
v
,
x
),
void
())
auto
from_value_impl
(
rank
<
10
>
,
const
value
&
v
,
T
&
x
)
->
decltype
(
migraphx_from_value
(
v
,
x
),
void
())
{
{
migraphx_from_value
(
v
,
x
);
migraphx_from_value
(
v
,
x
);
}
}
...
@@ -218,13 +237,13 @@ auto from_value_impl(rank<9>, const value& v, T& x) -> decltype(migraphx_from_va
...
@@ -218,13 +237,13 @@ auto from_value_impl(rank<9>, const value& v, T& x) -> decltype(migraphx_from_va
template
<
class
T
>
template
<
class
T
>
value
to_value
(
const
T
&
x
)
value
to_value
(
const
T
&
x
)
{
{
return
detail
::
to_value_impl
(
rank
<
1
1
>
{},
x
);
return
detail
::
to_value_impl
(
rank
<
1
2
>
{},
x
);
}
}
template
<
class
T
>
template
<
class
T
>
void
from_value
(
const
value
&
v
,
T
&
x
)
void
from_value
(
const
value
&
v
,
T
&
x
)
{
{
detail
::
from_value_impl
(
rank
<
9
>
{},
v
,
x
);
detail
::
from_value_impl
(
rank
<
10
>
{},
v
,
x
);
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
Prev
1
2
3
4
5
6
7
…
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment