Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
b878f78f
Commit
b878f78f
authored
Aug 12, 2022
by
turneram
Browse files
Merge remote-tracking branch 'origin/develop' into rewrite-fast-gelu
parents
3b414cc2
55cb7d3a
Changes
197
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
183 additions
and
91 deletions
+183
-91
src/include/migraphx/op/tanh.hpp
src/include/migraphx/op/tanh.hpp
+1
-8
src/include/migraphx/op/topk.hpp
src/include/migraphx/op/topk.hpp
+1
-0
src/include/migraphx/op/transpose.hpp
src/include/migraphx/op/transpose.hpp
+2
-5
src/include/migraphx/op/unary_not.hpp
src/include/migraphx/op/unary_not.hpp
+2
-3
src/include/migraphx/op/unknown.hpp
src/include/migraphx/op/unknown.hpp
+0
-1
src/include/migraphx/op/unsqueeze.hpp
src/include/migraphx/op/unsqueeze.hpp
+2
-7
src/include/migraphx/op/where.hpp
src/include/migraphx/op/where.hpp
+2
-9
src/include/migraphx/operation.hpp
src/include/migraphx/operation.hpp
+5
-3
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+2
-0
src/include/migraphx/pad_calc.hpp
src/include/migraphx/pad_calc.hpp
+23
-25
src/include/migraphx/par_dfor.hpp
src/include/migraphx/par_dfor.hpp
+1
-0
src/include/migraphx/sqlite.hpp
src/include/migraphx/sqlite.hpp
+51
-0
src/include/migraphx/stringutils.hpp
src/include/migraphx/stringutils.hpp
+6
-6
src/insert_pad.cpp
src/insert_pad.cpp
+6
-0
src/instruction.cpp
src/instruction.cpp
+2
-2
src/normalize_ops.cpp
src/normalize_ops.cpp
+2
-2
src/onnx/include/migraphx/onnx/onnx_parser.hpp
src/onnx/include/migraphx/onnx/onnx_parser.hpp
+4
-2
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+19
-2
src/onnx/onnx_parser.cpp
src/onnx/onnx_parser.cpp
+51
-15
src/onnx/parse_constant.cpp
src/onnx/parse_constant.cpp
+1
-1
No files found.
src/include/migraphx/op/tanh.hpp
View file @
b878f78f
...
...
@@ -24,16 +24,9 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_TANH_HPP
#define MIGRAPHX_GUARD_OPERATORS_TANH_HPP
#include <array>
#include <migraphx/op/unary.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/op/unary.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
src/include/migraphx/op/topk.hpp
View file @
b878f78f
...
...
@@ -26,6 +26,7 @@
#include <algorithm>
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/par_for.hpp>
...
...
src/include/migraphx/op/transpose.hpp
View file @
b878f78f
...
...
@@ -24,14 +24,11 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_TRANSPOSE_HPP
#define MIGRAPHX_GUARD_OPERATORS_TRANSPOSE_HPP
#include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/config.hpp>
#include <migraphx/lifetime.hpp>
#include <cmath>
#include <utility>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
src/include/migraphx/op/unary_not.hpp
View file @
b878f78f
...
...
@@ -24,10 +24,9 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_UNARY_NOT_HPP
#define MIGRAPHX_GUARD_OPERATORS_UNARY_NOT_HPP
#include <migraphx/op/unary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/config.hpp>
#include <migraphx/op/unary.hpp>
#include <cmath>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
src/include/migraphx/op/unknown.hpp
View file @
b878f78f
...
...
@@ -25,7 +25,6 @@
#define MIGRAPHX_GUARD_RTGLIB_UNKNOWN_HPP
#include <migraphx/config.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/check_shapes.hpp>
namespace
migraphx
{
...
...
src/include/migraphx/op/unsqueeze.hpp
View file @
b878f78f
...
...
@@ -24,16 +24,11 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_UNSQUEEZE_HPP
#define MIGRAPHX_GUARD_OPERATORS_UNSQUEEZE_HPP
#include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/lifetime.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
src/include/migraphx/op/where.hpp
View file @
b878f78f
...
...
@@ -24,18 +24,11 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_WHERE_HPP
#define MIGRAPHX_GUARD_OPERATORS_WHERE_HPP
#include <array>
#include <migraphx/argument.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath>
#include <utility>
#include <migraphx/par_for.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
src/include/migraphx/operation.hpp
View file @
b878f78f
...
...
@@ -68,8 +68,10 @@ struct operation
*
* @param ctx This is the context created by the `target` during compilation. Implementations
* can use the target's `context` class rather than the `context` interface class.
* @param output This is the output shape. It is equivalent to running `compute_shape` with each
* `shape` of the `argument`.
* @param output Equivalent to running `compute_shape` with each `shape` of the `argument`.
* For a fixed shape, the returned argument will have the same shape as `output`.
* For a dynamic shape, the returned `argument` will be a fixed shape within the bounds
* set in the dynamic shape `output`.
* @param input This is the `argument` result from the previous instruction's computation.
* @return Return an `argument` of the result computation. The `shape` of `argument` should be
* the same the `output` shape.
...
...
@@ -137,7 +139,7 @@ auto compute_shape_op(rank<2>, const T& x, const std::vector<shape>& inputs)
->
decltype
(
x
.
normalize_compute_shape
(
inputs
))
{
dependent_type
<
operation
,
T
>
y
=
x
;
normalize_attributes
(
y
,
inputs
[
0
].
lens
());
normalize_attributes
(
y
,
inputs
[
0
].
max_
lens
());
return
any_cast
<
T
>
(
y
).
normalize_compute_shape
(
inputs
);
}
...
...
src/include/migraphx/operators.hpp
View file @
b878f78f
...
...
@@ -57,6 +57,7 @@
#include <migraphx/op/exp.hpp>
#include <migraphx/op/flatten.hpp>
#include <migraphx/op/floor.hpp>
#include <migraphx/op/fmod.hpp>
#include <migraphx/op/gather.hpp>
#include <migraphx/op/gathernd.hpp>
#include <migraphx/op/get_tuple_elem.hpp>
...
...
@@ -79,6 +80,7 @@
#include <migraphx/op/lstm.hpp>
#include <migraphx/op/max.hpp>
#include <migraphx/op/min.hpp>
#include <migraphx/op/mod.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/op/multibroadcast.hpp>
#include <migraphx/op/neg.hpp>
...
...
src/include/migraphx/pad_calc.hpp
View file @
b878f78f
...
...
@@ -24,38 +24,36 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_PAD_CALC_HPP
#define MIGRAPHX_GUARD_OPERATORS_PAD_CALC_HPP
#include <
utility
>
#include <
migraphx/config.hpp
>
#include <cstdint>
#include <vector>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
void
calculate_padding
(
int64_t
idx
,
std
::
vector
<
int64_t
>&
pads
,
int64_t
input_dim
,
int64_t
stride
,
int64_t
dilation
,
int64_t
weight_dim
,
bool
is_same_upper
=
true
)
{
int64_t
output_dim
=
(
input_dim
+
stride
-
1
)
/
stride
;
// round up result
int64_t
new_weight_dim
=
weight_dim
+
(
weight_dim
-
1
)
*
(
dilation
-
1
);
int64_t
pad
=
std
::
max
(
static_cast
<
int64_t
>
(
0
),
(
output_dim
-
1
)
*
stride
+
new_weight_dim
-
input_dim
);
auto
pad_ndims
=
pads
.
size
()
/
2
;
void
calculate_padding
(
int64_t
idx
,
std
::
vector
<
int64_t
>&
pads
,
int64_t
input_dim
,
int64_t
stride
,
int64_t
dilation
,
int64_t
weight_dim
,
bool
is_same_upper
=
true
);
if
(
is_same_upper
)
{
pads
[
idx
]
=
pad
/
2
;
pads
[
idx
+
pad_ndims
]
=
pad
-
pad
/
2
;
}
else
{
pads
[
idx
+
pad_ndims
]
=
pad
/
2
;
pads
[
idx
]
=
pad
-
pad
/
2
;
}
}
/*!
* Calculate the padding for auto_padding. Used for dynamic shapes
* where the padding calculation must be done at evaluation time.
* \param tensor_lens input tensor image shape
* \param k_lens weights kernel shape
* \param strides strides for the kernel
* \param dilations dilations for the kernel
* \param use_upper put odd padding on upper or lower side
* \return padding in the form of {x0_begin, x1_begin, ... x0_end , x1_end, ...}
*/
std
::
vector
<
std
::
size_t
>
calc_dyn_auto_pad
(
std
::
vector
<
std
::
size_t
>
tensor_lens
,
std
::
vector
<
std
::
size_t
>
k_lens
,
std
::
vector
<
std
::
size_t
>
strides
,
std
::
vector
<
std
::
size_t
>
dilations
,
bool
use_upper
=
true
);
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
...
...
src/include/migraphx/par_dfor.hpp
View file @
b878f78f
...
...
@@ -25,6 +25,7 @@
#define MIGRAPHX_GUARD_RTGLIB_PAR_DFOR_HPP
#include <migraphx/par_for.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/functional.hpp>
#include <array>
#include <numeric>
...
...
src/include/migraphx/sqlite.hpp
0 → 100644
View file @
b878f78f
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHX_SQLITE_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_SQLITE_HPP
#include <migraphx/config.hpp>
#include <migraphx/filesystem.hpp>
#include <memory>
#include <unordered_map>
#include <vector>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
sqlite_impl
;
struct
sqlite
{
sqlite
()
=
default
;
static
sqlite
read
(
const
fs
::
path
&
p
);
static
sqlite
write
(
const
fs
::
path
&
p
);
std
::
vector
<
std
::
unordered_map
<
std
::
string
,
std
::
string
>>
execute
(
const
std
::
string
&
s
);
private:
std
::
shared_ptr
<
sqlite_impl
>
impl
;
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_SQLITE_HPP
src/include/migraphx/stringutils.hpp
View file @
b878f78f
...
...
@@ -174,27 +174,27 @@ inline std::string interpolate_string(const std::string& input,
}
template
<
class
Iterator
>
inline
std
::
string
to_string_range
(
Iterator
start
,
Iterator
last
)
inline
std
::
string
to_string_range
(
Iterator
start
,
Iterator
last
,
const
char
*
delim
=
", "
)
{
std
::
stringstream
ss
;
if
(
start
!=
last
)
{
ss
<<
*
start
;
std
::
for_each
(
std
::
next
(
start
),
last
,
[
&
](
auto
&&
x
)
{
ss
<<
", "
<<
x
;
});
std
::
for_each
(
std
::
next
(
start
),
last
,
[
&
](
auto
&&
x
)
{
ss
<<
delim
<<
x
;
});
}
return
ss
.
str
();
}
template
<
class
Range
>
inline
std
::
string
to_string_range
(
const
Range
&
r
)
inline
std
::
string
to_string_range
(
const
Range
&
r
,
const
char
*
delim
=
", "
)
{
return
to_string_range
(
r
.
begin
(),
r
.
end
());
return
to_string_range
(
r
.
begin
(),
r
.
end
()
,
delim
);
}
template
<
class
T
>
inline
std
::
string
to_string_range
(
const
std
::
initializer_list
<
T
>&
r
)
inline
std
::
string
to_string_range
(
const
std
::
initializer_list
<
T
>&
r
,
const
char
*
delim
=
", "
)
{
return
to_string_range
(
r
.
begin
(),
r
.
end
());
return
to_string_range
(
r
.
begin
(),
r
.
end
()
,
delim
);
}
template
<
class
T
>
...
...
src/insert_pad.cpp
View file @
b878f78f
...
...
@@ -40,6 +40,12 @@ static void update_op(const instruction_ref& input, const instruction_ref& ins,
auto
val
=
op
.
to_value
();
auto
op_padding
=
val
.
at
(
"padding"
).
to_vector
<
size_t
>
();
// skip if shape is dynamic
if
(
input
->
get_shape
().
dynamic
())
{
return
;
}
auto
kdims
=
input
->
get_shape
().
lens
().
size
()
-
2
;
if
(
std
::
equal
(
op_padding
.
begin
(),
op_padding
.
begin
()
+
kdims
,
...
...
src/instruction.cpp
View file @
b878f78f
...
...
@@ -445,8 +445,8 @@ operation instruction::normalized_operator() const
operation
o
=
this
->
get_operator
();
if
(
this
->
need_normalization
())
{
auto
len
s
=
this
->
inputs
().
front
()
->
get_shape
()
.
lens
()
;
if
(
!
normalize_attributes
(
o
,
lens
))
auto
s
=
this
->
inputs
().
front
()
->
get_shape
();
if
(
!
normalize_attributes
(
o
,
s
.
max_
lens
()
))
return
this
->
get_operator
();
}
return
o
;
...
...
src/normalize_ops.cpp
View file @
b878f78f
...
...
@@ -43,9 +43,9 @@ void normalize_ops::apply(module& m) const
if
(
inputs
.
empty
())
continue
;
auto
lens
=
inputs
[
0
]
->
get_shape
()
.
lens
()
;
auto
s
=
inputs
[
0
]
->
get_shape
();
migraphx
::
operation
tuned_op
=
ins
->
get_operator
();
if
(
normalize_attributes
(
tuned_op
,
lens
))
if
(
normalize_attributes
(
tuned_op
,
s
.
max_
lens
()
))
{
m
.
replace_instruction
(
ins
,
tuned_op
,
inputs
);
ins
->
set_normalized
();
...
...
src/onnx/include/migraphx/onnx/onnx_parser.hpp
View file @
b878f78f
...
...
@@ -93,9 +93,10 @@ struct onnx_parser
onnx_parser
&
,
const
node_info
&
,
std
::
vector
<
instruction_ref
>
)
>
;
node_map
nodes
;
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
instructions
;
program
prog
=
program
();
s
td
::
size_t
default_dim_value
=
1
;
program
prog
=
program
();
s
hape
::
dynamic_dimension
default_
dyn_
dim_value
=
{
1
,
1
,
0
}
;
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
size_t
>>
map_input_dims
;
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
shape
::
dynamic_dimension
>>
map_dyn_input_dims
;
bool
skip_unknown_operators
=
false
;
int64_t
max_loop_iterations
=
10
;
int64_t
opset_version
=
13
;
...
...
@@ -118,6 +119,7 @@ struct onnx_parser
};
shape
::
type_t
get_type
(
int
dtype
);
bool
is_type_float
(
shape
::
type_t
dtype
);
}
// namespace onnx
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/onnx/onnx.cpp
View file @
b878f78f
...
...
@@ -41,8 +41,25 @@ template <class... Ts>
program
parse_onnx_from
(
const
onnx_options
&
options
,
Ts
&&
...
xs
)
{
onnx
::
onnx_parser
parser
;
parser
.
map_input_dims
=
options
.
map_input_dims
;
parser
.
default_dim_value
=
options
.
default_dim_value
;
parser
.
map_input_dims
=
options
.
map_input_dims
;
parser
.
map_dyn_input_dims
=
options
.
map_dyn_input_dims
;
auto
dim_val
=
options
.
default_dim_value
;
if
(
dim_val
!=
0
)
{
if
(
options
.
default_dyn_dim_value
!=
shape
::
dynamic_dimension
{
1
,
1
,
0
})
{
MIGRAPHX_THROW
(
"PARSE_ONNX_FROM: both default_dim_value and default_dyn_dim_value"
"set to non-default value"
);
}
else
{
parser
.
default_dyn_dim_value
=
{
dim_val
,
dim_val
,
0
};
}
}
else
{
parser
.
default_dyn_dim_value
=
options
.
default_dyn_dim_value
;
}
parser
.
skip_unknown_operators
=
options
.
skip_unknown_operators
;
parser
.
max_loop_iterations
=
options
.
max_loop_iterations
;
...
...
src/onnx/onnx_parser.cpp
View file @
b878f78f
...
...
@@ -28,16 +28,17 @@
#include <migraphx/stringutils.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/pad_calc.hpp>
#include <migraphx/common.hpp>
#include <migraphx/type_traits.hpp>
#include <migraphx/float_equal.hpp>
#include <migraphx/file_buffer.hpp>
#include <migraphx/filesystem.hpp>
#include <migraphx/op/unknown.hpp>
#include <migraphx/env.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
onnx
{
static
onnx_parser
::
attribute_map
get_attributes
(
const
onnx
::
NodeProto
&
node
)
...
...
@@ -58,7 +59,7 @@ create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, const
std
::
accumulate
(
dims
.
begin
(),
dims
.
end
(),
std
::
size_t
(
1
),
std
::
multiplies
<
std
::
size_t
>
());
if
(
elem_num
==
0
)
{
return
{
};
return
literal
{
shape_type
};
}
// in case of scalar constants in onnx file, use dims=1 to fill initializer data
...
...
@@ -75,7 +76,7 @@ static literal create_literal(shape::type_t shape_type, const std::vector<size_t
std
::
accumulate
(
dims
.
begin
(),
dims
.
end
(),
std
::
size_t
(
1
),
std
::
multiplies
<
std
::
size_t
>
());
if
(
elem_num
==
0
)
{
return
{
};
return
literal
{
shape_type
};
}
// scalar input
...
...
@@ -255,6 +256,11 @@ int64_t onnx_parser::get_opset_version(const onnx::ModelProto& model)
void
onnx_parser
::
parse_graph
(
module
*
mod
,
const
onnx
::
GraphProto
&
graph
)
{
if
(
not
map_input_dims
.
empty
()
and
not
map_dyn_input_dims
.
empty
())
{
MIGRAPHX_THROW
(
"PARSE_GRAPH: both map_input_dims and map_dyn_input_dims non-empty, only"
"one should be used"
);
}
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
mod_insts
;
for
(
auto
&&
f
:
graph
.
initializer
())
{
...
...
@@ -268,7 +274,7 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
// input not in initializer_data, so it is a real input
if
(
!
contains
(
mod_insts
,
name
))
{
// ONNX specification does not specify h
w
o to deal with the
// ONNX specification does not specify ho
w
to deal with the
// scenario that a nested subgraph contains a parameter with the
// name existed in its parent graph.
// In the current implementation, MIGraphX throws an exception for that.
...
...
@@ -278,13 +284,22 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
"
\"
existing in parent graph!"
);
}
shape
s
;
std
::
vector
<
std
::
size_t
>
dims
;
if
(
map_input_dims
.
count
(
name
)
>
0
)
{
dims
=
map_input_dims
.
at
(
name
);
s
=
parse_type
(
input
.
type
(),
dims
);
}
else
if
(
map_dyn_input_dims
.
count
(
name
)
>
0
)
{
shape
::
type_t
shape_type
=
get_type
(
input
.
type
().
tensor_type
().
elem_type
());
s
=
{
shape_type
,
map_dyn_input_dims
.
at
(
name
)};
}
else
{
s
=
parse_type
(
input
.
type
(),
dims
);
}
shape
s
=
parse_type
(
input
.
type
(),
dims
);
mod_insts
[
name
]
=
mod
->
add_parameter
(
name
,
s
);
}
}
...
...
@@ -439,30 +454,41 @@ shape onnx_parser::parse_type(const onnx::TypeProto& t,
return
{
shape_type
,
input_dims
};
}
std
::
vector
<
s
td
::
size_t
>
dims
;
std
::
vector
<
s
hape
::
dynamic_dimension
>
dynamic_
dims
;
auto
&&
tensor_dims
=
t
.
tensor_type
().
shape
().
dim
();
std
::
transform
(
tensor_dims
.
begin
(),
tensor_dims
.
end
(),
std
::
back_inserter
(
dims
),
[
&
](
auto
&&
d
)
->
s
td
::
size_t
{
std
::
back_inserter
(
dynamic_
dims
),
[
&
](
auto
&&
d
)
->
s
hape
::
dynamic_dimension
{
if
(
d
.
has_dim_value
())
{
if
(
static_cast
<
int
>
(
d
.
dim_value
())
<=
0
)
{
return
default_dim_value
;
return
default_
dyn_
dim_value
;
}
return
d
.
dim_value
();
std
::
size_t
tmp
=
d
.
dim_value
();
return
{
tmp
,
tmp
,
0
};
}
else
{
return
default_dim_value
;
return
default_
dyn_
dim_value
;
}
});
if
(
dims
.
empty
())
if
(
dynamic_dims
.
empty
())
{
return
{
shape_type
};
return
{
shape_type
,
dims
};
}
if
(
std
::
all_of
(
dynamic_dims
.
begin
(),
dynamic_dims
.
end
(),
[](
auto
dd
)
{
return
dd
.
is_fixed
();
}))
{
std
::
vector
<
std
::
size_t
>
dims
;
std
::
transform
(
dynamic_dims
.
begin
(),
dynamic_dims
.
end
(),
std
::
back_inserter
(
dims
),
[](
auto
d
)
{
return
d
.
max
;
});
return
{
shape_type
,
dims
};
}
return
{
shape_type
,
dynamic_dims
};
}
shape
::
type_t
get_type
(
int
dtype
)
...
...
@@ -487,6 +513,16 @@ shape::type_t get_type(int dtype)
}
}
bool
is_type_float
(
shape
::
type_t
dtype
)
{
bool
r
=
false
;
if
(
dtype
==
shape
::
float_type
||
dtype
==
shape
::
double_type
||
dtype
==
shape
::
half_type
)
{
r
=
true
;
}
return
r
;
}
}
// namespace onnx
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/onnx/parse_constant.cpp
View file @
b878f78f
...
...
@@ -43,7 +43,7 @@ struct parse_constant : op_parser<parse_constant>
// return empty literal
if
(
v
.
get_shape
().
elements
()
==
0
)
{
return
info
.
add_literal
(
literal
{});
return
info
.
add_literal
(
literal
{
v
.
get_shape
().
type
()
});
}
auto
dim_size
=
info
.
attributes
.
at
(
"value"
).
t
().
dims_size
();
...
...
Prev
1
2
3
4
5
6
7
8
9
10
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment