Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
360db15f
Commit
360db15f
authored
Mar 14, 2019
by
Shucai Xiao
Browse files
clang format
parent
f9c38c09
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
74 additions
and
62 deletions
+74
-62
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+26
-22
src/targets/gpu/gemm.cpp
src/targets/gpu/gemm.cpp
+40
-36
test/op_shape_test.cpp
test/op_shape_test.cpp
+8
-4
No files found.
src/include/migraphx/operators.hpp
View file @
360db15f
...
...
@@ -809,8 +809,8 @@ struct gather
// The dot operation is combination of the onnx GEMM and MatMul operators.
// For GEMM, it support two cases: 1) in the formula alpha * AB + beta * C,
// A and B are 2-D matrics and C is broadcastable to the shape of A*B. For
// the transpose of A and B, we add a tranpose operator beforehand if the
// A and B are 2-D matrics and C is broadcastable to the shape of A*B. For
// the transpose of A and B, we add a tranpose operator beforehand if the
// onnx gemm operator indicates a transpose required. 2) A and B are more
// than 2-D, then the dims except the last 2-D in A and B need to be the
// same, and C should be the same shape as A * B
...
...
@@ -898,36 +898,39 @@ struct dot
// If there are 3 inputs, there are two scenarios:
// 1. A and B are 2-D matrices and C is broadcastable to A * B
// 2. A and B are stack of matrices, then shape for the batch
// should be the same for A and B, and C is the same shape
// as A * B (For now, we add this requirement to simplify the
// should be the same for A and B, and C is the same shape
// as A * B (For now, we add this requirement to simplify the
// implementation. we can remove this requirement later)
if
(
inputs
.
size
()
==
3
)
{
auto
a_lens
=
inputs
[
0
].
lens
();
auto
b_lens
=
inputs
[
1
].
lens
();
auto
a_lens
=
inputs
[
0
].
lens
();
auto
b_lens
=
inputs
[
1
].
lens
();
auto
out_lens
=
a_lens
;
auto
t
=
inputs
[
0
].
type
();
if
(
inputs
[
1
].
lens
().
size
()
>
2
)
auto
t
=
inputs
[
0
].
type
();
if
(
inputs
[
1
].
lens
().
size
()
>
2
)
{
if
(
!
std
::
equal
(
a_lens
.
rbegin
()
+
2
,
a_lens
.
rend
(),
b_lens
.
rbegin
()
+
2
))
{
MIGRAPHX_THROW
(
"DOT: dimension mismatch, operand A: {"
+
to_string_range
(
a_lens
)
+
"}, cannot multiply operand B: {"
+
to_string_range
(
b_lens
)
+
"}"
);
MIGRAPHX_THROW
(
"DOT: dimension mismatch, operand A: {"
+
to_string_range
(
a_lens
)
+
"}, cannot multiply operand B: {"
+
to_string_range
(
b_lens
)
+
"}"
);
}
std
::
size_t
dim_0
=
a_lens
.
size
()
-
2
;
std
::
size_t
dim_1
=
a_lens
.
size
()
-
1
;
if
(
a_lens
[
dim_1
]
!=
b_lens
[
dim_0
])
MIGRAPHX_THROW
(
"Inner dimensions do not match, operand A: {"
+
to_string_range
(
a_lens
)
+
"}, operand B: {"
+
to_string_range
(
b_lens
)
+
"}"
);
MIGRAPHX_THROW
(
"Inner dimensions do not match, operand A: {"
+
to_string_range
(
a_lens
)
+
"}, operand B: {"
+
to_string_range
(
b_lens
)
+
"}"
);
out_lens
[
dim_1
]
=
b_lens
[
dim_1
];
// C should be the same shape as A * B
auto
c_lens
=
inputs
[
2
].
lens
();
if
(
!
std
::
equal
(
c_lens
.
begin
(),
c_lens
.
end
(),
out_lens
.
begin
()))
{
MIGRAPHX_THROW
(
"DOT: dimension mismatch, operand C: {"
+
to_string_range
(
c_lens
)
+
"}, cannot add to operand A * B: {"
+
to_string_range
(
out_lens
)
+
"}"
);
MIGRAPHX_THROW
(
"DOT: dimension mismatch, operand C: {"
+
to_string_range
(
c_lens
)
+
"}, cannot add to operand A * B: {"
+
to_string_range
(
out_lens
)
+
"}"
);
}
}
else
...
...
@@ -938,22 +941,23 @@ struct dot
if
(
a_lens
[
1
]
!=
b_lens
[
0
])
{
MIGRAPHX_THROW
(
"DOT : dimension mismatch, operand A: {"
+
to_string_range
(
a_lens
)
+
"}, cannot multiply operand B: {"
+
to_string_range
(
b_lens
)
+
"}"
);
MIGRAPHX_THROW
(
"DOT : dimension mismatch, operand A: {"
+
to_string_range
(
a_lens
)
+
"}, cannot multiply operand B: {"
+
to_string_range
(
b_lens
)
+
"}"
);
}
out_lens
[
1
]
=
b_lens
[
1
];
out_lens
[
1
]
=
b_lens
[
1
];
// check whether C is broadcastable to A * B
auto
c_lens
=
inputs
[
2
].
lens
();
if
(
c_lens
.
size
()
>
2
||
(
c_lens
.
size
()
==
1
&&
(
c_lens
[
0
]
!=
1
&&
c_lens
[
0
]
!=
b_lens
[
1
]))
||
(
c_lens
.
size
()
==
2
&&
(
c_lens
[
0
]
!=
1
&&
c_lens
[
0
]
!=
a_lens
[
0
]))
||
(
c_lens
.
size
()
==
2
&&
(
c_lens
[
1
]
!=
1
&&
c_lens
[
1
]
!=
b_lens
[
1
])))
(
c_lens
.
size
()
==
1
&&
(
c_lens
[
0
]
!=
1
&&
c_lens
[
0
]
!=
b_lens
[
1
]))
||
(
c_lens
.
size
()
==
2
&&
(
c_lens
[
0
]
!=
1
&&
c_lens
[
0
]
!=
a_lens
[
0
]))
||
(
c_lens
.
size
()
==
2
&&
(
c_lens
[
1
]
!=
1
&&
c_lens
[
1
]
!=
b_lens
[
1
])))
{
MIGRAPHX_THROW
(
"DOT: C {"
+
to_string_range
(
c_lens
)
+
"} is not broadcastable to A * B {"
+
to_string_range
(
out_lens
)
+
"}"
);
"} is not broadcastable to A * B {"
+
to_string_range
(
out_lens
)
+
"}"
);
}
}
...
...
src/targets/gpu/gemm.cpp
View file @
360db15f
...
...
@@ -369,43 +369,47 @@ argument miopen_gemm::compute(context& ctx,
bool
is_3inputs
=
(
args
.
size
()
==
4
);
if
(
is_3inputs
)
{
fill_result
(
output_shape
,
args
[
3
],
args
[
2
]);
fill_result
(
output_shape
,
args
[
3
],
args
[
2
]);
output_shape
.
visit_type
([
&
](
auto
as
)
{
auto
n_dim
=
output_shape
.
lens
().
size
();
auto
dim_1
=
n_dim
-
1
;
auto
dim_0
=
n_dim
-
2
;
auto
alpha_r
=
to_rocblas_type
(
as
(
op
.
alpha
));
auto
beta_r
=
to_rocblas_type
(
as
(
op
.
beta
));
bool
transa
=
args
[
0
].
get_shape
().
transposed
();
bool
transb
=
args
[
1
].
get_shape
().
transposed
();
rocblas_int
lda
=
args
[
0
].
get_shape
().
strides
()[
transa
?
dim_1
:
dim_0
];
rocblas_int
ldb
=
args
[
1
].
get_shape
().
strides
()[
transb
?
dim_1
:
dim_0
];
rocblas_int
ldc
=
args
[
3
].
get_shape
().
strides
()[
dim_0
];
auto
out_lens
=
output_shape
.
lens
();
rocblas_int
m
=
out_lens
[
dim_0
];
rocblas_int
n
=
out_lens
[
dim_1
];
rocblas_int
k
=
args
[
0
].
get_shape
().
lens
()[
dim_1
];
auto
num_matrices
=
std
::
accumulate
(
out_lens
.
rbegin
()
+
2
,
out_lens
.
rend
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
auto
to_pointer
=
[
&
](
auto
&&
arg
)
{
return
to_rocblas_type
(
as
.
from
(
arg
.
data
()));
};
generic_rocblas_batched_gemm
(
as
,
ctx
.
get_stream
().
get_rocblas
(),
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
transa
?
rocblas_operation_transpose
:
rocblas_operation_none
,
n
,
m
,
k
,
&
alpha_r
,
to_pointer
(
args
[
1
]),
ldb
,
k
*
n
,
to_pointer
(
args
[
0
]),
lda
,
m
*
k
,
&
beta_r
,
to_pointer
(
args
[
3
]),
ldc
,
m
*
n
,
num_matrices
);
auto
n_dim
=
output_shape
.
lens
().
size
();
auto
dim_1
=
n_dim
-
1
;
auto
dim_0
=
n_dim
-
2
;
auto
alpha_r
=
to_rocblas_type
(
as
(
op
.
alpha
));
auto
beta_r
=
to_rocblas_type
(
as
(
op
.
beta
));
bool
transa
=
args
[
0
].
get_shape
().
transposed
();
bool
transb
=
args
[
1
].
get_shape
().
transposed
();
rocblas_int
lda
=
args
[
0
].
get_shape
().
strides
()[
transa
?
dim_1
:
dim_0
];
rocblas_int
ldb
=
args
[
1
].
get_shape
().
strides
()[
transb
?
dim_1
:
dim_0
];
rocblas_int
ldc
=
args
[
3
].
get_shape
().
strides
()[
dim_0
];
auto
out_lens
=
output_shape
.
lens
();
rocblas_int
m
=
out_lens
[
dim_0
];
rocblas_int
n
=
out_lens
[
dim_1
];
rocblas_int
k
=
args
[
0
].
get_shape
().
lens
()[
dim_1
];
auto
num_matrices
=
std
::
accumulate
(
out_lens
.
rbegin
()
+
2
,
out_lens
.
rend
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
auto
to_pointer
=
[
&
](
auto
&&
arg
)
{
return
to_rocblas_type
(
as
.
from
(
arg
.
data
()));
};
generic_rocblas_batched_gemm
(
as
,
ctx
.
get_stream
().
get_rocblas
(),
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
transa
?
rocblas_operation_transpose
:
rocblas_operation_none
,
n
,
m
,
k
,
&
alpha_r
,
to_pointer
(
args
[
1
]),
ldb
,
k
*
n
,
to_pointer
(
args
[
0
]),
lda
,
m
*
k
,
&
beta_r
,
to_pointer
(
args
[
3
]),
ldc
,
m
*
n
,
num_matrices
);
});
...
...
test/op_shape_test.cpp
View file @
360db15f
...
...
@@ -423,15 +423,19 @@ TEST_CASE(dot)
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
1
,
1
,
4
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
5
,
7
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
4
,
7
}},
migraphx
::
op
::
dot
{},
s_m1
,
s_m2
);
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
4
,
7
}},
migraphx
::
op
::
dot
{},
s_m1
,
s_m2
);
}
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
4
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
2
,
1
,
5
,
7
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
2
,
2
,
4
,
7
}},
migraphx
::
op
::
dot
{},
s_m1
,
s_m2
);
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
2
,
2
,
4
,
7
}},
migraphx
::
op
::
dot
{},
s_m1
,
s_m2
);
}
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment