Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
ca28e1e8
Commit
ca28e1e8
authored
Mar 14, 2019
by
Shucai Xiao
Browse files
clang format
parent
02f359b2
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
45 additions
and
43 deletions
+45
-43
src/targets/gpu/gemm.cpp
src/targets/gpu/gemm.cpp
+43
-42
src/targets/gpu/include/migraphx/gpu/gemm.hpp
src/targets/gpu/include/migraphx/gpu/gemm.hpp
+2
-1
No files found.
src/targets/gpu/gemm.cpp
View file @
ca28e1e8
...
@@ -265,12 +265,16 @@ argument miopen_gemm::batch_matmul(context& ctx,
...
@@ -265,12 +265,16 @@ argument miopen_gemm::batch_matmul(context& ctx,
std
::
vector
<
std
::
size_t
>
a_batch_lens
(
a_lens
.
begin
(),
a_lens
.
begin
()
+
an_dim
-
2
);
std
::
vector
<
std
::
size_t
>
a_batch_lens
(
a_lens
.
begin
(),
a_lens
.
begin
()
+
an_dim
-
2
);
std
::
vector
<
std
::
size_t
>
b_batch_lens
(
b_lens
.
begin
(),
b_lens
.
begin
()
+
bn_dim
-
2
);
std
::
vector
<
std
::
size_t
>
b_batch_lens
(
b_lens
.
begin
(),
b_lens
.
begin
()
+
bn_dim
-
2
);
if
(
a_batch_lens
==
b_batch_lens
||
a_batch_lens
.
empty
()
||
b_batch_lens
.
empty
())
if
(
a_batch_lens
==
b_batch_lens
||
a_batch_lens
.
empty
()
||
b_batch_lens
.
empty
())
{
{
std
::
size_t
numa_matrices
=
std
::
size_t
numa_matrices
=
std
::
accumulate
(
a_batch_lens
.
begin
(),
std
::
accumulate
(
a_batch_lens
.
begin
(),
a_batch_lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
a_batch_lens
.
end
(),
std
::
size_t
numb_matrices
=
std
::
size_t
{
1
},
std
::
accumulate
(
b_batch_lens
.
begin
(),
b_batch_lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
std
::
multiplies
<
std
::
size_t
>
());
std
::
size_t
numb_matrices
=
std
::
accumulate
(
b_batch_lens
.
begin
(),
b_batch_lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
std
::
size_t
num_matrices
=
std
::
max
(
numa_matrices
,
numb_matrices
);
std
::
size_t
num_matrices
=
std
::
max
(
numa_matrices
,
numb_matrices
);
rocblas_int
stride_a
=
(
numa_matrices
==
1
)
?
0
:
m
*
k
;
rocblas_int
stride_a
=
(
numa_matrices
==
1
)
?
0
:
m
*
k
;
rocblas_int
stride_b
=
(
numb_matrices
==
1
)
?
0
:
k
*
n
;
rocblas_int
stride_b
=
(
numb_matrices
==
1
)
?
0
:
k
*
n
;
...
@@ -278,7 +282,7 @@ argument miopen_gemm::batch_matmul(context& ctx,
...
@@ -278,7 +282,7 @@ argument miopen_gemm::batch_matmul(context& ctx,
output_shape
.
visit_type
([
&
](
auto
as
)
{
output_shape
.
visit_type
([
&
](
auto
as
)
{
auto
alpha_r
=
to_rocblas_type
(
as
(
op
.
alpha
));
auto
alpha_r
=
to_rocblas_type
(
as
(
op
.
alpha
));
auto
beta_r
=
to_rocblas_type
(
as
(
beta
));
auto
beta_r
=
to_rocblas_type
(
as
(
beta
));
auto
to_pointer
=
[
&
](
auto
&&
arg
)
{
return
to_rocblas_type
(
as
.
from
(
arg
.
data
()));};
auto
to_pointer
=
[
&
](
auto
&&
arg
)
{
return
to_rocblas_type
(
as
.
from
(
arg
.
data
()));
};
generic_rocblas_batched_gemm
(
generic_rocblas_batched_gemm
(
as
,
as
,
ctx
.
get_stream
().
get_rocblas
(),
ctx
.
get_stream
().
get_rocblas
(),
...
@@ -336,8 +340,7 @@ argument miopen_gemm::batch_matmul(context& ctx,
...
@@ -336,8 +340,7 @@ argument miopen_gemm::batch_matmul(context& ctx,
auto
to_pointer
=
[
&
](
auto
&&
arg
,
std
::
size_t
offset
=
0
)
{
auto
to_pointer
=
[
&
](
auto
&&
arg
,
std
::
size_t
offset
=
0
)
{
return
to_rocblas_type
(
as
.
from
(
arg
.
data
()
+
offset
));
return
to_rocblas_type
(
as
.
from
(
arg
.
data
()
+
offset
));
};
};
generic_rocblas_gemm
(
generic_rocblas_gemm
(
as
,
as
,
ctx
.
get_stream
().
get_rocblas
(),
ctx
.
get_stream
().
get_rocblas
(),
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
transa
?
rocblas_operation_transpose
:
rocblas_operation_none
,
transa
?
rocblas_operation_transpose
:
rocblas_operation_none
,
...
@@ -447,9 +450,7 @@ argument miopen_gemm::compute(context& ctx,
...
@@ -447,9 +450,7 @@ argument miopen_gemm::compute(context& ctx,
output_shape
.
visit_type
([
&
](
auto
as
)
{
output_shape
.
visit_type
([
&
](
auto
as
)
{
auto
alpha_r
=
to_rocblas_type
(
as
(
op
.
alpha
));
auto
alpha_r
=
to_rocblas_type
(
as
(
op
.
alpha
));
auto
beta_r
=
to_rocblas_type
(
as
(
beta
));
auto
beta_r
=
to_rocblas_type
(
as
(
beta
));
auto
to_pointer
=
[
&
](
auto
&&
arg
)
{
auto
to_pointer
=
[
&
](
auto
&&
arg
)
{
return
to_rocblas_type
(
as
.
from
(
arg
.
data
()));
};
return
to_rocblas_type
(
as
.
from
(
arg
.
data
()));
};
generic_rocblas_batched_gemm
(
generic_rocblas_batched_gemm
(
as
,
as
,
...
...
src/targets/gpu/include/migraphx/gpu/gemm.hpp
View file @
ca28e1e8
...
@@ -22,7 +22,8 @@ struct miopen_gemm
...
@@ -22,7 +22,8 @@ struct miopen_gemm
private:
private:
void
fill_result
(
const
shape
&
output_shape
,
const
argument
&
result
,
const
argument
&
c
)
const
;
void
fill_result
(
const
shape
&
output_shape
,
const
argument
&
result
,
const
argument
&
c
)
const
;
argument
batch_matmul
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
)
const
;
argument
batch_matmul
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
)
const
;
};
};
}
// namespace gpu
}
// namespace gpu
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment