Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
783e9474
Commit
783e9474
authored
Feb 08, 2022
by
Khalique Ahmed
Browse files
change type for alpha beta
parent
297bfdd0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
117 additions
and
57 deletions
+117
-57
src/targets/gpu/gemm_impl.cpp
src/targets/gpu/gemm_impl.cpp
+117
-57
No files found.
src/targets/gpu/gemm_impl.cpp
View file @
783e9474
...
@@ -83,8 +83,8 @@ void gemm_impl(context& ctx,
...
@@ -83,8 +83,8 @@ void gemm_impl(context& ctx,
auto
a_lens
=
args
[
0
].
get_shape
().
lens
();
auto
a_lens
=
args
[
0
].
get_shape
().
lens
();
auto
b_lens
=
args
[
1
].
get_shape
().
lens
();
auto
b_lens
=
args
[
1
].
get_shape
().
lens
();
output_shape
.
visit_type
([
&
](
auto
as
)
{
output_shape
.
visit_type
([
&
](
auto
as
)
{
auto
alpha_r
=
compute_fp32
?
alpha
:
as
(
alpha
);
auto
alpha_r
=
as
(
alpha
);
auto
beta_r
=
compute_fp32
?
beta
:
as
(
beta
);
auto
beta_r
=
as
(
beta
);
auto
out_lens
=
output_shape
.
lens
();
auto
out_lens
=
output_shape
.
lens
();
rocblas_int
m
=
out_lens
[
dim_0
];
rocblas_int
m
=
out_lens
[
dim_0
];
...
@@ -104,64 +104,124 @@ void gemm_impl(context& ctx,
...
@@ -104,64 +104,124 @@ void gemm_impl(context& ctx,
// column-major format. When doing a C = A * B, we actually do
// column-major format. When doing a C = A * B, we actually do
// C^T = (B^T) * (A^T). That is the reason we input args[1] as
// C^T = (B^T) * (A^T). That is the reason we input args[1] as
// A and args[0] as B in calling the rocblas_gemm.
// A and args[0] as B in calling the rocblas_gemm.
rocblas_invoke
(
&
rocblas_gemm_ex
,
ctx
.
get_stream
().
get_rocblas
(),
if
(
compute_fp32
)
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
rocblas_invoke
(
&
rocblas_gemm_ex
,
transa
?
rocblas_operation_transpose
:
rocblas_operation_none
,
ctx
.
get_stream
().
get_rocblas
(),
n
,
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
m
,
transa
?
rocblas_operation_transpose
:
rocblas_operation_none
,
k
,
n
,
&
alpha_r
,
m
,
to_pointer
(
args
.
at
(
1
)),
k
,
arg_type
,
&
alpha
,
ldb
,
to_pointer
(
args
.
at
(
1
)),
to_pointer
(
args
.
at
(
0
)),
arg_type
,
arg_type
,
ldb
,
lda
,
to_pointer
(
args
.
at
(
0
)),
&
beta_r
,
arg_type
,
to_pointer
(
args
[
2
]),
lda
,
output_type
,
&
beta
,
ldc
,
to_pointer
(
args
[
2
]),
is_3inputs
?
to_pointer
(
args
[
3
])
:
to_pointer
(
args
[
2
]),
output_type
,
output_type
,
ldc
,
ldc
,
is_3inputs
?
to_pointer
(
args
[
3
])
:
to_pointer
(
args
[
2
]),
compute_type
,
output_type
,
rocblas_gemm_algo_standard
,
ldc
,
0
,
compute_type
,
flag
);
rocblas_gemm_algo_standard
,
0
,
flag
);
else
rocblas_invoke
(
&
rocblas_gemm_ex
,
ctx
.
get_stream
().
get_rocblas
(),
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
transa
?
rocblas_operation_transpose
:
rocblas_operation_none
,
n
,
m
,
k
,
&
alpha_r
,
to_pointer
(
args
.
at
(
1
)),
arg_type
,
ldb
,
to_pointer
(
args
.
at
(
0
)),
arg_type
,
lda
,
&
beta_r
,
to_pointer
(
args
[
2
]),
output_type
,
ldc
,
is_3inputs
?
to_pointer
(
args
[
3
])
:
to_pointer
(
args
[
2
]),
output_type
,
ldc
,
compute_type
,
rocblas_gemm_algo_standard
,
0
,
flag
);
}
}
else
else
{
{
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex
,
if
(
compute_fp32
)
ctx
.
get_stream
().
get_rocblas
(),
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex
,
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
ctx
.
get_stream
().
get_rocblas
(),
transa
?
rocblas_operation_transpose
:
rocblas_operation_none
,
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
n
,
transa
?
rocblas_operation_transpose
:
rocblas_operation_none
,
m
,
n
,
k
,
m
,
&
alpha_r
,
k
,
to_pointer
(
args
.
at
(
1
)),
&
alpha
,
arg_type
,
to_pointer
(
args
.
at
(
1
)),
ldb
,
arg_type
,
k
*
n
,
ldb
,
to_pointer
(
args
.
at
(
0
)),
k
*
n
,
arg_type
,
to_pointer
(
args
.
at
(
0
)),
lda
,
arg_type
,
m
*
k
,
lda
,
&
beta_r
,
m
*
k
,
to_pointer
(
args
[
2
]),
&
beta
,
output_type
,
to_pointer
(
args
[
2
]),
ldc
,
output_type
,
m
*
n
,
ldc
,
is_3inputs
?
to_pointer
(
args
[
3
])
:
to_pointer
(
args
[
2
]),
m
*
n
,
output_type
,
is_3inputs
?
to_pointer
(
args
[
3
])
:
to_pointer
(
args
[
2
]),
ldc
,
output_type
,
m
*
n
,
ldc
,
num_matrices
,
m
*
n
,
compute_type
,
num_matrices
,
rocblas_gemm_algo_standard
,
compute_type
,
0
,
rocblas_gemm_algo_standard
,
flag
);
0
,
flag
);
else
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex
,
ctx
.
get_stream
().
get_rocblas
(),
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
transa
?
rocblas_operation_transpose
:
rocblas_operation_none
,
n
,
m
,
k
,
&
alpha_r
,
to_pointer
(
args
.
at
(
1
)),
arg_type
,
ldb
,
k
*
n
,
to_pointer
(
args
.
at
(
0
)),
arg_type
,
lda
,
m
*
k
,
&
beta_r
,
to_pointer
(
args
[
2
]),
output_type
,
ldc
,
m
*
n
,
is_3inputs
?
to_pointer
(
args
[
3
])
:
to_pointer
(
args
[
2
]),
output_type
,
ldc
,
m
*
n
,
num_matrices
,
compute_type
,
rocblas_gemm_algo_standard
,
0
,
flag
);
}
}
});
});
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment