Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
9d52515a
Commit
9d52515a
authored
Mar 08, 2019
by
Shucai Xiao
Browse files
code backup for an initial implementation of the function compute_shape.
parent
c45be227
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
147 additions
and
43 deletions
+147
-43
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+127
-1
src/targets/gpu/gemm.cpp
src/targets/gpu/gemm.cpp
+20
-42
No files found.
src/include/migraphx/operators.hpp
View file @
9d52515a
...
@@ -810,7 +810,8 @@ struct gather
...
@@ -810,7 +810,8 @@ struct gather
// The dot operation is combination of the onnx GEMM and MatMul operators.
// The dot operation is combination of the onnx GEMM and MatMul operators.
// For GEMM, it support the C matrix in the formula alpha * AB + beta * C,
// For GEMM, it support the C matrix in the formula alpha * AB + beta * C,
// in which C is broadcastable to the shape of AB. For the transpose of A
// in which C is broadcastable to the shape of AB. For the transpose of A
// and B, we add a tranpose operator if the onnx file needs.
// and B, we add a tranpose operator beforehand if the onnx gemm operator
// indicates a transpose.
// For MatMul, it has the same definition as the numpy.matmul, which means
// For MatMul, it has the same definition as the numpy.matmul, which means
// A, B could be 1 to N-dims. For 1-dim input of A, it is a vector * matrix,
// A, B could be 1 to N-dims. For 1-dim input of A, it is a vector * matrix,
// for 1-dim of B, it is a matrix * vector. Note that there is not support
// for 1-dim of B, it is a matrix * vector. Note that there is not support
...
@@ -829,6 +830,54 @@ struct dot
...
@@ -829,6 +830,54 @@ struct dot
return
pack
(
f
(
self
.
alpha
,
"alpha"
),
f
(
self
.
beta
,
"beta"
));
return
pack
(
f
(
self
.
alpha
,
"alpha"
),
f
(
self
.
beta
,
"beta"
));
}
}
std
::
vector
<
std
::
size_t
>
shape_broadcast
(
std
::
vector
<
std
::
size_t
>
&
a
,
std
::
vector
<
std
::
size_t
>
&
b
)
const
{
if
(
a
.
empty
())
return
b
;
else
if
(
b
.
empty
())
return
a
;
auto
a_size
=
a
.
size
();
auto
b_size
=
b
.
size
();
auto
n_dim
=
std
::
min
(
a_size
,
b_size
);
std
::
vector
<
std
::
size_t
>
out_lens
(
std
::
max
(
a_size
,
b_size
));
for
(
std
::
size_t
i
=
0
;
i
<
n_dim
;
++
i
)
{
if
(
a
[
a_size
-
1
-
i
]
==
b
[
b_size
-
1
-
i
])
{
out_lens
[
i
]
=
a
[
a_size
-
1
-
i
];
}
else
if
(
a
[
a_size
-
1
-
i
]
==
1
)
{
out_lens
[
i
]
=
b
[
b_size
-
1
-
i
];
}
else
if
(
b
[
b_size
-
1
-
i
]
==
1
)
{
out_lens
[
i
]
=
a
[
a_size
-
1
-
i
];
}
else
{
MIGRAPHX_THROW
(
"DOT : dimension mismatch, matrix A: {"
+
to_string_range
(
a
)
+
"}, and matrix B: {"
+
to_string_range
(
b
)
+
"} are not broadcastable"
);
}
}
if
(
a_size
>
n_dim
)
{
std
::
copy
(
a
.
rbegin
()
+
n_dim
,
a
.
rend
(),
out_lens
.
begin
()
+
n_dim
);
}
if
(
b_size
>
n_dim
)
{
std
::
copy
(
b
.
rbegin
()
+
n_dim
,
b
.
rend
(),
out_lens
.
rbegin
()
+
n_dim
);
}
std
::
reverse
(
out_lens
.
begin
(),
out_lens
.
end
());
return
out_lens
;
}
std
::
string
name
()
const
{
return
"dot"
;
}
std
::
string
name
()
const
{
return
"dot"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
...
@@ -837,6 +886,83 @@ struct dot
...
@@ -837,6 +886,83 @@ struct dot
const
shape
&
b
=
inputs
.
at
(
1
);
const
shape
&
b
=
inputs
.
at
(
1
);
auto
t
=
a
.
type
();
auto
t
=
a
.
type
();
if
(
a
.
scalar
()
||
b
.
scalar
())
{
MIGRAPHX_THROW
(
"DOT: scalar operands are not allowed, use op::mul{} instead"
);
}
auto
a_lens
=
a
.
lens
();
auto
b_lens
=
b
.
lens
();
std
::
vector
<
std
::
size_t
>
out_lens
;
if
(
a_lens
.
size
()
==
1
)
{
// inner product, output is a scalar, following numpy.matmul()
if
(
b_lens
.
size
()
==
1
)
{
if
(
a_lens
.
front
()
!=
b_lens
.
front
())
{
MIGRAPHX_THROW
(
"DOT : dimension mismatch, vector A: {"
+
to_string_range
(
a_lens
)
+
"}, cannot multiply vector B: {"
+
to_string_range
(
b_lens
)
+
"}"
);
}
}
else
{
std
::
size_t
dim_0
=
b_lens
.
size
()
-
2
;
if
(
a_lens
.
front
()
!=
b_lens
[
dim_0
])
{
MIGRAPHX_THROW
(
"DOT : dimension mismatch, vector A: {"
+
to_string_range
(
a_lens
)
+
"}, cannot multiply matrix B: {"
+
to_string_range
(
b_lens
)
+
"}"
);
}
out_lens
=
b_lens
;
out_lens
.
erase
(
out_lens
.
begin
()
+
dim_0
);
}
}
else
{
std
::
size_t
dim_0
=
a_lens
.
size
()
-
1
;
if
(
b_lens
.
size
()
==
1
)
{
if
(
a_lens
.
back
()
!=
b_lens
.
back
())
{
MIGRAPHX_THROW
(
"DOT : dimension mismatch, matrix A: {"
+
to_string_range
(
a_lens
)
+
"}, cannot multiply vector B: {"
+
to_string_range
(
b_lens
)
+
"}"
);
}
out_lens
=
a_lens
;
out_lens
.
pop_back
();
}
else
{
std
::
size_t
dim_0
=
a_lens
.
size
()
-
1
;
std
::
size_t
dim_1
=
b_lens
.
size
()
-
2
;
if
(
a_lens
[
dim_0
]
!=
b_lens
[
dim_1
])
{
MIGRAPHX_THROW
(
"DOT : dimension mismatch, matrix A: {"
+
to_string_range
(
a_lens
)
+
"}, cannot multiply matrix B: {"
+
to_string_range
(
b_lens
)
+
"}"
);
}
a_lens
.
pop_back
();
std
::
size_t
out_m
=
a_lens
.
back
();
a_lens
.
pop_back
();
std
::
size_t
out_n
=
b_lens
.
back
();
b_lens
.
pop_back
();
b_lens
.
pop_back
();
out_lens
=
shape_broadcast
(
a_lens
,
b_lens
);
out_lens
.
push_back
(
out_m
);
out_lens
.
push_back
(
out_n
);
}
}
// c is broadcast
if
(
inputs
.
size
()
==
3
)
// according to the specification of the numpy.matmul()
// according to the specification of the numpy.matmul()
// inputs with the shape dims more than 2 are acceptable
// inputs with the shape dims more than 2 are acceptable
// as long as dim values are the same in the two inputs
// as long as dim values are the same in the two inputs
...
...
src/targets/gpu/gemm.cpp
View file @
9d52515a
...
@@ -127,9 +127,6 @@ argument miopen_gemm::compute(context& ctx,
...
@@ -127,9 +127,6 @@ argument miopen_gemm::compute(context& ctx,
auto
alpha_r
=
to_rocblas_type
(
as
(
op
.
alpha
));
auto
alpha_r
=
to_rocblas_type
(
as
(
op
.
alpha
));
auto
beta_r
=
to_rocblas_type
(
as
(
op
.
beta
));
auto
beta_r
=
to_rocblas_type
(
as
(
op
.
beta
));
auto
to_pointer
=
[
&
](
auto
&&
arg
)
{
return
to_rocblas_type
(
as
.
from
(
arg
.
data
()));
};
auto
to_pointer
=
[
&
](
auto
&&
arg
)
{
return
to_rocblas_type
(
as
.
from
(
arg
.
data
()));
};
// call the strided implementation only if there are multiple matrices
if
(
batch_num
>
1
)
{
generic_rocblas_batched_gemm
(
generic_rocblas_batched_gemm
(
as
,
as
,
ctx
.
get_stream
().
get_rocblas
(),
ctx
.
get_stream
().
get_rocblas
(),
...
@@ -150,25 +147,6 @@ argument miopen_gemm::compute(context& ctx,
...
@@ -150,25 +147,6 @@ argument miopen_gemm::compute(context& ctx,
ldc
,
ldc
,
m
*
n
,
m
*
n
,
batch_num
);
batch_num
);
}
else
{
generic_rocblas_gemm
(
as
,
ctx
.
get_stream
().
get_rocblas
(),
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
transa
?
rocblas_operation_transpose
:
rocblas_operation_none
,
n
,
m
,
k
,
&
alpha_r
,
to_pointer
(
args
[
1
]),
ldb
,
to_pointer
(
args
[
0
]),
lda
,
&
beta_r
,
to_pointer
(
args
[
2
]),
ldc
);
}
});
});
return
(
is_3inputs
?
args
[
3
]
:
args
[
2
]);
return
(
is_3inputs
?
args
[
3
]
:
args
[
2
]);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment