Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
ca28e1e8
Commit
ca28e1e8
authored
Mar 14, 2019
by
Shucai Xiao
Browse files
clang format
parent
02f359b2
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
45 additions
and
43 deletions
+45
-43
src/targets/gpu/gemm.cpp
src/targets/gpu/gemm.cpp
+43
-42
src/targets/gpu/include/migraphx/gpu/gemm.hpp
src/targets/gpu/include/migraphx/gpu/gemm.hpp
+2
-1
No files found.
src/targets/gpu/gemm.cpp
View file @
ca28e1e8
...
...
@@ -244,15 +244,15 @@ argument miopen_gemm::batch_matmul(context& ctx,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
)
const
{
bool
transa
=
args
[
0
].
get_shape
().
transposed
();
bool
transb
=
args
[
1
].
get_shape
().
transposed
();
bool
transa
=
args
[
0
].
get_shape
().
transposed
();
bool
transb
=
args
[
1
].
get_shape
().
transposed
();
auto
a_lens
=
args
[
0
].
get_shape
().
lens
();
auto
b_lens
=
args
[
1
].
get_shape
().
lens
();
auto
out_lens
=
output_shape
.
lens
();
auto
an_dim
=
a_lens
.
size
();
auto
bn_dim
=
b_lens
.
size
();
auto
an_dim
=
a_lens
.
size
();
auto
bn_dim
=
b_lens
.
size
();
auto
outn_dim
=
out_lens
.
size
();
rocblas_int
lda
=
args
[
0
].
get_shape
().
strides
()[
transa
?
an_dim
-
1
:
an_dim
-
2
];
...
...
@@ -265,20 +265,24 @@ argument miopen_gemm::batch_matmul(context& ctx,
std
::
vector
<
std
::
size_t
>
a_batch_lens
(
a_lens
.
begin
(),
a_lens
.
begin
()
+
an_dim
-
2
);
std
::
vector
<
std
::
size_t
>
b_batch_lens
(
b_lens
.
begin
(),
b_lens
.
begin
()
+
bn_dim
-
2
);
if
(
a_batch_lens
==
b_batch_lens
||
a_batch_lens
.
empty
()
||
b_batch_lens
.
empty
())
if
(
a_batch_lens
==
b_batch_lens
||
a_batch_lens
.
empty
()
||
b_batch_lens
.
empty
())
{
std
::
size_t
numa_matrices
=
std
::
accumulate
(
a_batch_lens
.
begin
(),
a_batch_lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
std
::
size_t
numb_matrices
=
std
::
accumulate
(
b_batch_lens
.
begin
(),
b_batch_lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
std
::
size_t
num_matrices
=
std
::
max
(
numa_matrices
,
numb_matrices
);
rocblas_int
stride_a
=
(
numa_matrices
==
1
)
?
0
:
m
*
k
;
rocblas_int
stride_b
=
(
numb_matrices
==
1
)
?
0
:
k
*
n
;
rocblas_int
stride_c
=
m
*
n
;
std
::
size_t
numa_matrices
=
std
::
accumulate
(
a_batch_lens
.
begin
(),
a_batch_lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
std
::
size_t
numb_matrices
=
std
::
accumulate
(
b_batch_lens
.
begin
(),
b_batch_lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
std
::
size_t
num_matrices
=
std
::
max
(
numa_matrices
,
numb_matrices
);
rocblas_int
stride_a
=
(
numa_matrices
==
1
)
?
0
:
m
*
k
;
rocblas_int
stride_b
=
(
numb_matrices
==
1
)
?
0
:
k
*
n
;
rocblas_int
stride_c
=
m
*
n
;
output_shape
.
visit_type
([
&
](
auto
as
)
{
auto
alpha_r
=
to_rocblas_type
(
as
(
op
.
alpha
));
auto
beta_r
=
to_rocblas_type
(
as
(
beta
));
auto
to_pointer
=
[
&
](
auto
&&
arg
)
{
return
to_rocblas_type
(
as
.
from
(
arg
.
data
()));};
auto
to_pointer
=
[
&
](
auto
&&
arg
)
{
return
to_rocblas_type
(
as
.
from
(
arg
.
data
()));
};
generic_rocblas_batched_gemm
(
as
,
ctx
.
get_stream
().
get_rocblas
(),
...
...
@@ -313,19 +317,19 @@ argument miopen_gemm::batch_matmul(context& ctx,
shape_for_each
(
out_batch_shape
,
[
&
](
auto
out_idx
)
{
std
::
size_t
out_ind
=
out_batch_shape
.
index
(
out_idx
.
begin
(),
out_idx
.
end
());
auto
type_size
=
output_shape
.
type_size
();
auto
type_size
=
output_shape
.
type_size
();
std
::
vector
<
std
::
size_t
>
a_idx
(
a_batch_lens
.
size
());
std
::
vector
<
std
::
size_t
>
b_idx
(
b_batch_lens
.
size
());
std
::
transform
(
out_idx
.
begin
()
+
a_len_diff
,
out_idx
.
end
(),
a_batch_lens
.
begin
(),
a_idx
.
begin
(),
[
&
](
auto
i
,
auto
j
)
{
return
(
j
==
1
)
?
0
:
i
;
});
out_idx
.
end
(),
a_batch_lens
.
begin
(),
a_idx
.
begin
(),
[
&
](
auto
i
,
auto
j
)
{
return
(
j
==
1
)
?
0
:
i
;
});
std
::
transform
(
out_idx
.
begin
()
+
b_len_diff
,
out_idx
.
end
(),
b_batch_lens
.
begin
(),
b_idx
.
begin
(),
[
&
](
auto
i
,
auto
j
)
{
return
(
j
==
1
)
?
0
:
i
;
});
out_idx
.
end
(),
b_batch_lens
.
begin
(),
b_idx
.
begin
(),
[
&
](
auto
i
,
auto
j
)
{
return
(
j
==
1
)
?
0
:
i
;
});
std
::
size_t
a_ind
=
a_batch_shape
.
index
(
a_idx
.
begin
(),
a_idx
.
end
());
std
::
size_t
b_ind
=
b_batch_shape
.
index
(
b_idx
.
begin
(),
b_idx
.
end
());
...
...
@@ -336,22 +340,21 @@ argument miopen_gemm::batch_matmul(context& ctx,
auto
to_pointer
=
[
&
](
auto
&&
arg
,
std
::
size_t
offset
=
0
)
{
return
to_rocblas_type
(
as
.
from
(
arg
.
data
()
+
offset
));
};
generic_rocblas_gemm
(
as
,
ctx
.
get_stream
().
get_rocblas
(),
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
transa
?
rocblas_operation_transpose
:
rocblas_operation_none
,
n
,
m
,
k
,
&
alpha_r
,
to_pointer
(
args
[
1
],
k
*
n
*
b_ind
*
type_size
),
ldb
,
to_pointer
(
args
[
0
],
m
*
k
*
a_ind
*
type_size
),
lda
,
&
beta_r
,
to_pointer
(
args
[
2
],
m
*
n
*
out_ind
*
type_size
),
ldc
);
generic_rocblas_gemm
(
as
,
ctx
.
get_stream
().
get_rocblas
(),
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
transa
?
rocblas_operation_transpose
:
rocblas_operation_none
,
n
,
m
,
k
,
&
alpha_r
,
to_pointer
(
args
[
1
],
k
*
n
*
b_ind
*
type_size
),
ldb
,
to_pointer
(
args
[
0
],
m
*
k
*
a_ind
*
type_size
),
lda
,
&
beta_r
,
to_pointer
(
args
[
2
],
m
*
n
*
out_ind
*
type_size
),
ldc
);
});
});
}
...
...
@@ -447,9 +450,7 @@ argument miopen_gemm::compute(context& ctx,
output_shape
.
visit_type
([
&
](
auto
as
)
{
auto
alpha_r
=
to_rocblas_type
(
as
(
op
.
alpha
));
auto
beta_r
=
to_rocblas_type
(
as
(
beta
));
auto
to_pointer
=
[
&
](
auto
&&
arg
)
{
return
to_rocblas_type
(
as
.
from
(
arg
.
data
()));
};
auto
to_pointer
=
[
&
](
auto
&&
arg
)
{
return
to_rocblas_type
(
as
.
from
(
arg
.
data
()));
};
generic_rocblas_batched_gemm
(
as
,
...
...
src/targets/gpu/include/migraphx/gpu/gemm.hpp
View file @
ca28e1e8
...
...
@@ -22,7 +22,8 @@ struct miopen_gemm
private:
void
fill_result
(
const
shape
&
output_shape
,
const
argument
&
result
,
const
argument
&
c
)
const
;
argument
batch_matmul
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
)
const
;
argument
batch_matmul
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
)
const
;
};
}
// namespace gpu
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment