Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
1f304aed
Commit
1f304aed
authored
Nov 09, 2018
by
Scott Thornton
Browse files
Fixed faulty add compute_shape when using multibroadcast
parent
6c42bc6e
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
28 additions
and
21 deletions
+28
-21
src/include/migraph/operators.hpp
src/include/migraph/operators.hpp
+17
-20
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+4
-1
test/op_shape_test.cpp
test/op_shape_test.cpp
+7
-0
No files found.
src/include/migraph/operators.hpp
View file @
1f304aed
...
@@ -762,39 +762,34 @@ struct broadcast
...
@@ -762,39 +762,34 @@ struct broadcast
struct
multibroadcast
struct
multibroadcast
{
{
std
::
vector
<
std
::
size_t
>
output_lens
;
std
::
vector
<
std
::
size_t
>
output_lens
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
output_lens
,
"output_lens"
));
}
std
::
string
name
()
const
{
return
"multibroadcast"
;
}
std
::
string
name
()
const
{
return
"multibroadcast"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
auto
t
=
inputs
.
at
(
0
).
type
();
auto
t
=
inputs
.
at
(
0
).
type
();
auto
input
=
inputs
.
at
(
0
);
auto
input
=
inputs
.
at
(
0
);
if
(
input
.
lens
().
size
()
<=
0
)
if
(
input
.
lens
().
size
()
<=
0
)
MIGRAPH_THROW
(
"inputs dimensions should be > 0"
);
MIGRAPH_THROW
(
"inputs dimensions should be > 0"
);
if
(
input
.
lens
().
size
()
>
output_lens
.
size
())
if
(
input
.
lens
().
size
()
>
output_lens
.
size
())
MIGRAPH_THROW
(
"inputs dimensions should <= output size"
);
MIGRAPH_THROW
(
"inputs dimensions should <= output size"
);
std
::
vector
<
size_t
>
bcast_strides
(
output_lens
.
size
(),
0
);
std
::
vector
<
size_t
>
bcast_strides
(
output_lens
.
size
(),
0
);
auto
offset
=
output_lens
.
size
()
-
input
.
lens
().
size
();
auto
offset
=
output_lens
.
size
()
-
input
.
lens
().
size
();
if
(
input
.
lens
().
size
()
<
output_lens
.
size
())
for
(
int
i
=
input
.
lens
().
size
()
-
1
;
i
>=
0
;
i
--
)
{
for
(
std
::
size_t
i
=
output_lens
.
size
()
-
1
;
i
>
0
;
i
--
)
{
if
(
output_lens
[
i
]
==
input
.
lens
()[
i
-
offset
])
{
bcast_strides
[
i
]
=
input
.
strides
()[
i
-
offset
];
}
}
}
else
{
for
(
std
::
size_t
i
=
0
;
i
<
input
.
lens
().
size
();
i
++
)
{
{
if
(
output_lens
[
i
]
==
input
.
lens
()[
i
])
if
(
output_lens
[
i
+
offset
]
==
input
.
lens
()[
i
])
{
{
bcast_strides
[
i
]
=
input
.
strides
()[
i
];
bcast_strides
[
i
+
offset
]
=
input
.
strides
()[
i
];
}
}
}
}
}
return
{
t
,
output_lens
,
bcast_strides
};
return
{
t
,
output_lens
,
bcast_strides
};
...
@@ -833,7 +828,9 @@ struct binary
...
@@ -833,7 +828,9 @@ struct binary
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
}.
has
(
2
).
same_type
().
same_dims
();
check_shapes
{
inputs
}.
has
(
2
).
same_type
().
same_dims
();
return
inputs
.
at
(
0
);
auto
t
=
inputs
.
at
(
0
).
type
();
auto
lens
=
inputs
.
at
(
0
).
lens
();
return
{
t
,
lens
};
}
}
};
};
...
...
src/onnx/onnx.cpp
View file @
1f304aed
...
@@ -107,6 +107,7 @@ struct onnx_parser
...
@@ -107,6 +107,7 @@ struct onnx_parser
prog
.
add_instruction
(
op
::
broadcast
{
axis
,
args
[
0
]
->
get_shape
()},
args
[
1
]);
prog
.
add_instruction
(
op
::
broadcast
{
axis
,
args
[
0
]
->
get_shape
()},
args
[
1
]);
return
prog
.
add_instruction
(
x
,
args
[
0
],
l
);
return
prog
.
add_instruction
(
x
,
args
[
0
],
l
);
}
}
return
prog
.
add_instruction
(
x
,
args
);
}
}
else
else
{
{
...
@@ -147,8 +148,10 @@ struct onnx_parser
...
@@ -147,8 +148,10 @@ struct onnx_parser
output_lens
[
i
+
offset
]
=
std
::
max
(
s0
[
i
],
s1
[
i
+
offset
]);
output_lens
[
i
+
offset
]
=
std
::
max
(
s0
[
i
],
s1
[
i
+
offset
]);
}
}
}
}
auto
l0
=
prog
.
add_instruction
(
op
::
multibroadcast
{
output_lens
},
args
[
0
]);
auto
l1
=
prog
.
add_instruction
(
op
::
multibroadcast
{
output_lens
},
args
[
1
]);
return
prog
.
add_instruction
(
x
,
l0
,
l1
);
}
}
return
prog
.
add_instruction
(
x
,
args
);
});
});
}
}
...
...
test/op_shape_test.cpp
View file @
1f304aed
...
@@ -183,6 +183,13 @@ void multibroadcast_shape()
...
@@ -183,6 +183,13 @@ void multibroadcast_shape()
migraph
::
op
::
multibroadcast
{
lens
},
migraph
::
op
::
multibroadcast
{
lens
},
input
);
input
);
}
}
{
std
::
vector
<
std
::
size_t
>
lens
{
4
,
4
,
1
,
3
};
migraph
::
shape
input
{
migraph
::
shape
::
float_type
,
{
4
,
1
,
3
}};
expect_shape
(
migraph
::
shape
{
migraph
::
shape
::
float_type
,
lens
,
{
0
,
3
,
3
,
1
}},
migraph
::
op
::
multibroadcast
{
lens
},
input
);
}
{
{
std
::
vector
<
std
::
size_t
>
lens
{
4
,
1
,
1
,
3
};
std
::
vector
<
std
::
size_t
>
lens
{
4
,
1
,
1
,
3
};
migraph
::
shape
input
{
migraph
::
shape
::
float_type
,
{
4
,
1
,
1
,
1
}};
migraph
::
shape
input
{
migraph
::
shape
::
float_type
,
{
4
,
1
,
1
,
1
}};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment