Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
45e82cad
Commit
45e82cad
authored
Apr 03, 2019
by
Shucai Xiao
Browse files
clang format
parent
154b9287
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
41 additions
and
42 deletions
+41
-42
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+2
-1
src/targets/gpu/gemm.cpp
src/targets/gpu/gemm.cpp
+39
-41
No files found.
src/include/migraphx/operators.hpp
View file @
45e82cad
...
...
@@ -857,7 +857,8 @@ struct dot
out_lens
[
dim_1
]
=
b
.
lens
()[
dim_1
];
if
(
inputs
.
size
()
==
3
&&
out_lens
!=
inputs
.
at
(
2
).
lens
())
{
MIGRAPHX_THROW
(
"DOT: dimension mismatch, operand C: {"
+
to_string_range
(
inputs
.
at
(
2
).
lens
())
+
MIGRAPHX_THROW
(
"DOT: dimension mismatch, operand C: {"
+
to_string_range
(
inputs
.
at
(
2
).
lens
())
+
"}, cannot add to operand A * B: {"
+
to_string_range
(
out_lens
)
+
"}"
);
}
...
...
src/targets/gpu/gemm.cpp
View file @
45e82cad
...
...
@@ -274,60 +274,58 @@ argument miopen_gemm::compute(context& ctx,
// matrix * vector (b is a vector)
else
if
(
b_lens
.
size
()
==
2
&&
b_lens
.
at
(
1
)
==
1
)
{
bool
transa
=
args
[
0
].
get_shape
().
transposed
();
rocblas_int
m
=
static_cast
<
rocblas_int
>
(
a_lens
[
0
]);
rocblas_int
n
=
static_cast
<
rocblas_int
>
(
a_lens
[
1
]);
rocblas_int
lda
=
args
[
0
].
get_shape
().
strides
()[
transa
?
1
:
0
];
float
beta
=
0.0
f
;
bool
transa
=
args
[
0
].
get_shape
().
transposed
();
rocblas_int
m
=
static_cast
<
rocblas_int
>
(
a_lens
[
0
]);
rocblas_int
n
=
static_cast
<
rocblas_int
>
(
a_lens
[
1
]);
rocblas_int
lda
=
args
[
0
].
get_shape
().
strides
()[
transa
?
1
:
0
];
float
beta
=
0.0
f
;
assert
(
a_lens
.
back
()
==
args
[
1
].
get_shape
().
elements
());
output_shape
.
visit_type
([
&
](
auto
as
)
{
auto
alpha_r
=
to_rocblas_type
(
as
(
op
.
alpha
));
auto
beta_r
=
to_rocblas_type
(
as
(
beta
));
auto
to_pointer
=
[
&
](
auto
&&
arg
)
{
return
to_rocblas_type
(
as
.
from
(
arg
.
data
()));
};
generic_rocblas_gemv
(
as
,
ctx
.
get_stream
().
get_rocblas
(),
transa
?
rocblas_operation_transpose
:
rocblas_operation_none
,
m
,
n
,
&
alpha_r
,
to_pointer
(
args
[
0
]),
lda
,
to_pointer
(
args
[
1
]),
1
,
&
beta_r
,
to_pointer
(
args
[
2
]),
1
);
generic_rocblas_gemv
(
as
,
ctx
.
get_stream
().
get_rocblas
(),
transa
?
rocblas_operation_transpose
:
rocblas_operation_none
,
m
,
n
,
&
alpha_r
,
to_pointer
(
args
[
0
]),
lda
,
to_pointer
(
args
[
1
]),
1
,
&
beta_r
,
to_pointer
(
args
[
2
]),
1
);
});
}
// vector * matrix (a is a vector)
else
if
(
a_lens
.
size
()
==
2
&&
a_lens
.
at
(
0
)
==
1
)
{
bool
transb
=
!
args
[
1
].
get_shape
().
transposed
();
rocblas_int
ldb
=
args
[
1
].
get_shape
().
strides
()[(
transb
?
1
:
0
)];
rocblas_int
m
=
b_lens
[
0
];
rocblas_int
n
=
b_lens
[
1
];
float
beta
=
0.0
f
;
bool
transb
=
!
args
[
1
].
get_shape
().
transposed
();
rocblas_int
ldb
=
args
[
1
].
get_shape
().
strides
()[(
transb
?
1
:
0
)];
rocblas_int
m
=
b_lens
[
0
];
rocblas_int
n
=
b_lens
[
1
];
float
beta
=
0.0
f
;
assert
(
b_lens
[
0
]
==
args
[
0
].
get_shape
().
elements
());
output_shape
.
visit_type
([
&
](
auto
as
)
{
auto
alpha_r
=
to_rocblas_type
(
as
(
op
.
alpha
));
auto
beta_r
=
to_rocblas_type
(
as
(
beta
));
auto
to_pointer
=
[
&
](
auto
&&
arg
)
{
return
to_rocblas_type
(
as
.
from
(
arg
.
data
()));
};
generic_rocblas_gemv
(
as
,
ctx
.
get_stream
().
get_rocblas
(),
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
m
,
n
,
&
alpha_r
,
to_pointer
(
args
[
1
]),
ldb
,
to_pointer
(
args
[
0
]),
1
,
&
beta_r
,
to_pointer
(
args
[
2
]),
1
);
generic_rocblas_gemv
(
as
,
ctx
.
get_stream
().
get_rocblas
(),
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
m
,
n
,
&
alpha_r
,
to_pointer
(
args
[
1
]),
ldb
,
to_pointer
(
args
[
0
]),
1
,
&
beta_r
,
to_pointer
(
args
[
2
]),
1
);
});
}
// batch matrix multiplication
...
...
@@ -337,7 +335,7 @@ argument miopen_gemm::compute(context& ctx,
auto
n_dim
=
output_shape
.
lens
().
size
();
auto
dim_1
=
n_dim
-
1
;
auto
dim_0
=
n_dim
-
2
;
float
beta
=
0.0
f
;
float
beta
=
0.0
f
;
auto
alpha_r
=
to_rocblas_type
(
as
(
op
.
alpha
));
auto
beta_r
=
to_rocblas_type
(
as
(
beta
));
bool
transa
=
args
[
0
].
get_shape
().
transposed
();
...
...
@@ -374,9 +372,9 @@ argument miopen_gemm::compute(context& ctx,
ldc
,
m
*
n
,
num_matrices
);
});
});
}
return
args
[
2
];
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment