Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
f9c38c09
Commit
f9c38c09
authored
Mar 14, 2019
by
Shucai Xiao
Browse files
support the case of multiple matrices as inputs with C, but C should be the same shape as A * B
parent
ca28e1e8
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
91 additions
and
39 deletions
+91
-39
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+58
-26
src/targets/gpu/gemm.cpp
src/targets/gpu/gemm.cpp
+17
-11
test/op_shape_test.cpp
test/op_shape_test.cpp
+16
-2
No files found.
src/include/migraphx/operators.hpp
View file @
f9c38c09
...
...
@@ -808,10 +808,12 @@ struct gather
};
// The dot operation is combination of the onnx GEMM and MatMul operators.
// For GEMM, it support the C matrix in the formula alpha * AB + beta * C,
// in which C is broadcastable to the shape of AB. For the transpose of A
// and B, we add a tranpose operator beforehand if the onnx gemm operator
// indicates a transpose.
// For GEMM, it support two cases: 1) in the formula alpha * AB + beta * C,
// A and B are 2-D matrics and C is broadcastable to the shape of A*B. For
// the transpose of A and B, we add a tranpose operator beforehand if the
// onnx gemm operator indicates a transpose required. 2) A and B are more
// than 2-D, then the dims except the last 2-D in A and B need to be the
// same, and C should be the same shape as A * B
// For MatMul, it has the same definition as the numpy.matmul, which means
// A, B could be 1 to N-dims. For 1-dim input of A, it is a vector * matrix,
// for 1-dim of B, it is a matrix * vector. Note that there is not support
...
...
@@ -893,36 +895,66 @@ struct dot
std
::
string
name
()
const
{
return
"dot"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
// If there are 3 inputs, then A and B must be matrices and
// C should be broadcastable to A * B
// If there are 3 inputs, there are two scenarios:
// 1. A and B are 2-D matrices and C is broadcastable to A * B
// 2. A and B are stack of matrices, then shape for the batch
// should be the same for A and B, and C is the same shape
// as A * B (For now, we add this requirement to simplify the
// implementation. we can remove this requirement later)
if
(
inputs
.
size
()
==
3
)
{
check_shapes
{
inputs
,
*
this
}.
has
(
3
).
same_type
();
check_shapes
{{
inputs
[
0
]},
*
this
}.
only_dims
(
2
);
check_shapes
{{
inputs
[
1
]},
*
this
}.
only_dims
(
2
);
auto
a_lens
=
inputs
[
0
].
lens
();
auto
b_lens
=
inputs
[
1
].
lens
();
auto
out_lens
=
a_lens
;
auto
t
=
inputs
[
0
].
type
();
if
(
a_lens
[
1
]
!=
b_lens
[
0
]
)
if
(
inputs
[
1
].
lens
().
size
()
>
2
)
{
MIGRAPHX_THROW
(
"DOT : dimension mismatch, operand A: {"
+
to_string_range
(
a_lens
)
+
"}, cannot multiply operand B: {"
+
to_string_range
(
b_lens
)
+
"}"
);
}
if
(
!
std
::
equal
(
a_lens
.
rbegin
()
+
2
,
a_lens
.
rend
(),
b_lens
.
rbegin
()
+
2
))
{
MIGRAPHX_THROW
(
"DOT: dimension mismatch, operand A: {"
+
to_string_range
(
a_lens
)
+
"}, cannot multiply operand B: {"
+
to_string_range
(
b_lens
)
+
"}"
);
}
auto
out_lens
=
a_lens
;
out_lens
[
1
]
=
b_lens
[
1
];
// check whether C is broadcastable to A * B
auto
c_lens
=
inputs
[
2
].
lens
();
if
(
c_lens
.
size
()
>
2
||
(
c_lens
.
size
()
==
1
&&
(
c_lens
[
0
]
!=
1
&&
c_lens
[
0
]
!=
b_lens
[
1
]))
||
(
c_lens
.
size
()
==
2
&&
(
c_lens
[
0
]
!=
1
&&
c_lens
[
0
]
!=
a_lens
[
0
]))
||
(
c_lens
.
size
()
==
2
&&
(
c_lens
[
1
]
!=
1
&&
c_lens
[
1
]
!=
b_lens
[
1
])))
std
::
size_t
dim_0
=
a_lens
.
size
()
-
2
;
std
::
size_t
dim_1
=
a_lens
.
size
()
-
1
;
if
(
a_lens
[
dim_1
]
!=
b_lens
[
dim_0
])
MIGRAPHX_THROW
(
"Inner dimensions do not match, operand A: {"
+
to_string_range
(
a_lens
)
+
"}, operand B: {"
+
to_string_range
(
b_lens
)
+
"}"
);
out_lens
[
dim_1
]
=
b_lens
[
dim_1
];
// C should be the same shape as A * B
auto
c_lens
=
inputs
[
2
].
lens
();
if
(
!
std
::
equal
(
c_lens
.
begin
(),
c_lens
.
end
(),
out_lens
.
begin
()))
{
MIGRAPHX_THROW
(
"DOT: dimension mismatch, operand C: {"
+
to_string_range
(
c_lens
)
+
"}, cannot add to operand A * B: {"
+
to_string_range
(
out_lens
)
+
"}"
);
}
}
else
{
MIGRAPHX_THROW
(
"DOT: C {"
+
to_string_range
(
c_lens
)
+
"} is not broadcastable to A * B {"
+
to_string_range
(
out_lens
)
+
"}"
);
check_shapes
{
inputs
,
*
this
}.
has
(
3
).
same_type
();
check_shapes
{{
inputs
[
0
]},
*
this
}.
only_dims
(
2
);
check_shapes
{{
inputs
[
1
]},
*
this
}.
only_dims
(
2
);
if
(
a_lens
[
1
]
!=
b_lens
[
0
])
{
MIGRAPHX_THROW
(
"DOT : dimension mismatch, operand A: {"
+
to_string_range
(
a_lens
)
+
"}, cannot multiply operand B: {"
+
to_string_range
(
b_lens
)
+
"}"
);
}
out_lens
[
1
]
=
b_lens
[
1
];
// check whether C is broadcastable to A * B
auto
c_lens
=
inputs
[
2
].
lens
();
if
(
c_lens
.
size
()
>
2
||
(
c_lens
.
size
()
==
1
&&
(
c_lens
[
0
]
!=
1
&&
c_lens
[
0
]
!=
b_lens
[
1
]))
||
(
c_lens
.
size
()
==
2
&&
(
c_lens
[
0
]
!=
1
&&
c_lens
[
0
]
!=
a_lens
[
0
]))
||
(
c_lens
.
size
()
==
2
&&
(
c_lens
[
1
]
!=
1
&&
c_lens
[
1
]
!=
b_lens
[
1
])))
{
MIGRAPHX_THROW
(
"DOT: C {"
+
to_string_range
(
c_lens
)
+
"} is not broadcastable to A * B {"
+
to_string_range
(
out_lens
)
+
"}"
);
}
}
return
{
t
,
out_lens
};
...
...
src/targets/gpu/gemm.cpp
View file @
f9c38c09
...
...
@@ -369,23 +369,25 @@ argument miopen_gemm::compute(context& ctx,
bool
is_3inputs
=
(
args
.
size
()
==
4
);
if
(
is_3inputs
)
{
fill_result
(
output_shape
,
args
[
3
],
args
[
2
]);
fill_result
(
output_shape
,
args
[
3
],
args
[
2
]);
output_shape
.
visit_type
([
&
](
auto
as
)
{
auto
n_dim
=
output_shape
.
lens
().
size
();
auto
dim_1
=
n_dim
-
1
;
auto
dim_0
=
n_dim
-
2
;
auto
alpha_r
=
to_rocblas_type
(
as
(
op
.
alpha
));
auto
beta_r
=
to_rocblas_type
(
as
(
op
.
beta
));
bool
transa
=
args
[
0
].
get_shape
().
transposed
();
bool
transb
=
args
[
1
].
get_shape
().
transposed
();
rocblas_int
lda
=
args
[
0
].
get_shape
().
strides
()[
transa
?
1
:
0
];
rocblas_int
ldb
=
args
[
1
].
get_shape
().
strides
()[
transb
?
1
:
0
];
rocblas_int
ldc
=
args
[
3
].
get_shape
().
strides
()[
0
];
rocblas_int
lda
=
args
[
0
].
get_shape
().
strides
()[
transa
?
dim_1
:
dim_
0
];
rocblas_int
ldb
=
args
[
1
].
get_shape
().
strides
()[
transb
?
dim_1
:
dim_
0
];
rocblas_int
ldc
=
args
[
3
].
get_shape
().
strides
()[
dim_
0
];
auto
out_lens
=
output_shape
.
lens
();
rocblas_int
m
=
out_lens
[
0
];
rocblas_int
n
=
out_lens
[
1
];
rocblas_int
k
=
args
[
0
].
get_shape
().
lens
()[
1
];
rocblas_int
m
=
out_lens
[
dim_0
];
rocblas_int
n
=
out_lens
[
dim_1
];
rocblas_int
k
=
args
[
0
].
get_shape
().
lens
()[
dim_1
];
auto
num_matrices
=
std
::
accumulate
(
out_lens
.
rbegin
()
+
2
,
out_lens
.
rend
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
auto
to_pointer
=
[
&
](
auto
&&
arg
)
{
return
to_rocblas_type
(
as
.
from
(
arg
.
data
()));
};
generic_rocblas_gemm
(
as
,
generic_rocblas_batched_gemm
(
as
,
ctx
.
get_stream
().
get_rocblas
(),
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
transa
?
rocblas_operation_transpose
:
rocblas_operation_none
,
...
...
@@ -395,11 +397,15 @@ argument miopen_gemm::compute(context& ctx,
&
alpha_r
,
to_pointer
(
args
[
1
]),
ldb
,
k
*
n
,
to_pointer
(
args
[
0
]),
lda
,
m
*
k
,
&
beta_r
,
to_pointer
(
args
[
3
]),
ldc
);
ldc
,
m
*
n
,
num_matrices
);
});
...
...
test/op_shape_test.cpp
View file @
f9c38c09
...
...
@@ -420,6 +420,20 @@ TEST_CASE(dot)
s_m2
);
}
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
1
,
1
,
4
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
5
,
7
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
4
,
7
}},
migraphx
::
op
::
dot
{},
s_m1
,
s_m2
);
}
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
4
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
2
,
1
,
5
,
7
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
2
,
2
,
4
,
7
}},
migraphx
::
op
::
dot
{},
s_m1
,
s_m2
);
}
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
3
,
1
,
4
,
6
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
3
,
1
,
5
,
7
}};
...
...
@@ -433,14 +447,14 @@ TEST_CASE(dot)
}
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
1
,
1
,
4
,
5
}};
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
1
,
3
,
4
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
5
,
7
}};
throws_shape
(
migraphx
::
op
::
dot
{},
s_m1
,
s_m2
);
}
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
4
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
2
,
1
,
5
,
7
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
5
,
7
}};
throws_shape
(
migraphx
::
op
::
dot
{},
s_m1
,
s_m2
);
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment