Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
7702c20d
"docs/vscode:/vscode.git/clone" did not exist on "e5464ee484450c2671dd0226516c99c60ce70d9d"
Commit
7702c20d
authored
Aug 19, 2022
by
Paul
Browse files
Merge
parents
c362e7fa
9afce86d
Changes
248
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
160 additions
and
103 deletions
+160
-103
src/include/migraphx/op/slice.hpp
src/include/migraphx/op/slice.hpp
+1
-5
src/include/migraphx/op/squeeze.hpp
src/include/migraphx/op/squeeze.hpp
+1
-7
src/include/migraphx/op/step.hpp
src/include/migraphx/op/step.hpp
+1
-6
src/include/migraphx/op/sub.hpp
src/include/migraphx/op/sub.hpp
+1
-8
src/include/migraphx/op/tan.hpp
src/include/migraphx/op/tan.hpp
+1
-8
src/include/migraphx/op/tanh.hpp
src/include/migraphx/op/tanh.hpp
+1
-8
src/include/migraphx/op/topk.hpp
src/include/migraphx/op/topk.hpp
+1
-0
src/include/migraphx/op/transpose.hpp
src/include/migraphx/op/transpose.hpp
+2
-5
src/include/migraphx/op/unary_not.hpp
src/include/migraphx/op/unary_not.hpp
+2
-3
src/include/migraphx/op/unknown.hpp
src/include/migraphx/op/unknown.hpp
+0
-1
src/include/migraphx/op/unsqueeze.hpp
src/include/migraphx/op/unsqueeze.hpp
+24
-14
src/include/migraphx/op/where.hpp
src/include/migraphx/op/where.hpp
+2
-9
src/include/migraphx/operation.hpp
src/include/migraphx/operation.hpp
+5
-3
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+2
-0
src/include/migraphx/pad_calc.hpp
src/include/migraphx/pad_calc.hpp
+23
-25
src/include/migraphx/par_dfor.hpp
src/include/migraphx/par_dfor.hpp
+1
-0
src/include/migraphx/permutation.hpp
src/include/migraphx/permutation.hpp
+6
-0
src/include/migraphx/program.hpp
src/include/migraphx/program.hpp
+5
-0
src/include/migraphx/ranges.hpp
src/include/migraphx/ranges.hpp
+6
-0
src/include/migraphx/shape.hpp
src/include/migraphx/shape.hpp
+75
-1
No files found.
src/include/migraphx/op/slice.hpp
View file @
7702c20d
...
@@ -25,14 +25,10 @@
...
@@ -25,14 +25,10 @@
#define MIGRAPHX_GUARD_OPERATORS_SLICE_HPP
#define MIGRAPHX_GUARD_OPERATORS_SLICE_HPP
#include <migraphx/check_shapes.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath>
#include <utility>
#include <vector>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
src/include/migraphx/op/squeeze.hpp
View file @
7702c20d
...
@@ -24,17 +24,11 @@
...
@@ -24,17 +24,11 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_SQUEEZE_HPP
#ifndef MIGRAPHX_GUARD_OPERATORS_SQUEEZE_HPP
#define MIGRAPHX_GUARD_OPERATORS_SQUEEZE_HPP
#define MIGRAPHX_GUARD_OPERATORS_SQUEEZE_HPP
#include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/lifetime.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
src/include/migraphx/op/step.hpp
View file @
7702c20d
...
@@ -24,16 +24,11 @@
...
@@ -24,16 +24,11 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_STEP_HPP
#ifndef MIGRAPHX_GUARD_OPERATORS_STEP_HPP
#define MIGRAPHX_GUARD_OPERATORS_STEP_HPP
#define MIGRAPHX_GUARD_OPERATORS_STEP_HPP
#include "migraphx/stringutils.hpp"
#include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/
lifetim
e.hpp>
#include <migraphx/
valu
e.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
src/include/migraphx/op/sub.hpp
View file @
7702c20d
...
@@ -24,16 +24,9 @@
...
@@ -24,16 +24,9 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_SUB_HPP
#ifndef MIGRAPHX_GUARD_OPERATORS_SUB_HPP
#define MIGRAPHX_GUARD_OPERATORS_SUB_HPP
#define MIGRAPHX_GUARD_OPERATORS_SUB_HPP
#include <array>
#include <migraphx/op/binary.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/op/binary.hpp>
#include <cmath>
#include <cmath>
#include <utility>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
src/include/migraphx/op/tan.hpp
View file @
7702c20d
...
@@ -24,16 +24,9 @@
...
@@ -24,16 +24,9 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_TAN_HPP
#ifndef MIGRAPHX_GUARD_OPERATORS_TAN_HPP
#define MIGRAPHX_GUARD_OPERATORS_TAN_HPP
#define MIGRAPHX_GUARD_OPERATORS_TAN_HPP
#include <array>
#include <migraphx/op/unary.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/op/unary.hpp>
#include <cmath>
#include <cmath>
#include <utility>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
src/include/migraphx/op/tanh.hpp
View file @
7702c20d
...
@@ -24,16 +24,9 @@
...
@@ -24,16 +24,9 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_TANH_HPP
#ifndef MIGRAPHX_GUARD_OPERATORS_TANH_HPP
#define MIGRAPHX_GUARD_OPERATORS_TANH_HPP
#define MIGRAPHX_GUARD_OPERATORS_TANH_HPP
#include <array>
#include <migraphx/op/unary.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/op/unary.hpp>
#include <cmath>
#include <cmath>
#include <utility>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
src/include/migraphx/op/topk.hpp
View file @
7702c20d
...
@@ -26,6 +26,7 @@
...
@@ -26,6 +26,7 @@
#include <algorithm>
#include <algorithm>
#include <migraphx/check_shapes.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/par_for.hpp>
...
...
src/include/migraphx/op/transpose.hpp
View file @
7702c20d
...
@@ -24,14 +24,11 @@
...
@@ -24,14 +24,11 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_TRANSPOSE_HPP
#ifndef MIGRAPHX_GUARD_OPERATORS_TRANSPOSE_HPP
#define MIGRAPHX_GUARD_OPERATORS_TRANSPOSE_HPP
#define MIGRAPHX_GUARD_OPERATORS_TRANSPOSE_HPP
#include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/lifetime.hpp>
#include <migraphx/value.hpp>
#include <cmath>
#include <migraphx/op/normalize_attribute.hpp>
#include <utility>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
src/include/migraphx/op/unary_not.hpp
View file @
7702c20d
...
@@ -24,10 +24,9 @@
...
@@ -24,10 +24,9 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_UNARY_NOT_HPP
#ifndef MIGRAPHX_GUARD_OPERATORS_UNARY_NOT_HPP
#define MIGRAPHX_GUARD_OPERATORS_UNARY_NOT_HPP
#define MIGRAPHX_GUARD_OPERATORS_UNARY_NOT_HPP
#include <migraphx/op/unary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/op/unary.hpp>
#include <cmath>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
src/include/migraphx/op/unknown.hpp
View file @
7702c20d
...
@@ -25,7 +25,6 @@
...
@@ -25,7 +25,6 @@
#define MIGRAPHX_GUARD_RTGLIB_UNKNOWN_HPP
#define MIGRAPHX_GUARD_RTGLIB_UNKNOWN_HPP
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/check_shapes.hpp>
namespace
migraphx
{
namespace
migraphx
{
...
...
src/include/migraphx/op/unsqueeze.hpp
View file @
7702c20d
...
@@ -24,16 +24,11 @@
...
@@ -24,16 +24,11 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_UNSQUEEZE_HPP
#ifndef MIGRAPHX_GUARD_OPERATORS_UNSQUEEZE_HPP
#define MIGRAPHX_GUARD_OPERATORS_UNSQUEEZE_HPP
#define MIGRAPHX_GUARD_OPERATORS_UNSQUEEZE_HPP
#include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/lifetime.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -42,11 +37,12 @@ namespace op {
...
@@ -42,11 +37,12 @@ namespace op {
struct
unsqueeze
struct
unsqueeze
{
{
std
::
vector
<
int64_t
>
axes
;
std
::
vector
<
int64_t
>
axes
;
std
::
vector
<
int64_t
>
steps
;
template
<
class
Self
,
class
F
>
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
static
auto
reflect
(
Self
&
self
,
F
f
)
{
{
return
pack
(
f
(
self
.
axes
,
"axes"
));
return
pack
(
f
(
self
.
axes
,
"axes"
)
,
f
(
self
.
steps
,
"steps"
)
);
}
}
value
attributes
()
const
value
attributes
()
const
...
@@ -73,6 +69,9 @@ struct unsqueeze
...
@@ -73,6 +69,9 @@ struct unsqueeze
MIGRAPHX_THROW
(
"UNSQUEEZE: Input must be a scalar"
);
MIGRAPHX_THROW
(
"UNSQUEEZE: Input must be a scalar"
);
}
}
if
(
steps
.
size
()
>
axes
.
size
())
MIGRAPHX_THROW
(
"UNSQUEEZE: Steps provided with no axis"
);
std
::
size_t
new_size
=
old_lens
.
size
()
+
axes
.
size
();
std
::
size_t
new_size
=
old_lens
.
size
()
+
axes
.
size
();
std
::
vector
<
std
::
size_t
>
new_lens
(
new_size
);
std
::
vector
<
std
::
size_t
>
new_lens
(
new_size
);
...
@@ -80,16 +79,27 @@ struct unsqueeze
...
@@ -80,16 +79,27 @@ struct unsqueeze
std
::
size_t
p
=
0
;
std
::
size_t
p
=
0
;
for
(
auto
i
:
range
(
new_size
))
for
(
auto
i
:
range
(
new_size
))
{
{
if
(
std
::
find
(
axes
.
begin
(),
axes
.
end
(),
i
)
!=
axes
.
end
())
auto
axis_idx
=
std
::
find
(
axes
.
begin
(),
axes
.
end
(),
i
)
-
axes
.
begin
();
if
(
axis_idx
<
axes
.
size
())
{
{
new_lens
[
i
]
=
1
;
std
::
int64_t
step
=
1
;
if
(
p
==
0
)
// unsqueeze on the first axes
if
(
axis_idx
<
steps
.
size
())
step
=
steps
[
axis_idx
];
if
(
step
==
0
)
MIGRAPHX_THROW
(
"UNSQUEEZE: step must be non-zero"
);
new_lens
[
i
]
=
step
;
if
(
p
<
old_strides
.
size
())
{
{
new_strides
[
i
]
=
old_lens
[
0
]
*
old_strides
[
0
];
if
((
old_lens
[
p
]
%
step
)
!=
0
)
MIGRAPHX_THROW
(
"UNSQUEEZE: Axis dimenstion is not divisible by step"
);
old_lens
[
p
]
/=
step
;
new_strides
[
i
]
=
old_strides
[
p
]
*
old_lens
[
p
];
}
}
else
// unsqueeze on middle or last axes
else
{
{
new_strides
[
i
]
=
(
p
<
old_strides
.
size
())
?
old_strides
[
p
-
1
]
:
1
;
if
(
step
!=
1
)
MIGRAPHX_THROW
(
"UNSQUEEZE: Step must be 1 for extra axes"
);
new_strides
[
i
]
=
1
;
}
}
}
}
else
else
...
...
src/include/migraphx/op/where.hpp
View file @
7702c20d
...
@@ -24,18 +24,11 @@
...
@@ -24,18 +24,11 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_WHERE_HPP
#ifndef MIGRAPHX_GUARD_OPERATORS_WHERE_HPP
#define MIGRAPHX_GUARD_OPERATORS_WHERE_HPP
#define MIGRAPHX_GUARD_OPERATORS_WHERE_HPP
#include <array>
#include <migraphx/argument.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/par_for.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
src/include/migraphx/operation.hpp
View file @
7702c20d
...
@@ -68,8 +68,10 @@ struct operation
...
@@ -68,8 +68,10 @@ struct operation
*
*
* @param ctx This is the context created by the `target` during compilation. Implementations
* @param ctx This is the context created by the `target` during compilation. Implementations
* can use the target's `context` class rather than the `context` interface class.
* can use the target's `context` class rather than the `context` interface class.
* @param output This is the output shape. It is equivalent to running `compute_shape` with each
* @param output Equivalent to running `compute_shape` with each `shape` of the `argument`.
* `shape` of the `argument`.
* For a fixed shape, the returned argument will have the same shape as `output`.
* For a dynamic shape, the returned `argument` will be a fixed shape within the bounds
* set in the dynamic shape `output`.
* @param input This is the `argument` result from the previous instruction's computation.
* @param input This is the `argument` result from the previous instruction's computation.
* @return Return an `argument` of the result computation. The `shape` of `argument` should be
* @return Return an `argument` of the result computation. The `shape` of `argument` should be
* the same the `output` shape.
* the same the `output` shape.
...
@@ -137,7 +139,7 @@ auto compute_shape_op(rank<2>, const T& x, const std::vector<shape>& inputs)
...
@@ -137,7 +139,7 @@ auto compute_shape_op(rank<2>, const T& x, const std::vector<shape>& inputs)
->
decltype
(
x
.
normalize_compute_shape
(
inputs
))
->
decltype
(
x
.
normalize_compute_shape
(
inputs
))
{
{
dependent_type
<
operation
,
T
>
y
=
x
;
dependent_type
<
operation
,
T
>
y
=
x
;
normalize_attributes
(
y
,
inputs
[
0
].
lens
());
normalize_attributes
(
y
,
inputs
[
0
].
max_
lens
());
return
any_cast
<
T
>
(
y
).
normalize_compute_shape
(
inputs
);
return
any_cast
<
T
>
(
y
).
normalize_compute_shape
(
inputs
);
}
}
...
...
src/include/migraphx/operators.hpp
View file @
7702c20d
...
@@ -57,6 +57,7 @@
...
@@ -57,6 +57,7 @@
#include <migraphx/op/exp.hpp>
#include <migraphx/op/exp.hpp>
#include <migraphx/op/flatten.hpp>
#include <migraphx/op/flatten.hpp>
#include <migraphx/op/floor.hpp>
#include <migraphx/op/floor.hpp>
#include <migraphx/op/fmod.hpp>
#include <migraphx/op/gather.hpp>
#include <migraphx/op/gather.hpp>
#include <migraphx/op/gathernd.hpp>
#include <migraphx/op/gathernd.hpp>
#include <migraphx/op/get_tuple_elem.hpp>
#include <migraphx/op/get_tuple_elem.hpp>
...
@@ -79,6 +80,7 @@
...
@@ -79,6 +80,7 @@
#include <migraphx/op/lstm.hpp>
#include <migraphx/op/lstm.hpp>
#include <migraphx/op/max.hpp>
#include <migraphx/op/max.hpp>
#include <migraphx/op/min.hpp>
#include <migraphx/op/min.hpp>
#include <migraphx/op/mod.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/op/multibroadcast.hpp>
#include <migraphx/op/multibroadcast.hpp>
#include <migraphx/op/neg.hpp>
#include <migraphx/op/neg.hpp>
...
...
src/include/migraphx/pad_calc.hpp
View file @
7702c20d
...
@@ -24,38 +24,36 @@
...
@@ -24,38 +24,36 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_PAD_CALC_HPP
#ifndef MIGRAPHX_GUARD_OPERATORS_PAD_CALC_HPP
#define MIGRAPHX_GUARD_OPERATORS_PAD_CALC_HPP
#define MIGRAPHX_GUARD_OPERATORS_PAD_CALC_HPP
#include <
utility
>
#include <
migraphx/config.hpp
>
#include <cstdint>
#include <cstdint>
#include <vector>
#include <vector>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
void
calculate_padding
(
int64_t
idx
,
void
calculate_padding
(
int64_t
idx
,
std
::
vector
<
int64_t
>&
pads
,
std
::
vector
<
int64_t
>&
pads
,
int64_t
input_dim
,
int64_t
input_dim
,
int64_t
stride
,
int64_t
stride
,
int64_t
dilation
,
int64_t
dilation
,
int64_t
weight_dim
,
int64_t
weight_dim
,
bool
is_same_upper
=
true
)
bool
is_same_upper
=
true
);
{
int64_t
output_dim
=
(
input_dim
+
stride
-
1
)
/
stride
;
// round up result
int64_t
new_weight_dim
=
weight_dim
+
(
weight_dim
-
1
)
*
(
dilation
-
1
);
int64_t
pad
=
std
::
max
(
static_cast
<
int64_t
>
(
0
),
(
output_dim
-
1
)
*
stride
+
new_weight_dim
-
input_dim
);
auto
pad_ndims
=
pads
.
size
()
/
2
;
if
(
is_same_upper
)
/*!
{
* Calculate the padding for auto_padding. Used for dynamic shapes
pads
[
idx
]
=
pad
/
2
;
* where the padding calculation must be done at evaluation time.
pads
[
idx
+
pad_ndims
]
=
pad
-
pad
/
2
;
* \param tensor_lens input tensor image shape
}
* \param k_lens weights kernel shape
else
* \param strides strides for the kernel
{
* \param dilations dilations for the kernel
pads
[
idx
+
pad_ndims
]
=
pad
/
2
;
* \param use_upper put odd padding on upper or lower side
pads
[
idx
]
=
pad
-
pad
/
2
;
* \return padding in the form of {x0_begin, x1_begin, ... x0_end , x1_end, ...}
}
*/
}
std
::
vector
<
std
::
size_t
>
calc_dyn_auto_pad
(
std
::
vector
<
std
::
size_t
>
tensor_lens
,
std
::
vector
<
std
::
size_t
>
k_lens
,
std
::
vector
<
std
::
size_t
>
strides
,
std
::
vector
<
std
::
size_t
>
dilations
,
bool
use_upper
=
true
);
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
...
...
src/include/migraphx/par_dfor.hpp
View file @
7702c20d
...
@@ -25,6 +25,7 @@
...
@@ -25,6 +25,7 @@
#define MIGRAPHX_GUARD_RTGLIB_PAR_DFOR_HPP
#define MIGRAPHX_GUARD_RTGLIB_PAR_DFOR_HPP
#include <migraphx/par_for.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/functional.hpp>
#include <array>
#include <array>
#include <numeric>
#include <numeric>
...
...
src/include/migraphx/permutation.hpp
View file @
7702c20d
...
@@ -55,8 +55,14 @@ inline std::vector<int64_t> sort_permutation(const Vector& data, Op op)
...
@@ -55,8 +55,14 @@ inline std::vector<int64_t> sort_permutation(const Vector& data, Op op)
return
result
;
return
result
;
}
}
/*!
* Returns the permutation needed to apply to the shape to undo the current permutation
*/
std
::
vector
<
int64_t
>
invert_permutation
(
const
std
::
vector
<
int64_t
>&
permutation
);
std
::
vector
<
int64_t
>
invert_permutation
(
const
std
::
vector
<
int64_t
>&
permutation
);
/*!
* Finds the permutation most likely from a transpose operator that has been applied to the shape.
*/
std
::
vector
<
int64_t
>
find_permutation
(
const
shape
&
s
);
std
::
vector
<
int64_t
>
find_permutation
(
const
shape
&
s
);
std
::
vector
<
int64_t
>
find_permutation
(
const
std
::
vector
<
shape
>&
shapes
);
std
::
vector
<
int64_t
>
find_permutation
(
const
std
::
vector
<
shape
>&
shapes
);
...
...
src/include/migraphx/program.hpp
View file @
7702c20d
...
@@ -33,6 +33,8 @@
...
@@ -33,6 +33,8 @@
#include <migraphx/instruction_ref.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/target.hpp>
#include <migraphx/target.hpp>
#include <migraphx/compile_options.hpp>
#include <migraphx/compile_options.hpp>
#include <migraphx/target_assignments.hpp>
#include <migraphx/assignment_options.hpp>
#include <migraphx/env.hpp>
#include <migraphx/env.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <algorithm>
#include <algorithm>
...
@@ -84,6 +86,9 @@ struct program
...
@@ -84,6 +86,9 @@ struct program
instruction_ref
validate
()
const
;
instruction_ref
validate
()
const
;
target_assignments
get_target_assignments
(
const
std
::
vector
<
target
>&
targets
,
assignment_options
options
=
assignment_options
{});
void
compile
(
const
target
&
t
,
compile_options
options
=
compile_options
{});
void
compile
(
const
target
&
t
,
compile_options
options
=
compile_options
{});
bool
is_compiled
()
const
;
bool
is_compiled
()
const
;
...
...
src/include/migraphx/ranges.hpp
View file @
7702c20d
...
@@ -198,6 +198,12 @@ void transform(Range&& r, Iterator it, F f)
...
@@ -198,6 +198,12 @@ void transform(Range&& r, Iterator it, F f)
std
::
transform
(
r
.
begin
(),
r
.
end
(),
it
,
f
);
std
::
transform
(
r
.
begin
(),
r
.
end
(),
it
,
f
);
}
}
template
<
class
Range1
,
class
Range2
,
class
Iterator
,
class
F
>
void
transform
(
Range1
&&
r1
,
Range2
&&
r2
,
Iterator
it
,
F
f
)
{
std
::
transform
(
r1
.
begin
(),
r1
.
end
(),
r2
.
begin
(),
it
,
f
);
}
template
<
class
Range
>
template
<
class
Range
>
auto
reverse
(
Range
&
r
)
auto
reverse
(
Range
&
r
)
{
{
...
...
src/include/migraphx/shape.hpp
View file @
7702c20d
...
@@ -82,6 +82,23 @@ struct shape
...
@@ -82,6 +82,23 @@ struct shape
{
{
};
};
struct
dynamic_dimension
{
std
::
size_t
min
=
0
;
std
::
size_t
max
=
0
;
std
::
size_t
opt
=
0
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
);
bool
is_fixed
()
const
;
bool
has_optimal
()
const
;
friend
bool
operator
==
(
const
dynamic_dimension
&
x
,
const
dynamic_dimension
&
y
);
friend
bool
operator
!=
(
const
dynamic_dimension
&
x
,
const
dynamic_dimension
&
y
);
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
dynamic_dimension
&
x
);
};
static
const
std
::
vector
<
type_t
>&
types
();
static
const
std
::
vector
<
type_t
>&
types
();
static
std
::
string
name
(
type_t
t
);
static
std
::
string
name
(
type_t
t
);
...
@@ -92,6 +109,12 @@ struct shape
...
@@ -92,6 +109,12 @@ struct shape
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
);
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
);
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
,
std
::
vector
<
std
::
size_t
>
s
);
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
,
std
::
vector
<
std
::
size_t
>
s
);
// Force all calls of the format `shape( type_t, { size_t compatibles } )` to map to
// shape(type_t, std::vector<std::size_t> l)
shape
(
type_t
t
,
std
::
initializer_list
<
std
::
size_t
>
d
);
shape
(
type_t
t
,
std
::
vector
<
dynamic_dimension
>
dims
);
template
<
class
Range
>
template
<
class
Range
>
shape
(
type_t
t
,
const
Range
&
l
)
:
shape
(
t
,
std
::
vector
<
std
::
size_t
>
(
l
.
begin
(),
l
.
end
()))
shape
(
type_t
t
,
const
Range
&
l
)
:
shape
(
t
,
std
::
vector
<
std
::
size_t
>
(
l
.
begin
(),
l
.
end
()))
{
{
...
@@ -112,10 +135,44 @@ struct shape
...
@@ -112,10 +135,44 @@ struct shape
type_t
type
()
const
;
type_t
type
()
const
;
const
std
::
vector
<
std
::
size_t
>&
lens
()
const
;
const
std
::
vector
<
std
::
size_t
>&
lens
()
const
;
const
std
::
vector
<
std
::
size_t
>&
strides
()
const
;
const
std
::
vector
<
std
::
size_t
>&
strides
()
const
;
/*!
* Return the number of elements in the tensor.
*/
std
::
size_t
elements
()
const
;
std
::
size_t
elements
()
const
;
/*!
* Return the number of total bytes used for storage of the tensor data; includes subshapes.
* For dynamic shape, returns the maximum number of bytes presuming a packed shape.
*/
std
::
size_t
bytes
()
const
;
std
::
size_t
bytes
()
const
;
/*!
* Return the size of the type of the main shape.
* Returns 0 if there are subshapes.
*/
std
::
size_t
type_size
()
const
;
std
::
size_t
type_size
()
const
;
const
std
::
vector
<
dynamic_dimension
>&
dyn_dims
()
const
;
/*!
* Minimum lengths for dynamic shape.
* lens() for fixed shape.
*/
std
::
vector
<
std
::
size_t
>
min_lens
()
const
;
/*!
* Maximum lengths for dynamic shape.
* lens() for fixed shape.
*/
std
::
vector
<
std
::
size_t
>
max_lens
()
const
;
/*!
* Optimum lengths for dynamic shape.
* lens() for fixed shape.
*/
std
::
vector
<
std
::
size_t
>
opt_lens
()
const
;
/// Map multiple indices to space index
/// Map multiple indices to space index
std
::
size_t
index
(
std
::
initializer_list
<
std
::
size_t
>
l
)
const
;
std
::
size_t
index
(
std
::
initializer_list
<
std
::
size_t
>
l
)
const
;
/// Map multiple indices to space index
/// Map multiple indices to space index
...
@@ -136,19 +193,27 @@ struct shape
...
@@ -136,19 +193,27 @@ struct shape
std
::
vector
<
std
::
size_t
>
multi
(
std
::
size_t
i
)
const
;
std
::
vector
<
std
::
size_t
>
multi
(
std
::
size_t
i
)
const
;
void
multi_copy
(
std
::
size_t
i
,
std
::
size_t
*
start
,
const
std
::
size_t
*
end
)
const
;
void
multi_copy
(
std
::
size_t
i
,
std
::
size_t
*
start
,
const
std
::
size_t
*
end
)
const
;
/// Returns true if the shape is packed with no padding
/// Returns true if the shape is packed (number of elements and buffer size the same) with no
/// padding
bool
packed
()
const
;
bool
packed
()
const
;
/// Returns true is the shape has been transposed. That is the strides are not in descending
/// Returns true is the shape has been transposed. That is the strides are not in descending
/// order
/// order
bool
transposed
()
const
;
bool
transposed
()
const
;
/// Returns true if the shape is broadcasting a dimension. That is, one of the strides are zero
/// Returns true if the shape is broadcasting a dimension. That is, one of the strides are zero
bool
broadcasted
()
const
;
bool
broadcasted
()
const
;
/// Returns true if the shape is in its standard format. That is, the shape is both packed and
/// Returns true if the shape is in its standard format. That is, the shape is both packed and
/// not transposed.
/// not transposed.
bool
standard
()
const
;
bool
standard
()
const
;
/// Returns true if all strides are equal to 0 (scalar tensor)
/// Returns true if all strides are equal to 0 (scalar tensor)
bool
scalar
()
const
;
bool
scalar
()
const
;
/// Return true if the shape is dynamic
bool
dynamic
()
const
;
shape
normalize_standard
()
const
;
shape
normalize_standard
()
const
;
shape
with_lens
(
type_t
t
,
const
std
::
vector
<
std
::
size_t
>&
l
)
const
;
shape
with_lens
(
type_t
t
,
const
std
::
vector
<
std
::
size_t
>&
l
)
const
;
...
@@ -191,6 +256,10 @@ struct shape
...
@@ -191,6 +256,10 @@ struct shape
std
::
size_t
size
(
std
::
size_t
n
=
1
)
const
{
return
sizeof
(
type
)
*
n
;
}
std
::
size_t
size
(
std
::
size_t
n
=
1
)
const
{
return
sizeof
(
type
)
*
n
;
}
auto
is_integral
()
const
{
return
std
::
is_integral
<
type
>
{};
}
auto
is_signed
()
const
{
return
std
::
is_signed
<
type
>
{};
}
auto
is_unsigned
()
const
{
return
std
::
is_unsigned
<
type
>
{};
}
template
<
class
U
>
template
<
class
U
>
type
*
from
(
U
*
buffer
,
std
::
size_t
n
=
0
)
const
type
*
from
(
U
*
buffer
,
std
::
size_t
n
=
0
)
const
{
{
...
@@ -248,6 +317,11 @@ struct shape
...
@@ -248,6 +317,11 @@ struct shape
const
std
::
vector
<
shape
>&
sub_shapes
()
const
;
const
std
::
vector
<
shape
>&
sub_shapes
()
const
;
/*!
* Returns the number of elements in the data buffer.
* For a dynamic shape, returns the maximum number of elements of the data buffer and assumes it
* is packed.
*/
std
::
size_t
element_space
()
const
;
std
::
size_t
element_space
()
const
;
private:
private:
...
...
Prev
1
2
3
4
5
6
7
8
9
10
…
13
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment