Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
037205c5
Commit
037205c5
authored
Nov 26, 2023
by
Umang Yadav
Browse files
Works now
parent
3aa465fd
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
21 additions
and
12 deletions
+21
-12
src/targets/gpu/CMakeLists.txt
src/targets/gpu/CMakeLists.txt
+2
-2
src/targets/gpu/gemm_impl.cpp
src/targets/gpu/gemm_impl.cpp
+4
-4
test/verify/gemm_2args_bmv.cpp
test/verify/gemm_2args_bmv.cpp
+8
-3
test/verify/gemm_2args_mm_1.cpp
test/verify/gemm_2args_mm_1.cpp
+7
-3
No files found.
src/targets/gpu/CMakeLists.txt
View file @
037205c5
...
@@ -286,9 +286,9 @@ endif()
...
@@ -286,9 +286,9 @@ endif()
if
(
HAS_ROCBLAS_FP8_BETA_API
)
if
(
HAS_ROCBLAS_FP8_BETA_API
)
target_compile_definitions
(
migraphx_gpu PUBLIC -DMIGRAPHX_USE_ROCBLAS_FP8_API -DROCBLAS_BETA_FEATURES_API -DROCBLAS_NO_DEPRECATED_WARNINGS
)
target_compile_definitions
(
migraphx_gpu PUBLIC -DMIGRAPHX_USE_ROCBLAS_FP8_API -DROCBLAS_BETA_FEATURES_API -DROCBLAS_NO_DEPRECATED_WARNINGS
)
message
(
STATU
A
"MIGraphX is using B
ETA
API of rocBLAS for FP8 computations"
)
message
(
STATU
S
"MIGraphX is using B
eta
API of rocBLAS for FP8 computations"
)
else
()
else
()
message
(
STATUS
"rocBLAS does not have F
P
8 B
ETA
API"
)
message
(
STATUS
"rocBLAS does not have F
p
8 B
eta
API"
)
endif
()
endif
()
target_link_libraries
(
migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas
)
target_link_libraries
(
migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas
)
...
...
src/targets/gpu/gemm_impl.cpp
View file @
037205c5
...
@@ -229,7 +229,7 @@ struct gemm_impl
...
@@ -229,7 +229,7 @@ struct gemm_impl
auto
common_args
=
create_strided_batched_args_common_fp8
(
ctx
,
input_args
);
auto
common_args
=
create_strided_batched_args_common_fp8
(
ctx
,
input_args
);
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex3
,
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex3
,
common_args
,
common_args
,
rocblas_gemm_algo_s
olution_index
,
rocblas_gemm_algo_s
tandard
,
solution_idx
,
solution_idx
,
gemm_flags
);
gemm_flags
);
}
}
...
@@ -238,7 +238,7 @@ struct gemm_impl
...
@@ -238,7 +238,7 @@ struct gemm_impl
auto
common_args
=
create_gemm_ex_args_common_fp8
(
ctx
,
input_args
);
auto
common_args
=
create_gemm_ex_args_common_fp8
(
ctx
,
input_args
);
rocblas_invoke
(
&
rocblas_gemm_ex3
,
rocblas_invoke
(
&
rocblas_gemm_ex3
,
common_args
,
common_args
,
rocblas_gemm_algo_s
olution_index
,
rocblas_gemm_algo_s
tandard
,
solution_idx
,
solution_idx
,
gemm_flags
);
gemm_flags
);
}
}
...
@@ -388,7 +388,7 @@ struct gemm_impl
...
@@ -388,7 +388,7 @@ struct gemm_impl
ldd
,
ldd
,
d_stride
,
d_stride
,
num_matrices
,
num_matrices
,
rocblas_compute_type_
f8_f8_
f32
);
rocblas_compute_type_f32
);
}
}
/**
/**
...
@@ -447,7 +447,7 @@ struct gemm_impl
...
@@ -447,7 +447,7 @@ struct gemm_impl
is_3inputs
?
args
[
3
].
data
()
:
args
[
2
].
data
(),
is_3inputs
?
args
[
3
].
data
()
:
args
[
2
].
data
(),
output_type
,
output_type
,
ldd
,
ldd
,
rocblas_compute_type_
f8_f8_
f32
);
rocblas_compute_type_f32
);
}
}
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
/**
/**
...
...
test/verify/gemm_2args_bmv.cpp
View file @
037205c5
...
@@ -27,14 +27,15 @@
...
@@ -27,14 +27,15 @@
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
struct
gemm_2args_bmv
:
verify_program
<
gemm_2args_bmv
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
gemm_2args_bmv
:
verify_program
<
gemm_2args_bmv
<
DType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
float_t
ype
,
{
2
,
3
,
3
,
5
}};
migraphx
::
shape
m1_shape
{
DT
ype
,
{
2
,
3
,
3
,
5
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
float_t
ype
,
{
5
}};
migraphx
::
shape
m2_shape
{
DT
ype
,
{
5
}};
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
l2
=
mm
->
add_parameter
(
"2"
,
m2_shape
);
auto
l2
=
mm
->
add_parameter
(
"2"
,
m2_shape
);
auto
ul2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
1
}}}),
l2
);
auto
ul2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
1
}}}),
l2
);
...
@@ -46,3 +47,7 @@ struct gemm_2args_bmv : verify_program<gemm_2args_bmv>
...
@@ -46,3 +47,7 @@ struct gemm_2args_bmv : verify_program<gemm_2args_bmv>
return
p
;
return
p
;
}
}
};
};
template
struct
gemm_2args_bmv
<
migraphx
::
shape
::
float_type
>;
template
struct
gemm_2args_bmv
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
test/verify/gemm_2args_mm_1.cpp
View file @
037205c5
...
@@ -27,14 +27,15 @@
...
@@ -27,14 +27,15 @@
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
struct
gemm_2args_mm_1
:
verify_program
<
gemm_2args_mm_1
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
gemm_2args_mm_1
:
verify_program
<
gemm_2args_mm_1
<
DType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
float_t
ype
,
{
2
,
2
,
3
}};
migraphx
::
shape
m1_shape
{
DT
ype
,
{
2
,
2
,
3
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
float_t
ype
,
{
1
,
3
,
4
}};
migraphx
::
shape
m2_shape
{
DT
ype
,
{
1
,
3
,
4
}};
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
l2
=
mm
->
add_parameter
(
"2"
,
m2_shape
);
auto
l2
=
mm
->
add_parameter
(
"2"
,
m2_shape
);
auto
bl2
=
auto
bl2
=
...
@@ -45,3 +46,6 @@ struct gemm_2args_mm_1 : verify_program<gemm_2args_mm_1>
...
@@ -45,3 +46,6 @@ struct gemm_2args_mm_1 : verify_program<gemm_2args_mm_1>
return
p
;
return
p
;
}
}
};
};
template
struct
gemm_2args_mm_1
<
migraphx
::
shape
::
float_type
>;
template
struct
gemm_2args_mm_1
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment