Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
0f85317e
Commit
0f85317e
authored
Jul 11, 2019
by
Shucai Xiao
Browse files
changes to make bert model work
parent
a3aacad6
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
29 additions
and
9 deletions
+29
-9
src/include/migraphx/op/dot.hpp
src/include/migraphx/op/dot.hpp
+1
-0
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+0
-8
src/targets/gpu/gemm.cpp
src/targets/gpu/gemm.cpp
+28
-1
No files found.
src/include/migraphx/op/dot.hpp
View file @
0f85317e
...
...
@@ -48,6 +48,7 @@ struct dot
"} x {"
+
to_string_range
(
b
.
lens
())
+
"}"
);
}
// dims for batch should be standard
std
::
size_t
dim_0
=
a
.
lens
().
size
()
-
2
;
std
::
size_t
dim_1
=
a
.
lens
().
size
()
-
1
;
if
(
a
.
lens
()[
dim_1
]
!=
b
.
lens
()[
dim_0
])
...
...
src/onnx/onnx.cpp
View file @
0f85317e
...
...
@@ -691,14 +691,6 @@ struct onnx_parser
}
}
if
(
!
bl1
->
get_shape
().
standard
())
{
bl1
=
prog
.
add_instruction
(
op
::
contiguous
{},
bl1
);
}
if
(
!
bl0
->
get_shape
().
standard
())
{
bl0
=
prog
.
add_instruction
(
op
::
contiguous
{},
bl0
);
}
auto
dot_res
=
prog
.
add_instruction
(
op
::
dot
{
1.0
f
,
0.0
f
},
bl0
,
bl1
);
int64_t
num_axis
=
static_cast
<
int64_t
>
(
dot_res
->
get_shape
().
lens
().
size
());
if
(
is_a_prepended
)
...
...
src/targets/gpu/gemm.cpp
View file @
0f85317e
...
...
@@ -170,7 +170,34 @@ rocblas_half to_rocblas_type(half x) { return reinterpret_cast<const rocblas_hal
shape
miopen_gemm
::
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
std
::
vector
<
shape
>
input_shapes
(
inputs
.
begin
(),
inputs
.
begin
()
+
inputs
.
size
()
-
1
);
check_shapes
{
input_shapes
}.
standard
();
check_shapes
{
input_shapes
}.
not_broadcasted
();
auto
a_strides
=
inputs
[
0
].
strides
();
auto
dim_0
=
a_strides
.
size
()
-
2
;
if
(
a_strides
.
size
()
>
2
)
{
if
(
!
std
::
all_of
(
a_strides
.
begin
(),
a_strides
.
begin
()
+
dim_0
,
[
&
](
auto
batch_size
)
{
return
std
::
all_of
(
a_strides
.
begin
()
+
dim_0
,
a_strides
.
end
(),
[
&
](
auto
data_size
)
{
return
batch_size
>=
data_size
;
});
}))
{
MIGRAPHX_THROW
(
"DOT: batch size of a {"
+
to_string_range
(
a_strides
)
+
"} is transposed!"
);
}
}
auto
b_strides
=
inputs
[
1
].
strides
();
if
(
b_strides
.
size
()
>
2
)
{
if
(
!
std
::
all_of
(
b_strides
.
begin
(),
b_strides
.
begin
()
+
dim_0
,
[
&
](
auto
batch_size
)
{
return
std
::
all_of
(
b_strides
.
begin
()
+
dim_0
,
b_strides
.
end
(),
[
&
](
auto
data_size
)
{
return
batch_size
>=
data_size
;
});
}))
{
MIGRAPHX_THROW
(
"DOT: batch size of b {"
+
to_string_range
(
b_strides
)
+
"} is transposed!"
);
}
}
return
op
.
compute_shape
(
input_shapes
);
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment