Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
8ba8f907
Unverified
Commit
8ba8f907
authored
Apr 05, 2019
by
mvermeulen
Committed by
GitHub
Apr 05, 2019
Browse files
Merge pull request #224 from ROCmSoftwarePlatform/separate_op_headfiles
Separate headfile operators.hpp into multiple files.
parents
a065486a
a7ee70a9
Changes
100
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
563 additions
and
8 deletions
+563
-8
src/auto_contiguous.cpp
src/auto_contiguous.cpp
+1
-1
src/eliminate_allocation.cpp
src/eliminate_allocation.cpp
+1
-1
src/eliminate_concat.cpp
src/eliminate_concat.cpp
+2
-1
src/eliminate_contiguous.cpp
src/eliminate_contiguous.cpp
+0
-1
src/eliminate_identity.cpp
src/eliminate_identity.cpp
+0
-1
src/eliminate_pad.cpp
src/eliminate_pad.cpp
+4
-1
src/fwd_conv_batchnorm_rewrite.cpp
src/fwd_conv_batchnorm_rewrite.cpp
+3
-1
src/include/migraphx/concat_opt.hpp
src/include/migraphx/concat_opt.hpp
+1
-1
src/include/migraphx/op/abnormal_ops.hpp
src/include/migraphx/op/abnormal_ops.hpp
+62
-0
src/include/migraphx/op/abs.hpp
src/include/migraphx/op/abs.hpp
+29
-0
src/include/migraphx/op/acos.hpp
src/include/migraphx/op/acos.hpp
+29
-0
src/include/migraphx/op/add.hpp
src/include/migraphx/op/add.hpp
+29
-0
src/include/migraphx/op/as_shape.hpp
src/include/migraphx/op/as_shape.hpp
+46
-0
src/include/migraphx/op/asin.hpp
src/include/migraphx/op/asin.hpp
+29
-0
src/include/migraphx/op/atan.hpp
src/include/migraphx/op/atan.hpp
+29
-0
src/include/migraphx/op/batch_norm.hpp
src/include/migraphx/op/batch_norm.hpp
+56
-0
src/include/migraphx/op/binary.hpp
src/include/migraphx/op/binary.hpp
+34
-0
src/include/migraphx/op/broadcast.hpp
src/include/migraphx/op/broadcast.hpp
+75
-0
src/include/migraphx/op/common.hpp
src/include/migraphx/op/common.hpp
+38
-0
src/include/migraphx/op/concat.hpp
src/include/migraphx/op/concat.hpp
+95
-0
No files found.
src/auto_contiguous.cpp
View file @
8ba8f907
#include <migraphx/auto_contiguous.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/op
erator
s.hpp>
#include <migraphx/op
/contiguou
s.hpp>
#include <migraphx/iterator_for.hpp>
namespace
migraphx
{
...
...
src/eliminate_allocation.cpp
View file @
8ba8f907
#include <migraphx/eliminate_allocation.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/op
erators
.hpp>
#include <migraphx/op
/load
.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
...
...
src/eliminate_concat.cpp
View file @
8ba8f907
...
...
@@ -2,7 +2,8 @@
#include <migraphx/eliminate_concat.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/op/load.hpp>
#include <migraphx/op/identity.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/dfor.hpp>
...
...
src/eliminate_contiguous.cpp
View file @
8ba8f907
#include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
...
...
src/eliminate_identity.cpp
View file @
8ba8f907
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/stringutils.hpp>
#include <utility>
...
...
src/eliminate_pad.cpp
View file @
8ba8f907
#include <migraphx/eliminate_pad.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/im2col.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/pad.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/stringutils.hpp>
...
...
src/fwd_conv_batchnorm_rewrite.cpp
View file @
8ba8f907
#include <migraphx/fwd_conv_batchnorm_rewrite.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/op/batch_norm.hpp>
#include <migraphx/op/broadcast.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/dfor.hpp>
...
...
src/include/migraphx/concat_opt.hpp
View file @
8ba8f907
...
...
@@ -9,7 +9,7 @@
#include <utility>
#include <migraphx/operation.hpp>
#include <migraphx/op
erators
.hpp>
#include <migraphx/op
/concat
.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
...
...
src/include/migraphx/op/abnormal_ops.hpp
0 → 100644
View file @
8ba8f907
#ifndef MIGRAPHX_GUARD_OPERATORS_ABNORMAL_OPS_HPP
#define MIGRAPHX_GUARD_OPERATORS_ABNORMAL_OPS_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
not_computable
{
argument
compute
(
const
shape
&
,
const
std
::
vector
<
argument
>&
)
const
{
MIGRAPHX_THROW
(
"not computable"
);
}
};
struct
undefined
{
std
::
string
name
()
const
{
return
"undefined"
;
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
0
);
return
{};
}
argument
compute
(
const
shape
&
,
const
std
::
vector
<
argument
>&
)
const
{
return
{{},
nullptr
};
}
};
struct
unknown
{
std
::
string
op
;
std
::
string
name
()
const
{
return
"unknown:"
+
op
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
input
)
const
{
if
(
input
.
empty
())
return
{};
else
return
input
.
front
();
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
unknown
&
x
)
{
os
<<
x
.
name
();
return
os
;
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/abs.hpp
0 → 100644
View file @
8ba8f907
#ifndef MIGRAPHX_GUARD_OPERATORS_ABS_HPP
#define MIGRAPHX_GUARD_OPERATORS_ABS_HPP
#include <array>
#include <migraphx/op/unary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
abs
:
unary
{
std
::
string
name
()
const
{
return
"abs"
;
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/acos.hpp
0 → 100644
View file @
8ba8f907
#ifndef MIGRAPHX_GUARD_OPERATORS_ACOS_HPP
#define MIGRAPHX_GUARD_OPERATORS_ACOS_HPP
#include <array>
#include <migraphx/op/unary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
acos
:
unary
{
std
::
string
name
()
const
{
return
"acos"
;
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/add.hpp
0 → 100644
View file @
8ba8f907
#ifndef MIGRAPHX_GUARD_OPERATORS_ADD_HPP
#define MIGRAPHX_GUARD_OPERATORS_ADD_HPP
#include <array>
#include <migraphx/op/binary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
add
:
binary
{
std
::
string
name
()
const
{
return
"add"
;
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/as_shape.hpp
0 → 100644
View file @
8ba8f907
#ifndef MIGRAPHX_GUARD_OPERATORS_AS_SHAPE_HPP
#define MIGRAPHX_GUARD_OPERATORS_AS_SHAPE_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
as_shape
{
shape
s
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
s
,
"shape"
));
}
std
::
string
name
()
const
{
return
"as_shape"
;
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
).
standard
();
assert
(
inputs
.
front
().
elements
()
==
s
.
elements
());
return
s
;
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
front
().
data
)};
}
int
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/asin.hpp
0 → 100644
View file @
8ba8f907
#ifndef MIGRAPHX_GUARD_OPERATORS_ASIN_HPP
#define MIGRAPHX_GUARD_OPERATORS_ASIN_HPP
#include <array>
#include <migraphx/op/unary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
asin
:
unary
{
std
::
string
name
()
const
{
return
"asin"
;
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/atan.hpp
0 → 100644
View file @
8ba8f907
#ifndef MIGRAPHX_GUARD_OPERATORS_ATAN_HPP
#define MIGRAPHX_GUARD_OPERATORS_ATAN_HPP
#include <array>
#include <migraphx/op/unary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
atan
:
unary
{
std
::
string
name
()
const
{
return
"atan"
;
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/batch_norm.hpp
0 → 100644
View file @
8ba8f907
#ifndef MIGRAPHX_GUARD_OPERATORS_BATCH_NORM_HPP
#define MIGRAPHX_GUARD_OPERATORS_BATCH_NORM_HPP
#include <array>
#include <migraphx/op/common.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
batch_norm_inference
{
float
epsilon
=
1.0e-6
f
;
float
momentum
=
0.9
f
;
std
::
string
name
()
const
{
return
"batch_norm_inference"
;
}
enum
bn_infer_mode_t
{
per_activation
,
spatial
,
};
bn_infer_mode_t
bn_mode
=
spatial
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
epsilon
,
"epsilon"
),
f
(
self
.
momentum
,
"momentum"
),
f
(
self
.
bn_mode
,
"bn_mode"
));
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
5
);
check_shapes
{
inputs
.
data
(),
inputs
.
data
()
+
1
,
*
this
}.
only_dims
(
4
);
check_shapes
{
inputs
.
data
()
+
1
,
inputs
.
data
()
+
inputs
.
size
(),
*
this
}.
same_shape
().
elements
(
inputs
.
front
().
lens
()[
1
]);
return
inputs
.
front
();
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/binary.hpp
0 → 100644
View file @
8ba8f907
#ifndef MIGRAPHX_GUARD_OPERATORS_BINARY_HPP
#define MIGRAPHX_GUARD_OPERATORS_BINARY_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
binary
{
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
}.
has
(
2
).
same_type
().
same_dims
();
auto
t
=
inputs
.
at
(
0
).
type
();
auto
lens
=
inputs
.
at
(
0
).
lens
();
return
{
t
,
lens
};
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/broadcast.hpp
0 → 100644
View file @
8ba8f907
#ifndef MIGRAPHX_GUARD_OPERATORS_BROADCAST_HPP
#define MIGRAPHX_GUARD_OPERATORS_BROADCAST_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
/// The broadcast operator performs the numpy-style broadcasting of an axis of a given tensor. This
/// is achieved primarily by setting the stride of the broadcasted axis to zero. Linear indicies are
/// computed from multi-indicies by computing the inner product on the multi-index with the strides.
/// For example, if we have a tensor A(2,3) it has lengths of (2,3) and strides of (3,1). If we want
/// to compute the linear offset that corresponds to the element on the 2nd row (i = 1) and 3rd
/// column (j = 2), we compute the following inner product (1,2) dot (3, 1) = 1*3 + 2*1 = 5. It is
/// obvious from there that we can negate the effects of a given axis by setting the stride of that
/// axis to zero.
struct
broadcast
{
uint64_t
axis
=
0
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
axis
,
"axis"
));
}
shape
broadcast_shape
;
std
::
string
name
()
const
{
return
"broadcast"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
auto
t
=
inputs
.
at
(
0
).
type
();
auto
input
=
inputs
.
at
(
0
);
std
::
vector
<
size_t
>
bcast_strides
(
broadcast_shape
.
lens
().
size
(),
0
);
if
(
std
::
all_of
(
broadcast_shape
.
lens
().
cbegin
(),
broadcast_shape
.
lens
().
cend
(),
[
&
](
auto
x
)
{
return
x
==
1
;
}))
{
if
(
axis
!=
0
)
MIGRAPHX_THROW
(
"when broadcasting tensor of size 1, axis should be 0"
);
return
{
t
,
broadcast_shape
.
lens
(),
std
::
move
(
bcast_strides
)};
}
else
{
assert
(
broadcast_shape
.
lens
().
size
()
-
axis
>=
input
.
lens
().
size
());
if
(
!
std
::
equal
(
input
.
lens
().
begin
(),
input
.
lens
().
end
(),
broadcast_shape
.
lens
().
begin
()
+
axis
))
MIGRAPHX_THROW
(
"when broadcasting success sizes must match"
);
std
::
copy
(
input
.
strides
().
begin
(),
input
.
strides
().
end
(),
bcast_strides
.
begin
()
+
axis
);
return
{
t
,
broadcast_shape
.
lens
(),
std
::
move
(
bcast_strides
)};
}
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
at
(
0
).
data
)};
}
int
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/common.hpp
0 → 100644
View file @
8ba8f907
#ifndef MIGRAPHX_GUARD_OPERATORS_COMMON_HPP
#define MIGRAPHX_GUARD_OPERATORS_COMMON_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
enum
padding_mode_t
{
default_
,
// NOLINT
same
,
valid
};
// indicate rnn computation direction
enum
class
rnn_direction
{
forward
,
reverse
,
bidirectional
,
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/concat.hpp
0 → 100644
View file @
8ba8f907
#ifndef MIGRAPHX_GUARD_OPERATORS_CONCAT_HPP
#define MIGRAPHX_GUARD_OPERATORS_CONCAT_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
concat
{
std
::
size_t
axis
=
0
;
std
::
string
name
()
const
{
return
"concat"
;
}
std
::
vector
<
std
::
size_t
>
compute_offsets
(
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
)
const
{
std
::
vector
<
std
::
size_t
>
offsets
;
std
::
vector
<
std
::
size_t
>
offset
(
args
[
0
].
get_shape
().
lens
().
size
(),
0
);
offset
[
axis
]
=
0
;
for
(
const
auto
&
arg
:
args
)
{
offsets
.
push_back
(
output_shape
.
index
(
offset
));
offset
[
axis
]
+=
arg
.
get_shape
().
lens
()[
axis
];
}
return
offsets
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
if
(
inputs
.
empty
())
{
MIGRAPHX_THROW
(
"Number of input tensors should exceed 0"
);
}
const
auto
&
first_shape_lens
=
inputs
.
front
().
lens
();
const
auto
&
type
=
inputs
.
front
().
type
();
for
(
std
::
size_t
l
=
0
;
l
<
first_shape_lens
.
size
();
l
++
)
{
if
(
l
!=
axis
)
{
if
(
!
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
&
](
auto
s
)
{
return
s
.
lens
()[
l
]
==
first_shape_lens
[
l
];
}))
{
MIGRAPHX_THROW
(
"Non-axis dimensions should match"
);
}
}
}
std
::
size_t
new_dim_axis
=
0
;
for
(
const
auto
&
input
:
inputs
)
{
const
auto
&
lens
=
input
.
lens
();
new_dim_axis
+=
lens
[
axis
];
}
std
::
vector
<
std
::
size_t
>
new_lens
;
std
::
copy
(
first_shape_lens
.
begin
(),
first_shape_lens
.
end
(),
std
::
back_inserter
(
new_lens
));
new_lens
[
axis
]
=
new_dim_axis
;
return
{
type
,
new_lens
};
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
output_shape
};
std
::
vector
<
std
::
size_t
>
coffsets
=
compute_offsets
(
output_shape
,
args
);
for
(
std
::
size_t
l
=
0
;
l
<
args
.
size
();
l
++
)
{
auto
argl
=
args
[
l
];
std
::
size_t
nelements
=
argl
.
get_shape
().
elements
();
visit_all
(
result
,
argl
)([
&
](
auto
output
,
auto
input
)
{
auto
slice_shape
=
shape
{
output_shape
.
type
(),
input
.
get_shape
().
lens
(),
output_shape
.
strides
()};
auto
slice
=
make_view
(
slice_shape
,
output
.
data
()
+
coffsets
[
l
]);
// cppcheck-suppress useStlAlgorithm
for
(
std
::
size_t
i
=
0
;
i
<
nelements
;
i
++
)
{
slice
[
i
]
=
input
[
i
];
}
});
}
return
result
;
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment