Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
bf6f82d8
Commit
bf6f82d8
authored
May 16, 2019
by
Paul
Browse files
Merge from develop
parents
6a0797e2
b93f5320
Changes
92
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
240 additions
and
109 deletions
+240
-109
src/include/migraphx/op/leaky_relu.hpp
src/include/migraphx/op/leaky_relu.hpp
+7
-6
src/include/migraphx/op/load.hpp
src/include/migraphx/op/load.hpp
+1
-1
src/include/migraphx/op/logsoftmax.hpp
src/include/migraphx/op/logsoftmax.hpp
+1
-1
src/include/migraphx/op/multibroadcast.hpp
src/include/migraphx/op/multibroadcast.hpp
+2
-2
src/include/migraphx/op/pooling.hpp
src/include/migraphx/op/pooling.hpp
+28
-33
src/include/migraphx/op/reshape.hpp
src/include/migraphx/op/reshape.hpp
+1
-1
src/include/migraphx/op/scalar.hpp
src/include/migraphx/op/scalar.hpp
+1
-1
src/include/migraphx/op/slice.hpp
src/include/migraphx/op/slice.hpp
+1
-1
src/include/migraphx/op/squeeze.hpp
src/include/migraphx/op/squeeze.hpp
+1
-1
src/include/migraphx/op/transpose.hpp
src/include/migraphx/op/transpose.hpp
+1
-1
src/include/migraphx/op/unary.hpp
src/include/migraphx/op/unary.hpp
+9
-1
src/include/migraphx/op/unsqueeze.hpp
src/include/migraphx/op/unsqueeze.hpp
+9
-5
src/include/migraphx/operation.hpp
src/include/migraphx/operation.hpp
+13
-11
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+1
-0
src/include/migraphx/program.hpp
src/include/migraphx/program.hpp
+12
-1
src/include/migraphx/reflect.hpp
src/include/migraphx/reflect.hpp
+61
-6
src/include/migraphx/stringutils.hpp
src/include/migraphx/stringutils.hpp
+3
-2
src/instruction.cpp
src/instruction.cpp
+30
-8
src/onnx/cifar10.cpp
src/onnx/cifar10.cpp
+1
-1
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+57
-26
No files found.
src/include/migraphx/op/leaky_relu.hpp
View file @
bf6f82d8
...
@@ -18,19 +18,20 @@ namespace op {
...
@@ -18,19 +18,20 @@ namespace op {
struct
leaky_relu
struct
leaky_relu
{
{
std
::
string
name
()
const
{
return
"leaky_relu"
;
}
float
alpha
;
float
alpha
;
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
return
inputs
.
front
();
}
template
<
class
Self
,
class
F
>
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
static
auto
reflect
(
Self
&
self
,
F
f
)
{
{
return
pack
(
f
(
self
.
alpha
,
"alpha"
));
return
pack
(
f
(
self
.
alpha
,
"alpha"
));
}
}
std
::
string
name
()
const
{
return
"leaky_relu"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
return
inputs
.
front
();
}
};
};
}
// namespace op
}
// namespace op
...
...
src/include/migraphx/op/load.hpp
View file @
bf6f82d8
...
@@ -39,7 +39,7 @@ struct load
...
@@ -39,7 +39,7 @@ struct load
MIGRAPHX_THROW
(
"Load access is out of bounds"
);
MIGRAPHX_THROW
(
"Load access is out of bounds"
);
return
{
s
,
args
[
0
].
data
()
+
offset
};
return
{
s
,
args
[
0
].
data
()
+
offset
};
}
}
in
t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
std
::
ptrdiff_
t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
load
&
op
)
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
load
&
op
)
{
{
...
...
src/include/migraphx/op/logsoftmax.hpp
View file @
bf6f82d8
...
@@ -29,7 +29,7 @@ struct logsoftmax
...
@@ -29,7 +29,7 @@ struct logsoftmax
std
::
string
name
()
const
{
return
"logsoftmax"
;
}
std
::
string
name
()
const
{
return
"logsoftmax"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
}.
has
(
1
);
check_shapes
{
inputs
}.
has
(
1
)
.
standard
()
;
if
(
axis
<
0
||
axis
>
inputs
[
0
].
lens
().
size
())
if
(
axis
<
0
||
axis
>
inputs
[
0
].
lens
().
size
())
{
{
MIGRAPHX_THROW
(
"LogSoftMax: input axis value "
+
std
::
to_string
(
axis
)
+
MIGRAPHX_THROW
(
"LogSoftMax: input axis value "
+
std
::
to_string
(
axis
)
+
...
...
src/include/migraphx/op/multibroadcast.hpp
View file @
bf6f82d8
...
@@ -42,7 +42,7 @@ struct multibroadcast
...
@@ -42,7 +42,7 @@ struct multibroadcast
std
::
vector
<
size_t
>
bcast_strides
(
output_lens
.
size
(),
0
);
std
::
vector
<
size_t
>
bcast_strides
(
output_lens
.
size
(),
0
);
auto
offset
=
output_lens
.
size
()
-
input
.
lens
().
size
();
auto
offset
=
output_lens
.
size
()
-
input
.
lens
().
size
();
for
(
in
t
i
=
input
.
lens
().
size
()
-
1
;
i
>=
0
;
i
--
)
for
(
std
::
ptrdiff_
t
i
=
input
.
lens
().
size
()
-
1
;
i
>=
0
;
i
--
)
{
{
if
(
output_lens
[
i
+
offset
]
==
input
.
lens
()[
i
])
if
(
output_lens
[
i
+
offset
]
==
input
.
lens
()[
i
])
{
{
...
@@ -55,7 +55,7 @@ struct multibroadcast
...
@@ -55,7 +55,7 @@ struct multibroadcast
{
{
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
at
(
0
).
data
)};
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
at
(
0
).
data
)};
}
}
in
t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
std
::
ptrdiff_
t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
};
}
// namespace op
}
// namespace op
...
...
src/include/migraphx/op/pooling.hpp
View file @
bf6f82d8
...
@@ -9,6 +9,7 @@
...
@@ -9,6 +9,7 @@
#include <migraphx/streamutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/int_divide.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <cmath>
#include <utility>
#include <utility>
...
@@ -49,20 +50,19 @@ struct pooling
...
@@ -49,20 +50,19 @@ struct pooling
if
(
padding_mode
==
default_
)
if
(
padding_mode
==
default_
)
{
{
return
{
return
{
t
,
t
,
{
{
input
.
lens
()[
0
],
input
.
lens
()[
0
],
input
.
lens
()[
1
],
input
.
lens
()[
1
],
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
1
,
1
,
std
::
ptrdiff_t
(
std
::
floor
((
input
.
lens
()[
2
]
+
2
*
padding
[
0
]
-
lengths
[
0
])
/
floor_divide
<
std
::
ptrdiff_t
>
(
static_cast
<
float
>
(
stride
[
0
])
))
+
input
.
lens
()[
2
]
+
2
*
padding
[
0
]
-
lengths
[
0
],
stride
[
0
])
+
1
)),
1
)),
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
1
,
1
,
std
::
ptrdiff_t
(
std
::
floor
((
input
.
lens
()[
3
]
+
2
*
padding
[
1
]
-
lengths
[
1
])
/
floor_divide
<
std
::
ptrdiff_t
>
(
static_cast
<
float
>
(
stride
[
1
])
))
+
input
.
lens
()[
3
]
+
2
*
padding
[
1
]
-
lengths
[
1
],
stride
[
1
])
+
1
)),
1
)),
}};
}};
}
}
...
@@ -71,27 +71,22 @@ struct pooling
...
@@ -71,27 +71,22 @@ struct pooling
return
{
t
,
return
{
t
,
{
input
.
lens
()[
0
],
{
input
.
lens
()[
0
],
input
.
lens
()[
1
],
input
.
lens
()[
1
],
static_cast
<
std
::
size_t
>
(
ceil_divide
<
std
::
size_t
>
(
input
.
lens
()[
2
],
stride
[
0
]),
std
::
ceil
(
static_cast
<
double
>
(
input
.
lens
()[
2
])
/
stride
[
0
])),
ceil_divide
<
std
::
size_t
>
(
input
.
lens
()[
3
],
stride
[
1
])}};
static_cast
<
std
::
size_t
>
(
std
::
ceil
(
static_cast
<
double
>
(
input
.
lens
()[
3
])
/
stride
[
1
]))}};
}
}
else
if
(
padding_mode
==
valid
)
else
if
(
padding_mode
==
valid
)
{
{
return
{
t
,
return
{
t
,
{
{
input
.
lens
()[
0
],
input
.
lens
()[
0
],
input
.
lens
()[
1
],
input
.
lens
()[
1
],
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
1
,
1
,
std
::
ptrdiff_t
(
std
::
floor
((
input
.
lens
()[
2
]
-
lengths
[
0
])
/
floor_divide
<
std
::
ptrdiff_t
>
(
input
.
lens
()[
2
]
-
lengths
[
0
],
stride
[
0
])
+
1
)),
static_cast
<
float
>
(
stride
[
0
])))
+
1
)),
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
1
,
1
,
std
::
ptrdiff_t
(
std
::
floor
((
input
.
lens
()[
3
]
-
lengths
[
1
])
/
floor_divide
<
std
::
ptrdiff_t
>
(
input
.
lens
()[
3
]
-
lengths
[
1
],
stride
[
1
])
+
1
)),
static_cast
<
float
>
(
stride
[
1
])))
+
1
)),
}};
}};
}
}
else
else
...
...
src/include/migraphx/op/reshape.hpp
View file @
bf6f82d8
...
@@ -66,7 +66,7 @@ struct reshape
...
@@ -66,7 +66,7 @@ struct reshape
{
{
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
front
().
data
)};
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
front
().
data
)};
}
}
in
t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
std
::
ptrdiff_
t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
};
}
// namespace op
}
// namespace op
...
...
src/include/migraphx/op/scalar.hpp
View file @
bf6f82d8
...
@@ -40,7 +40,7 @@ struct scalar
...
@@ -40,7 +40,7 @@ struct scalar
{
{
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
at
(
0
).
data
)};
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
at
(
0
).
data
)};
}
}
in
t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
std
::
ptrdiff_
t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
};
}
// namespace op
}
// namespace op
...
...
src/include/migraphx/op/slice.hpp
View file @
bf6f82d8
...
@@ -86,7 +86,7 @@ struct slice
...
@@ -86,7 +86,7 @@ struct slice
auto
offset
=
compute_offset
(
input
.
get_shape
())
*
output_shape
.
type_size
();
auto
offset
=
compute_offset
(
input
.
get_shape
())
*
output_shape
.
type_size
();
return
{
std
::
move
(
output_shape
),
[
=
]
{
return
input
.
data
()
+
offset
;
}};
return
{
std
::
move
(
output_shape
),
[
=
]
{
return
input
.
data
()
+
offset
;
}};
}
}
in
t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
std
::
ptrdiff_
t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
};
}
// namespace op
}
// namespace op
...
...
src/include/migraphx/op/squeeze.hpp
View file @
bf6f82d8
...
@@ -70,7 +70,7 @@ struct squeeze
...
@@ -70,7 +70,7 @@ struct squeeze
{
{
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
front
().
data
)};
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
front
().
data
)};
}
}
in
t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
std
::
ptrdiff_
t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
};
}
// namespace op
}
// namespace op
...
...
src/include/migraphx/op/transpose.hpp
View file @
bf6f82d8
...
@@ -57,7 +57,7 @@ struct transpose
...
@@ -57,7 +57,7 @@ struct transpose
{
{
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
front
().
data
)};
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
front
().
data
)};
}
}
in
t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
std
::
ptrdiff_
t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
};
}
// namespace op
}
// namespace op
...
...
src/include/migraphx/op/unary.hpp
View file @
bf6f82d8
...
@@ -13,7 +13,15 @@ struct unary : op_name<Derived>
...
@@ -13,7 +13,15 @@ struct unary : op_name<Derived>
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
}.
has
(
1
);
check_shapes
{
inputs
}.
has
(
1
);
return
inputs
.
at
(
0
);
auto
s
=
inputs
.
at
(
0
);
if
(
s
.
packed
())
{
return
s
;
}
else
{
return
{
s
.
type
(),
s
.
lens
()};
}
}
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
...
...
src/include/migraphx/op/unsqueeze.hpp
View file @
bf6f82d8
...
@@ -29,10 +29,14 @@ struct unsqueeze
...
@@ -29,10 +29,14 @@ struct unsqueeze
std
::
string
name
()
const
{
return
"unsqueeze"
;
}
std
::
string
name
()
const
{
return
"unsqueeze"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
).
standard
();
check_shapes
{
inputs
,
*
this
}.
has
(
1
).
standard
_or_scalar
();
auto
input_shape
=
inputs
[
0
];
auto
input_shape
=
inputs
[
0
];
auto
type
=
input_shape
.
type
();
auto
type
=
input_shape
.
type
();
auto
old_lens
=
input_shape
.
lens
();
auto
old_lens
=
input_shape
.
lens
();
if
(
input_shape
.
scalar
())
return
shape
{
type
,
old_lens
};
std
::
size_t
new_size
=
old_lens
.
size
()
+
axes
.
size
();
std
::
size_t
new_size
=
old_lens
.
size
()
+
axes
.
size
();
std
::
vector
<
std
::
size_t
>
new_lens
(
new_size
);
std
::
vector
<
std
::
size_t
>
new_lens
(
new_size
);
std
::
size_t
p
=
0
;
std
::
size_t
p
=
0
;
...
@@ -53,7 +57,7 @@ struct unsqueeze
...
@@ -53,7 +57,7 @@ struct unsqueeze
{
{
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
front
().
data
)};
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
front
().
data
)};
}
}
in
t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
std
::
ptrdiff_
t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
};
}
// namespace op
}
// namespace op
...
...
src/include/migraphx/operation.hpp
View file @
bf6f82d8
...
@@ -49,7 +49,7 @@ struct operation
...
@@ -49,7 +49,7 @@ struct operation
argument
compute
(
context
&
ctx
,
const
shape
&
output
,
const
std
::
vector
<
argument
>&
input
)
const
;
argument
compute
(
context
&
ctx
,
const
shape
&
output
,
const
std
::
vector
<
argument
>&
input
)
const
;
/// An optional method to return which argument the output will alias. If
/// An optional method to return which argument the output will alias. If
/// there is no aliased output then -1 can be returned.
/// there is no aliased output then -1 can be returned.
in
t
output_alias
(
const
std
::
vector
<
shape
>&
input
)
const
;
std
::
ptrdiff_
t
output_alias
(
const
std
::
vector
<
shape
>&
input
)
const
;
/// An optional stream operator to print the operation. When this is not
/// An optional stream operator to print the operation. When this is not
/// implemented, it will just print the operation's name.
/// implemented, it will just print the operation's name.
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
operation
&
op
);
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
operation
&
op
);
...
@@ -69,7 +69,7 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
...
@@ -69,7 +69,7 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
{
{
os
<<
x
.
name
();
os
<<
x
.
name
();
char
delim
=
'['
;
char
delim
=
'['
;
reflect_each
(
x
,
[
&
](
auto
&
y
,
auto
name
)
{
reflect_each
(
x
,
[
&
](
auto
&
&
y
,
auto
name
)
{
os
<<
delim
;
os
<<
delim
;
os
<<
name
<<
"="
;
os
<<
name
<<
"="
;
stream_write_value
(
os
,
y
);
stream_write_value
(
os
,
y
);
...
@@ -87,6 +87,8 @@ namespace operation_equal {
...
@@ -87,6 +87,8 @@ namespace operation_equal {
template
<
class
T
,
class
U
>
template
<
class
T
,
class
U
>
auto
operator
==
(
const
T
&
x
,
const
U
&
y
)
->
decltype
(
x
.
name
()
==
y
.
name
())
auto
operator
==
(
const
T
&
x
,
const
U
&
y
)
->
decltype
(
x
.
name
()
==
y
.
name
())
{
{
static_assert
(
is_reflectable
<
T
>
{}
or
sizeof
(
T
)
<=
1
,
"Missing equality operator or reflect method."
);
if
(
x
.
name
()
!=
y
.
name
())
if
(
x
.
name
()
!=
y
.
name
())
return
false
;
return
false
;
const
auto
&
yy
=
any_cast
<
T
>
(
y
);
const
auto
&
yy
=
any_cast
<
T
>
(
y
);
...
@@ -175,7 +177,7 @@ auto is_context_free_op(const T& x) -> decltype(is_context_free_op(
...
@@ -175,7 +177,7 @@ auto is_context_free_op(const T& x) -> decltype(is_context_free_op(
}
}
template
<
class
T
>
template
<
class
T
>
in
t
output_alias_op
(
rank
<
0
>
,
const
T
&
,
const
std
::
vector
<
shape
>&
)
std
::
ptrdiff_
t
output_alias_op
(
rank
<
0
>
,
const
T
&
,
const
std
::
vector
<
shape
>&
)
{
{
return
-
1
;
return
-
1
;
}
}
...
@@ -188,7 +190,7 @@ auto output_alias_op(rank<1>, const T& x, const std::vector<shape>& shapes)
...
@@ -188,7 +190,7 @@ auto output_alias_op(rank<1>, const T& x, const std::vector<shape>& shapes)
}
}
template
<
class
T
>
template
<
class
T
>
in
t
output_alias_op
(
const
T
&
x
,
const
std
::
vector
<
shape
>&
shapes
)
std
::
ptrdiff_
t
output_alias_op
(
const
T
&
x
,
const
std
::
vector
<
shape
>&
shapes
)
{
{
return
output_alias_op
(
rank
<
1
>
{},
x
,
shapes
);
return
output_alias_op
(
rank
<
1
>
{},
x
,
shapes
);
}
}
...
@@ -239,7 +241,7 @@ auto has_finalize_op(const T&) -> decltype(has_finalize_op(rank<1>{},
...
@@ -239,7 +241,7 @@ auto has_finalize_op(const T&) -> decltype(has_finalize_op(rank<1>{},
* std::string name() const;
* std::string name() const;
* bool is_context_free() const;
* bool is_context_free() const;
* bool has_finalize() const;
* bool has_finalize() const;
*
in
t output_alias(const std::vector<shape>& input) const;
*
std::ptrdiff_
t output_alias(const std::vector<shape>& input) const;
* void finalize(context& ctx,const shape& output,const std::vector<shape>& input) ;
* void finalize(context& ctx,const shape& output,const std::vector<shape>& input) ;
* shape compute_shape(const std::vector<shape>& input) const;
* shape compute_shape(const std::vector<shape>& input) const;
* argument compute(context& ctx,const shape& output,const std::vector<argument>& input) const;
* argument compute(context& ctx,const shape& output,const std::vector<argument>& input) const;
...
@@ -325,7 +327,7 @@ struct operation
...
@@ -325,7 +327,7 @@ struct operation
return
(
*
this
).
private_detail_te_get_handle
().
has_finalize
();
return
(
*
this
).
private_detail_te_get_handle
().
has_finalize
();
}
}
in
t
output_alias
(
const
std
::
vector
<
shape
>&
input
)
const
std
::
ptrdiff_
t
output_alias
(
const
std
::
vector
<
shape
>&
input
)
const
{
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
assert
((
*
this
).
private_detail_te_handle_mem_var
);
return
(
*
this
).
private_detail_te_get_handle
().
output_alias
(
input
);
return
(
*
this
).
private_detail_te_get_handle
().
output_alias
(
input
);
...
@@ -383,7 +385,7 @@ struct operation
...
@@ -383,7 +385,7 @@ struct operation
virtual
std
::
string
name
()
const
=
0
;
virtual
std
::
string
name
()
const
=
0
;
virtual
bool
is_context_free
()
const
=
0
;
virtual
bool
is_context_free
()
const
=
0
;
virtual
bool
has_finalize
()
const
=
0
;
virtual
bool
has_finalize
()
const
=
0
;
virtual
in
t
output_alias
(
const
std
::
vector
<
shape
>&
input
)
const
=
0
;
virtual
std
::
ptrdiff_
t
output_alias
(
const
std
::
vector
<
shape
>&
input
)
const
=
0
;
virtual
void
virtual
void
finalize
(
context
&
ctx
,
const
shape
&
output
,
const
std
::
vector
<
shape
>&
input
)
=
0
;
finalize
(
context
&
ctx
,
const
shape
&
output
,
const
std
::
vector
<
shape
>&
input
)
=
0
;
virtual
shape
compute_shape
(
const
std
::
vector
<
shape
>&
input
)
const
=
0
;
virtual
shape
compute_shape
(
const
std
::
vector
<
shape
>&
input
)
const
=
0
;
...
@@ -432,7 +434,7 @@ struct operation
...
@@ -432,7 +434,7 @@ struct operation
bool
has_finalize
()
const
override
{
return
has_finalize_op
(
private_detail_te_value
);
}
bool
has_finalize
()
const
override
{
return
has_finalize_op
(
private_detail_te_value
);
}
in
t
output_alias
(
const
std
::
vector
<
shape
>&
input
)
const
override
std
::
ptrdiff_
t
output_alias
(
const
std
::
vector
<
shape
>&
input
)
const
override
{
{
return
output_alias_op
(
private_detail_te_value
,
input
);
return
output_alias_op
(
private_detail_te_value
,
input
);
...
...
src/include/migraphx/operators.hpp
View file @
bf6f82d8
...
@@ -11,6 +11,7 @@
...
@@ -11,6 +11,7 @@
#include <migraphx/op/batch_norm.hpp>
#include <migraphx/op/batch_norm.hpp>
#include <migraphx/op/binary.hpp>
#include <migraphx/op/binary.hpp>
#include <migraphx/op/broadcast.hpp>
#include <migraphx/op/broadcast.hpp>
#include <migraphx/op/clip.hpp>
#include <migraphx/op/common.hpp>
#include <migraphx/op/common.hpp>
#include <migraphx/op/concat.hpp>
#include <migraphx/op/concat.hpp>
#include <migraphx/op/contiguous.hpp>
#include <migraphx/op/contiguous.hpp>
...
...
src/include/migraphx/program.hpp
View file @
bf6f82d8
...
@@ -30,8 +30,16 @@ const operation& get_operation(instruction_ref ins);
...
@@ -30,8 +30,16 @@ const operation& get_operation(instruction_ref ins);
struct
program
struct
program
{
{
program
();
program
();
// move constructor
program
(
program
&&
)
noexcept
;
program
(
program
&&
)
noexcept
;
program
&
operator
=
(
program
&&
)
noexcept
;
// copy constructor
program
(
const
program
&
);
// copy assignment operator
program
&
operator
=
(
program
);
~
program
()
noexcept
;
~
program
()
noexcept
;
using
parameter_map
=
std
::
unordered_map
<
std
::
string
,
argument
>
;
using
parameter_map
=
std
::
unordered_map
<
std
::
string
,
argument
>
;
...
@@ -118,6 +126,9 @@ struct program
...
@@ -118,6 +126,9 @@ struct program
friend
bool
operator
==
(
const
program
&
x
,
const
program
&
y
);
friend
bool
operator
==
(
const
program
&
x
,
const
program
&
y
);
friend
bool
operator
!=
(
const
program
&
x
,
const
program
&
y
)
{
return
!
(
x
==
y
);
}
friend
bool
operator
!=
(
const
program
&
x
,
const
program
&
y
)
{
return
!
(
x
==
y
);
}
private:
void
assign
(
const
program
&
p
);
private:
private:
std
::
unique_ptr
<
program_impl
>
impl
;
std
::
unique_ptr
<
program_impl
>
impl
;
};
};
...
...
src/include/migraphx/reflect.hpp
View file @
bf6f82d8
...
@@ -11,6 +11,15 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -11,6 +11,15 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace
detail
{
namespace
detail
{
struct
reflect_placeholder
{
template
<
class
...
Ts
>
int
operator
()(
Ts
&&
...)
const
{
return
0
;
}
};
template
<
class
T
,
class
Selector
>
template
<
class
T
,
class
Selector
>
auto
reflect_impl
(
rank
<
1
>
,
T
&
x
,
Selector
f
)
->
decltype
(
T
::
reflect
(
x
,
f
))
auto
reflect_impl
(
rank
<
1
>
,
T
&
x
,
Selector
f
)
->
decltype
(
T
::
reflect
(
x
,
f
))
{
{
...
@@ -23,8 +32,53 @@ auto reflect_impl(rank<0>, T&, Selector)
...
@@ -23,8 +32,53 @@ auto reflect_impl(rank<0>, T&, Selector)
return
pack
();
return
pack
();
}
}
template
<
class
T
>
auto
reflectable_impl
(
rank
<
1
>
,
T
&&
x
)
->
decltype
(
T
::
reflect
(
x
,
reflect_placeholder
{}),
std
::
true_type
{});
template
<
class
T
>
auto
reflectable_impl
(
rank
<
0
>
,
T
&&
)
->
decltype
(
std
::
false_type
{});
template
<
class
T
>
struct
remove_rvalue_reference
{
using
type
=
T
;
};
template
<
class
T
>
struct
remove_rvalue_reference
<
T
&&>
{
using
type
=
T
;
};
template
<
class
T
>
struct
wrapper
{
using
type
=
typename
remove_rvalue_reference
<
T
>::
type
;
type
data
;
type
get
()
const
{
return
data
;
}
};
template
<
class
T
>
wrapper
<
T
>
wrap
(
std
::
remove_reference_t
<
T
>&
x
)
{
return
wrapper
<
T
>
{
std
::
forward
<
T
>
(
x
)};
}
template
<
class
...
Ts
>
using
auto_tuple_t
=
std
::
tuple
<
typename
remove_rvalue_reference
<
Ts
>::
type
...
>
;
template
<
class
...
Ts
>
auto_tuple_t
<
Ts
...
>
auto_tuple
(
Ts
&&
...
xs
)
{
return
auto_tuple_t
<
Ts
...
>
{
std
::
forward
<
Ts
>
(
xs
)...};
}
}
// namespace detail
}
// namespace detail
template
<
class
T
>
using
is_reflectable
=
decltype
(
detail
::
reflectable_impl
(
rank
<
1
>
{},
std
::
declval
<
T
>
()));
template
<
class
T
,
class
Selector
>
template
<
class
T
,
class
Selector
>
auto
reflect
(
T
&
x
,
Selector
f
)
auto
reflect
(
T
&
x
,
Selector
f
)
{
{
...
@@ -34,15 +88,16 @@ auto reflect(T& x, Selector f)
...
@@ -34,15 +88,16 @@ auto reflect(T& x, Selector f)
template
<
class
T
>
template
<
class
T
>
auto
reflect_tie
(
T
&
x
)
auto
reflect_tie
(
T
&
x
)
{
{
return
reflect
(
x
,
[](
auto
&&
y
,
auto
&&
...)
{
return
std
::
ref
(
y
);
})(
return
reflect
(
x
,
[](
auto
&&
y
,
auto
&&
...)
{
return
detail
::
wrap
<
decltype
(
y
)
>
(
y
);
})(
[](
auto
&&
...
xs
)
{
return
std
::
ti
e
(
xs
.
get
()...);
});
[](
auto
&&
...
xs
)
{
return
detail
::
auto_tupl
e
(
xs
.
get
()...);
});
}
}
template
<
class
T
,
class
F
>
template
<
class
T
,
class
F
>
void
reflect_each
(
T
&
x
,
F
f
)
void
reflect_each
(
T
&
x
,
F
f
)
{
{
return
reflect
(
x
,
[](
auto
&&
y
,
auto
...
ys
)
{
return
pack
(
std
::
ref
(
y
),
ys
...);
})(
return
reflect
(
x
,
[](
auto
&&
y
,
auto
...
ys
)
{
[
&
](
auto
&&
...
xs
)
{
return
pack
(
detail
::
wrap
<
decltype
(
y
)
>
(
y
),
ys
...);
})([
&
](
auto
&&
...
xs
)
{
each_args
([
&
](
auto
p
)
{
p
([
&
](
auto
&&
y
,
auto
...
ys
)
{
f
(
y
.
get
(),
ys
...);
});
},
xs
...);
each_args
([
&
](
auto
p
)
{
p
([
&
](
auto
&&
y
,
auto
...
ys
)
{
f
(
y
.
get
(),
ys
...);
});
},
xs
...);
});
});
}
}
...
...
src/include/migraphx/stringutils.hpp
View file @
bf6f82d8
...
@@ -38,8 +38,9 @@ inline std::string join_strings(Strings strings, const std::string& delim)
...
@@ -38,8 +38,9 @@ inline std::string join_strings(Strings strings, const std::string& delim)
return
""
;
return
""
;
auto
nit
=
std
::
next
(
it
);
auto
nit
=
std
::
next
(
it
);
return
std
::
accumulate
(
return
std
::
accumulate
(
nit
,
strings
.
end
(),
*
it
,
[
&
](
std
::
string
x
,
std
::
string
y
)
{
nit
,
strings
.
end
(),
*
it
,
[
&
](
std
::
string
x
,
std
::
string
y
)
{
return
x
+
delim
+
y
;
});
return
std
::
move
(
x
)
+
delim
+
std
::
move
(
y
);
});
}
}
template
<
class
F
>
template
<
class
F
>
...
...
src/instruction.cpp
View file @
bf6f82d8
...
@@ -28,6 +28,12 @@ void instruction::replace(const shape& r)
...
@@ -28,6 +28,12 @@ void instruction::replace(const shape& r)
}
}
}
}
void
instruction
::
replace
(
operation
o
)
{
op
=
std
::
move
(
o
);
recompute_shape
();
}
void
instruction
::
recompute_shape
()
{
replace
(
compute_shape
(
op
,
arguments
));
}
void
instruction
::
recompute_shape
()
{
replace
(
compute_shape
(
op
,
arguments
));
}
void
instruction
::
clear_arguments
()
void
instruction
::
clear_arguments
()
...
@@ -162,7 +168,24 @@ void instruction::replace_argument(instruction_ref old, instruction_ref new_ins)
...
@@ -162,7 +168,24 @@ void instruction::replace_argument(instruction_ref old, instruction_ref new_ins)
old
->
remove_output
(
*
this
);
old
->
remove_output
(
*
this
);
}
}
argument
instruction
::
eval
()
const
bool
instruction
::
can_eval
()
const
{
if
(
op
.
name
()
==
"@literal"
)
{
return
true
;
}
else
if
(
is_context_free
(
op
))
{
return
std
::
all_of
(
this
->
inputs
().
begin
(),
this
->
inputs
().
end
(),
[](
auto
arg
)
{
return
arg
->
can_eval
();
});
}
else
{
return
false
;
}
}
argument
instruction
::
eval
(
bool
check_eval
)
const
{
{
if
(
op
.
name
()
==
"@literal"
)
if
(
op
.
name
()
==
"@literal"
)
{
{
...
@@ -170,14 +193,13 @@ argument instruction::eval() const
...
@@ -170,14 +193,13 @@ argument instruction::eval() const
}
}
if
(
is_context_free
(
op
))
if
(
is_context_free
(
op
))
{
{
std
::
vector
<
argument
>
args
;
if
(
check_eval
and
not
this
->
can_eval
())
for
(
auto
&&
arg
:
this
->
inputs
())
{
argument
a
=
arg
->
eval
();
if
(
a
.
empty
())
return
{};
return
{};
args
.
push_back
(
a
);
std
::
vector
<
argument
>
args
;
}
std
::
transform
(
this
->
inputs
().
begin
(),
this
->
inputs
().
end
(),
std
::
back_inserter
(
args
),
[](
auto
arg
)
{
return
arg
->
eval
(
false
);
});
return
op
.
compute
(
result
,
args
);
return
op
.
compute
(
result
,
args
);
}
}
return
{};
return
{};
...
...
src/onnx/cifar10.cpp
View file @
bf6f82d8
...
@@ -32,7 +32,7 @@ auto read_cifar10_images(const std::string& full_path)
...
@@ -32,7 +32,7 @@ auto read_cifar10_images(const std::string& full_path)
labels
[
i
]
=
*
pimage
++
;
labels
[
i
]
=
*
pimage
++
;
for
(
size_t
j
=
0
;
j
<
nbytes_per_image
;
j
++
)
for
(
size_t
j
=
0
;
j
<
nbytes_per_image
;
j
++
)
{
{
float
v
=
*
(
pimage
+
j
)
/
255.0
f
;
float
v
=
float
(
*
(
pimage
+
j
)
)
/
255.0
f
;
data
[
i
*
nbytes_per_image
+
j
]
=
v
;
data
[
i
*
nbytes_per_image
+
j
]
=
v
;
}
}
}
}
...
...
src/onnx/onnx.cpp
View file @
bf6f82d8
...
@@ -63,6 +63,7 @@ struct onnx_parser
...
@@ -63,6 +63,7 @@ struct onnx_parser
add_variadic_op
(
"Max"
,
op
::
max
{});
add_variadic_op
(
"Max"
,
op
::
max
{});
add_variadic_op
(
"Min"
,
op
::
min
{});
add_variadic_op
(
"Min"
,
op
::
min
{});
add_mem_op
(
"Clip"
,
&
onnx_parser
::
parse_clip
);
add_mem_op
(
"LRN"
,
&
onnx_parser
::
parse_lrn
);
add_mem_op
(
"LRN"
,
&
onnx_parser
::
parse_lrn
);
add_mem_op
(
"ImageScaler"
,
&
onnx_parser
::
parse_imagescaler
);
add_mem_op
(
"ImageScaler"
,
&
onnx_parser
::
parse_imagescaler
);
add_mem_op
(
"LeakyRelu"
,
&
onnx_parser
::
parse_leaky_relu
);
add_mem_op
(
"LeakyRelu"
,
&
onnx_parser
::
parse_leaky_relu
);
...
@@ -207,7 +208,7 @@ struct onnx_parser
...
@@ -207,7 +208,7 @@ struct onnx_parser
template
<
class
T
>
template
<
class
T
>
void
add_generic_op
(
std
::
string
name
,
T
x
)
void
add_generic_op
(
std
::
string
name
,
T
x
)
{
{
add_op
(
name
,
[
this
,
x
](
attribute_map
,
std
::
vector
<
instruction_ref
>
args
)
{
add_op
(
name
,
[
this
,
x
](
const
attribute_map
&
,
std
::
vector
<
instruction_ref
>
args
)
{
return
prog
.
add_instruction
(
x
,
args
);
return
prog
.
add_instruction
(
x
,
args
);
});
});
}
}
...
@@ -215,7 +216,7 @@ struct onnx_parser
...
@@ -215,7 +216,7 @@ struct onnx_parser
template
<
class
T
>
template
<
class
T
>
void
add_variadic_op
(
std
::
string
name
,
T
x
)
void
add_variadic_op
(
std
::
string
name
,
T
x
)
{
{
add_op
(
name
,
[
this
,
x
](
attribute_map
,
std
::
vector
<
instruction_ref
>
args
)
{
add_op
(
name
,
[
this
,
x
](
const
attribute_map
&
,
std
::
vector
<
instruction_ref
>
args
)
{
return
std
::
accumulate
(
std
::
next
(
args
.
begin
()),
return
std
::
accumulate
(
std
::
next
(
args
.
begin
()),
args
.
end
(),
args
.
end
(),
args
.
front
(),
args
.
front
(),
...
@@ -225,6 +226,22 @@ struct onnx_parser
...
@@ -225,6 +226,22 @@ struct onnx_parser
});
});
}
}
instruction_ref
parse_clip
(
const
std
::
string
&
,
const
attribute_map
&
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
op
::
clip
op
;
if
(
contains
(
attributes
,
"max"
))
{
op
.
max_val
=
parse_value
(
attributes
.
at
(
"max"
)).
at
<
float
>
();
}
if
(
contains
(
attributes
,
"min"
))
{
op
.
min_val
=
parse_value
(
attributes
.
at
(
"min"
)).
at
<
float
>
();
}
return
prog
.
add_instruction
(
op
,
std
::
move
(
args
));
}
instruction_ref
instruction_ref
parse_softmax
(
const
std
::
string
&
,
const
attribute_map
&
,
std
::
vector
<
instruction_ref
>
args
)
parse_softmax
(
const
std
::
string
&
,
const
attribute_map
&
,
std
::
vector
<
instruction_ref
>
args
)
{
{
...
@@ -1361,28 +1378,26 @@ struct onnx_parser
...
@@ -1361,28 +1378,26 @@ struct onnx_parser
static
literal
parse_tensor
(
const
onnx
::
TensorProto
&
t
)
static
literal
parse_tensor
(
const
onnx
::
TensorProto
&
t
)
{
{
std
::
vector
<
std
::
size_t
>
dims
(
t
.
dims
().
begin
(),
t
.
dims
().
end
());
std
::
vector
<
std
::
size_t
>
dims
(
t
.
dims
().
begin
(),
t
.
dims
().
end
());
// in case of scalar constants in onnx file, use dims=1 to fill initializer data
if
(
dims
.
empty
())
{
dims
=
{
1
};
}
if
(
t
.
has_raw_data
())
if
(
t
.
has_raw_data
())
{
{
const
std
::
string
&
s
=
t
.
raw_data
();
const
std
::
string
&
s
=
t
.
raw_data
();
switch
(
t
.
data_type
())
switch
(
t
.
data_type
())
{
{
case
onnx
::
TensorProto
::
UNDEFINED
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
UNDEFINED
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
FLOAT
:
return
literal
{{
shape
::
float_type
,
dims
}
,
s
.
data
()
}
;
case
onnx
::
TensorProto
::
FLOAT
:
return
create_
literal
(
shape
::
float_type
,
dims
,
s
.
data
()
)
;
case
onnx
::
TensorProto
::
UINT8
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
UINT8
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
INT8
:
return
literal
{{
shape
::
int32_type
,
dims
},
s
.
data
()};
case
onnx
::
TensorProto
::
INT8
:
return
create_literal
(
shape
::
int32_type
,
dims
,
s
.
data
());
case
onnx
::
TensorProto
::
UINT16
:
return
literal
{{
shape
::
int32_type
,
dims
},
s
.
data
()};
case
onnx
::
TensorProto
::
UINT16
:
case
onnx
::
TensorProto
::
INT16
:
return
literal
{{
shape
::
int32_type
,
dims
},
s
.
data
()};
return
create_literal
(
shape
::
int32_type
,
dims
,
s
.
data
());
case
onnx
::
TensorProto
::
INT32
:
return
literal
{{
shape
::
int32_type
,
dims
},
s
.
data
()};
case
onnx
::
TensorProto
::
INT16
:
return
create_literal
(
shape
::
int32_type
,
dims
,
s
.
data
());
case
onnx
::
TensorProto
::
INT64
:
return
literal
{{
shape
::
int64_type
,
dims
},
s
.
data
()};
case
onnx
::
TensorProto
::
INT32
:
return
create_literal
(
shape
::
int32_type
,
dims
,
s
.
data
());
case
onnx
::
TensorProto
::
INT64
:
return
create_literal
(
shape
::
int64_type
,
dims
,
s
.
data
());
case
onnx
::
TensorProto
::
STRING
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
STRING
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
BOOL
:
return
literal
{{
shape
::
int32_type
,
dims
},
s
.
data
()};
case
onnx
::
TensorProto
::
BOOL
:
return
create_literal
(
shape
::
int32_type
,
dims
,
s
.
data
());
case
onnx
::
TensorProto
::
FLOAT16
:
return
literal
{{
shape
::
half_type
,
dims
},
s
.
data
()};
case
onnx
::
TensorProto
::
FLOAT16
:
case
onnx
::
TensorProto
::
DOUBLE
:
return
literal
{{
shape
::
double_type
,
dims
},
s
.
data
()};
return
create_literal
(
shape
::
half_type
,
dims
,
s
.
data
());
case
onnx
::
TensorProto
::
DOUBLE
:
return
create_literal
(
shape
::
double_type
,
dims
,
s
.
data
());
case
onnx
::
TensorProto
::
UINT32
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
UINT32
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
UINT64
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
UINT64
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
COMPLEX64
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
COMPLEX64
:
throw
std
::
runtime_error
(
""
);
...
@@ -1394,21 +1409,21 @@ struct onnx_parser
...
@@ -1394,21 +1409,21 @@ struct onnx_parser
{
{
case
onnx
::
TensorProto
::
UNDEFINED
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
UNDEFINED
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
FLOAT
:
case
onnx
::
TensorProto
::
FLOAT
:
return
literal
{{
shape
::
float_type
,
dims
}
,
t
.
float_data
()
.
begin
(),
t
.
float_data
().
end
()}
;
return
create_
literal
(
shape
::
float_type
,
dims
,
t
.
float_data
()
)
;
case
onnx
::
TensorProto
::
UINT8
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
UINT8
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
INT8
:
case
onnx
::
TensorProto
::
INT8
:
return
literal
{{
shape
::
int32_type
,
dims
}
,
t
.
int32_data
()
.
begin
(),
t
.
int32_data
().
end
()}
;
return
create_
literal
(
shape
::
int32_type
,
dims
,
t
.
int32_data
()
)
;
case
onnx
::
TensorProto
::
UINT16
:
case
onnx
::
TensorProto
::
UINT16
:
return
literal
{{
shape
::
int32_type
,
dims
}
,
t
.
int32_data
()
.
begin
(),
t
.
int32_data
().
end
()}
;
return
create_
literal
(
shape
::
int32_type
,
dims
,
t
.
int32_data
()
)
;
case
onnx
::
TensorProto
::
INT16
:
case
onnx
::
TensorProto
::
INT16
:
return
literal
{{
shape
::
int32_type
,
dims
}
,
t
.
int32_data
()
.
begin
(),
t
.
int32_data
().
end
()}
;
return
create_
literal
(
shape
::
int32_type
,
dims
,
t
.
int32_data
()
)
;
case
onnx
::
TensorProto
::
INT32
:
case
onnx
::
TensorProto
::
INT32
:
return
literal
{{
shape
::
int32_type
,
dims
}
,
t
.
int32_data
()
.
begin
(),
t
.
int32_data
().
end
()}
;
return
create_
literal
(
shape
::
int32_type
,
dims
,
t
.
int32_data
()
)
;
case
onnx
::
TensorProto
::
INT64
:
case
onnx
::
TensorProto
::
INT64
:
return
literal
{{
shape
::
int64_type
,
dims
}
,
t
.
int64_data
()
.
begin
(),
t
.
int64_data
().
end
()}
;
return
create_
literal
(
shape
::
int64_type
,
dims
,
t
.
int64_data
()
)
;
case
onnx
::
TensorProto
::
STRING
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
STRING
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
BOOL
:
case
onnx
::
TensorProto
::
BOOL
:
return
literal
{{
shape
::
int32_type
,
dims
}
,
t
.
int32_data
()
.
begin
(),
t
.
int32_data
().
end
()}
;
return
create_
literal
(
shape
::
int32_type
,
dims
,
t
.
int32_data
()
)
;
case
onnx
::
TensorProto
::
FLOAT16
:
case
onnx
::
TensorProto
::
FLOAT16
:
{
{
std
::
vector
<
uint16_t
>
data_uint16
(
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
());
std
::
vector
<
uint16_t
>
data_uint16
(
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
());
...
@@ -1417,11 +1432,10 @@ struct onnx_parser
...
@@ -1417,11 +1432,10 @@ struct onnx_parser
data_uint16
.
end
(),
data_uint16
.
end
(),
std
::
back_inserter
(
data_half
),
std
::
back_inserter
(
data_half
),
[](
uint16_t
raw_val
)
{
return
*
reinterpret_cast
<
half
*>
(
&
raw_val
);
});
[](
uint16_t
raw_val
)
{
return
*
reinterpret_cast
<
half
*>
(
&
raw_val
);
});
return
literal
{{
shape
::
half_type
,
dims
}
,
data_half
.
begin
(),
data_half
.
end
()}
;
return
create_
literal
(
shape
::
half_type
,
dims
,
data_half
)
;
}
}
case
onnx
::
TensorProto
::
DOUBLE
:
case
onnx
::
TensorProto
::
DOUBLE
:
return
literal
{
return
create_literal
(
shape
::
double_type
,
dims
,
t
.
double_data
());
{
shape
::
double_type
,
dims
},
t
.
double_data
().
begin
(),
t
.
double_data
().
end
()};
case
onnx
::
TensorProto
::
UINT32
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
UINT32
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
UINT64
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
UINT64
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
COMPLEX64
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
COMPLEX64
:
throw
std
::
runtime_error
(
""
);
...
@@ -1430,6 +1444,23 @@ struct onnx_parser
...
@@ -1430,6 +1444,23 @@ struct onnx_parser
MIGRAPHX_THROW
(
"Invalid tensor type"
);
MIGRAPHX_THROW
(
"Invalid tensor type"
);
}
}
static
literal
create_literal
(
shape
::
type_t
shape_type
,
const
std
::
vector
<
size_t
>&
dims
,
const
char
*
data
)
{
// in case of scalar constants in onnx file, use dims=1 to fill initializer data
if
(
dims
.
empty
())
return
literal
{{
shape_type
},
data
};
return
literal
{{
shape_type
,
dims
},
data
};
}
template
<
class
T
,
MIGRAPHX_REQUIRES
(
not
std
::
is_pointer
<
T
>{})
>
static
literal
create_literal
(
shape
::
type_t
shape_type
,
const
std
::
vector
<
size_t
>&
dims
,
T
data
)
{
if
(
dims
.
empty
())
return
literal
{{
shape_type
},
data
.
begin
(),
data
.
end
()};
return
literal
{{
shape_type
,
dims
},
data
.
begin
(),
data
.
end
()};
}
static
shape
parse_type
(
const
onnx
::
TypeProto
&
t
)
static
shape
parse_type
(
const
onnx
::
TypeProto
&
t
)
{
{
shape
::
type_t
shape_type
{};
shape
::
type_t
shape_type
{};
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment