Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
3a848f0d
"tests/gpt2/test_modeling_flax_gpt2.py" did not exist on "eb3e072a3b24806c72e35e9246e1cf972de1c77f"
Commit
3a848f0d
authored
Mar 19, 2020
by
Paul
Browse files
Merge branch 'develop' into doc2
parents
64e8e30a
d1e945da
Changes
208
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
319 additions
and
55 deletions
+319
-55
src/include/migraphx/instruction.hpp
src/include/migraphx/instruction.hpp
+1
-1
src/include/migraphx/onnx.hpp
src/include/migraphx/onnx.hpp
+13
-1
src/include/migraphx/op/acosh.hpp
src/include/migraphx/op/acosh.hpp
+22
-0
src/include/migraphx/op/argmax.hpp
src/include/migraphx/op/argmax.hpp
+14
-7
src/include/migraphx/op/argmin.hpp
src/include/migraphx/op/argmin.hpp
+13
-7
src/include/migraphx/op/asinh.hpp
src/include/migraphx/op/asinh.hpp
+22
-0
src/include/migraphx/op/atanh.hpp
src/include/migraphx/op/atanh.hpp
+22
-0
src/include/migraphx/op/deconvolution.hpp
src/include/migraphx/op/deconvolution.hpp
+68
-0
src/include/migraphx/op/flatten.hpp
src/include/migraphx/op/flatten.hpp
+12
-9
src/include/migraphx/op/logsoftmax.hpp
src/include/migraphx/op/logsoftmax.hpp
+3
-2
src/include/migraphx/op/prelu.hpp
src/include/migraphx/op/prelu.hpp
+22
-0
src/include/migraphx/op/reduce_op.hpp
src/include/migraphx/op/reduce_op.hpp
+9
-0
src/include/migraphx/op/reduce_prod.hpp
src/include/migraphx/op/reduce_prod.hpp
+27
-0
src/include/migraphx/op/softmax.hpp
src/include/migraphx/op/softmax.hpp
+3
-2
src/include/migraphx/op/squeeze.hpp
src/include/migraphx/op/squeeze.hpp
+12
-4
src/include/migraphx/op/transpose.hpp
src/include/migraphx/op/transpose.hpp
+14
-5
src/include/migraphx/op/unsqueeze.hpp
src/include/migraphx/op/unsqueeze.hpp
+8
-1
src/include/migraphx/operation.hpp
src/include/migraphx/operation.hpp
+13
-7
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+8
-2
src/include/migraphx/pass.hpp
src/include/migraphx/pass.hpp
+13
-7
No files found.
src/include/migraphx/instruction.hpp
View file @
3a848f0d
...
...
@@ -93,7 +93,7 @@ struct instruction
void
replace
(
const
shape
&
r
);
operation
op
;
shape
result
;
shape
result
{}
;
std
::
vector
<
instruction_ref
>
output
;
std
::
vector
<
instruction_ref
>
arguments
;
literal
lit
;
...
...
src/include/migraphx/onnx.hpp
View file @
3a848f0d
...
...
@@ -7,8 +7,20 @@
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
/// struct to pass in onnx options to parser
struct
onnx_options
{
unsigned
int
batch_size
=
1
;
};
/// Create a program from an onnx file
program
parse_onnx
(
const
std
::
string
&
name
);
program
parse_onnx
(
const
std
::
string
&
name
,
onnx_options
=
onnx_options
{});
/// Create a program from an onnx buffer
program
parse_onnx_buffer
(
const
std
::
string
&
buffer
,
onnx_options
options
);
/// Create a program from an onnx buffer
program
parse_onnx_buffer
(
const
void
*
data
,
std
::
size_t
size
,
onnx_options
options
);
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
...
...
src/include/migraphx/op/acosh.hpp
0 → 100644
View file @
3a848f0d
#ifndef MIGRAPHX_GUARD_OPERATORS_ACOSH_HPP
#define MIGRAPHX_GUARD_OPERATORS_ACOSH_HPP
#include <migraphx/op/unary.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
acosh
:
unary
<
acosh
>
{
auto
apply
()
const
{
return
[](
auto
x
)
{
return
std
::
acosh
(
x
);
};
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/argmax.hpp
View file @
3a848f0d
...
...
@@ -27,25 +27,30 @@ struct argmax
check_shapes
{
inputs
,
*
this
}.
has
(
1
).
standard
();
auto
lens
=
inputs
[
0
].
lens
();
int64_t
n_dim
=
static_cast
<
int64_t
>
(
lens
.
size
());
if
(
axis
>=
n_dim
||
axis
<
0
)
if
(
axis
>=
n_dim
||
axis
<
-
n_dim
)
{
MIGRAPHX_THROW
(
"ARGMAX: axis is out of range."
);
}
lens
[
axis
]
=
1
;
int64_t
tuned_axis
=
(
axis
<
0
)
?
axis
+
n_dim
:
axis
;
lens
[
tuned_axis
]
=
1
;
return
{
shape
::
int64_type
,
lens
};
}
template
<
class
T
>
int64_t
calc_argmax
(
T
&
input
,
std
::
vector
<
std
::
size_t
>&
indices
,
size_t
item_num
)
const
int64_t
calc_argmax
(
T
&
input
,
int64_t
tuned_axis
,
std
::
vector
<
std
::
size_t
>&
indices
,
size_t
item_num
)
const
{
auto
max_val
=
input
(
indices
.
begin
(),
indices
.
end
());
int64_t
max_index
=
0
;
for
(
std
::
size_t
i
=
1
;
i
<
item_num
;
++
i
)
{
indices
[
axis
]
=
i
;
auto
cur_val
=
input
(
indices
.
begin
(),
indices
.
end
());
indices
[
tuned_
axis
]
=
i
;
auto
cur_val
=
input
(
indices
.
begin
(),
indices
.
end
());
if
(
max_val
<
cur_val
)
{
max_val
=
cur_val
;
...
...
@@ -59,13 +64,15 @@ struct argmax
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
output_shape
};
auto
batch_item_num
=
args
.
front
().
get_shape
().
lens
()[
axis
];
auto
n_dim
=
args
.
front
().
get_shape
().
lens
().
size
();
auto
tuned_axis
=
axis
<
0
?
axis
+
n_dim
:
axis
;
auto
batch_item_num
=
args
.
front
().
get_shape
().
lens
()[
tuned_axis
];
result
.
visit
([
&
](
auto
output
)
{
args
[
0
].
visit
([
&
](
auto
input
)
{
par_for
(
output_shape
.
elements
(),
[
&
](
auto
i
)
{
auto
data_idx
=
output_shape
.
multi
(
i
);
output
[
i
]
=
this
->
calc_argmax
(
input
,
data_idx
,
batch_item_num
);
output
[
i
]
=
this
->
calc_argmax
(
input
,
tuned_axis
,
data_idx
,
batch_item_num
);
});
});
});
...
...
src/include/migraphx/op/argmin.hpp
View file @
3a848f0d
...
...
@@ -27,25 +27,29 @@ struct argmin
check_shapes
{
inputs
,
*
this
}.
has
(
1
).
standard
();
auto
lens
=
inputs
[
0
].
lens
();
int64_t
n_dim
=
static_cast
<
int64_t
>
(
lens
.
size
());
if
(
axis
>=
n_dim
||
axis
<
0
)
if
(
axis
>=
n_dim
||
axis
<
-
n_dim
)
{
MIGRAPHX_THROW
(
"ARGMIN: axis is out of range."
);
}
lens
[
axis
]
=
1
;
int64_t
tuned_axis
=
(
axis
<
0
)
?
axis
+
n_dim
:
axis
;
lens
[
tuned_axis
]
=
1
;
return
{
shape
::
int64_type
,
lens
};
}
template
<
class
T
>
int64_t
calc_argmin
(
T
&
input
,
std
::
vector
<
std
::
size_t
>&
indices
,
size_t
item_num
)
const
int64_t
calc_argmin
(
T
&
input
,
int64_t
tuned_axis
,
std
::
vector
<
std
::
size_t
>&
indices
,
size_t
item_num
)
const
{
auto
min_val
=
input
(
indices
.
begin
(),
indices
.
end
());
int64_t
min_index
=
0
;
for
(
std
::
size_t
i
=
1
;
i
<
item_num
;
++
i
)
{
indices
[
axis
]
=
i
;
auto
cur_val
=
input
(
indices
.
begin
(),
indices
.
end
());
indices
[
tuned_
axis
]
=
i
;
auto
cur_val
=
input
(
indices
.
begin
(),
indices
.
end
());
if
(
min_val
>
cur_val
)
{
min_val
=
cur_val
;
...
...
@@ -59,13 +63,15 @@ struct argmin
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
output_shape
};
std
::
size_t
batch_item_num
=
args
.
front
().
get_shape
().
lens
()[
axis
];
auto
n_dim
=
args
.
front
().
get_shape
().
lens
().
size
();
auto
tuned_axis
=
axis
<
0
?
axis
+
n_dim
:
axis
;
std
::
size_t
batch_item_num
=
args
.
front
().
get_shape
().
lens
()[
tuned_axis
];
result
.
visit
([
&
](
auto
output
)
{
args
[
0
].
visit
([
&
](
auto
input
)
{
par_for
(
output_shape
.
elements
(),
[
&
](
auto
i
)
{
auto
data_idx
=
output_shape
.
multi
(
i
);
output
[
i
]
=
this
->
calc_argmin
(
input
,
data_idx
,
batch_item_num
);
output
[
i
]
=
this
->
calc_argmin
(
input
,
tuned_axis
,
data_idx
,
batch_item_num
);
});
});
});
...
...
src/include/migraphx/op/asinh.hpp
0 → 100644
View file @
3a848f0d
#ifndef MIGRAPHX_GUARD_OPERATORS_ASINH_HPP
#define MIGRAPHX_GUARD_OPERATORS_ASINH_HPP
#include <migraphx/op/unary.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
asinh
:
unary
<
asinh
>
{
auto
apply
()
const
{
return
[](
auto
x
)
{
return
std
::
asinh
(
x
);
};
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/atanh.hpp
0 → 100644
View file @
3a848f0d
#ifndef MIGRAPHX_GUARD_OPERATORS_ATANH_HPP
#define MIGRAPHX_GUARD_OPERATORS_ATANH_HPP
#include <migraphx/op/unary.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
atanh
:
unary
<
atanh
>
{
auto
apply
()
const
{
return
[](
auto
x
)
{
return
std
::
atanh
(
x
);
};
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/deconvolution.hpp
0 → 100644
View file @
3a848f0d
#ifndef MIGRAPHX_GUARD_OPERATORS_DECONVOLUTION_HPP
#define MIGRAPHX_GUARD_OPERATORS_DECONVOLUTION_HPP
#include <array>
#include <migraphx/op/common.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
deconvolution
{
std
::
array
<
std
::
size_t
,
2
>
padding
=
{{
0
,
0
}};
std
::
array
<
std
::
size_t
,
2
>
stride
=
{{
1
,
1
}};
std
::
array
<
std
::
size_t
,
2
>
dilation
=
{{
1
,
1
}};
padding_mode_t
padding_mode
=
default_
;
int
group
=
1
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
padding
,
"padding"
),
f
(
self
.
stride
,
"stride"
),
f
(
self
.
dilation
,
"dilation"
),
f
(
self
.
padding_mode
,
"padding_mode"
),
f
(
self
.
group
,
"group"
));
}
std
::
string
name
()
const
{
return
"deconvolution"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
2
).
same_type
().
same_ndims
().
only_dims
(
4
);
const
shape
&
input
=
inputs
.
at
(
0
);
const
shape
&
weights
=
inputs
.
at
(
1
);
auto
t
=
input
.
type
();
return
{
t
,
{
input
.
lens
()[
0
],
weights
.
lens
()[
1
],
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
1
,
stride
[
0
]
*
(
input
.
lens
()[
2
]
-
1
)
+
((
weights
.
lens
()[
2
]
-
1
)
*
dilation
[
0
]
+
1
)
-
2
*
padding
[
0
])),
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
1
,
stride
[
1
]
*
(
input
.
lens
()[
3
]
-
1
)
+
((
weights
.
lens
()[
3
]
-
1
)
*
dilation
[
1
]
+
1
)
-
2
*
padding
[
1
])),
}};
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/flatten.hpp
View file @
3a848f0d
...
...
@@ -18,7 +18,7 @@ namespace op {
struct
flatten
{
u
int64_t
axis
=
0
;
int64_t
axis
=
1
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
...
...
@@ -30,16 +30,19 @@ struct flatten
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
}.
has
(
1
);
auto
&&
lens
=
inputs
.
front
().
lens
();
if
(
axis
>
lens
.
size
()
)
auto
&&
lens
=
inputs
.
front
().
lens
();
int64_t
n_dim
=
static_cast
<
int64_t
>
(
lens
.
size
());
if
(
axis
>
n_dim
or
axis
<
-
n_dim
)
{
MIGRAPHX_THROW
(
"axis for flatten
must be less than tensor
ran
k
"
);
MIGRAPHX_THROW
(
"
FLATTEN:
axis for flatten
is out of
ran
ge
"
);
}
auto
x
=
std
::
accumulate
(
lens
.
begin
(),
lens
.
begin
()
+
axis
,
std
::
size_t
{
1
},
std
::
multiplies
<>
{});
auto
y
=
std
::
accumulate
(
lens
.
begin
()
+
axis
,
lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<>
{});
auto
tuned_axis
=
(
axis
<
0
)
?
axis
+
n_dim
:
axis
;
auto
x
=
std
::
accumulate
(
lens
.
begin
(),
lens
.
begin
()
+
tuned_axis
,
std
::
size_t
{
1
},
std
::
multiplies
<>
{});
auto
y
=
std
::
accumulate
(
lens
.
begin
()
+
tuned_axis
,
lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<>
{});
return
{
inputs
.
at
(
0
).
type
(),
{
x
,
y
}};
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
...
...
src/include/migraphx/op/logsoftmax.hpp
View file @
3a848f0d
...
...
@@ -11,7 +11,7 @@ namespace op {
struct
logsoftmax
{
int
axis
=
1
;
int
64_t
axis
=
1
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
...
...
@@ -23,7 +23,8 @@ struct logsoftmax
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
}.
has
(
1
).
standard
();
if
(
axis
<
0
||
axis
>=
inputs
[
0
].
lens
().
size
())
int64_t
n_dim
=
static_cast
<
int64_t
>
(
inputs
[
0
].
lens
().
size
());
if
(
axis
<
-
n_dim
||
axis
>=
n_dim
)
{
MIGRAPHX_THROW
(
"LogSoftMax: input axis value "
+
std
::
to_string
(
axis
)
+
" is out of range"
);
...
...
src/include/migraphx/op/prelu.hpp
0 → 100644
View file @
3a848f0d
#ifndef MIGRAPHX_GUARD_OPERATORS_PRELU_HPP
#define MIGRAPHX_GUARD_OPERATORS_PRELU_HPP
#include <migraphx/op/binary.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
prelu
:
binary
<
prelu
>
{
auto
apply
()
const
{
return
[](
auto
x
,
auto
slope
)
{
return
((
x
<
0
)
?
(
x
*
slope
)
:
x
);
};
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/reduce_op.hpp
View file @
3a848f0d
...
...
@@ -40,6 +40,15 @@ struct zero
}
};
struct
one
{
template
<
class
T
>
operator
T
()
const
{
return
T
{
1
};
}
};
template
<
class
Derived
>
struct
reduce_op
:
op_name
<
Derived
>
{
...
...
src/include/migraphx/op/reduce_prod.hpp
0 → 100644
View file @
3a848f0d
#ifndef MIGRAPHX_GUARD_OPERATORS_REDUCE_PROD_HPP
#define MIGRAPHX_GUARD_OPERATORS_REDUCE_PROD_HPP
#include <migraphx/op/reduce_op.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
reduce_prod
:
reduce_op
<
reduce_prod
>
{
reduce_prod
()
{}
reduce_prod
(
std
::
vector
<
int64_t
>
ax
)
:
reduce_op
(
std
::
move
(
ax
))
{}
auto
op
()
const
{
return
[
=
](
auto
x
,
auto
y
)
{
return
x
*
y
;
};
}
auto
init
()
const
{
return
one
();
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/softmax.hpp
View file @
3a848f0d
...
...
@@ -11,7 +11,7 @@ namespace op {
struct
softmax
{
int
axis
=
1
;
int
64_t
axis
=
1
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
...
...
@@ -23,7 +23,8 @@ struct softmax
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
}.
has
(
1
).
standard
();
if
(
axis
<
0
||
axis
>=
inputs
[
0
].
lens
().
size
())
int64_t
n_dim
=
inputs
[
0
].
lens
().
size
();
if
(
axis
<
-
n_dim
||
axis
>=
n_dim
)
{
MIGRAPHX_THROW
(
"SoftMax: input axis value "
+
std
::
to_string
(
axis
)
+
" is out of range"
);
...
...
src/include/migraphx/op/squeeze.hpp
View file @
3a848f0d
...
...
@@ -33,13 +33,21 @@ struct squeeze
auto
input_shape
=
inputs
[
0
];
auto
type
=
input_shape
.
type
();
auto
old_lens
=
input_shape
.
lens
();
if
(
std
::
any_of
(
axes
.
begin
(),
axes
.
end
(),
[
&
](
auto
axis
)
{
return
input_shape
.
lens
()[
axis
]
!=
1
;
}))
// change to support negative axis value
std
::
vector
<
int64_t
>
tuned_axes
(
axes
.
size
());
std
::
transform
(
axes
.
begin
(),
axes
.
end
(),
tuned_axes
.
begin
(),
[
&
](
auto
i
)
{
return
i
>=
0
?
i
:
i
+
old_lens
.
size
();
});
if
(
std
::
any_of
(
tuned_axes
.
begin
(),
tuned_axes
.
end
(),
[
&
](
auto
axis
)
{
return
old_lens
[
axis
]
!=
1
;
}))
{
MIGRAPHX_THROW
(
"squeeze axis dimension should be equal to 1"
);
}
std
::
vector
<
std
::
size_t
>
new_lens
;
if
(
axes
.
empty
())
if
(
tuned_
axes
.
empty
())
{
std
::
copy_if
(
old_lens
.
begin
(),
old_lens
.
end
(),
...
...
@@ -50,7 +58,7 @@ struct squeeze
{
for
(
std
::
size_t
i
=
0
;
i
<
old_lens
.
size
();
i
++
)
{
if
(
std
::
find
(
axes
.
begin
(),
axes
.
end
(),
i
)
==
axes
.
end
())
if
(
std
::
find
(
tuned_
axes
.
begin
(),
tuned_
axes
.
end
(),
i
)
==
tuned_
axes
.
end
())
{
new_lens
.
push_back
(
old_lens
[
i
]);
}
...
...
src/include/migraphx/op/transpose.hpp
View file @
3a848f0d
...
...
@@ -34,13 +34,22 @@ struct transpose
auto
input_lens
=
input
.
lens
();
auto
input_strides
=
input
.
strides
();
auto
t
=
input
.
type
();
if
(
dims
.
size
()
!=
input_lens
.
size
())
auto
tuned_dims
=
dims
;
// if not perm provided, reverse the dims
if
(
tuned_dims
.
empty
())
{
tuned_dims
.
resize
(
input_lens
.
size
());
std
::
iota
(
tuned_dims
.
begin
(),
tuned_dims
.
end
(),
0
);
std
::
reverse
(
tuned_dims
.
begin
(),
tuned_dims
.
end
());
}
if
(
tuned_dims
.
size
()
!=
input_lens
.
size
())
{
MIGRAPHX_THROW
(
"Permutation has wrong number of axes"
);
}
std
::
vector
<
int64_t
>
axes
(
dims
.
size
());
std
::
vector
<
int64_t
>
axes
(
tuned_
dims
.
size
());
std
::
iota
(
axes
.
begin
(),
axes
.
end
(),
0
);
if
(
!
std
::
is_permutation
(
axes
.
begin
(),
axes
.
end
(),
dims
.
begin
()))
if
(
!
std
::
is_permutation
(
axes
.
begin
(),
axes
.
end
(),
tuned_
dims
.
begin
()))
{
MIGRAPHX_THROW
(
"Invalid permutation"
);
}
...
...
@@ -48,8 +57,8 @@ struct transpose
std
::
vector
<
size_t
>
output_strides
(
input_lens
.
size
());
for
(
std
::
size_t
i
=
0
;
i
<
output_lens
.
size
();
i
++
)
{
output_lens
[
i
]
=
input_lens
[
dims
[
i
]];
output_strides
[
i
]
=
input_strides
[
dims
[
i
]];
output_lens
[
i
]
=
input_lens
[
tuned_
dims
[
i
]];
output_strides
[
i
]
=
input_strides
[
tuned_
dims
[
i
]];
}
return
{
t
,
output_lens
,
output_strides
};
}
...
...
src/include/migraphx/op/unsqueeze.hpp
View file @
3a848f0d
...
...
@@ -38,11 +38,18 @@ struct unsqueeze
return
shape
{
type
,
old_lens
};
std
::
size_t
new_size
=
old_lens
.
size
()
+
axes
.
size
();
// in case of axes to be negative, tune to positive
std
::
vector
<
int64_t
>
tuned_axes
(
axes
.
size
());
std
::
transform
(
axes
.
begin
(),
axes
.
end
(),
tuned_axes
.
begin
(),
[
new_size
](
auto
i
)
{
return
i
>=
0
?
i
:
i
+
new_size
;
});
std
::
vector
<
std
::
size_t
>
new_lens
(
new_size
);
std
::
size_t
p
=
0
;
for
(
std
::
size_t
i
=
0
;
i
<
new_size
;
i
++
)
{
if
(
std
::
find
(
axes
.
begin
(),
axes
.
end
(),
i
)
!=
axes
.
end
())
if
(
std
::
find
(
tuned_
axes
.
begin
(),
tuned_
axes
.
end
(),
i
)
!=
tuned_
axes
.
end
())
{
new_lens
[
i
]
=
1
;
}
...
...
src/include/migraphx/operation.hpp
View file @
3a848f0d
...
...
@@ -257,11 +257,17 @@ struct operation
template
<
typename
PrivateDetailTypeErasedT
>
operation
&
operator
=
(
PrivateDetailTypeErasedT
value
)
{
if
(
private_detail_te_handle_mem_var
.
unique
())
*
private_detail_te_handle_mem_var
=
std
::
forward
<
PrivateDetailTypeErasedT
>
(
value
);
else
if
(
!
private_detail_te_handle_mem_var
)
private_detail_te_handle_mem_var
=
std
::
make_shared
<
PrivateDetailTypeErasedT
>
(
std
::
forward
<
PrivateDetailTypeErasedT
>
(
value
));
using
std
::
swap
;
auto
*
derived
=
this
->
any_cast
<
PrivateDetailTypeErasedT
>
();
if
(
derived
and
private_detail_te_handle_mem_var
.
unique
())
{
*
derived
=
std
::
forward
<
PrivateDetailTypeErasedT
>
(
value
);
}
else
{
operation
rhs
(
value
);
swap
(
private_detail_te_handle_mem_var
,
rhs
.
private_detail_te_handle_mem_var
);
}
return
*
this
;
}
...
...
@@ -269,7 +275,7 @@ struct operation
template
<
typename
PrivateDetailTypeErasedT
>
PrivateDetailTypeErasedT
*
any_cast
()
{
return
private_detail_te_get_handle
().
type
()
==
typeid
(
PrivateDetailTypeErasedT
)
return
this
->
type
_id
()
==
typeid
(
PrivateDetailTypeErasedT
)
?
std
::
addressof
(
static_cast
<
private_detail_te_handle_type
<
typename
std
::
remove_cv
<
PrivateDetailTypeErasedT
>::
type
>&>
(
private_detail_te_get_handle
())
...
...
@@ -280,7 +286,7 @@ struct operation
template
<
typename
PrivateDetailTypeErasedT
>
const
typename
std
::
remove_cv
<
PrivateDetailTypeErasedT
>::
type
*
any_cast
()
const
{
return
private_detail_te_get_handle
().
type
()
==
typeid
(
PrivateDetailTypeErasedT
)
return
this
->
type
_id
()
==
typeid
(
PrivateDetailTypeErasedT
)
?
std
::
addressof
(
static_cast
<
const
private_detail_te_handle_type
<
typename
std
::
remove_cv
<
PrivateDetailTypeErasedT
>::
type
>&>
(
private_detail_te_get_handle
())
...
...
src/include/migraphx/operators.hpp
View file @
3a848f0d
...
...
@@ -4,12 +4,15 @@
#include <migraphx/op/abnormal_ops.hpp>
#include <migraphx/op/abs.hpp>
#include <migraphx/op/acos.hpp>
#include <migraphx/op/acosh.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/op/argmax.hpp>
#include <migraphx/op/argmin.hpp>
#include <migraphx/op/asin.hpp>
#include <migraphx/op/asinh.hpp>
#include <migraphx/op/as_shape.hpp>
#include <migraphx/op/atan.hpp>
#include <migraphx/op/atanh.hpp>
#include <migraphx/op/batch_norm.hpp>
#include <migraphx/op/binary.hpp>
#include <migraphx/op/broadcast.hpp>
...
...
@@ -23,6 +26,7 @@
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/cosh.hpp>
#include <migraphx/op/cos.hpp>
#include <migraphx/op/deconvolution.hpp>
#include <migraphx/op/div.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/elu.hpp>
...
...
@@ -48,13 +52,15 @@
#include <migraphx/op/outline.hpp>
#include <migraphx/op/pad.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/prelu.hpp>
#include <migraphx/op/quant_convolution.hpp>
#include <migraphx/op/quant_dot.hpp>
#include <migraphx/op/pow.hpp>
#include <migraphx/op/reduce_
sum
.hpp>
#include <migraphx/op/reduce_
max
.hpp>
#include <migraphx/op/reduce_mean.hpp>
#include <migraphx/op/reduce_min.hpp>
#include <migraphx/op/reduce_max.hpp>
#include <migraphx/op/reduce_prod.hpp>
#include <migraphx/op/reduce_sum.hpp>
#include <migraphx/op/relu.hpp>
#include <migraphx/op/reshape.hpp>
#include <migraphx/op/rnn.hpp>
...
...
src/include/migraphx/pass.hpp
View file @
3a848f0d
...
...
@@ -57,11 +57,17 @@ struct pass
template
<
typename
PrivateDetailTypeErasedT
>
pass
&
operator
=
(
PrivateDetailTypeErasedT
value
)
{
if
(
private_detail_te_handle_mem_var
.
unique
())
*
private_detail_te_handle_mem_var
=
std
::
forward
<
PrivateDetailTypeErasedT
>
(
value
);
else
if
(
!
private_detail_te_handle_mem_var
)
private_detail_te_handle_mem_var
=
std
::
make_shared
<
PrivateDetailTypeErasedT
>
(
std
::
forward
<
PrivateDetailTypeErasedT
>
(
value
));
using
std
::
swap
;
auto
*
derived
=
this
->
any_cast
<
PrivateDetailTypeErasedT
>
();
if
(
derived
and
private_detail_te_handle_mem_var
.
unique
())
{
*
derived
=
std
::
forward
<
PrivateDetailTypeErasedT
>
(
value
);
}
else
{
pass
rhs
(
value
);
swap
(
private_detail_te_handle_mem_var
,
rhs
.
private_detail_te_handle_mem_var
);
}
return
*
this
;
}
...
...
@@ -69,7 +75,7 @@ struct pass
template
<
typename
PrivateDetailTypeErasedT
>
PrivateDetailTypeErasedT
*
any_cast
()
{
return
private_detail_te_get_handle
().
type
()
==
typeid
(
PrivateDetailTypeErasedT
)
return
this
->
type
_id
()
==
typeid
(
PrivateDetailTypeErasedT
)
?
std
::
addressof
(
static_cast
<
private_detail_te_handle_type
<
typename
std
::
remove_cv
<
PrivateDetailTypeErasedT
>::
type
>&>
(
private_detail_te_get_handle
())
...
...
@@ -80,7 +86,7 @@ struct pass
template
<
typename
PrivateDetailTypeErasedT
>
const
typename
std
::
remove_cv
<
PrivateDetailTypeErasedT
>::
type
*
any_cast
()
const
{
return
private_detail_te_get_handle
().
type
()
==
typeid
(
PrivateDetailTypeErasedT
)
return
this
->
type
_id
()
==
typeid
(
PrivateDetailTypeErasedT
)
?
std
::
addressof
(
static_cast
<
const
private_detail_te_handle_type
<
typename
std
::
remove_cv
<
PrivateDetailTypeErasedT
>::
type
>&>
(
private_detail_te_get_handle
())
...
...
Prev
1
2
3
4
5
6
…
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment