Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
6c42bc6e
Commit
6c42bc6e
authored
Nov 08, 2018
by
Scott Thornton
Browse files
Formatting
parent
8fce4170
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
19 additions
and
21 deletions
+19
-21
src/include/migraph/operators.hpp
src/include/migraph/operators.hpp
+9
-9
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+10
-12
No files found.
src/include/migraph/operators.hpp
View file @
6c42bc6e
...
@@ -769,29 +769,29 @@ struct multibroadcast
...
@@ -769,29 +769,29 @@ struct multibroadcast
auto
t
=
inputs
.
at
(
0
).
type
();
auto
t
=
inputs
.
at
(
0
).
type
();
auto
input
=
inputs
.
at
(
0
);
auto
input
=
inputs
.
at
(
0
);
if
(
input
.
lens
().
size
()
<=
0
)
if
(
input
.
lens
().
size
()
<=
0
)
MIGRAPH_THROW
(
"inputs dimensions should be > 0"
);
MIGRAPH_THROW
(
"inputs dimensions should be > 0"
);
if
(
input
.
lens
().
size
()
>
output_lens
.
size
())
if
(
input
.
lens
().
size
()
>
output_lens
.
size
())
MIGRAPH_THROW
(
"inputs dimensions should <= output size"
);
MIGRAPH_THROW
(
"inputs dimensions should <= output size"
);
std
::
vector
<
size_t
>
bcast_strides
(
output_lens
.
size
(),
0
);
std
::
vector
<
size_t
>
bcast_strides
(
output_lens
.
size
(),
0
);
auto
offset
=
output_lens
.
size
()
-
input
.
lens
().
size
();
auto
offset
=
output_lens
.
size
()
-
input
.
lens
().
size
();
if
(
input
.
lens
().
size
()
<
output_lens
.
size
())
if
(
input
.
lens
().
size
()
<
output_lens
.
size
())
{
{
for
(
std
::
size_t
i
=
output_lens
.
size
()
-
1
;
i
>
0
;
i
--
)
for
(
std
::
size_t
i
=
output_lens
.
size
()
-
1
;
i
>
0
;
i
--
)
{
{
if
(
output_lens
[
i
]
==
input
.
lens
()[
i
-
offset
])
if
(
output_lens
[
i
]
==
input
.
lens
()[
i
-
offset
])
{
{
bcast_strides
[
i
]
=
input
.
strides
()[
i
-
offset
];
bcast_strides
[
i
]
=
input
.
strides
()[
i
-
offset
];
}
}
}
}
}
}
else
else
{
{
for
(
std
::
size_t
i
=
0
;
i
<
input
.
lens
().
size
();
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
input
.
lens
().
size
();
i
++
)
{
{
if
(
output_lens
[
i
]
==
input
.
lens
()[
i
])
if
(
output_lens
[
i
]
==
input
.
lens
()[
i
])
{
{
bcast_strides
[
i
]
=
input
.
strides
()[
i
];
bcast_strides
[
i
]
=
input
.
strides
()[
i
];
}
}
...
...
src/onnx/onnx.cpp
View file @
6c42bc6e
...
@@ -93,7 +93,8 @@ struct onnx_parser
...
@@ -93,7 +93,8 @@ struct onnx_parser
void
add_broadcastable_binary_op
(
std
::
string
name
,
T
x
)
void
add_broadcastable_binary_op
(
std
::
string
name
,
T
x
)
{
{
ops
.
emplace
(
name
,
[
this
,
x
](
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
ops
.
emplace
(
name
,
[
this
,
x
](
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
if
(
args
.
size
()
!=
2
)
MIGRAPH_THROW
(
"binaGry operators should have 2 operands"
);
if
(
args
.
size
()
!=
2
)
MIGRAPH_THROW
(
"binaGry operators should have 2 operands"
);
if
(
contains
(
attributes
,
"broadcast"
))
if
(
contains
(
attributes
,
"broadcast"
))
{
{
uint64_t
broadcasted
=
parse_value
(
attributes
.
at
(
"broadcast"
)).
at
<
uint64_t
>
();
uint64_t
broadcasted
=
parse_value
(
attributes
.
at
(
"broadcast"
)).
at
<
uint64_t
>
();
...
@@ -127,33 +128,30 @@ struct onnx_parser
...
@@ -127,33 +128,30 @@ struct onnx_parser
const
std
::
vector
<
std
::
size_t
>&
s1
=
args
[
1
]
->
get_shape
().
lens
();
const
std
::
vector
<
std
::
size_t
>&
s1
=
args
[
1
]
->
get_shape
().
lens
();
// Copy the larger vector to output_lens
// Copy the larger vector to output_lens
std
::
vector
<
std
::
size_t
>
output_lens
=
std
::
vector
<
std
::
size_t
>
output_lens
=
(
s0
.
size
()
>=
s1
.
size
())
?
s0
:
s1
;
(
s0
.
size
()
>=
s1
.
size
())
?
s0
:
s1
;
if
(
s0
.
size
()
>=
s1
.
size
())
if
(
s0
.
size
()
>=
s1
.
size
())
{
{
// s0 is bigger, so iterate over the range of s1
// s0 is bigger, so iterate over the range of s1
auto
offset
=
s0
.
size
()
-
s1
.
size
();
auto
offset
=
s0
.
size
()
-
s1
.
size
();
for
(
std
::
size_t
i
=
0
;
i
<
s1
.
size
();
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
s1
.
size
();
i
++
)
{
{
output_lens
[
i
+
offset
]
=
std
::
max
(
s0
[
i
+
offset
],
s1
[
i
]);
output_lens
[
i
+
offset
]
=
std
::
max
(
s0
[
i
+
offset
],
s1
[
i
]);
}
}
}
}
else
else
{
{
// s1 is bigger, so iterate over the range of s0
// s1 is bigger, so iterate over the range of s0
auto
offset
=
s1
.
size
()
-
s0
.
size
();
auto
offset
=
s1
.
size
()
-
s0
.
size
();
for
(
std
::
size_t
i
=
0
;
i
<
s0
.
size
();
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
s0
.
size
();
i
++
)
{
{
output_lens
[
i
+
offset
]
=
std
::
max
(
s0
[
i
],
s1
[
i
+
offset
]);
output_lens
[
i
+
offset
]
=
std
::
max
(
s0
[
i
],
s1
[
i
+
offset
]);
}
}
}
}
}
}
return
prog
.
add_instruction
(
x
,
args
);
return
prog
.
add_instruction
(
x
,
args
);
});
});
}
}
template
<
class
T
>
template
<
class
T
>
void
add_generic_op
(
std
::
string
name
,
T
x
)
void
add_generic_op
(
std
::
string
name
,
T
x
)
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment