Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
3aa465fd
Commit
3aa465fd
authored
Nov 26, 2023
by
Umang Yadav
Browse files
compiles all right
parent
a6c57726
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
159 additions
and
24 deletions
+159
-24
src/targets/gpu/CMakeLists.txt
src/targets/gpu/CMakeLists.txt
+9
-0
src/targets/gpu/gemm_impl.cpp
src/targets/gpu/gemm_impl.cpp
+98
-15
src/targets/gpu/include/migraphx/gpu/rocblas.hpp
src/targets/gpu/include/migraphx/gpu/rocblas.hpp
+2
-0
src/targets/gpu/lowering.cpp
src/targets/gpu/lowering.cpp
+32
-0
src/targets/gpu/rocblas.cpp
src/targets/gpu/rocblas.cpp
+9
-0
test/verify/test_convert.cpp
test/verify/test_convert.cpp
+9
-9
No files found.
src/targets/gpu/CMakeLists.txt
View file @
3aa465fd
...
...
@@ -253,6 +253,8 @@ check_library_exists(MIOpen "miopenHiddenSetConvolutionFindMode" "${MIOPEN_LOCAT
check_library_exists
(
MIOpen
"miopenFindSolutions"
"
${
MIOPEN_LOCATION
}
"
HAS_FIND_2_API
)
# Beta API for automated GEMM tuning
check_library_exists
(
roc::rocblas
"rocblas_gemm_ex_get_solutions"
"
${
ROCBLAS_LOCATION
}
"
HAS_ROCBLAS_TUNING_BETA_FEATURE_API
)
# rocblas FP8 API
check_library_exists
(
roc::rocblas
"rocblas_gemm_ex3"
"
${
ROCBLAS_LOCATION
}
"
HAS_ROCBLAS_FP8_BETA_API
)
set
(
MIGRAPHX_USE_FIND_2_API
"
${
HAS_FIND_2_API
}
"
CACHE BOOL
""
)
...
...
@@ -282,6 +284,13 @@ else()
message
(
STATUS
"rocBLAS does not have User Tuning Beta API"
)
endif
()
if
(
HAS_ROCBLAS_FP8_BETA_API
)
target_compile_definitions
(
migraphx_gpu PUBLIC -DMIGRAPHX_USE_ROCBLAS_FP8_API -DROCBLAS_BETA_FEATURES_API -DROCBLAS_NO_DEPRECATED_WARNINGS
)
message
(
STATUA
"MIGraphX is using BETA API of rocBLAS for FP8 computations"
)
else
()
message
(
STATUS
"rocBLAS does not have FP8 BETA API"
)
endif
()
target_link_libraries
(
migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas
)
target_link_libraries
(
migraphx_gpu PRIVATE migraphx_device migraphx_kernels
)
if
(
MIGRAPHX_USE_COMPOSABLEKERNEL
)
...
...
src/targets/gpu/gemm_impl.cpp
View file @
3aa465fd
...
...
@@ -23,10 +23,12 @@
*/
#include <rocblas/rocblas.h>
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/gemm_impl.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/time.hpp>
#include <type_traits>
using
microseconds
=
std
::
chrono
::
duration
<
double
,
std
::
micro
>
;
...
...
@@ -46,7 +48,7 @@ rocblas_datatype get_type(shape::type_t type)
case
shape
::
uint8_type
:
return
rocblas_datatype_u8_r
;
case
shape
::
int32_type
:
return
rocblas_datatype_i32_r
;
case
shape
::
uint32_type
:
return
rocblas_datatype_u32_r
;
case
shape
::
fp8e4m3fnuz_type
:
case
shape
::
fp8e4m3fnuz_type
:
return
rocblas_datatype_f8_r
;
case
shape
::
tuple_type
:
case
shape
::
bool_type
:
case
shape
::
uint16_type
:
...
...
@@ -217,23 +219,50 @@ struct gemm_impl
void
run
(
context
&
ctx
,
const
std
::
vector
<
argument
>&
input_args
,
int32_t
solution_idx
=
0
)
const
{
if
(
strided_batched
)
if
(
rocblas_fp8_available
()
and
std
::
any_of
(
input_args
.
begin
(),
input_args
.
end
(),
[](
const
auto
i
)
{
return
i
.
get_shape
().
type
()
==
migraphx
::
shape
::
fp8e4m3fnuz_type
;
}))
{
auto
common_args
=
create_strided_batched_args_common
(
ctx
,
input_args
);
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex
,
common_args
,
rocblas_gemm_algo_solution_index
,
solution_idx
,
gemm_flags
);
if
(
strided_batched
)
{
auto
common_args
=
create_strided_batched_args_common_fp8
(
ctx
,
input_args
);
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex3
,
common_args
,
rocblas_gemm_algo_solution_index
,
solution_idx
,
gemm_flags
);
}
else
{
auto
common_args
=
create_gemm_ex_args_common_fp8
(
ctx
,
input_args
);
rocblas_invoke
(
&
rocblas_gemm_ex3
,
common_args
,
rocblas_gemm_algo_solution_index
,
solution_idx
,
gemm_flags
);
}
}
else
{
auto
common_args
=
create_gemm_ex_args_common
(
ctx
,
input_args
);
rocblas_invoke
(
&
rocblas_gemm_ex
,
common_args
,
rocblas_gemm_algo_solution_index
,
solution_idx
,
gemm_flags
);
if
(
strided_batched
)
{
auto
common_args
=
create_strided_batched_args_common
(
ctx
,
input_args
);
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex
,
common_args
,
rocblas_gemm_algo_solution_index
,
solution_idx
,
gemm_flags
);
}
else
{
auto
common_args
=
create_gemm_ex_args_common
(
ctx
,
input_args
);
rocblas_invoke
(
&
rocblas_gemm_ex
,
common_args
,
rocblas_gemm_algo_solution_index
,
solution_idx
,
gemm_flags
);
}
}
}
...
...
@@ -331,6 +360,36 @@ struct gemm_impl
num_matrices
,
compute_type
);
}
auto
create_strided_batched_args_common_fp8
(
context
&
ctx
,
const
std
::
vector
<
argument
>&
args
)
const
{
return
pack
(
ctx
.
get_stream
().
get_rocblas
(),
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
transa
?
rocblas_operation_transpose
:
rocblas_operation_none
,
n
,
m
,
k
,
get_alpha
(),
args
[
1
].
data
(),
arg_type
,
ldb
,
b_stride
,
args
[
0
].
data
(),
arg_type
,
lda
,
a_stride
,
get_beta
(),
args
[
2
].
data
(),
output_type
,
ldc
,
c_stride
,
is_3inputs
?
args
[
3
].
data
()
:
args
[
2
].
data
(),
output_type
,
ldd
,
d_stride
,
num_matrices
,
rocblas_compute_type_f8_f8_f32
);
}
/**
* Helper method to create that subset of a long rocBLAS argument list that is common
...
...
@@ -366,6 +425,30 @@ struct gemm_impl
ldd
,
compute_type
);
}
auto
create_gemm_ex_args_common_fp8
(
context
&
ctx
,
const
std
::
vector
<
argument
>&
args
)
const
{
return
pack
(
ctx
.
get_stream
().
get_rocblas
(),
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
transa
?
rocblas_operation_transpose
:
rocblas_operation_none
,
n
,
m
,
k
,
get_alpha
(),
args
[
1
].
data
(),
arg_type
,
ldb
,
args
[
0
].
data
(),
arg_type
,
lda
,
get_beta
(),
args
[
2
].
data
(),
output_type
,
ldc
,
is_3inputs
?
args
[
3
].
data
()
:
args
[
2
].
data
(),
output_type
,
ldd
,
rocblas_compute_type_f8_f8_f32
);
}
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
/**
* Find best rocBLAS solution: Get list of solutions and try them all, returning the index
...
...
@@ -481,8 +564,8 @@ struct gemm_impl
rocblas_int
b_stride
=
0
;
rocblas_int
c_stride
=
0
;
rocblas_int
d_stride
=
0
;
rocblas_datatype
compute_type
=
rocblas_datatype_f32_r
;
rocblas_datatype
arg_type
=
rocblas_datatype_f32_r
;
rocblas_datatype
compute_type
=
rocblas_datatype_f32_r
;
rocblas_datatype
output_type
=
rocblas_datatype_f32_r
;
bool
strided_batched
=
true
;
bool
is_3inputs
=
true
;
...
...
src/targets/gpu/include/migraphx/gpu/rocblas.hpp
View file @
3aa465fd
...
...
@@ -40,6 +40,8 @@ struct context;
MIGRAPHX_GPU_EXPORT
bool
get_compute_fp32_flag
();
MIGRAPHX_GPU_EXPORT
bool
rocblas_fp8_available
();
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
...
...
src/targets/gpu/lowering.cpp
View file @
3aa465fd
...
...
@@ -220,12 +220,44 @@ struct miopen_apply
return
mod
->
insert_instruction
(
ins
,
make_op
(
"allocate"
,
{{
"shape"
,
to_value
(
s
)}}));
}
instruction_ref
convert_fp8_to_fp32
(
instruction_ref
ins
)
{
std
::
vector
<
instruction_ref
>
fp8_inputs
=
ins
->
inputs
();
std
::
vector
<
instruction_ref
>
fp32_inputs
;
for
(
const
auto
&
i
:
fp8_inputs
)
{
fp32_inputs
.
push_back
(
mod
->
insert_instruction
(
ins
,
migraphx
::
make_op
(
"convert"
,
{{
"target_type"
,
migraphx
::
to_value
(
migraphx
::
shape
::
type_t
::
float_type
)}}),
i
));
}
auto
fp32_ins
=
mod
->
insert_instruction
(
ins
,
ins
->
get_operator
(),
{
fp32_inputs
});
auto
fp8_ins
=
mod
->
insert_instruction
(
ins
,
migraphx
::
make_op
(
"convert"
,
{{
"target_type"
,
migraphx
::
to_value
(
migraphx
::
shape
::
type_t
::
fp8e4m3fnuz_type
)}}),
fp32_ins
);
mod
->
replace_instruction
(
ins
,
fp8_ins
);
return
fp32_ins
;
}
template
<
typename
Op
>
void
add_gemm_op
(
const
std
::
string
&
name
)
{
apply_map
.
emplace
(
name
,
[
=
](
instruction_ref
ins
)
{
std
::
vector
<
instruction_ref
>
refs
=
ins
->
inputs
();
assert
(
refs
.
size
()
==
2
);
if
(
not
rocblas_fp8_available
()
and
std
::
any_of
(
refs
.
begin
(),
refs
.
end
(),
[](
const
auto
i
)
{
return
i
->
get_shape
().
type
()
==
migraphx
::
shape
::
fp8e4m3fnuz_type
;
}))
{
// replace fp8 ins with fp32 ins
ins
=
convert_fp8_to_fp32
(
ins
);
}
auto
output
=
insert_allocation
(
ins
,
ins
->
get_shape
());
refs
.
push_back
(
output
);
return
mod
->
replace_instruction
(
ins
,
rocblas_gemm
<
Op
>
{
Op
{},
1
,
0
,
compute_fp32
},
refs
);
...
...
src/targets/gpu/rocblas.cpp
View file @
3aa465fd
...
...
@@ -53,6 +53,15 @@ bool get_compute_fp32_flag()
return
(
starts_with
(
device_name
,
"gfx9"
)
and
device_name
>=
"gfx908"
);
}
bool
rocblas_fp8_available
()
{
#ifndef MIGRAPHX_USE_ROCBLAS_FP8_API
return
false
;
#else
const
auto
device_name
=
trim
(
split_string
(
get_device_name
(),
':'
).
front
());
return
(
starts_with
(
device_name
,
"gfx9"
)
and
device_name
>=
"gfx940"
);
#endif
}
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
test/verify/test_convert.cpp
View file @
3aa465fd
...
...
@@ -29,26 +29,26 @@
#include <migraphx/make_op.hpp>
struct
test_convert
:
verify_program
<
test_convert
>
template
<
migraphx
::
shape
::
type_t
From
,
migraphx
::
shape
::
type_t
To
>
struct
test_convert
:
verify_program
<
test_convert
<
From
,
To
>>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
sa
{
migraphx
::
shape
::
int8_type
,
{
8
,
24
}};
migraphx
::
shape
sb
{
migraphx
::
shape
::
int8_type
,
{
24
,
6
}};
migraphx
::
shape
sa
{
From
,
{
8
,
24
}};
migraphx
::
shape
sb
{
From
,
{
24
,
6
}};
auto
pa
=
mm
->
add_parameter
(
"a"
,
sa
);
auto
pb
=
mm
->
add_parameter
(
"b"
,
sb
);
auto
ia
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"convert"
,
{{
"target_type"
,
migraphx
::
to_value
(
migraphx
::
shape
::
float_type
)}}),
pa
);
migraphx
::
make_op
(
"convert"
,
{{
"target_type"
,
migraphx
::
to_value
(
To
)}}),
pa
);
auto
ib
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"convert"
,
{{
"target_type"
,
migraphx
::
to_value
(
migraphx
::
shape
::
float_type
)}}),
pb
);
migraphx
::
make_op
(
"convert"
,
{{
"target_type"
,
migraphx
::
to_value
(
To
)}}),
pb
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
ia
,
ib
);
return
p
;
};
};
template
struct
test_convert
<
migraphx
::
shape
::
int8_type
,
migraphx
::
shape
::
float_type
>;
template
struct
test_convert
<
migraphx
::
shape
::
fp8e4m3fnuz_type
,
migraphx
::
shape
::
float_type
>;
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment