Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
be99b85d
Commit
be99b85d
authored
Jul 23, 2019
by
Khalique
Browse files
Merge branch 'test_bert' of
https://github.com/ROCmSoftwarePlatform/AMDMIGraphX
into bert_ops
parents
8819f4dc
b0e589fd
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
51 additions
and
4 deletions
+51
-4
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+16
-3
src/targets/gpu/device/gather.cpp
src/targets/gpu/device/gather.cpp
+1
-1
src/targets/gpu/gemm.cpp
src/targets/gpu/gemm.cpp
+18
-0
src/targets/gpu/include/migraphx/gpu/gemm.hpp
src/targets/gpu/include/migraphx/gpu/gemm.hpp
+1
-0
test/onnx/onnx_test.cpp
test/onnx/onnx_test.cpp
+15
-0
test/onnx/reshape_non_standard.onnx
test/onnx/reshape_non_standard.onnx
+0
-0
No files found.
src/onnx/onnx.cpp
View file @
be99b85d
...
...
@@ -460,16 +460,30 @@ struct onnx_parser
{
op
::
reshape
op
;
if
(
args
.
size
()
==
1
)
{
if
(
contains
(
attributes
,
"shape"
))
{
literal
s
=
parse_value
(
attributes
.
at
(
"shape"
));
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
op
.
dims
));
});
}
else
{
MIGRAPHX_THROW
(
"Parse_reshape: shape attribute is needed when only one argument is provided!"
);
}
}
if
(
args
.
size
()
==
2
)
{
auto
s
=
args
[
1
]
->
eval
();
check_arg_empty
(
s
,
"Reshape: dynamic shape is not supported"
);
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
op
.
dims
));
});
}
if
(
!
args
[
0
]
->
get_shape
().
standard
())
{
args
[
0
]
=
prog
.
add_instruction
(
op
::
contiguous
{},
args
[
0
]);
}
return
prog
.
add_instruction
(
op
,
args
[
0
]);
}
...
...
@@ -972,7 +986,6 @@ struct onnx_parser
std
::
vector
<
std
::
size_t
>
dims
;
arg_s
.
visit
([
&
](
auto
input
)
{
dims
.
assign
(
input
.
begin
(),
input
.
end
());
});
auto
out_lens
=
compute_broadcasted_lens
(
in_lens
,
dims
);
return
prog
.
add_instruction
(
op
::
multibroadcast
{
out_lens
},
args
[
0
]);
}
...
...
src/targets/gpu/device/gather.cpp
View file @
be99b85d
...
...
@@ -25,7 +25,7 @@ argument gather(hipStream_t stream, argument result, argument arg1, argument arg
arg2
.
visit
([
&
](
auto
indices
)
{
const
auto
*
indices_ptr
=
device_cast
(
indices
.
data
());
auto
*
output_ptr
=
device_cast
(
output
.
data
());
gs_launch
(
stream
,
nelements
)([
=
](
auto
i
)
{
gs_launch
(
stream
,
nelements
,
256
)([
=
](
auto
i
)
{
auto
idx
=
out_comp
.
multi
(
i
);
idx
[
axis_index
]
=
indices_ptr
[
idx
[
axis_index
]];
output_ptr
[
i
]
=
input
[
idx
];
...
...
src/targets/gpu/gemm.cpp
View file @
be99b85d
...
...
@@ -167,10 +167,28 @@ rb_type<T>* to_rocblas_type(T* x)
rocblas_half
to_rocblas_type
(
half
x
)
{
return
reinterpret_cast
<
const
rocblas_half
&>
(
x
);
}
void
miopen_gemm
::
batch_not_transposed
(
const
std
::
vector
<
std
::
size_t
>&
strides
)
const
{
if
(
strides
.
size
()
<=
2
)
return
;
auto
dim_0
=
strides
.
size
()
-
2
;
auto
matrix_size
=
std
::
max
(
strides
[
dim_0
],
strides
[
dim_0
+
1
]);
std
::
vector
<
std
::
size_t
>
batch
(
strides
.
begin
(),
strides
.
begin
()
+
dim_0
);
if
(
std
::
adjacent_find
(
batch
.
begin
(),
batch
.
end
(),
[
&
](
auto
i
,
auto
j
)
{
return
(
i
<
j
or
i
<
matrix_size
or
j
<
matrix_size
);
})
!=
batch
.
end
())
{
MIGRAPHX_THROW
(
"DOT: batch size of a {"
+
to_string_range
(
strides
)
+
"} is transposed!"
);
}
}
shape
miopen_gemm
::
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
std
::
vector
<
shape
>
input_shapes
(
inputs
.
begin
(),
inputs
.
begin
()
+
inputs
.
size
()
-
1
);
check_shapes
{
input_shapes
}.
not_broadcasted
();
batch_not_transposed
(
inputs
[
0
].
strides
());
batch_not_transposed
(
inputs
[
1
].
strides
());
return
op
.
compute_shape
(
input_shapes
);
}
...
...
src/targets/gpu/include/migraphx/gpu/gemm.hpp
View file @
be99b85d
...
...
@@ -24,6 +24,7 @@ struct miopen_gemm
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
;
argument
compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
)
const
;
void
batch_not_transposed
(
const
std
::
vector
<
std
::
size_t
>&
strides
)
const
;
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
shapes
)
const
{
return
shapes
.
size
()
-
1
;
...
...
test/onnx/onnx_test.cpp
View file @
be99b85d
...
...
@@ -447,6 +447,21 @@ TEST_CASE(reshape_test)
EXPECT
(
p
==
prog
);
}
TEST_CASE
(
reshape_non_standard
)
{
migraphx
::
program
p
;
migraphx
::
op
::
reshape
op
;
std
::
vector
<
int64_t
>
reshape_dims
{
4
,
3
,
2
};
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
}};
auto
x
=
p
.
add_parameter
(
"x"
,
s
);
auto
tran_x
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
0
,
2
,
1
}},
x
);
auto
cont_x
=
p
.
add_instruction
(
migraphx
::
op
::
contiguous
{},
tran_x
);
p
.
add_instruction
(
migraphx
::
op
::
reshape
{{
4
,
3
,
2
}},
cont_x
);
auto
prog
=
migraphx
::
parse_onnx
(
"reshape_non_standard.onnx"
);
EXPECT
(
p
==
prog
);
}
TEST_CASE
(
shape_test
)
{
migraphx
::
program
p
;
...
...
test/onnx/reshape_non_standard.onnx
0 → 100644
View file @
be99b85d
File added
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment