Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
9bff4331
Commit
9bff4331
authored
Mar 21, 2023
by
Paul
Browse files
Merge
parents
214b313f
94a7f6ee
Changes
274
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
537 additions
and
187 deletions
+537
-187
src/include/migraphx/context.hpp
src/include/migraphx/context.hpp
+3
-1
src/include/migraphx/file_buffer.hpp
src/include/migraphx/file_buffer.hpp
+1
-1
src/include/migraphx/half.hpp
src/include/migraphx/half.hpp
+3
-3
src/include/migraphx/instruction.hpp
src/include/migraphx/instruction.hpp
+2
-0
src/include/migraphx/instruction_ref.hpp
src/include/migraphx/instruction_ref.hpp
+2
-2
src/include/migraphx/match/layernorm.hpp
src/include/migraphx/match/layernorm.hpp
+30
-6
src/include/migraphx/memory_coloring.hpp
src/include/migraphx/memory_coloring.hpp
+3
-2
src/include/migraphx/module.hpp
src/include/migraphx/module.hpp
+6
-0
src/include/migraphx/op/allocate.hpp
src/include/migraphx/op/allocate.hpp
+1
-1
src/include/migraphx/op/argmax.hpp
src/include/migraphx/op/argmax.hpp
+19
-11
src/include/migraphx/op/concat.hpp
src/include/migraphx/op/concat.hpp
+64
-25
src/include/migraphx/op/dot.hpp
src/include/migraphx/op/dot.hpp
+51
-22
src/include/migraphx/op/flatten.hpp
src/include/migraphx/op/flatten.hpp
+39
-9
src/include/migraphx/op/gather.hpp
src/include/migraphx/op/gather.hpp
+40
-15
src/include/migraphx/op/gathernd.hpp
src/include/migraphx/op/gathernd.hpp
+88
-17
src/include/migraphx/op/nonmaxsuppression.hpp
src/include/migraphx/op/nonmaxsuppression.hpp
+33
-31
src/include/migraphx/op/normalize_attribute.hpp
src/include/migraphx/op/normalize_attribute.hpp
+21
-9
src/include/migraphx/op/pad.hpp
src/include/migraphx/op/pad.hpp
+21
-10
src/include/migraphx/op/reduce_op.hpp
src/include/migraphx/op/reduce_op.hpp
+39
-15
src/include/migraphx/op/reshape.hpp
src/include/migraphx/op/reshape.hpp
+71
-7
No files found.
src/include/migraphx/context.hpp
View file @
9bff4331
...
@@ -66,6 +66,7 @@ any_ptr get_queue_context(T&)
...
@@ -66,6 +66,7 @@ any_ptr get_queue_context(T&)
{
{
return
{};
return
{};
}
}
template
<
class
T
>
template
<
class
T
>
void
wait_for_context
(
T
&
,
any_ptr
)
void
wait_for_context
(
T
&
,
any_ptr
)
{
{
...
@@ -302,7 +303,7 @@ struct context
...
@@ -302,7 +303,7 @@ struct context
PrivateDetailTypeErasedT
value
,
PrivateDetailTypeErasedT
value
,
typename
std
::
enable_if
<
not
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
typename
std
::
enable_if
<
not
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
int
>::
type
*
=
nullptr
)
noexcept
int
>::
type
*
=
nullptr
)
noexcept
:
private_detail_te_value
(
value
)
:
private_detail_te_value
(
std
::
move
(
value
)
)
{
{
}
}
...
@@ -412,6 +413,7 @@ inline const ValueType& any_cast(const context& x)
...
@@ -412,6 +413,7 @@ inline const ValueType& any_cast(const context& x)
#endif
#endif
inline
void
migraphx_to_value
(
value
&
v
,
const
context
&
ctx
)
{
v
=
ctx
.
to_value
();
}
inline
void
migraphx_to_value
(
value
&
v
,
const
context
&
ctx
)
{
v
=
ctx
.
to_value
();
}
inline
void
migraphx_from_value
(
const
value
&
v
,
context
&
ctx
)
{
ctx
.
from_value
(
v
);
}
inline
void
migraphx_from_value
(
const
value
&
v
,
context
&
ctx
)
{
ctx
.
from_value
(
v
);
}
#endif
#endif
...
...
src/include/migraphx/file_buffer.hpp
View file @
9bff4331
...
@@ -31,7 +31,7 @@
...
@@ -31,7 +31,7 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
std
::
vector
<
char
>
read_buffer
(
const
std
::
string
&
filename
);
std
::
vector
<
char
>
read_buffer
(
const
std
::
string
&
filename
,
size_t
offset
=
0
,
size_t
nbytes
=
0
);
std
::
string
read_string
(
const
std
::
string
&
filename
);
std
::
string
read_string
(
const
std
::
string
&
filename
);
void
write_buffer
(
const
std
::
string
&
filename
,
const
char
*
buffer
,
std
::
size_t
size
);
void
write_buffer
(
const
std
::
string
&
filename
,
const
char
*
buffer
,
std
::
size_t
size
);
...
...
src/include/migraphx/half.hpp
View file @
9bff4331
...
@@ -25,7 +25,7 @@
...
@@ -25,7 +25,7 @@
#ifndef MIGRAPHX_GUARD_RTGLIB_HALF_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_HALF_HPP
#define MIGRAPHX_GUARD_RTGLIB_HALF_HPP
#define MIGRAPHX_GUARD_RTGLIB_HALF_HPP
#include <half.hpp>
#include <half
/half
.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
namespace
migraphx
{
...
@@ -58,12 +58,12 @@ using deduce = typename detail::deduce<T>::type;
...
@@ -58,12 +58,12 @@ using deduce = typename detail::deduce<T>::type;
namespace
std
{
namespace
std
{
template
<
class
T
>
template
<
class
T
>
struct
common_type
<
migraphx
::
half
,
T
>
:
std
::
common_type
<
float
,
T
>
struct
common_type
<
migraphx
::
half
,
T
>
:
std
::
common_type
<
float
,
T
>
// NOLINT
{
{
};
};
template
<
class
T
>
template
<
class
T
>
struct
common_type
<
T
,
migraphx
::
half
>
:
std
::
common_type
<
float
,
T
>
struct
common_type
<
T
,
migraphx
::
half
>
:
std
::
common_type
<
float
,
T
>
// NOLINT
{
{
};
};
...
...
src/include/migraphx/instruction.hpp
View file @
9bff4331
...
@@ -121,6 +121,8 @@ struct instruction
...
@@ -121,6 +121,8 @@ struct instruction
bool
can_eval
()
const
;
bool
can_eval
()
const
;
bool
is_undefined
()
const
;
argument
eval
(
bool
check_eval
=
true
)
const
;
argument
eval
(
bool
check_eval
=
true
)
const
;
void
finalize
(
context
&
ctx
);
void
finalize
(
context
&
ctx
);
...
...
src/include/migraphx/instruction_ref.hpp
View file @
9bff4331
...
@@ -41,7 +41,7 @@ migraphx::instruction* as_address(const instruction_ref& ins) noexcept;
...
@@ -41,7 +41,7 @@ migraphx::instruction* as_address(const instruction_ref& ins) noexcept;
namespace
std
{
namespace
std
{
template
<
>
template
<
>
struct
hash
<
migraphx
::
instruction_ref
>
struct
hash
<
migraphx
::
instruction_ref
>
// NOLINT
{
{
using
argument_type
=
migraphx
::
instruction_ref
;
using
argument_type
=
migraphx
::
instruction_ref
;
using
result_type
=
std
::
size_t
;
using
result_type
=
std
::
size_t
;
...
@@ -52,7 +52,7 @@ struct hash<migraphx::instruction_ref>
...
@@ -52,7 +52,7 @@ struct hash<migraphx::instruction_ref>
};
};
template
<
>
template
<
>
struct
equal_to
<
migraphx
::
instruction_ref
>
struct
equal_to
<
migraphx
::
instruction_ref
>
// NOLINT
{
{
using
argument_type
=
migraphx
::
instruction_ref
;
using
argument_type
=
migraphx
::
instruction_ref
;
using
result_type
=
bool
;
using
result_type
=
bool
;
...
...
src/include/migraphx/match/layernorm.hpp
View file @
9bff4331
...
@@ -36,22 +36,46 @@ template <class F>
...
@@ -36,22 +36,46 @@ template <class F>
struct
layernorm_matcher
struct
layernorm_matcher
{
{
F
f
;
F
f
;
auto
last_axis
()
const
{
return
make_basic_pred_matcher
([](
instruction_ref
ins
)
{
auto
v
=
ins
->
get_operator
().
to_value
();
if
(
not
v
.
contains
(
"axes"
))
return
false
;
auto
axes
=
v
[
"axes"
].
to_vector
<
std
::
size_t
>
();
if
(
axes
.
size
()
!=
1
)
return
false
;
return
axes
.
front
()
==
ins
->
inputs
().
front
()
->
get_shape
().
lens
().
size
()
-
1
;
});
}
auto
reduce_mean
()
const
{
return
f
(
"reduce_mean"
)(
last_axis
());
}
auto
x_minus_mean
()
const
auto
x_minus_mean
()
const
{
{
return
f
(
"sub"
)(
arg
(
0
)(
any
().
bind
(
"x"
)),
arg
(
1
)(
skip_broadcasts
(
f
(
"
reduce_mean
"
))));
return
f
(
"sub"
)(
arg
(
0
)(
any
().
bind
(
"x"
)),
arg
(
1
)(
skip_broadcasts
(
reduce_mean
(
))));
}
}
auto
variance
()
const
auto
variance
()
const
{
{
return
f
(
"reduce_mean"
)(
arg
(
0
)(
f
(
"pow"
)(
arg
(
0
)(
x_minus_mean
()),
arg
(
1
)(
has_value
(
2.0
f
)))));
return
reduce_mean
()(
arg
(
0
)(
any_of
(
f
(
"pow"
)(
arg
(
0
)(
x_minus_mean
()),
arg
(
1
)(
has_value
(
2.0
f
))),
f
(
"mul"
)(
arg
(
0
)(
x_minus_mean
()),
arg
(
1
)(
x_minus_mean
())),
f
(
"sqdiff"
)(
either_arg
(
0
,
1
)(
any
().
bind
(
"x"
),
skip_broadcasts
(
reduce_mean
()))))));
}
}
auto
layernorm_onnx
(
)
const
auto
sqrt_add_eps
(
const
std
::
string
&
name
)
const
{
{
return
f
(
"div"
)(
arg
(
0
)(
x_minus_mean
()),
auto
add_eps
=
f
(
"add"
)(
either_arg
(
0
,
1
)(
variance
(),
is_constant
().
bind
(
"eps"
)));
return
skip_broadcasts
(
f
(
name
)(
arg
(
0
)(
any_of
(
add_eps
,
variance
()))));
}
arg
(
1
)(
skip_broadcasts
(
f
(
"sqrt"
)(
arg
(
0
)(
auto
layernorm_onnx
()
const
f
(
"add"
)(
either_arg
(
0
,
1
)(
variance
(),
is_constant
().
bind
(
"eps"
))))))));
{
auto
div_sqrt
=
f
(
"div"
)(
arg
(
0
)(
x_minus_mean
()),
arg
(
1
)(
sqrt_add_eps
(
"sqrt"
)));
auto
mul_rsqrt
=
f
(
"mul"
)(
either_arg
(
0
,
1
)(
x_minus_mean
(),
sqrt_add_eps
(
"rsqrt"
)));
return
any
(
any_of
(
div_sqrt
,
mul_rsqrt
));
}
}
auto
matcher
()
const
{
return
layernorm_onnx
();
}
auto
matcher
()
const
{
return
layernorm_onnx
();
}
...
...
src/include/migraphx/memory_coloring.hpp
View file @
9bff4331
...
@@ -33,13 +33,14 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -33,13 +33,14 @@ inline namespace MIGRAPHX_INLINE_NS {
struct
module
;
struct
module
;
/**
/**
* Remove memory allocations. It uses graph coloring to find memory allocations that can be reused.
* Remove multiple memory allocations using graph coloring to find memory allocations that can be
* reused.
*/
*/
struct
memory_coloring
struct
memory_coloring
{
{
std
::
string
allocation_op
{};
std
::
string
allocation_op
{};
bool
verify
=
false
;
bool
verify
=
false
;
std
::
string
name
()
const
{
return
"memory
coloring"
;
}
std
::
string
name
()
const
{
return
"memory
_
coloring"
;
}
void
apply
(
module
&
m
)
const
;
void
apply
(
module
&
m
)
const
;
};
};
...
...
src/include/migraphx/module.hpp
View file @
9bff4331
...
@@ -205,6 +205,12 @@ struct module
...
@@ -205,6 +205,12 @@ struct module
void
print_graph
(
std
::
ostream
&
os
,
bool
brief
=
false
)
const
;
void
print_graph
(
std
::
ostream
&
os
,
bool
brief
=
false
)
const
;
void
print_py
(
std
::
ostream
&
os
)
const
;
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
print_py
(
std
::
ostream
&
os
,
const
std
::
string
&
mname
,
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
names
)
const
;
void
print_cpp
(
std
::
ostream
&
os
)
const
;
void
print_cpp
(
std
::
ostream
&
os
)
const
;
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
print_cpp
(
std
::
ostream
&
os
,
print_cpp
(
std
::
ostream
&
os
,
...
...
src/include/migraphx/op/allocate.hpp
View file @
9bff4331
...
@@ -44,7 +44,7 @@ struct allocate
...
@@ -44,7 +44,7 @@ struct allocate
std
::
string
name
()
const
{
return
"allocate"
;
}
std
::
string
name
()
const
{
return
"allocate"
;
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
{
migraphx
::
check_shapes
{
inputs
,
*
this
}.
has
(
0
);
migraphx
::
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
0
);
return
s
;
return
s
;
}
}
argument
compute
(
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
)
const
argument
compute
(
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
)
const
...
...
src/include/migraphx/op/argmax.hpp
View file @
9bff4331
...
@@ -30,6 +30,7 @@
...
@@ -30,6 +30,7 @@
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/dyn_output.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -56,12 +57,20 @@ struct argmax
...
@@ -56,12 +57,20 @@ struct argmax
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
);
auto
lens
=
inputs
[
0
].
lens
();
const
auto
&
s0
=
inputs
[
0
];
if
(
s0
.
dynamic
())
lens
[
axis
]
=
1
;
{
auto
dyn_dims
=
s0
.
dyn_dims
();
return
{
shape
::
int64_type
,
lens
};
dyn_dims
[
axis
]
=
{
1
,
1
,
0
};
return
{
shape
::
int64_type
,
dyn_dims
};
}
else
{
auto
lens
=
s0
.
lens
();
lens
[
axis
]
=
1
;
return
{
shape
::
int64_type
,
lens
};
}
}
}
template
<
class
T
>
template
<
class
T
>
...
@@ -79,19 +88,18 @@ struct argmax
...
@@ -79,19 +88,18 @@ struct argmax
max_index
=
i
;
max_index
=
i
;
}
}
}
}
return
max_index
;
return
max_index
;
}
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
{
argument
result
{
out
put_shape
};
argument
result
{
dyn_out
.
com
put
ed
_shape
};
auto
batch_item_num
=
args
.
front
().
get_shape
().
lens
()[
axis
];
auto
batch_item_num
=
args
.
front
().
get_shape
().
lens
()[
axis
];
result
.
visit
([
&
](
auto
output
)
{
result
.
visit
([
&
](
auto
output
)
{
args
[
0
].
visit
([
&
](
auto
input
)
{
args
[
0
].
visit
([
&
](
auto
input
)
{
par_for
(
out
put_shape
.
elements
(),
[
&
](
auto
i
)
{
par_for
(
dyn_out
.
com
put
ed
_shape
.
elements
(),
[
&
](
auto
i
)
{
auto
data_idx
=
out
put_shape
.
multi
(
i
);
auto
data_idx
=
dyn_out
.
com
put
ed
_shape
.
multi
(
i
);
output
[
i
]
=
this
->
calc_argmax
(
input
,
data_idx
,
batch_item_num
);
output
[
i
]
=
this
->
calc_argmax
(
input
,
data_idx
,
batch_item_num
);
});
});
});
});
...
...
src/include/migraphx/op/concat.hpp
View file @
9bff4331
...
@@ -26,6 +26,7 @@
...
@@ -26,6 +26,7 @@
#include <array>
#include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/literal.hpp>
...
@@ -73,49 +74,87 @@ struct concat
...
@@ -73,49 +74,87 @@ struct concat
}
}
return
offsets
;
return
offsets
;
}
}
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
if
(
inputs
.
empty
())
// inputs can contain 1 or more shapes (variadic). compute_shape_op ensures there must
// be at least 1.
check_shapes
{
inputs
,
*
this
,
true
}.
same_ndims
().
same_type
();
if
(
std
::
none_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
&
](
const
shape
&
s
)
{
return
s
.
dynamic
();
}))
{
{
MIGRAPHX_THROW
(
"CONCAT: Number of input tensors should exceed 0"
);
// Static input shapes
const
auto
&
first_shape_lens
=
inputs
.
front
().
lens
();
const
auto
&
type
=
inputs
.
front
().
type
();
for
(
std
::
size_t
ll
=
0
;
ll
<
first_shape_lens
.
size
();
ll
++
)
{
if
(
ll
!=
axis
)
{
if
(
not
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
&
](
auto
s
)
{
return
s
.
lens
()[
ll
]
==
first_shape_lens
[
ll
];
}))
{
MIGRAPHX_THROW
(
"CONCAT: all input dimensions should match along axis "
+
std
::
to_string
(
ll
));
}
}
}
std
::
size_t
new_dim_axis
=
0
;
for
(
const
auto
&
input
:
inputs
)
{
const
auto
&
lens
=
input
.
lens
();
new_dim_axis
+=
lens
[
axis
];
}
std
::
vector
<
std
::
size_t
>
new_lens
=
first_shape_lens
;
new_lens
[
axis
]
=
new_dim_axis
;
return
shape
::
from_permutation
(
type
,
new_lens
,
find_permutation
(
inputs
));
}
}
else
if
(
std
::
all_of
(
const
auto
&
first_shape_lens
=
inputs
.
front
().
lens
();
inputs
.
begin
(),
inputs
.
end
(),
[
&
](
const
shape
&
s
)
{
return
s
.
dynamic
();
}))
const
auto
&
type
=
inputs
.
front
().
type
();
for
(
std
::
size_t
l
=
0
;
l
<
first_shape_lens
.
size
();
l
++
)
{
{
if
(
l
!=
axis
)
// Dynamic input shapes
for
(
std
::
size_t
index
=
0
;
index
<
inputs
[
0
].
ndim
();
index
++
)
{
{
if
(
not
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
&
](
auto
s
)
{
if
(
index
!=
axis
)
return
s
.
lens
()[
l
]
==
first_shape_lens
[
l
];
}))
{
{
MIGRAPHX_THROW
(
"CONCAT: Non-axis dimensions should match"
);
if
(
not
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
&
](
const
shape
&
s
)
{
return
s
.
dyn_dims
()[
index
]
==
inputs
[
0
].
dyn_dims
()[
index
];
}))
MIGRAPHX_THROW
(
"CONCAT: all input dimensions should match in axis "
+
std
::
to_string
(
index
));
}
}
}
}
std
::
size_t
new_min
=
0
;
std
::
size_t
new_max
=
0
;
for
(
const
auto
&
input
:
inputs
)
{
auto
ddim
=
input
.
dyn_dims
()[
axis
];
new_min
+=
ddim
.
min
;
new_max
+=
ddim
.
max
;
}
auto
new_dims
=
inputs
[
0
].
dyn_dims
();
new_dims
[
axis
]
=
migraphx
::
shape
::
dynamic_dimension
{
new_min
,
new_max
,
0
};
return
{
inputs
[
0
].
type
(),
new_dims
};
}
}
std
::
size_t
new_dim_axis
=
0
;
else
for
(
const
auto
&
input
:
inputs
)
{
{
const
auto
&
lens
=
input
.
lens
();
MIGRAPHX_THROW
(
"CONCAT: Cannot mix static and dynamic input shapes."
);
new_dim_axis
+=
lens
[
axis
];
}
}
std
::
vector
<
std
::
size_t
>
new_lens
;
std
::
copy
(
first_shape_lens
.
begin
(),
first_shape_lens
.
end
(),
std
::
back_inserter
(
new_lens
));
new_lens
[
axis
]
=
new_dim_axis
;
return
shape
::
from_permutation
(
type
,
new_lens
,
find_permutation
(
inputs
));
}
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
{
argument
result
{
out
put_shape
};
argument
result
{
dyn_out
.
com
put
ed
_shape
};
std
::
vector
<
std
::
size_t
>
coffsets
=
compute_offsets
(
out
put_shape
,
args
);
std
::
vector
<
std
::
size_t
>
coffsets
=
compute_offsets
(
dyn_out
.
com
put
ed
_shape
,
args
);
for
(
std
::
size_t
l
=
0
;
l
<
args
.
size
();
l
++
)
for
(
std
::
size_t
l
=
0
;
l
<
args
.
size
();
l
++
)
{
{
auto
argl
=
args
[
l
];
auto
argl
=
args
[
l
];
visit_all
(
result
,
argl
)([
&
](
auto
output
,
auto
input
)
{
visit_all
(
result
,
argl
)([
&
](
auto
output
,
auto
input
)
{
auto
slice_shape
=
auto
slice_shape
=
shape
{
dyn_out
.
computed_shape
.
type
(),
shape
{
output_shape
.
type
(),
input
.
get_shape
().
lens
(),
output_shape
.
strides
()};
input
.
get_shape
().
lens
(),
auto
slice
=
make_view
(
slice_shape
,
output
.
data
()
+
coffsets
[
l
]);
dyn_out
.
computed_shape
.
strides
()};
auto
slice
=
make_view
(
slice_shape
,
output
.
data
()
+
coffsets
[
l
]);
std
::
copy
(
input
.
begin
(),
input
.
end
(),
slice
.
begin
());
std
::
copy
(
input
.
begin
(),
input
.
end
(),
slice
.
begin
());
});
});
}
}
...
...
src/include/migraphx/op/dot.hpp
View file @
9bff4331
...
@@ -28,6 +28,7 @@
...
@@ -28,6 +28,7 @@
#include <migraphx/argument.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/gemm.hpp>
#include <migraphx/gemm.hpp>
#include <migraphx/dyn_output.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -38,41 +39,69 @@ struct dot
...
@@ -38,41 +39,69 @@ struct dot
std
::
string
name
()
const
{
return
"dot"
;
}
std
::
string
name
()
const
{
return
"dot"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
same_type
().
has
(
2
);
check_shapes
{
inputs
,
*
this
,
true
}.
same_type
().
same_ndims
().
has
(
2
);
const
shape
&
a
=
inputs
.
at
(
0
);
const
shape
&
a
=
inputs
.
at
(
0
);
const
shape
&
b
=
inputs
.
at
(
1
);
const
shape
&
b
=
inputs
.
at
(
1
);
auto
t
=
a
.
type
();
auto
t
=
a
.
type
();
if
(
not
std
::
all_of
(
if
(
not
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[](
auto
s
)
{
return
s
.
ndim
()
>=
2
;
}))
inputs
.
begin
(),
inputs
.
end
(),
[](
auto
s
)
{
return
s
.
lens
().
size
()
>=
2
;
}))
{
{
MIGRAPHX_THROW
(
"DOT: dot only accept 2 or more dim
s operands
"
);
MIGRAPHX_THROW
(
"DOT: dot only accept
s operands with
2 or more dim
ensions
"
);
}
}
if
(
a
.
dynamic
()
or
b
.
dynamic
())
// only handle the case that the batch size of a and b are the same
if
(
not
std
::
equal
(
a
.
lens
().
rbegin
()
+
2
,
a
.
lens
().
rend
(),
b
.
lens
().
rbegin
()
+
2
,
b
.
lens
().
rend
()))
{
{
MIGRAPHX_THROW
(
"DOT: batch size of A and B mismatch: {"
+
to_string_range
(
a
.
lens
())
+
auto
s0
=
a
.
to_dynamic
();
"} x {"
+
to_string_range
(
b
.
lens
())
+
"}"
);
auto
s1
=
b
.
to_dynamic
();
if
(
not
std
::
equal
(
s0
.
dyn_dims
().
rbegin
()
+
2
,
s0
.
dyn_dims
().
rend
(),
s1
.
dyn_dims
().
rbegin
()
+
2
,
s1
.
dyn_dims
().
rend
()))
{
MIGRAPHX_THROW
(
"DOT: dynamic outer dimensions of A and B mismatch: {"
+
to_string_range
(
s0
.
dyn_dims
())
+
"} x {"
+
to_string_range
(
s1
.
dyn_dims
())
+
"}"
);
}
std
::
size_t
dim_0
=
s0
.
ndim
()
-
2
;
std
::
size_t
dim_1
=
s0
.
ndim
()
-
1
;
if
(
s0
.
dyn_dims
()[
dim_1
]
!=
s1
.
dyn_dims
()[
dim_0
])
{
MIGRAPHX_THROW
(
"DOT: dynamic inner dimensions do not match: {"
+
to_string_range
(
s0
.
dyn_dims
())
+
"} x {"
+
to_string_range
(
s1
.
dyn_dims
())
+
"}"
);
}
auto
out_dyn_dims
=
s0
.
dyn_dims
();
out_dyn_dims
[
dim_1
]
=
s1
.
dyn_dims
()[
dim_1
];
return
{
t
,
out_dyn_dims
};
}
}
else
std
::
size_t
dim_0
=
a
.
lens
().
size
()
-
2
;
std
::
size_t
dim_1
=
a
.
lens
().
size
()
-
1
;
if
(
a
.
lens
()[
dim_1
]
!=
b
.
lens
()[
dim_0
])
{
{
MIGRAPHX_THROW
(
"DOT: inner dimensions do not match: {"
+
to_string_range
(
a
.
lens
())
+
// only handle the case that all the dimensions except the last two are the same
"} x {"
+
to_string_range
(
b
.
lens
())
+
"}"
);
if
(
not
std
::
equal
(
}
a
.
lens
().
rbegin
()
+
2
,
a
.
lens
().
rend
(),
b
.
lens
().
rbegin
()
+
2
,
b
.
lens
().
rend
()))
{
MIGRAPHX_THROW
(
"DOT: static outer dimensions of A and B mismatch: {"
+
to_string_range
(
a
.
lens
())
+
"} x {"
+
to_string_range
(
b
.
lens
())
+
"}"
);
}
auto
out_lens
=
a
.
lens
();
std
::
size_t
dim_0
=
a
.
ndim
()
-
2
;
out_lens
[
dim_1
]
=
b
.
lens
()[
dim_1
];
std
::
size_t
dim_1
=
a
.
ndim
()
-
1
;
return
{
t
,
out_lens
};
if
(
a
.
lens
()[
dim_1
]
!=
b
.
lens
()[
dim_0
])
{
MIGRAPHX_THROW
(
"DOT: static inner dimensions do not match: {"
+
to_string_range
(
a
.
lens
())
+
"} x {"
+
to_string_range
(
b
.
lens
())
+
"}"
);
}
auto
out_lens
=
a
.
lens
();
out_lens
[
dim_1
]
=
b
.
lens
()[
dim_1
];
return
{
t
,
out_lens
};
}
}
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
{
argument
result
=
argument
{
out
put_shape
};
argument
result
=
argument
{
dyn_out
.
com
put
ed
_shape
};
visit_all
(
result
,
args
[
0
],
args
[
1
])(
visit_all
(
result
,
args
[
0
],
args
[
1
])(
[
&
](
auto
cmat
,
auto
amat
,
auto
bmat
)
{
gemm
(
cmat
,
amat
,
bmat
,
1.0
f
,
0.0
f
);
});
[
&
](
auto
cmat
,
auto
amat
,
auto
bmat
)
{
gemm
(
cmat
,
amat
,
bmat
,
1.0
f
,
0.0
f
);
});
return
result
;
return
result
;
...
...
src/include/migraphx/op/flatten.hpp
View file @
9bff4331
...
@@ -55,17 +55,47 @@ struct flatten
...
@@ -55,17 +55,47 @@ struct flatten
std
::
string
name
()
const
{
return
"flatten"
;
}
std
::
string
name
()
const
{
return
"flatten"
;
}
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
).
standard
();
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
);
auto
&&
lens
=
inputs
.
front
().
lens
();
auto
s
=
inputs
[
0
];
auto
x
=
if
(
s
.
dynamic
())
std
::
accumulate
(
lens
.
begin
(),
lens
.
begin
()
+
axis
,
std
::
size_t
{
1
},
std
::
multiplies
<>
{});
{
auto
y
=
auto
min_lens
=
s
.
min_lens
();
std
::
accumulate
(
lens
.
begin
()
+
axis
,
lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<>
{});
auto
max_lens
=
s
.
max_lens
();
return
{
inputs
.
at
(
0
).
type
(),
{
x
,
y
}};
auto
opt_lens
=
s
.
opt_lens
();
// If any of the opt values is 0, output opt will be 0
shape
::
dynamic_dimension
x
=
{
std
::
accumulate
(
min_lens
.
begin
(),
min_lens
.
begin
()
+
axis
,
std
::
size_t
{
1
},
std
::
multiplies
<>
{}),
std
::
accumulate
(
max_lens
.
begin
(),
max_lens
.
begin
()
+
axis
,
std
::
size_t
{
1
},
std
::
multiplies
<>
{}),
std
::
accumulate
(
opt_lens
.
begin
(),
opt_lens
.
begin
()
+
axis
,
std
::
size_t
{
1
},
std
::
multiplies
<>
{})};
shape
::
dynamic_dimension
y
=
{
std
::
accumulate
(
min_lens
.
begin
()
+
axis
,
min_lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<>
{}),
std
::
accumulate
(
max_lens
.
begin
()
+
axis
,
max_lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<>
{}),
std
::
accumulate
(
opt_lens
.
begin
()
+
axis
,
opt_lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<>
{}),
};
return
{
s
.
type
(),
{
x
,
y
}};
}
else
{
check_shapes
{
inputs
,
*
this
}.
standard
();
auto
&&
lens
=
s
.
lens
();
auto
x
=
std
::
accumulate
(
lens
.
begin
(),
lens
.
begin
()
+
axis
,
std
::
size_t
{
1
},
std
::
multiplies
<>
{});
auto
y
=
std
::
accumulate
(
lens
.
begin
()
+
axis
,
lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<>
{});
return
{
s
.
type
(),
{
x
,
y
}};
}
}
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
{
return
args
[
0
].
reshape
(
out
put_shape
);
return
args
[
0
].
reshape
(
dyn_out
.
com
put
ed
_shape
);
}
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
};
...
...
src/include/migraphx/op/gather.hpp
View file @
9bff4331
...
@@ -26,6 +26,7 @@
...
@@ -26,6 +26,7 @@
#include <array>
#include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/literal.hpp>
...
@@ -61,35 +62,59 @@ struct gather
...
@@ -61,35 +62,59 @@ struct gather
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
2
);
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
2
);
auto
lens
=
inputs
[
0
].
lens
();
shape
data
=
inputs
[
0
];
auto
type
=
inputs
[
0
].
type
();
shape
indices
=
inputs
[
1
];
lens
.
erase
(
lens
.
begin
()
+
axis
);
auto
type
=
data
.
type
();
if
(
not
inputs
[
1
].
scalar
())
// If index_dims is dynamic, convert the data to dynamic too.
if
(
indices
.
dynamic
())
{
{
auto
ind_lens
=
inputs
[
1
].
lens
();
data
=
data
.
to_dynamic
();
lens
.
insert
(
lens
.
begin
()
+
axis
,
ind_lens
.
begin
(),
ind_lens
.
end
());
}
}
if
(
data
.
dynamic
())
// for scalar output
if
(
lens
.
empty
())
{
{
return
{
type
};
auto
dims
=
data
.
dyn_dims
();
dims
.
erase
(
dims
.
begin
()
+
axis
);
if
(
not
indices
.
scalar
())
{
auto
index_dims
=
indices
.
to_dynamic
().
dyn_dims
();
dims
.
insert
(
dims
.
begin
()
+
axis
,
index_dims
.
begin
(),
index_dims
.
end
());
}
return
{
type
,
dims
};
}
}
else
{
// Both data and indices are static. indices may be scalar
auto
lens
=
data
.
lens
();
lens
.
erase
(
lens
.
begin
()
+
axis
);
return
{
type
,
lens
};
if
(
not
indices
.
scalar
())
{
auto
ind_lens
=
indices
.
lens
();
lens
.
insert
(
lens
.
begin
()
+
axis
,
ind_lens
.
begin
(),
ind_lens
.
end
());
}
// for scalar output
if
(
lens
.
empty
())
{
return
{
type
};
}
return
{
type
,
lens
};
}
}
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
{
argument
result
{
out
put_shape
};
argument
result
{
dyn_out
.
com
put
ed
_shape
};
// negative axis means counting dimensions from back
// negative axis means counting dimensions from back
auto
lens
=
args
[
0
].
get_shape
().
lens
();
auto
lens
=
args
[
0
].
get_shape
().
lens
();
std
::
size_t
axis_dim_size
=
lens
[
axis
];
std
::
size_t
axis_dim_size
=
lens
[
axis
];
// max dimension in axis
// max dimension in axis
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
data
)
{
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
data
)
{
args
[
1
].
visit
([
&
](
auto
indices
)
{
args
[
1
].
visit
([
&
](
auto
indices
)
{
if
(
out
put_shape
.
scalar
())
if
(
dyn_out
.
com
put
ed
_shape
.
scalar
())
{
{
auto
in_index
=
indices
.
front
();
auto
in_index
=
indices
.
front
();
in_index
=
(
in_index
<
0
)
?
in_index
+
axis_dim_size
:
in_index
;
in_index
=
(
in_index
<
0
)
?
in_index
+
axis_dim_size
:
in_index
;
...
...
src/include/migraphx/op/gathernd.hpp
View file @
9bff4331
/*
/*
* The MIT License (MIT)
* The MIT License (MIT)
*
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* of this software and associated documentation files (the "Software"), to deal
...
@@ -25,6 +25,7 @@
...
@@ -25,6 +25,7 @@
#define MIGRAPHX_GUARD_OPERATORS_GATHERND_HPP
#define MIGRAPHX_GUARD_OPERATORS_GATHERND_HPP
#include <migraphx/check_shapes.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/argument.hpp>
...
@@ -47,33 +48,103 @@ struct gathernd
...
@@ -47,33 +48,103 @@ struct gathernd
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
2
);
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
2
);
auto
r
=
inputs
.
front
().
lens
().
size
();
auto
i_shape
=
inputs
.
back
();
auto
q
=
inputs
.
back
().
lens
().
size
();
auto
data_shape
=
inputs
.
front
();
auto
k
=
inputs
.
back
().
lens
().
back
();
auto
r
=
data_shape
.
ndim
();
auto
q
=
i_shape
.
ndim
();
size_t
k
;
if
(
i_shape
.
dynamic
())
{
// the rank of the output is a function of k, so it must be fixed.
if
(
not
i_shape
.
dyn_dims
().
back
().
is_fixed
())
{
MIGRAPHX_THROW
(
"GATHERND: last dimension of indices tensor must be fixed (min=max)"
);
}
k
=
i_shape
.
dyn_dims
().
back
().
min
;
}
else
k
=
i_shape
.
lens
().
back
();
// Begin input validation checks.
int
output_ndim
=
int
(
q
)
+
r
-
k
-
batch_dims
-
1
;
if
(
k
>
r
-
batch_dims
)
if
(
k
>
r
-
batch_dims
)
{
{
MIGRAPHX_THROW
(
"GATHERND: Indices of length "
+
std
::
to_string
(
k
)
+
MIGRAPHX_THROW
(
"GATHERND: Indices of length "
+
std
::
to_string
(
k
)
+
" cannot be used to access data of rank "
+
" cannot be used to access data of rank "
+
std
::
to_string
(
r
-
batch_dims
));
std
::
to_string
(
r
-
batch_dims
));
}
}
auto
indices_lens_iter
=
inputs
.
back
().
lens
().
begin
();
auto
output_lens_size
=
q
+
r
-
k
-
batch_dims
-
1
;
if
(
batch_dims
>=
q
or
batch_dims
>=
r
)
std
::
vector
<
std
::
size_t
>
output_lens
(
output_lens_size
);
{
std
::
copy
(
indices_lens_iter
,
indices_lens_iter
+
(
q
-
1
),
output_lens
.
begin
());
MIGRAPHX_THROW
(
"GATHERND: rank of an input cannot be less than batch_dims="
+
if
(
k
<
r
-
batch_dims
)
std
::
to_string
(
batch_dims
));
}
if
(
output_ndim
<
0
)
{
MIGRAPHX_THROW
(
"GATHERND: Indices too large for static data input: k="
+
std
::
to_string
(
k
));
}
if
(
migraphx
::
none_of
(
inputs
,
[](
auto
v
)
{
return
v
.
dynamic
();
}))
{
auto
indices_lens_iter
=
i_shape
.
lens
().
begin
();
// A rank 0 output is a scalar
if
(
output_ndim
==
0
)
return
shape
{
data_shape
.
type
(),
{
1
}};
// Part of the output shape comes from indices tensor, part from data tensor
std
::
vector
<
std
::
size_t
>
output_lens
(
output_ndim
);
std
::
copy
(
indices_lens_iter
,
indices_lens_iter
+
(
q
-
1
),
output_lens
.
begin
());
// fill the rest of output shape from data tensor
if
(
k
+
batch_dims
<
r
)
{
auto
data_lens
=
data_shape
.
lens
();
std
::
copy
(
data_lens
.
begin
()
+
batch_dims
+
k
,
data_lens
.
end
(),
output_lens
.
begin
()
+
q
-
1
);
}
shape
output_shape
{
data_shape
.
type
(),
output_lens
};
return
output_shape
;
}
else
{
{
auto
data_lens
=
inputs
.
front
().
lens
();
// If one or both inputs are dynamic shapes, the output is dynamic.
std
::
copy
(
// Make both inputs dynamic to simplify computations.
data_lens
.
begin
()
+
batch_dims
+
k
,
data_lens
.
end
(),
output_lens
.
begin
()
+
q
-
1
);
data_shape
=
data_shape
.
to_dynamic
();
i_shape
=
i_shape
.
to_dynamic
();
// A rank 0 output is a scalar
if
(
output_ndim
==
0
)
return
shape
(
data_shape
.
type
(),
{
shape
::
dynamic_dimension
({
1
,
1
,
0
})});
// Part of the output shape comes from indices tensor, part from data tensor
std
::
vector
<
shape
::
dynamic_dimension
>
output_dims
(
output_ndim
);
std
::
copy
(
i_shape
.
dyn_dims
().
begin
(),
i_shape
.
dyn_dims
().
begin
()
+
q
-
1
,
output_dims
.
begin
());
// fill the rest of output shape from data tensor
if
(
k
+
batch_dims
<
r
)
{
auto
data_dims
=
data_shape
.
dyn_dims
();
std
::
copy
(
data_dims
.
begin
()
+
batch_dims
+
k
,
data_dims
.
begin
()
+
r
,
output_dims
.
begin
()
+
q
-
1
);
}
shape
output_shape
(
data_shape
.
type
(),
output_dims
);
return
output_shape
;
}
}
shape
output_shape
{
inputs
.
front
().
type
(),
output_lens
};
return
output_shape
;
}
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
{
argument
result
{
out
put_shape
};
argument
result
{
dyn_out
.
com
put
ed
_shape
};
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
data
)
{
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
data
)
{
args
[
1
].
visit
([
&
](
auto
indices
)
{
args
[
1
].
visit
([
&
](
auto
indices
)
{
auto
indices_shape
=
indices
.
get_shape
();
auto
indices_shape
=
indices
.
get_shape
();
...
...
src/include/migraphx/op/nonmaxsuppression.hpp
View file @
9bff4331
...
@@ -143,16 +143,22 @@ struct nonmaxsuppression
...
@@ -143,16 +143,22 @@ struct nonmaxsuppression
void
sort
()
void
sort
()
{
{
std
::
sort
(
x
.
begin
(),
x
.
end
());
if
(
x
[
0
]
>
x
[
1
])
std
::
sort
(
y
.
begin
(),
y
.
end
());
{
std
::
swap
(
x
[
0
],
x
[
1
]);
}
if
(
y
[
0
]
>
y
[
1
])
{
std
::
swap
(
y
[
0
],
y
[
1
]);
}
}
}
std
::
array
<
double
,
2
>&
operator
[](
std
::
size_t
i
)
{
return
i
==
0
?
x
:
y
;
}
std
::
array
<
double
,
2
>&
operator
[](
std
::
size_t
i
)
{
return
i
==
0
?
x
:
y
;
}
double
area
()
const
double
area
()
const
{
{
assert
(
std
::
is_sorted
(
x
.
begin
(),
x
.
end
())
);
assert
(
x
[
0
]
<=
x
[
1
]
);
assert
(
std
::
is_sorted
(
y
.
begin
(),
y
.
end
())
);
assert
(
y
[
0
]
<=
y
[
1
]
);
return
(
x
[
1
]
-
x
[
0
])
*
(
y
[
1
]
-
y
[
0
]);
return
(
x
[
1
]
-
x
[
0
])
*
(
y
[
1
]
-
y
[
0
]);
}
}
};
};
...
@@ -190,14 +196,10 @@ struct nonmaxsuppression
...
@@ -190,14 +196,10 @@ struct nonmaxsuppression
{
{
intersection
[
i
][
0
]
=
std
::
max
(
b1
[
i
][
0
],
b2
[
i
][
0
]);
intersection
[
i
][
0
]
=
std
::
max
(
b1
[
i
][
0
],
b2
[
i
][
0
]);
intersection
[
i
][
1
]
=
std
::
min
(
b1
[
i
][
1
],
b2
[
i
][
1
]);
intersection
[
i
][
1
]
=
std
::
min
(
b1
[
i
][
1
],
b2
[
i
][
1
]);
}
if
(
intersection
[
i
][
0
]
>
intersection
[
i
][
1
])
{
std
::
vector
<
std
::
array
<
double
,
2
>>
bbox
=
{
intersection
.
x
,
intersection
.
y
};
return
false
;
if
(
std
::
any_of
(
bbox
.
begin
(),
bbox
.
end
(),
[](
auto
bx
)
{
}
return
not
std
::
is_sorted
(
bx
.
begin
(),
bx
.
end
());
}))
{
return
false
;
}
}
const
double
area1
=
b1
.
area
();
const
double
area1
=
b1
.
area
();
...
@@ -265,31 +267,31 @@ struct nonmaxsuppression
...
@@ -265,31 +267,31 @@ struct nonmaxsuppression
auto
batch_boxes_start
=
boxes
.
begin
()
+
batch_idx
*
num_boxes
*
4
;
auto
batch_boxes_start
=
boxes
.
begin
()
+
batch_idx
*
num_boxes
*
4
;
auto
boxes_heap
=
filter_boxes_by_score
(
scores_start
,
num_boxes
,
score_threshold
);
auto
boxes_heap
=
filter_boxes_by_score
(
scores_start
,
num_boxes
,
score_threshold
);
selected_boxes_inside_class
.
clear
();
selected_boxes_inside_class
.
clear
();
// Get the next box with top score, filter by iou_threshold
while
(
not
boxes_heap
.
empty
()
&&
while
(
not
boxes_heap
.
empty
()
&&
selected_boxes_inside_class
.
size
()
<
max_output_boxes_per_class
)
selected_boxes_inside_class
.
size
()
<
max_output_boxes_per_class
)
{
{
//
Check with existing selected boxes for this class, remove box if it
//
select next top scorer box and remove any boxes from boxes_heap that exceeds IOU
//
exceeds the IOU (Intersection Over Union) threshold
//
threshold with the selected box
const
auto
next_top_score
=
boxes_heap
.
top
();
const
auto
next_top_score
=
boxes_heap
.
top
();
bool
not_selected
=
boxes_heap
.
pop
();
std
::
any_of
(
selected_boxes_inside_class
.
begin
(),
selected_boxes_inside_class
.
push_back
(
next_top_score
);
selected_boxes_inside_class
.
end
(),
selected_indices
.
push_back
(
batch_idx
);
[
&
](
auto
selected_index
)
{
selected_indices
.
push_back
(
class_idx
);
return
this
->
suppress_by_iou
(
selected_indices
.
push_back
(
next_top_score
.
second
);
batch_box
(
batch_boxes_start
,
next_top_score
.
second
),
std
::
priority_queue
<
std
::
pair
<
double
,
int64_t
>>
remainder_boxes
;
batch_box
(
batch_boxes_start
,
selected_index
.
second
),
while
(
not
boxes_heap
.
empty
())
iou_threshold
);
});
if
(
not
not_selected
)
{
{
selected_boxes_inside_class
.
push_back
(
next_top_score
);
auto
iou_candidate_box
=
boxes_heap
.
top
();
selected_indices
.
push_back
(
batch_idx
);
if
(
not
this
->
suppress_by_iou
(
selected_indices
.
push_back
(
class_idx
);
batch_box
(
batch_boxes_start
,
iou_candidate_box
.
second
),
selected_indices
.
push_back
(
next_top_score
.
second
);
batch_box
(
batch_boxes_start
,
next_top_score
.
second
),
iou_threshold
))
{
remainder_boxes
.
push
(
iou_candidate_box
);
}
boxes_heap
.
pop
();
}
}
boxes_heap
.
pop
()
;
boxes_heap
=
remainder_boxes
;
}
}
});
});
std
::
copy
(
selected_indices
.
begin
(),
selected_indices
.
end
(),
output
.
begin
());
std
::
copy
(
selected_indices
.
begin
(),
selected_indices
.
end
(),
output
.
begin
());
...
...
src/include/migraphx/op/normalize_attribute.hpp
View file @
9bff4331
...
@@ -31,18 +31,30 @@ namespace migraphx {
...
@@ -31,18 +31,30 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
namespace
op
{
// different attributes
/**
// 1) use_input(default)/use_output
* `normalize_attribute` settings:
// 2) use_rank(default)/use_len
* Note that default options are not included as enums.
// 3) clip_min(default)/not_clip_min
* 1. `use_input` (default) vs. `use_output`:
// 3.1) include_min(default)/exclude_min
* Affects the rank of the attribute.
// 4) clip_max(default)/not_clip_max
* `use_input -> lens.size()`, `use_output -> lens.size() + vec.size()`.
// 4.1) exclude_max(default)/include_max
* 2. use_rank (default) vs use_len:
// 5) normalize padding
* `use_rank` sets the max value/index of the attribute as the rank of lens.
* `use_lens` sets the max value/index as the corresponding value in lens at the axes index.
* 3. `clip_min` vs. `not_clip_min` (default):
* Clip values less than the minimum to the minimum or not.
* 4. `include_min` vs. `exclude_min` (default):
* Include or exclude the minimum value/index for range checking and clipping.
* 5. `clip_max` vs. `not_clip_max` (default):
* Clip values greater than the maximum or not.
* 6. `include_max` vs. `exclude_max` (default):
* Include or exclude the maximum value/index for range checking and clipping.
* 7. `normalize_padding`:
* To normalize the padding to `2*(pad ndim)` dimensions.
*/
enum
class
normalize_attribute
enum
class
normalize_attribute
{
{
use_len
,
use_output
,
use_output
,
use_len
,
clip_max
,
clip_max
,
clip_min
,
clip_min
,
include_max
,
include_max
,
...
...
src/include/migraphx/op/pad.hpp
View file @
9bff4331
...
@@ -59,18 +59,29 @@ struct pad
...
@@ -59,18 +59,29 @@ struct pad
std
::
string
name
()
const
{
return
"pad"
;
}
std
::
string
name
()
const
{
return
"pad"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
);
auto
&&
idims
=
inputs
.
front
().
lens
();
const
auto
&
s0
=
inputs
.
front
();
std
::
vector
<
std
::
size_t
>
rdims
(
idims
.
begin
(),
idims
.
end
());
if
(
s0
.
dynamic
())
std
::
size_t
num_dims
=
rdims
.
size
();
for
(
std
::
size_t
i
=
0
;
i
<
num_dims
;
i
++
)
{
{
rdims
[
i
]
+=
pads
[
i
]
+
pads
[
i
+
num_dims
];
auto
out_dyn_dims
=
s0
.
dyn_dims
();
for
(
std
::
size_t
i
=
0
;
i
<
s0
.
ndim
();
++
i
)
{
out_dyn_dims
[
i
]
+=
pads
[
i
]
+
pads
[
i
+
s0
.
ndim
()];
}
return
{
s0
.
type
(),
out_dyn_dims
};
}
else
{
auto
&&
idims
=
s0
.
lens
();
std
::
vector
<
std
::
size_t
>
rdims
(
idims
.
begin
(),
idims
.
end
());
std
::
size_t
num_dims
=
rdims
.
size
();
for
(
std
::
size_t
i
=
0
;
i
<
num_dims
;
i
++
)
{
rdims
[
i
]
+=
pads
[
i
]
+
pads
[
i
+
num_dims
];
}
shape
s
{
s0
.
type
(),
rdims
};
return
s
;
}
}
shape
s
{
inputs
.
front
().
type
(),
rdims
};
return
s
;
}
}
std
::
size_t
pad_ndims
()
const
std
::
size_t
pad_ndims
()
const
...
...
src/include/migraphx/op/reduce_op.hpp
View file @
9bff4331
...
@@ -26,6 +26,7 @@
...
@@ -26,6 +26,7 @@
#include <migraphx/op/name.hpp>
#include <migraphx/op/name.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/tensor_view.hpp>
#include <migraphx/tensor_view.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/shape_for_each.hpp>
...
@@ -105,18 +106,41 @@ struct reduce_op : op_name<Derived>
...
@@ -105,18 +106,41 @@ struct reduce_op : op_name<Derived>
return
tuned_axes
;
return
tuned_axes
;
}
}
/**
* @brief returns a shape in which the axis or axes named
* for reduction by this op are set, to size 1.
*
* @param inputs list of input shapes
* @return shape
*/
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
);
auto
s
=
inputs
.
at
(
0
);
auto
s
=
inputs
.
at
(
0
);
auto
lens
=
s
.
lens
();
if
(
s
.
dynamic
())
auto
tuned_axes
=
tune_axes
(
lens
.
size
());
for
(
auto
axis
:
tuned_axes
)
{
{
lens
[
axis
]
=
1
;
auto
output_dyn_dims
=
s
.
dyn_dims
();
auto
tuned_axes
=
tune_axes
(
output_dyn_dims
.
size
());
for
(
const
auto
&
axis
:
tuned_axes
)
{
// At the time of writing, there's no functional difference between
// optimum of 0 (no opt) or 1.
output_dyn_dims
[
axis
]
=
{
1
,
1
,
0
};
}
return
shape
{
s
.
type
(),
output_dyn_dims
};
}
else
{
auto
lens
=
s
.
lens
();
auto
tuned_axes
=
tune_axes
(
lens
.
size
());
for
(
const
auto
&
axis
:
tuned_axes
)
{
lens
[
axis
]
=
1
;
}
return
inputs
[
0
].
with_lens
(
lens
);
}
}
return
inputs
[
0
].
with_lens
(
lens
);
}
}
template
<
class
T
>
template
<
class
T
>
...
@@ -124,7 +148,7 @@ struct reduce_op : op_name<Derived>
...
@@ -124,7 +148,7 @@ struct reduce_op : op_name<Derived>
const
std
::
vector
<
T
>&
in_lens
,
const
std
::
vector
<
T
>&
in_lens
,
std
::
vector
<
T
>&
out_lens
)
const
std
::
vector
<
T
>&
out_lens
)
const
{
{
for
(
auto
axis
:
tuned_axes
)
for
(
const
auto
&
axis
:
tuned_axes
)
{
{
out_lens
[
axis
]
=
in_lens
[
axis
];
out_lens
[
axis
]
=
in_lens
[
axis
];
}
}
...
@@ -151,17 +175,17 @@ struct reduce_op : op_name<Derived>
...
@@ -151,17 +175,17 @@ struct reduce_op : op_name<Derived>
static_cast
<
const
Derived
&>
(
*
this
).
output
(
batch_shape
)(
val
);
static_cast
<
const
Derived
&>
(
*
this
).
output
(
batch_shape
)(
val
);
}
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
{
argument
result
{
out
put_shape
};
argument
result
{
dyn_out
.
com
put
ed
_shape
};
auto
arg_lens
=
args
.
front
().
get_shape
().
lens
();
auto
arg_lens
=
args
.
front
().
get_shape
().
lens
();
auto
tuned_axes
=
tune_axes
(
arg_lens
.
size
());
auto
tuned_axes
=
tune_axes
(
arg_lens
.
size
());
std
::
vector
<
std
::
size_t
>
batch_lens
(
out
put_shape
.
lens
().
size
(),
1
);
std
::
vector
<
std
::
size_t
>
batch_lens
(
dyn_out
.
com
put
ed
_shape
.
lens
().
size
(),
1
);
tune_dims
(
tuned_axes
,
arg_lens
,
batch_lens
);
tune_dims
(
tuned_axes
,
arg_lens
,
batch_lens
);
shape
batch_shape
{
out
put_shape
.
type
(),
batch_lens
};
shape
batch_shape
{
dyn_out
.
com
put
ed
_shape
.
type
(),
batch_lens
};
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
input
)
{
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
input
)
{
par_for
(
out
put_shape
.
elements
(),
[
&
](
auto
i
)
{
par_for
(
dyn_out
.
com
put
ed
_shape
.
elements
(),
[
&
](
auto
i
)
{
auto
out_idx
=
out
put_shape
.
multi
(
i
);
auto
out_idx
=
dyn_out
.
com
put
ed
_shape
.
multi
(
i
);
this
->
reduce
(
input
,
batch_shape
,
tuned_axes
,
out_idx
,
output
);
this
->
reduce
(
input
,
batch_shape
,
tuned_axes
,
out_idx
,
output
);
});
});
});
});
...
...
src/include/migraphx/op/reshape.hpp
View file @
9bff4331
...
@@ -28,6 +28,7 @@
...
@@ -28,6 +28,7 @@
#include <migraphx/argument.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/value.hpp>
#include <migraphx/dyn_output.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -46,14 +47,60 @@ struct reshape
...
@@ -46,14 +47,60 @@ struct reshape
value
attributes
()
const
{
return
{{
"require_std_shape"
,
true
}};
}
value
attributes
()
const
{
return
{{
"require_std_shape"
,
true
}};
}
std
::
string
name
()
const
{
return
"reshape"
;
}
std
::
string
name
()
const
{
return
"reshape"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
dyn_compute_shape
(
shape
s0
)
const
{
auto
dyn_dims
=
s0
.
dyn_dims
();
auto
num_not_fixed
=
std
::
count_if
(
dyn_dims
.
cbegin
(),
dyn_dims
.
cend
(),
[](
auto
dd
)
{
return
not
dd
.
is_fixed
();
});
if
(
num_not_fixed
!=
1
)
{
MIGRAPHX_THROW
(
"Reshape: Only supports one non-fixed dynamic_dimension"
);
}
// track number of fixed elements in input and output
std
::
size_t
num_dims_ele
=
1
;
std
::
size_t
num_dd_ele
=
1
;
for
(
std
::
size_t
i
=
0
;
i
<
dyn_dims
.
size
();
++
i
)
{
if
(
dyn_dims
[
i
].
is_fixed
())
{
num_dims_ele
*=
dims
[
i
];
num_dd_ele
*=
dyn_dims
[
i
].
min
;
}
else
{
if
(
dims
[
i
]
!=
0
and
dims
[
i
]
!=
-
1
)
{
MIGRAPHX_THROW
(
"Reshape: Non-fixed dynamic_dimension doesn't match with 0 or -1 "
"output dimension"
);
}
}
}
if
(
num_dims_ele
!=
num_dd_ele
)
{
MIGRAPHX_THROW
(
"Reshape: Number of fixed elements must match. Input: "
+
std
::
to_string
(
num_dd_ele
)
+
" Output: "
+
std
::
to_string
(
num_dims_ele
));
}
// construct output dynamic shape from dims attribute
std
::
vector
<
shape
::
dynamic_dimension
>
output_dyn_dims
(
dims
.
size
());
std
::
transform
(
dims
.
cbegin
(),
dims
.
cend
(),
dyn_dims
.
cbegin
(),
output_dyn_dims
.
begin
(),
[](
std
::
size_t
dim
,
auto
dyn_dim
)
{
if
(
not
dyn_dim
.
is_fixed
())
return
dyn_dim
;
return
shape
::
dynamic_dimension
{
dim
,
dim
};
});
return
{
s0
.
type
(),
output_dyn_dims
};
}
shape
static_compute_shape
(
std
::
vector
<
shape
>
inputs
,
std
::
size_t
n_neg_dims
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
).
standard
();
check_shapes
{
inputs
,
*
this
}.
standard
();
auto
&&
idims
=
inputs
.
front
().
lens
();
auto
&&
idims
=
inputs
.
front
().
lens
();
std
::
vector
<
std
::
size_t
>
rdims
(
dims
.
begin
(),
dims
.
end
());
std
::
vector
<
std
::
size_t
>
rdims
(
dims
.
begin
(),
dims
.
end
());
auto
n_neg_dims
=
std
::
count
(
dims
.
begin
(),
dims
.
end
(),
-
1
);
if
(
n_neg_dims
>
1
)
MIGRAPHX_THROW
(
"Reshape: Dimensions for reshape can only have one -1 dim"
);
for
(
std
::
size_t
i
=
0
;
i
<
dims
.
size
();
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
dims
.
size
();
i
++
)
{
{
...
@@ -86,9 +133,26 @@ struct reshape
...
@@ -86,9 +133,26 @@ struct reshape
return
s
;
return
s
;
}
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
);
auto
n_neg_dims
=
std
::
count
(
dims
.
begin
(),
dims
.
end
(),
-
1
);
if
(
n_neg_dims
>
1
)
MIGRAPHX_THROW
(
"Reshape: Dimensions for reshape can only have one -1 dim"
);
auto
s0
=
inputs
[
0
];
if
(
s0
.
dynamic
())
{
return
dyn_compute_shape
(
s0
);
}
else
{
return
static_compute_shape
(
inputs
,
n_neg_dims
);
}
}
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
{
return
args
[
0
].
reshape
(
out
put_shape
);
return
args
[
0
].
reshape
(
dyn_out
.
com
put
ed
_shape
);
}
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
...
...
Prev
1
2
3
4
5
6
7
…
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment