Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
ad3c4c1d
Commit
ad3c4c1d
authored
Feb 25, 2022
by
Khalique Ahmed
Browse files
use void pointer to select alpha beta
parent
4fffcdd5
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
68 additions
and
119 deletions
+68
-119
src/targets/gpu/gemm_impl.cpp
src/targets/gpu/gemm_impl.cpp
+68
-119
No files found.
src/targets/gpu/gemm_impl.cpp
View file @
ad3c4c1d
...
@@ -83,9 +83,22 @@ void gemm_impl(context& ctx,
...
@@ -83,9 +83,22 @@ void gemm_impl(context& ctx,
auto
a_lens
=
args
[
0
].
get_shape
().
lens
();
auto
a_lens
=
args
[
0
].
get_shape
().
lens
();
auto
b_lens
=
args
[
1
].
get_shape
().
lens
();
auto
b_lens
=
args
[
1
].
get_shape
().
lens
();
output_shape
.
visit_type
([
&
](
auto
as
)
{
output_shape
.
visit_type
([
&
](
auto
as
)
{
auto
alpha_r
=
as
(
alpha
);
auto
alpha_r
=
as
(
alpha
);
auto
beta_r
=
as
(
beta
);
auto
beta_r
=
as
(
beta
);
// use void pointer to select different data type if using fp32 mode
void
*
alpha_v
{
&
alpha_r
};
void
*
beta_v
{
&
beta_r
};
if
(
compute_fp32
)
{
alpha_v
=
&
alpha
;
beta_v
=
&
beta
;
}
auto
out_lens
=
output_shape
.
lens
();
auto
out_lens
=
output_shape
.
lens
();
rocblas_int
m
=
out_lens
[
dim_0
];
rocblas_int
m
=
out_lens
[
dim_0
];
rocblas_int
n
=
out_lens
[
dim_1
];
rocblas_int
n
=
out_lens
[
dim_1
];
...
@@ -100,38 +113,6 @@ void gemm_impl(context& ctx,
...
@@ -100,38 +113,6 @@ void gemm_impl(context& ctx,
out_lens
.
rbegin
()
+
2
,
out_lens
.
rend
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
out_lens
.
rbegin
()
+
2
,
out_lens
.
rend
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
if
(
num_matrices
==
1
)
if
(
num_matrices
==
1
)
{
{
// the rocblas_gemm API handles inputs and output matrices as
// column-major format. When doing a C = A * B, we actually do
// C^T = (B^T) * (A^T). That is the reason we input args[1] as
// A and args[0] as B in calling the rocblas_gemm.
if
(
compute_fp32
)
rocblas_invoke
(
&
rocblas_gemm_ex
,
ctx
.
get_stream
().
get_rocblas
(),
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
transa
?
rocblas_operation_transpose
:
rocblas_operation_none
,
n
,
m
,
k
,
&
alpha
,
to_pointer
(
args
.
at
(
1
)),
arg_type
,
ldb
,
to_pointer
(
args
.
at
(
0
)),
arg_type
,
lda
,
&
beta
,
to_pointer
(
args
[
2
]),
output_type
,
ldc
,
is_3inputs
?
to_pointer
(
args
[
3
])
:
to_pointer
(
args
[
2
]),
output_type
,
ldc
,
compute_type
,
rocblas_gemm_algo_standard
,
0
,
flag
);
else
rocblas_invoke
(
&
rocblas_gemm_ex
,
rocblas_invoke
(
&
rocblas_gemm_ex
,
ctx
.
get_stream
().
get_rocblas
(),
ctx
.
get_stream
().
get_rocblas
(),
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
...
@@ -139,14 +120,14 @@ void gemm_impl(context& ctx,
...
@@ -139,14 +120,14 @@ void gemm_impl(context& ctx,
n
,
n
,
m
,
m
,
k
,
k
,
&
alpha_
r
,
alpha_
v
,
to_pointer
(
args
.
at
(
1
)),
to_pointer
(
args
.
at
(
1
)),
arg_type
,
arg_type
,
ldb
,
ldb
,
to_pointer
(
args
.
at
(
0
)),
to_pointer
(
args
.
at
(
0
)),
arg_type
,
arg_type
,
lda
,
lda
,
&
beta_
r
,
beta_
v
,
to_pointer
(
args
[
2
]),
to_pointer
(
args
[
2
]),
output_type
,
output_type
,
ldc
,
ldc
,
...
@@ -160,38 +141,6 @@ void gemm_impl(context& ctx,
...
@@ -160,38 +141,6 @@ void gemm_impl(context& ctx,
}
}
else
else
{
{
if
(
compute_fp32
)
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex
,
ctx
.
get_stream
().
get_rocblas
(),
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
transa
?
rocblas_operation_transpose
:
rocblas_operation_none
,
n
,
m
,
k
,
&
alpha
,
to_pointer
(
args
.
at
(
1
)),
arg_type
,
ldb
,
k
*
n
,
to_pointer
(
args
.
at
(
0
)),
arg_type
,
lda
,
m
*
k
,
&
beta
,
to_pointer
(
args
[
2
]),
output_type
,
ldc
,
m
*
n
,
is_3inputs
?
to_pointer
(
args
[
3
])
:
to_pointer
(
args
[
2
]),
output_type
,
ldc
,
m
*
n
,
num_matrices
,
compute_type
,
rocblas_gemm_algo_standard
,
0
,
flag
);
else
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex
,
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex
,
ctx
.
get_stream
().
get_rocblas
(),
ctx
.
get_stream
().
get_rocblas
(),
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
...
@@ -199,7 +148,7 @@ void gemm_impl(context& ctx,
...
@@ -199,7 +148,7 @@ void gemm_impl(context& ctx,
n
,
n
,
m
,
m
,
k
,
k
,
&
alpha_
r
,
alpha_
v
,
to_pointer
(
args
.
at
(
1
)),
to_pointer
(
args
.
at
(
1
)),
arg_type
,
arg_type
,
ldb
,
ldb
,
...
@@ -208,7 +157,7 @@ void gemm_impl(context& ctx,
...
@@ -208,7 +157,7 @@ void gemm_impl(context& ctx,
arg_type
,
arg_type
,
lda
,
lda
,
m
*
k
,
m
*
k
,
&
beta_
r
,
beta_
v
,
to_pointer
(
args
[
2
]),
to_pointer
(
args
[
2
]),
output_type
,
output_type
,
ldc
,
ldc
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment