Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
82f98478
Commit
82f98478
authored
Dec 05, 2023
by
Umang Yadav
Browse files
add comments
parent
cf91c2b1
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
26 additions
and
14 deletions
+26
-14
src/targets/gpu/gemm_impl.cpp
src/targets/gpu/gemm_impl.cpp
+26
-14
No files found.
src/targets/gpu/gemm_impl.cpp
View file @
82f98478
...
@@ -22,6 +22,7 @@
...
@@ -22,6 +22,7 @@
* THE SOFTWARE.
* THE SOFTWARE.
*/
*/
#include <rocblas/internal/rocblas-types.h>
#include <rocblas/rocblas.h>
#include <rocblas/rocblas.h>
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/gemm_impl.hpp>
#include <migraphx/gpu/gemm_impl.hpp>
...
@@ -36,6 +37,20 @@ namespace migraphx {
...
@@ -36,6 +37,20 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
gpu
{
/*
Regular rocBLAS API takes compute_type as `rocblas_datatype` enum value v/s "ex3" BETA API takes it
as `rocblas_computetype` enum value. `rb_compute_type` is faciliator to implictly cast interger enum
value to required type that can be used inside `common_args` generator.
*/
struct
rb_compute_type
{
int
type
=
0
;
rb_compute_type
(
rocblas_datatype
t
)
:
type
(
static_cast
<
int
>
(
t
))
{}
rb_compute_type
(
rocblas_computetype
t
)
:
type
(
static_cast
<
int
>
(
t
))
{}
operator
rocblas_datatype
()
const
{
return
static_cast
<
rocblas_datatype
>
(
type
);
}
operator
rocblas_computetype
()
const
{
return
static_cast
<
rocblas_computetype
>
(
type
);
}
};
// Convert rocBLAS datatypes to equivalent Migraphx data types
// Convert rocBLAS datatypes to equivalent Migraphx data types
rocblas_datatype
get_type
(
shape
::
type_t
type
)
rocblas_datatype
get_type
(
shape
::
type_t
type
)
{
{
...
@@ -185,12 +200,17 @@ struct gemm_impl
...
@@ -185,12 +200,17 @@ struct gemm_impl
{
{
output_type
=
rocblas_datatype_i32_r
;
output_type
=
rocblas_datatype_i32_r
;
}
}
compute_type
=
output_type
;
compute_type
=
rb_compute_type
{
output_type
}
;
if
(
compute_fp32
)
if
(
compute_fp32
)
{
{
if
(
arg_type
==
rocblas_datatype_f16_r
)
if
(
arg_type
==
rocblas_datatype_f16_r
)
compute_type
=
rocblas_datatype_f32_r
;
compute_type
=
rocblas_datatype_f32_r
;
}
}
else
if
(
arg_type
==
rocblas_datatype_f8_r
)
{
assert
(
get_type
(
input_shapes
[
1
].
type
())
==
rocblas_datatype_f8_r
);
compute_type
=
rocblas_compute_type_f32
;
}
auto
a_lens
=
input_shapes
[
0
].
lens
();
auto
a_lens
=
input_shapes
[
0
].
lens
();
auto
b_lens
=
input_shapes
[
1
].
lens
();
auto
b_lens
=
input_shapes
[
1
].
lens
();
...
@@ -230,7 +250,6 @@ struct gemm_impl
...
@@ -230,7 +250,6 @@ struct gemm_impl
auto
common_args
=
create_strided_batched_args_common
(
ctx
,
input_args
);
auto
common_args
=
create_strided_batched_args_common
(
ctx
,
input_args
);
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex3
,
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex3
,
common_args
,
common_args
,
rocblas_compute_type_f32
,
rocblas_gemm_algo_standard
,
rocblas_gemm_algo_standard
,
solution_idx
,
solution_idx
,
gemm_flags
);
gemm_flags
);
...
@@ -240,7 +259,6 @@ struct gemm_impl
...
@@ -240,7 +259,6 @@ struct gemm_impl
auto
common_args
=
create_gemm_ex_args_common
(
ctx
,
input_args
);
auto
common_args
=
create_gemm_ex_args_common
(
ctx
,
input_args
);
rocblas_invoke
(
&
rocblas_gemm_ex3
,
rocblas_invoke
(
&
rocblas_gemm_ex3
,
common_args
,
common_args
,
rocblas_compute_type_f32
,
rocblas_gemm_algo_standard
,
rocblas_gemm_algo_standard
,
solution_idx
,
solution_idx
,
gemm_flags
);
gemm_flags
);
...
@@ -254,7 +272,6 @@ struct gemm_impl
...
@@ -254,7 +272,6 @@ struct gemm_impl
auto
common_args
=
create_strided_batched_args_common
(
ctx
,
input_args
);
auto
common_args
=
create_strided_batched_args_common
(
ctx
,
input_args
);
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex
,
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex
,
common_args
,
common_args
,
compute_type
,
rocblas_gemm_algo_solution_index
,
rocblas_gemm_algo_solution_index
,
solution_idx
,
solution_idx
,
gemm_flags
);
gemm_flags
);
...
@@ -264,7 +281,6 @@ struct gemm_impl
...
@@ -264,7 +281,6 @@ struct gemm_impl
auto
common_args
=
create_gemm_ex_args_common
(
ctx
,
input_args
);
auto
common_args
=
create_gemm_ex_args_common
(
ctx
,
input_args
);
rocblas_invoke
(
&
rocblas_gemm_ex
,
rocblas_invoke
(
&
rocblas_gemm_ex
,
common_args
,
common_args
,
compute_type
,
rocblas_gemm_algo_solution_index
,
rocblas_gemm_algo_solution_index
,
solution_idx
,
solution_idx
,
gemm_flags
);
gemm_flags
);
...
@@ -304,7 +320,6 @@ struct gemm_impl
...
@@ -304,7 +320,6 @@ struct gemm_impl
auto
common_args
=
create_strided_batched_args_common
(
ctx
,
input_args
);
auto
common_args
=
create_strided_batched_args_common
(
ctx
,
input_args
);
check_valid
=
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex
,
check_valid
=
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex
,
common_args
,
common_args
,
compute_type
,
rocblas_gemm_algo_solution_index
,
rocblas_gemm_algo_solution_index
,
solution_idx
,
solution_idx
,
rocblas_gemm_flags_check_solution_index
);
rocblas_gemm_flags_check_solution_index
);
...
@@ -314,7 +329,6 @@ struct gemm_impl
...
@@ -314,7 +329,6 @@ struct gemm_impl
auto
common_args
=
create_gemm_ex_args_common
(
ctx
,
input_args
);
auto
common_args
=
create_gemm_ex_args_common
(
ctx
,
input_args
);
check_valid
=
rocblas_invoke
(
&
rocblas_gemm_ex
,
check_valid
=
rocblas_invoke
(
&
rocblas_gemm_ex
,
common_args
,
common_args
,
compute_type
,
rocblas_gemm_algo_solution_index
,
rocblas_gemm_algo_solution_index
,
solution_idx
,
solution_idx
,
rocblas_gemm_flags_check_solution_index
);
rocblas_gemm_flags_check_solution_index
);
...
@@ -365,7 +379,8 @@ struct gemm_impl
...
@@ -365,7 +379,8 @@ struct gemm_impl
output_type
,
output_type
,
ldd
,
ldd
,
d_stride
,
d_stride
,
num_matrices
);
num_matrices
,
compute_type
);
}
}
/**
/**
* Helper method to create that subset of a long rocBLAS argument list that is common
* Helper method to create that subset of a long rocBLAS argument list that is common
...
@@ -398,7 +413,8 @@ struct gemm_impl
...
@@ -398,7 +413,8 @@ struct gemm_impl
ldc
,
ldc
,
is_3inputs
?
args
[
3
].
data
()
:
args
[
2
].
data
(),
is_3inputs
?
args
[
3
].
data
()
:
args
[
2
].
data
(),
output_type
,
output_type
,
ldd
);
ldd
,
compute_type
);
}
}
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
...
@@ -428,7 +444,6 @@ struct gemm_impl
...
@@ -428,7 +444,6 @@ struct gemm_impl
auto
common_args
=
create_strided_batched_args_common
(
ctx
,
input_args
);
auto
common_args
=
create_strided_batched_args_common
(
ctx
,
input_args
);
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex_get_solutions
,
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex_get_solutions
,
common_args
,
common_args
,
compute_type
,
rocblas_gemm_algo_solution_index
,
rocblas_gemm_algo_solution_index
,
gemm_flags
,
gemm_flags
,
nullptr
,
nullptr
,
...
@@ -438,7 +453,6 @@ struct gemm_impl
...
@@ -438,7 +453,6 @@ struct gemm_impl
auto
common_sol_args
=
create_strided_batched_args_common
(
ctx
,
input_args
);
auto
common_sol_args
=
create_strided_batched_args_common
(
ctx
,
input_args
);
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex_get_solutions
,
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex_get_solutions
,
common_sol_args
,
common_sol_args
,
compute_type
,
rocblas_gemm_algo_solution_index
,
rocblas_gemm_algo_solution_index
,
gemm_flags
,
gemm_flags
,
solution_indices
.
data
(),
solution_indices
.
data
(),
...
@@ -449,7 +463,6 @@ struct gemm_impl
...
@@ -449,7 +463,6 @@ struct gemm_impl
auto
common_args
=
create_gemm_ex_args_common
(
ctx
,
input_args
);
auto
common_args
=
create_gemm_ex_args_common
(
ctx
,
input_args
);
rocblas_invoke
(
&
rocblas_gemm_ex_get_solutions
,
rocblas_invoke
(
&
rocblas_gemm_ex_get_solutions
,
common_args
,
common_args
,
compute_type
,
rocblas_gemm_algo_solution_index
,
rocblas_gemm_algo_solution_index
,
gemm_flags
,
gemm_flags
,
nullptr
,
nullptr
,
...
@@ -459,7 +472,6 @@ struct gemm_impl
...
@@ -459,7 +472,6 @@ struct gemm_impl
auto
common_sol_args
=
create_gemm_ex_args_common
(
ctx
,
input_args
);
auto
common_sol_args
=
create_gemm_ex_args_common
(
ctx
,
input_args
);
rocblas_invoke
(
&
rocblas_gemm_ex_get_solutions
,
rocblas_invoke
(
&
rocblas_gemm_ex_get_solutions
,
common_sol_args
,
common_sol_args
,
compute_type
,
rocblas_gemm_algo_solution_index
,
rocblas_gemm_algo_solution_index
,
gemm_flags
,
gemm_flags
,
solution_indices
.
data
(),
solution_indices
.
data
(),
...
@@ -521,7 +533,7 @@ struct gemm_impl
...
@@ -521,7 +533,7 @@ struct gemm_impl
rocblas_int
c_stride
=
0
;
rocblas_int
c_stride
=
0
;
rocblas_int
d_stride
=
0
;
rocblas_int
d_stride
=
0
;
rocblas_datatype
arg_type
=
rocblas_datatype_f32_r
;
rocblas_datatype
arg_type
=
rocblas_datatype_f32_r
;
r
ocblas_data
type
compute_type
=
rocblas_datatype_f32_r
;
r
b_compute_
type
compute_type
=
rocblas_datatype_f32_r
;
rocblas_datatype
output_type
=
rocblas_datatype_f32_r
;
rocblas_datatype
output_type
=
rocblas_datatype_f32_r
;
bool
strided_batched
=
true
;
bool
strided_batched
=
true
;
bool
is_3inputs
=
true
;
bool
is_3inputs
=
true
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment