Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
3a848f0d
Commit
3a848f0d
authored
Mar 19, 2020
by
Paul
Browse files
Merge branch 'develop' into doc2
parents
64e8e30a
d1e945da
Changes
208
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
319 additions
and
55 deletions
+319
-55
src/include/migraphx/instruction.hpp
src/include/migraphx/instruction.hpp
+1
-1
src/include/migraphx/onnx.hpp
src/include/migraphx/onnx.hpp
+13
-1
src/include/migraphx/op/acosh.hpp
src/include/migraphx/op/acosh.hpp
+22
-0
src/include/migraphx/op/argmax.hpp
src/include/migraphx/op/argmax.hpp
+14
-7
src/include/migraphx/op/argmin.hpp
src/include/migraphx/op/argmin.hpp
+13
-7
src/include/migraphx/op/asinh.hpp
src/include/migraphx/op/asinh.hpp
+22
-0
src/include/migraphx/op/atanh.hpp
src/include/migraphx/op/atanh.hpp
+22
-0
src/include/migraphx/op/deconvolution.hpp
src/include/migraphx/op/deconvolution.hpp
+68
-0
src/include/migraphx/op/flatten.hpp
src/include/migraphx/op/flatten.hpp
+12
-9
src/include/migraphx/op/logsoftmax.hpp
src/include/migraphx/op/logsoftmax.hpp
+3
-2
src/include/migraphx/op/prelu.hpp
src/include/migraphx/op/prelu.hpp
+22
-0
src/include/migraphx/op/reduce_op.hpp
src/include/migraphx/op/reduce_op.hpp
+9
-0
src/include/migraphx/op/reduce_prod.hpp
src/include/migraphx/op/reduce_prod.hpp
+27
-0
src/include/migraphx/op/softmax.hpp
src/include/migraphx/op/softmax.hpp
+3
-2
src/include/migraphx/op/squeeze.hpp
src/include/migraphx/op/squeeze.hpp
+12
-4
src/include/migraphx/op/transpose.hpp
src/include/migraphx/op/transpose.hpp
+14
-5
src/include/migraphx/op/unsqueeze.hpp
src/include/migraphx/op/unsqueeze.hpp
+8
-1
src/include/migraphx/operation.hpp
src/include/migraphx/operation.hpp
+13
-7
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+8
-2
src/include/migraphx/pass.hpp
src/include/migraphx/pass.hpp
+13
-7
No files found.
src/include/migraphx/instruction.hpp
View file @
3a848f0d
...
...
@@ -93,7 +93,7 @@ struct instruction
void
replace
(
const
shape
&
r
);
operation
op
;
shape
result
;
shape
result
{}
;
std
::
vector
<
instruction_ref
>
output
;
std
::
vector
<
instruction_ref
>
arguments
;
literal
lit
;
...
...
src/include/migraphx/onnx.hpp
View file @
3a848f0d
...
...
@@ -7,8 +7,20 @@
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
/// struct to pass in onnx options to parser
struct
onnx_options
{
unsigned
int
batch_size
=
1
;
};
/// Create a program from an onnx file
program
parse_onnx
(
const
std
::
string
&
name
);
program
parse_onnx
(
const
std
::
string
&
name
,
onnx_options
=
onnx_options
{});
/// Create a program from an onnx buffer
program
parse_onnx_buffer
(
const
std
::
string
&
buffer
,
onnx_options
options
);
/// Create a program from an onnx buffer
program
parse_onnx_buffer
(
const
void
*
data
,
std
::
size_t
size
,
onnx_options
options
);
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
...
...
src/include/migraphx/op/acosh.hpp
0 → 100644
View file @
3a848f0d
#ifndef MIGRAPHX_GUARD_OPERATORS_ACOSH_HPP
#define MIGRAPHX_GUARD_OPERATORS_ACOSH_HPP
#include <migraphx/op/unary.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
acosh
:
unary
<
acosh
>
{
auto
apply
()
const
{
return
[](
auto
x
)
{
return
std
::
acosh
(
x
);
};
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/argmax.hpp
View file @
3a848f0d
...
...
@@ -27,25 +27,30 @@ struct argmax
check_shapes
{
inputs
,
*
this
}.
has
(
1
).
standard
();
auto
lens
=
inputs
[
0
].
lens
();
int64_t
n_dim
=
static_cast
<
int64_t
>
(
lens
.
size
());
if
(
axis
>=
n_dim
||
axis
<
0
)
if
(
axis
>=
n_dim
||
axis
<
-
n_dim
)
{
MIGRAPHX_THROW
(
"ARGMAX: axis is out of range."
);
}
lens
[
axis
]
=
1
;
int64_t
tuned_axis
=
(
axis
<
0
)
?
axis
+
n_dim
:
axis
;
lens
[
tuned_axis
]
=
1
;
return
{
shape
::
int64_type
,
lens
};
}
template
<
class
T
>
int64_t
calc_argmax
(
T
&
input
,
std
::
vector
<
std
::
size_t
>&
indices
,
size_t
item_num
)
const
int64_t
calc_argmax
(
T
&
input
,
int64_t
tuned_axis
,
std
::
vector
<
std
::
size_t
>&
indices
,
size_t
item_num
)
const
{
auto
max_val
=
input
(
indices
.
begin
(),
indices
.
end
());
int64_t
max_index
=
0
;
for
(
std
::
size_t
i
=
1
;
i
<
item_num
;
++
i
)
{
indices
[
axis
]
=
i
;
auto
cur_val
=
input
(
indices
.
begin
(),
indices
.
end
());
indices
[
tuned_
axis
]
=
i
;
auto
cur_val
=
input
(
indices
.
begin
(),
indices
.
end
());
if
(
max_val
<
cur_val
)
{
max_val
=
cur_val
;
...
...
@@ -59,13 +64,15 @@ struct argmax
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
output_shape
};
auto
batch_item_num
=
args
.
front
().
get_shape
().
lens
()[
axis
];
auto
n_dim
=
args
.
front
().
get_shape
().
lens
().
size
();
auto
tuned_axis
=
axis
<
0
?
axis
+
n_dim
:
axis
;
auto
batch_item_num
=
args
.
front
().
get_shape
().
lens
()[
tuned_axis
];
result
.
visit
([
&
](
auto
output
)
{
args
[
0
].
visit
([
&
](
auto
input
)
{
par_for
(
output_shape
.
elements
(),
[
&
](
auto
i
)
{
auto
data_idx
=
output_shape
.
multi
(
i
);
output
[
i
]
=
this
->
calc_argmax
(
input
,
data_idx
,
batch_item_num
);
output
[
i
]
=
this
->
calc_argmax
(
input
,
tuned_axis
,
data_idx
,
batch_item_num
);
});
});
});
...
...
src/include/migraphx/op/argmin.hpp
View file @
3a848f0d
...
...
@@ -27,25 +27,29 @@ struct argmin
check_shapes
{
inputs
,
*
this
}.
has
(
1
).
standard
();
auto
lens
=
inputs
[
0
].
lens
();
int64_t
n_dim
=
static_cast
<
int64_t
>
(
lens
.
size
());
if
(
axis
>=
n_dim
||
axis
<
0
)
if
(
axis
>=
n_dim
||
axis
<
-
n_dim
)
{
MIGRAPHX_THROW
(
"ARGMIN: axis is out of range."
);
}
lens
[
axis
]
=
1
;
int64_t
tuned_axis
=
(
axis
<
0
)
?
axis
+
n_dim
:
axis
;
lens
[
tuned_axis
]
=
1
;
return
{
shape
::
int64_type
,
lens
};
}
template
<
class
T
>
int64_t
calc_argmin
(
T
&
input
,
std
::
vector
<
std
::
size_t
>&
indices
,
size_t
item_num
)
const
int64_t
calc_argmin
(
T
&
input
,
int64_t
tuned_axis
,
std
::
vector
<
std
::
size_t
>&
indices
,
size_t
item_num
)
const
{
auto
min_val
=
input
(
indices
.
begin
(),
indices
.
end
());
int64_t
min_index
=
0
;
for
(
std
::
size_t
i
=
1
;
i
<
item_num
;
++
i
)
{
indices
[
axis
]
=
i
;
auto
cur_val
=
input
(
indices
.
begin
(),
indices
.
end
());
indices
[
tuned_
axis
]
=
i
;
auto
cur_val
=
input
(
indices
.
begin
(),
indices
.
end
());
if
(
min_val
>
cur_val
)
{
min_val
=
cur_val
;
...
...
@@ -59,13 +63,15 @@ struct argmin
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
output_shape
};
std
::
size_t
batch_item_num
=
args
.
front
().
get_shape
().
lens
()[
axis
];
auto
n_dim
=
args
.
front
().
get_shape
().
lens
().
size
();
auto
tuned_axis
=
axis
<
0
?
axis
+
n_dim
:
axis
;
std
::
size_t
batch_item_num
=
args
.
front
().
get_shape
().
lens
()[
tuned_axis
];
result
.
visit
([
&
](
auto
output
)
{
args
[
0
].
visit
([
&
](
auto
input
)
{
par_for
(
output_shape
.
elements
(),
[
&
](
auto
i
)
{
auto
data_idx
=
output_shape
.
multi
(
i
);
output
[
i
]
=
this
->
calc_argmin
(
input
,
data_idx
,
batch_item_num
);
output
[
i
]
=
this
->
calc_argmin
(
input
,
tuned_axis
,
data_idx
,
batch_item_num
);
});
});
});
...
...
src/include/migraphx/op/asinh.hpp
0 → 100644
View file @
3a848f0d
#ifndef MIGRAPHX_GUARD_OPERATORS_ASINH_HPP
#define MIGRAPHX_GUARD_OPERATORS_ASINH_HPP
#include <migraphx/op/unary.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
asinh
:
unary
<
asinh
>
{
auto
apply
()
const
{
return
[](
auto
x
)
{
return
std
::
asinh
(
x
);
};
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/atanh.hpp
0 → 100644
View file @
3a848f0d
#ifndef MIGRAPHX_GUARD_OPERATORS_ATANH_HPP
#define MIGRAPHX_GUARD_OPERATORS_ATANH_HPP
#include <migraphx/op/unary.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
atanh
:
unary
<
atanh
>
{
auto
apply
()
const
{
return
[](
auto
x
)
{
return
std
::
atanh
(
x
);
};
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/deconvolution.hpp
0 → 100644
View file @
3a848f0d
#ifndef MIGRAPHX_GUARD_OPERATORS_DECONVOLUTION_HPP
#define MIGRAPHX_GUARD_OPERATORS_DECONVOLUTION_HPP
#include <array>
#include <migraphx/op/common.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
deconvolution
{
std
::
array
<
std
::
size_t
,
2
>
padding
=
{{
0
,
0
}};
std
::
array
<
std
::
size_t
,
2
>
stride
=
{{
1
,
1
}};
std
::
array
<
std
::
size_t
,
2
>
dilation
=
{{
1
,
1
}};
padding_mode_t
padding_mode
=
default_
;
int
group
=
1
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
padding
,
"padding"
),
f
(
self
.
stride
,
"stride"
),
f
(
self
.
dilation
,
"dilation"
),
f
(
self
.
padding_mode
,
"padding_mode"
),
f
(
self
.
group
,
"group"
));
}
std
::
string
name
()
const
{
return
"deconvolution"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
2
).
same_type
().
same_ndims
().
only_dims
(
4
);
const
shape
&
input
=
inputs
.
at
(
0
);
const
shape
&
weights
=
inputs
.
at
(
1
);
auto
t
=
input
.
type
();
return
{
t
,
{
input
.
lens
()[
0
],
weights
.
lens
()[
1
],
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
1
,
stride
[
0
]
*
(
input
.
lens
()[
2
]
-
1
)
+
((
weights
.
lens
()[
2
]
-
1
)
*
dilation
[
0
]
+
1
)
-
2
*
padding
[
0
])),
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
1
,
stride
[
1
]
*
(
input
.
lens
()[
3
]
-
1
)
+
((
weights
.
lens
()[
3
]
-
1
)
*
dilation
[
1
]
+
1
)
-
2
*
padding
[
1
])),
}};
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/flatten.hpp
View file @
3a848f0d
...
...
@@ -18,7 +18,7 @@ namespace op {
struct
flatten
{
u
int64_t
axis
=
0
;
int64_t
axis
=
1
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
...
...
@@ -30,16 +30,19 @@ struct flatten
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
}.
has
(
1
);
auto
&&
lens
=
inputs
.
front
().
lens
();
if
(
axis
>
lens
.
size
()
)
auto
&&
lens
=
inputs
.
front
().
lens
();
int64_t
n_dim
=
static_cast
<
int64_t
>
(
lens
.
size
());
if
(
axis
>
n_dim
or
axis
<
-
n_dim
)
{
MIGRAPHX_THROW
(
"axis for flatten
must be less than tensor
ran
k
"
);
MIGRAPHX_THROW
(
"
FLATTEN:
axis for flatten
is out of
ran
ge
"
);
}
auto
x
=
std
::
accumulate
(
lens
.
begin
(),
lens
.
begin
()
+
axis
,
std
::
size_t
{
1
},
std
::
multiplies
<>
{});
auto
y
=
std
::
accumulate
(
lens
.
begin
()
+
axis
,
lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<>
{});
auto
tuned_axis
=
(
axis
<
0
)
?
axis
+
n_dim
:
axis
;
auto
x
=
std
::
accumulate
(
lens
.
begin
(),
lens
.
begin
()
+
tuned_axis
,
std
::
size_t
{
1
},
std
::
multiplies
<>
{});
auto
y
=
std
::
accumulate
(
lens
.
begin
()
+
tuned_axis
,
lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<>
{});
return
{
inputs
.
at
(
0
).
type
(),
{
x
,
y
}};
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
...
...
src/include/migraphx/op/logsoftmax.hpp
View file @
3a848f0d
...
...
@@ -11,7 +11,7 @@ namespace op {
struct
logsoftmax
{
int
axis
=
1
;
int
64_t
axis
=
1
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
...
...
@@ -23,7 +23,8 @@ struct logsoftmax
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
}.
has
(
1
).
standard
();
if
(
axis
<
0
||
axis
>=
inputs
[
0
].
lens
().
size
())
int64_t
n_dim
=
static_cast
<
int64_t
>
(
inputs
[
0
].
lens
().
size
());
if
(
axis
<
-
n_dim
||
axis
>=
n_dim
)
{
MIGRAPHX_THROW
(
"LogSoftMax: input axis value "
+
std
::
to_string
(
axis
)
+
" is out of range"
);
...
...
src/include/migraphx/op/prelu.hpp
0 → 100644
View file @
3a848f0d
#ifndef MIGRAPHX_GUARD_OPERATORS_PRELU_HPP
#define MIGRAPHX_GUARD_OPERATORS_PRELU_HPP
#include <migraphx/op/binary.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
prelu
:
binary
<
prelu
>
{
auto
apply
()
const
{
return
[](
auto
x
,
auto
slope
)
{
return
((
x
<
0
)
?
(
x
*
slope
)
:
x
);
};
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/reduce_op.hpp
View file @
3a848f0d
...
...
@@ -40,6 +40,15 @@ struct zero
}
};
struct
one
{
template
<
class
T
>
operator
T
()
const
{
return
T
{
1
};
}
};
template
<
class
Derived
>
struct
reduce_op
:
op_name
<
Derived
>
{
...
...
src/include/migraphx/op/reduce_prod.hpp
0 → 100644
View file @
3a848f0d
#ifndef MIGRAPHX_GUARD_OPERATORS_REDUCE_PROD_HPP
#define MIGRAPHX_GUARD_OPERATORS_REDUCE_PROD_HPP
#include <migraphx/op/reduce_op.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
reduce_prod
:
reduce_op
<
reduce_prod
>
{
reduce_prod
()
{}
reduce_prod
(
std
::
vector
<
int64_t
>
ax
)
:
reduce_op
(
std
::
move
(
ax
))
{}
auto
op
()
const
{
return
[
=
](
auto
x
,
auto
y
)
{
return
x
*
y
;
};
}
auto
init
()
const
{
return
one
();
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/softmax.hpp
View file @
3a848f0d
...
...
@@ -11,7 +11,7 @@ namespace op {
struct
softmax
{
int
axis
=
1
;
int
64_t
axis
=
1
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
...
...
@@ -23,7 +23,8 @@ struct softmax
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
}.
has
(
1
).
standard
();
if
(
axis
<
0
||
axis
>=
inputs
[
0
].
lens
().
size
())
int64_t
n_dim
=
inputs
[
0
].
lens
().
size
();
if
(
axis
<
-
n_dim
||
axis
>=
n_dim
)
{
MIGRAPHX_THROW
(
"SoftMax: input axis value "
+
std
::
to_string
(
axis
)
+
" is out of range"
);
...
...
src/include/migraphx/op/squeeze.hpp
View file @
3a848f0d
...
...
@@ -33,13 +33,21 @@ struct squeeze
auto
input_shape
=
inputs
[
0
];
auto
type
=
input_shape
.
type
();
auto
old_lens
=
input_shape
.
lens
();
if
(
std
::
any_of
(
axes
.
begin
(),
axes
.
end
(),
[
&
](
auto
axis
)
{
return
input_shape
.
lens
()[
axis
]
!=
1
;
}))
// change to support negative axis value
std
::
vector
<
int64_t
>
tuned_axes
(
axes
.
size
());
std
::
transform
(
axes
.
begin
(),
axes
.
end
(),
tuned_axes
.
begin
(),
[
&
](
auto
i
)
{
return
i
>=
0
?
i
:
i
+
old_lens
.
size
();
});
if
(
std
::
any_of
(
tuned_axes
.
begin
(),
tuned_axes
.
end
(),
[
&
](
auto
axis
)
{
return
old_lens
[
axis
]
!=
1
;
}))
{
MIGRAPHX_THROW
(
"squeeze axis dimension should be equal to 1"
);
}
std
::
vector
<
std
::
size_t
>
new_lens
;
if
(
axes
.
empty
())
if
(
tuned_
axes
.
empty
())
{
std
::
copy_if
(
old_lens
.
begin
(),
old_lens
.
end
(),
...
...
@@ -50,7 +58,7 @@ struct squeeze
{
for
(
std
::
size_t
i
=
0
;
i
<
old_lens
.
size
();
i
++
)
{
if
(
std
::
find
(
axes
.
begin
(),
axes
.
end
(),
i
)
==
axes
.
end
())
if
(
std
::
find
(
tuned_
axes
.
begin
(),
tuned_
axes
.
end
(),
i
)
==
tuned_
axes
.
end
())
{
new_lens
.
push_back
(
old_lens
[
i
]);
}
...
...
src/include/migraphx/op/transpose.hpp
View file @
3a848f0d
...
...
@@ -34,13 +34,22 @@ struct transpose
auto
input_lens
=
input
.
lens
();
auto
input_strides
=
input
.
strides
();
auto
t
=
input
.
type
();
if
(
dims
.
size
()
!=
input_lens
.
size
())
auto
tuned_dims
=
dims
;
// if not perm provided, reverse the dims
if
(
tuned_dims
.
empty
())
{
tuned_dims
.
resize
(
input_lens
.
size
());
std
::
iota
(
tuned_dims
.
begin
(),
tuned_dims
.
end
(),
0
);
std
::
reverse
(
tuned_dims
.
begin
(),
tuned_dims
.
end
());
}
if
(
tuned_dims
.
size
()
!=
input_lens
.
size
())
{
MIGRAPHX_THROW
(
"Permutation has wrong number of axes"
);
}
std
::
vector
<
int64_t
>
axes
(
dims
.
size
());
std
::
vector
<
int64_t
>
axes
(
tuned_
dims
.
size
());
std
::
iota
(
axes
.
begin
(),
axes
.
end
(),
0
);
if
(
!
std
::
is_permutation
(
axes
.
begin
(),
axes
.
end
(),
dims
.
begin
()))
if
(
!
std
::
is_permutation
(
axes
.
begin
(),
axes
.
end
(),
tuned_
dims
.
begin
()))
{
MIGRAPHX_THROW
(
"Invalid permutation"
);
}
...
...
@@ -48,8 +57,8 @@ struct transpose
std
::
vector
<
size_t
>
output_strides
(
input_lens
.
size
());
for
(
std
::
size_t
i
=
0
;
i
<
output_lens
.
size
();
i
++
)
{
output_lens
[
i
]
=
input_lens
[
dims
[
i
]];
output_strides
[
i
]
=
input_strides
[
dims
[
i
]];
output_lens
[
i
]
=
input_lens
[
tuned_
dims
[
i
]];
output_strides
[
i
]
=
input_strides
[
tuned_
dims
[
i
]];
}
return
{
t
,
output_lens
,
output_strides
};
}
...
...
src/include/migraphx/op/unsqueeze.hpp
View file @
3a848f0d
...
...
@@ -38,11 +38,18 @@ struct unsqueeze
return
shape
{
type
,
old_lens
};
std
::
size_t
new_size
=
old_lens
.
size
()
+
axes
.
size
();
// in case of axes to be negative, tune to positive
std
::
vector
<
int64_t
>
tuned_axes
(
axes
.
size
());
std
::
transform
(
axes
.
begin
(),
axes
.
end
(),
tuned_axes
.
begin
(),
[
new_size
](
auto
i
)
{
return
i
>=
0
?
i
:
i
+
new_size
;
});
std
::
vector
<
std
::
size_t
>
new_lens
(
new_size
);
std
::
size_t
p
=
0
;
for
(
std
::
size_t
i
=
0
;
i
<
new_size
;
i
++
)
{
if
(
std
::
find
(
axes
.
begin
(),
axes
.
end
(),
i
)
!=
axes
.
end
())
if
(
std
::
find
(
tuned_
axes
.
begin
(),
tuned_
axes
.
end
(),
i
)
!=
tuned_
axes
.
end
())
{
new_lens
[
i
]
=
1
;
}
...
...
src/include/migraphx/operation.hpp
View file @
3a848f0d
...
...
@@ -257,11 +257,17 @@ struct operation
template
<
typename
PrivateDetailTypeErasedT
>
operation
&
operator
=
(
PrivateDetailTypeErasedT
value
)
{
if
(
private_detail_te_handle_mem_var
.
unique
())
*
private_detail_te_handle_mem_var
=
std
::
forward
<
PrivateDetailTypeErasedT
>
(
value
);
else
if
(
!
private_detail_te_handle_mem_var
)
private_detail_te_handle_mem_var
=
std
::
make_shared
<
PrivateDetailTypeErasedT
>
(
std
::
forward
<
PrivateDetailTypeErasedT
>
(
value
));
using
std
::
swap
;
auto
*
derived
=
this
->
any_cast
<
PrivateDetailTypeErasedT
>
();
if
(
derived
and
private_detail_te_handle_mem_var
.
unique
())
{
*
derived
=
std
::
forward
<
PrivateDetailTypeErasedT
>
(
value
);
}
else
{
operation
rhs
(
value
);
swap
(
private_detail_te_handle_mem_var
,
rhs
.
private_detail_te_handle_mem_var
);
}
return
*
this
;
}
...
...
@@ -269,7 +275,7 @@ struct operation
template
<
typename
PrivateDetailTypeErasedT
>
PrivateDetailTypeErasedT
*
any_cast
()
{
return
private_detail_te_get_handle
().
type
()
==
typeid
(
PrivateDetailTypeErasedT
)
return
this
->
type
_id
()
==
typeid
(
PrivateDetailTypeErasedT
)
?
std
::
addressof
(
static_cast
<
private_detail_te_handle_type
<
typename
std
::
remove_cv
<
PrivateDetailTypeErasedT
>::
type
>&>
(
private_detail_te_get_handle
())
...
...
@@ -280,7 +286,7 @@ struct operation
template
<
typename
PrivateDetailTypeErasedT
>
const
typename
std
::
remove_cv
<
PrivateDetailTypeErasedT
>::
type
*
any_cast
()
const
{
return
private_detail_te_get_handle
().
type
()
==
typeid
(
PrivateDetailTypeErasedT
)
return
this
->
type
_id
()
==
typeid
(
PrivateDetailTypeErasedT
)
?
std
::
addressof
(
static_cast
<
const
private_detail_te_handle_type
<
typename
std
::
remove_cv
<
PrivateDetailTypeErasedT
>::
type
>&>
(
private_detail_te_get_handle
())
...
...
src/include/migraphx/operators.hpp
View file @
3a848f0d
...
...
@@ -4,12 +4,15 @@
#include <migraphx/op/abnormal_ops.hpp>
#include <migraphx/op/abs.hpp>
#include <migraphx/op/acos.hpp>
#include <migraphx/op/acosh.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/op/argmax.hpp>
#include <migraphx/op/argmin.hpp>
#include <migraphx/op/asin.hpp>
#include <migraphx/op/asinh.hpp>
#include <migraphx/op/as_shape.hpp>
#include <migraphx/op/atan.hpp>
#include <migraphx/op/atanh.hpp>
#include <migraphx/op/batch_norm.hpp>
#include <migraphx/op/binary.hpp>
#include <migraphx/op/broadcast.hpp>
...
...
@@ -23,6 +26,7 @@
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/cosh.hpp>
#include <migraphx/op/cos.hpp>
#include <migraphx/op/deconvolution.hpp>
#include <migraphx/op/div.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/elu.hpp>
...
...
@@ -48,13 +52,15 @@
#include <migraphx/op/outline.hpp>
#include <migraphx/op/pad.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/prelu.hpp>
#include <migraphx/op/quant_convolution.hpp>
#include <migraphx/op/quant_dot.hpp>
#include <migraphx/op/pow.hpp>
#include <migraphx/op/reduce_
sum
.hpp>
#include <migraphx/op/reduce_
max
.hpp>
#include <migraphx/op/reduce_mean.hpp>
#include <migraphx/op/reduce_min.hpp>
#include <migraphx/op/reduce_max.hpp>
#include <migraphx/op/reduce_prod.hpp>
#include <migraphx/op/reduce_sum.hpp>
#include <migraphx/op/relu.hpp>
#include <migraphx/op/reshape.hpp>
#include <migraphx/op/rnn.hpp>
...
...
src/include/migraphx/pass.hpp
View file @
3a848f0d
...
...
@@ -57,11 +57,17 @@ struct pass
template
<
typename
PrivateDetailTypeErasedT
>
pass
&
operator
=
(
PrivateDetailTypeErasedT
value
)
{
if
(
private_detail_te_handle_mem_var
.
unique
())
*
private_detail_te_handle_mem_var
=
std
::
forward
<
PrivateDetailTypeErasedT
>
(
value
);
else
if
(
!
private_detail_te_handle_mem_var
)
private_detail_te_handle_mem_var
=
std
::
make_shared
<
PrivateDetailTypeErasedT
>
(
std
::
forward
<
PrivateDetailTypeErasedT
>
(
value
));
using
std
::
swap
;
auto
*
derived
=
this
->
any_cast
<
PrivateDetailTypeErasedT
>
();
if
(
derived
and
private_detail_te_handle_mem_var
.
unique
())
{
*
derived
=
std
::
forward
<
PrivateDetailTypeErasedT
>
(
value
);
}
else
{
pass
rhs
(
value
);
swap
(
private_detail_te_handle_mem_var
,
rhs
.
private_detail_te_handle_mem_var
);
}
return
*
this
;
}
...
...
@@ -69,7 +75,7 @@ struct pass
template
<
typename
PrivateDetailTypeErasedT
>
PrivateDetailTypeErasedT
*
any_cast
()
{
return
private_detail_te_get_handle
().
type
()
==
typeid
(
PrivateDetailTypeErasedT
)
return
this
->
type
_id
()
==
typeid
(
PrivateDetailTypeErasedT
)
?
std
::
addressof
(
static_cast
<
private_detail_te_handle_type
<
typename
std
::
remove_cv
<
PrivateDetailTypeErasedT
>::
type
>&>
(
private_detail_te_get_handle
())
...
...
@@ -80,7 +86,7 @@ struct pass
template
<
typename
PrivateDetailTypeErasedT
>
const
typename
std
::
remove_cv
<
PrivateDetailTypeErasedT
>::
type
*
any_cast
()
const
{
return
private_detail_te_get_handle
().
type
()
==
typeid
(
PrivateDetailTypeErasedT
)
return
this
->
type
_id
()
==
typeid
(
PrivateDetailTypeErasedT
)
?
std
::
addressof
(
static_cast
<
const
private_detail_te_handle_type
<
typename
std
::
remove_cv
<
PrivateDetailTypeErasedT
>::
type
>&>
(
private_detail_te_get_handle
())
...
...
Prev
1
2
3
4
5
6
…
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment