Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
6d0b6bcf
"src/include/vscode:/vscode.git/clone" did not exist on "99ee76c0e3aaad335b56b5c3a404feffd250f9df"
Unverified
Commit
6d0b6bcf
authored
Dec 05, 2023
by
Umang Yadav
Committed by
GitHub
Dec 05, 2023
Browse files
Add FP8 rocblas gemm support (#2473)
parent
e3e00547
Changes
48
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
272 additions
and
80 deletions
+272
-80
src/eliminate_data_type.cpp
src/eliminate_data_type.cpp
+72
-20
src/include/migraphx/eliminate_data_type.hpp
src/include/migraphx/eliminate_data_type.hpp
+2
-1
src/targets/gpu/CMakeLists.txt
src/targets/gpu/CMakeLists.txt
+9
-0
src/targets/gpu/gemm_impl.cpp
src/targets/gpu/gemm_impl.cpp
+68
-17
src/targets/gpu/include/migraphx/gpu/rocblas.hpp
src/targets/gpu/include/migraphx/gpu/rocblas.hpp
+2
-0
src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp
src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp
+2
-6
src/targets/gpu/rocblas.cpp
src/targets/gpu/rocblas.cpp
+10
-0
src/targets/gpu/target.cpp
src/targets/gpu/target.cpp
+7
-0
test/verify/gemm_2args_bmv.cpp
test/verify/gemm_2args_bmv.cpp
+8
-3
test/verify/gemm_2args_mm_1.cpp
test/verify/gemm_2args_mm_1.cpp
+8
-3
test/verify/gemm_2args_mm_2.cpp
test/verify/gemm_2args_mm_2.cpp
+9
-3
test/verify/gemm_2args_mm_3.cpp
test/verify/gemm_2args_mm_3.cpp
+9
-3
test/verify/gemm_2args_mm_4.cpp
test/verify/gemm_2args_mm_4.cpp
+9
-3
test/verify/gemm_2args_mm_5.cpp
test/verify/gemm_2args_mm_5.cpp
+8
-3
test/verify/gemm_2args_mm_6.cpp
test/verify/gemm_2args_mm_6.cpp
+9
-3
test/verify/gemm_2args_mm_7.cpp
test/verify/gemm_2args_mm_7.cpp
+8
-3
test/verify/gemm_2args_mm_8.cpp
test/verify/gemm_2args_mm_8.cpp
+8
-3
test/verify/gemm_2args_mv.cpp
test/verify/gemm_2args_mv.cpp
+8
-3
test/verify/gemm_2args_vbm.cpp
test/verify/gemm_2args_vbm.cpp
+8
-3
test/verify/gemm_2args_vm.cpp
test/verify/gemm_2args_vm.cpp
+8
-3
No files found.
src/eliminate_data_type.cpp
View file @
6d0b6bcf
...
@@ -31,6 +31,72 @@
...
@@ -31,6 +31,72 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
void
insert_convert_to_supported_type
(
module
&
m
,
instruction_ref
ins
,
migraphx
::
shape
::
type_t
target_type
,
std
::
set
<
migraphx
::
shape
::
type_t
>
unsupported_types
)
{
migraphx
::
shape
::
type_t
orig_type
=
ins
->
get_shape
().
type
();
std
::
vector
<
instruction_ref
>
inputs
=
ins
->
inputs
();
std
::
transform
(
inputs
.
begin
(),
inputs
.
end
(),
inputs
.
begin
(),
[
&
](
const
auto
&
i
)
{
if
(
contains
(
unsupported_types
,
i
->
get_shape
().
type
()))
{
return
m
.
insert_instruction
(
ins
,
migraphx
::
make_op
(
"convert"
,
{{
"target_type"
,
migraphx
::
to_value
(
target_type
)}}),
i
);
}
else
{
return
i
;
}
});
// if no change
if
(
inputs
==
ins
->
inputs
())
return
;
auto
op
=
ins
->
get_operator
();
auto
attributes
=
op
.
attributes
();
if
(
attributes
.
contains
(
"general_data_type"
))
{
op
=
make_op
(
attributes
[
"general_data_type"
].
to
<
std
::
string
>
(),
op
.
to_value
());
}
auto
new_ins
=
m
.
insert_instruction
(
ins
,
op
,
inputs
);
if
(
orig_type
==
shape
::
tuple_type
)
{
auto
orig_outs
=
ins
->
outputs
();
if
(
not
std
::
all_of
(
orig_outs
.
begin
(),
orig_outs
.
end
(),
[
&
](
const
auto
out_ins
)
{
return
out_ins
->
name
()
==
"get_tuple_elem"
;
}))
MIGRAPHX_THROW
(
"eliminate_data_type: Instruction with tuple output doesn't have all its "
"usages as get_tuple_elem instruction"
);
std
::
transform
(
orig_outs
.
begin
(),
orig_outs
.
end
(),
orig_outs
.
begin
(),
[
&
](
const
auto
out_ins
)
{
auto
gte_ins
=
m
.
insert_instruction
(
ins
,
out_ins
->
get_operator
(),
new_ins
);
auto
orig_out_type
=
out_ins
->
get_shape
().
type
();
if
(
contains
(
unsupported_types
,
orig_out_type
))
{
auto
gte_convert
=
m
.
insert_instruction
(
ins
,
make_op
(
"convert"
,
{{
"target_type"
,
orig_out_type
}}),
gte_ins
);
return
m
.
replace_instruction
(
out_ins
,
gte_convert
);
}
else
{
return
m
.
replace_instruction
(
out_ins
,
gte_ins
);
}
});
}
else
{
auto
convert_back_ins
=
m
.
insert_instruction
(
ins
,
migraphx
::
make_op
(
"convert"
,
{{
"target_type"
,
migraphx
::
to_value
(
orig_type
)}}),
new_ins
);
m
.
replace_instruction
(
ins
,
convert_back_ins
);
}
}
void
eliminate_data_type
::
apply
(
module
&
m
)
const
void
eliminate_data_type
::
apply
(
module
&
m
)
const
{
{
static
const
std
::
vector
<
std
::
string
>
skip_op_names
=
{
"convert"
,
static
const
std
::
vector
<
std
::
string
>
skip_op_names
=
{
"convert"
,
...
@@ -42,31 +108,17 @@ void eliminate_data_type::apply(module& m) const
...
@@ -42,31 +108,17 @@ void eliminate_data_type::apply(module& m) const
"scatternd_add"
,
"scatternd_add"
,
"scatternd_mul"
,
"scatternd_mul"
,
"scatternd_none"
};
"scatternd_none"
};
if
(
unsupported_types
.
empty
())
return
;
for
(
auto
ins
:
iterator_for
(
m
))
for
(
auto
ins
:
iterator_for
(
m
))
{
{
if
(
ins
->
name
()[
0
]
==
'@'
)
if
(
ins
->
name
()[
0
]
==
'@'
)
continue
;
continue
;
if
(
contains
(
skip_op_names
,
ins
->
name
()))
if
(
contains
(
skip_op_names
,
ins
->
name
())
and
not
contains
(
unsupported_ops
,
ins
->
name
()))
continue
;
auto
inputs
=
ins
->
inputs
();
std
::
transform
(
inputs
.
begin
(),
inputs
.
end
(),
inputs
.
begin
(),
[
&
](
auto
i
)
{
if
(
types
.
count
(
i
->
get_shape
().
type
())
==
0
)
return
i
;
return
m
.
insert_instruction
(
ins
,
make_op
(
"convert"
,
{{
"target_type"
,
target_type
}}),
i
);
});
if
(
inputs
==
ins
->
inputs
())
continue
;
continue
;
auto
op
=
ins
->
get_operator
();
if
(
contains
(
unsupported_ops
,
"all"
)
or
contains
(
unsupported_ops
,
ins
->
name
()))
auto
attributes
=
op
.
attributes
();
insert_convert_to_supported_type
(
m
,
ins
,
target_type
,
unsupported_types
);
if
(
attributes
.
contains
(
"general_data_type"
))
{
op
=
make_op
(
attributes
[
"general_data_type"
].
to
<
std
::
string
>
(),
op
.
to_value
());
}
auto
old_type
=
ins
->
get_shape
().
type
();
auto
out
=
m
.
insert_instruction
(
ins
,
op
,
inputs
);
auto
convert
=
m
.
insert_instruction
(
ins
,
make_op
(
"convert"
,
{{
"target_type"
,
old_type
}}),
out
);
m
.
replace_instruction
(
ins
,
convert
);
}
}
}
}
...
...
src/include/migraphx/eliminate_data_type.hpp
View file @
6d0b6bcf
...
@@ -40,8 +40,9 @@ struct module;
...
@@ -40,8 +40,9 @@ struct module;
*/
*/
struct
MIGRAPHX_EXPORT
eliminate_data_type
struct
MIGRAPHX_EXPORT
eliminate_data_type
{
{
std
::
set
<
shape
::
type_t
>
types
;
std
::
set
<
shape
::
type_t
>
unsupported_
types
;
shape
::
type_t
target_type
;
shape
::
type_t
target_type
;
std
::
set
<
std
::
string
>
unsupported_ops
=
{
"all"
};
std
::
string
name
()
const
{
return
"eliminate_data_type"
;
}
std
::
string
name
()
const
{
return
"eliminate_data_type"
;
}
void
apply
(
module
&
m
)
const
;
void
apply
(
module
&
m
)
const
;
};
};
...
...
src/targets/gpu/CMakeLists.txt
View file @
6d0b6bcf
...
@@ -259,6 +259,8 @@ check_library_exists(MIOpen "miopenHiddenSetConvolutionFindMode" "${MIOPEN_LOCAT
...
@@ -259,6 +259,8 @@ check_library_exists(MIOpen "miopenHiddenSetConvolutionFindMode" "${MIOPEN_LOCAT
check_library_exists
(
MIOpen
"miopenFindSolutions"
"
${
MIOPEN_LOCATION
}
"
HAS_FIND_2_API
)
check_library_exists
(
MIOpen
"miopenFindSolutions"
"
${
MIOPEN_LOCATION
}
"
HAS_FIND_2_API
)
# Beta API for automated GEMM tuning
# Beta API for automated GEMM tuning
check_library_exists
(
roc::rocblas
"rocblas_gemm_ex_get_solutions"
"
${
ROCBLAS_LOCATION
}
"
HAS_ROCBLAS_TUNING_BETA_FEATURE_API
)
check_library_exists
(
roc::rocblas
"rocblas_gemm_ex_get_solutions"
"
${
ROCBLAS_LOCATION
}
"
HAS_ROCBLAS_TUNING_BETA_FEATURE_API
)
# rocblas FP8 API
check_library_exists
(
roc::rocblas
"rocblas_gemm_strided_batched_ex3"
"
${
ROCBLAS_LOCATION
}
"
HAS_ROCBLAS_FP8_BETA_API
)
set
(
MIGRAPHX_USE_FIND_2_API
"
${
HAS_FIND_2_API
}
"
CACHE BOOL
""
)
set
(
MIGRAPHX_USE_FIND_2_API
"
${
HAS_FIND_2_API
}
"
CACHE BOOL
""
)
...
@@ -288,6 +290,13 @@ else()
...
@@ -288,6 +290,13 @@ else()
message
(
STATUS
"rocBLAS does not have User Tuning Beta API"
)
message
(
STATUS
"rocBLAS does not have User Tuning Beta API"
)
endif
()
endif
()
if
(
HAS_ROCBLAS_FP8_BETA_API
)
target_compile_definitions
(
migraphx_gpu PUBLIC -DMIGRAPHX_USE_ROCBLAS_FP8_API -DROCBLAS_BETA_FEATURES_API -DROCBLAS_NO_DEPRECATED_WARNINGS
)
message
(
STATUS
"MIGraphX is using Beta API of rocBLAS for FP8 computations"
)
else
()
message
(
STATUS
"rocBLAS does not have Fp8 Beta API"
)
endif
()
target_link_libraries
(
migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas
)
target_link_libraries
(
migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas
)
target_link_libraries
(
migraphx_gpu PRIVATE migraphx_device migraphx_kernels
)
target_link_libraries
(
migraphx_gpu PRIVATE migraphx_device migraphx_kernels
)
if
(
MIGRAPHX_USE_COMPOSABLEKERNEL
)
if
(
MIGRAPHX_USE_COMPOSABLEKERNEL
)
...
...
src/targets/gpu/gemm_impl.cpp
View file @
6d0b6bcf
...
@@ -22,11 +22,14 @@
...
@@ -22,11 +22,14 @@
* THE SOFTWARE.
* THE SOFTWARE.
*/
*/
#include <rocblas/internal/rocblas-types.h>
#include <rocblas/rocblas.h>
#include <rocblas/rocblas.h>
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/gemm_impl.hpp>
#include <migraphx/gpu/gemm_impl.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/time.hpp>
#include <migraphx/time.hpp>
#include <type_traits>
using
microseconds
=
std
::
chrono
::
duration
<
double
,
std
::
micro
>
;
using
microseconds
=
std
::
chrono
::
duration
<
double
,
std
::
micro
>
;
...
@@ -34,6 +37,20 @@ namespace migraphx {
...
@@ -34,6 +37,20 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
gpu
{
/*
Regular rocBLAS API takes compute_type as `rocblas_datatype` enum value v/s "ex3" BETA API takes it
as `rocblas_computetype` enum value. `rb_compute_type` is faciliator to implictly cast integer enum
value to required type that can be used inside `common_args` generator.
*/
struct
rb_compute_type
{
int
type
=
0
;
rb_compute_type
(
rocblas_datatype
t
)
:
type
(
static_cast
<
int
>
(
t
))
{}
rb_compute_type
(
rocblas_computetype
t
)
:
type
(
static_cast
<
int
>
(
t
))
{}
operator
rocblas_datatype
()
const
{
return
static_cast
<
rocblas_datatype
>
(
type
);
}
operator
rocblas_computetype
()
const
{
return
static_cast
<
rocblas_computetype
>
(
type
);
}
};
// Convert rocBLAS datatypes to equivalent Migraphx data types
// Convert rocBLAS datatypes to equivalent Migraphx data types
rocblas_datatype
get_type
(
shape
::
type_t
type
)
rocblas_datatype
get_type
(
shape
::
type_t
type
)
{
{
...
@@ -46,7 +63,7 @@ rocblas_datatype get_type(shape::type_t type)
...
@@ -46,7 +63,7 @@ rocblas_datatype get_type(shape::type_t type)
case
shape
::
uint8_type
:
return
rocblas_datatype_u8_r
;
case
shape
::
uint8_type
:
return
rocblas_datatype_u8_r
;
case
shape
::
int32_type
:
return
rocblas_datatype_i32_r
;
case
shape
::
int32_type
:
return
rocblas_datatype_i32_r
;
case
shape
::
uint32_type
:
return
rocblas_datatype_u32_r
;
case
shape
::
uint32_type
:
return
rocblas_datatype_u32_r
;
case
shape
::
fp8e4m3fnuz_type
:
case
shape
::
fp8e4m3fnuz_type
:
return
rocblas_datatype_f8_r
;
case
shape
::
tuple_type
:
case
shape
::
tuple_type
:
case
shape
::
bool_type
:
case
shape
::
bool_type
:
case
shape
::
uint16_type
:
case
shape
::
uint16_type
:
...
@@ -183,12 +200,17 @@ struct gemm_impl
...
@@ -183,12 +200,17 @@ struct gemm_impl
{
{
output_type
=
rocblas_datatype_i32_r
;
output_type
=
rocblas_datatype_i32_r
;
}
}
compute_type
=
output_type
;
compute_type
=
rb_compute_type
{
output_type
}
;
if
(
compute_fp32
)
if
(
compute_fp32
)
{
{
if
(
arg_type
==
rocblas_datatype_f16_r
)
if
(
arg_type
==
rocblas_datatype_f16_r
)
compute_type
=
rocblas_datatype_f32_r
;
compute_type
=
rocblas_datatype_f32_r
;
}
}
if
(
arg_type
==
rocblas_datatype_f8_r
)
{
assert
(
get_type
(
input_shapes
[
1
].
type
())
==
rocblas_datatype_f8_r
);
compute_type
=
rocblas_compute_type_f32
;
}
auto
a_lens
=
input_shapes
[
0
].
lens
();
auto
a_lens
=
input_shapes
[
0
].
lens
();
auto
b_lens
=
input_shapes
[
1
].
lens
();
auto
b_lens
=
input_shapes
[
1
].
lens
();
...
@@ -216,6 +238,34 @@ struct gemm_impl
...
@@ -216,6 +238,34 @@ struct gemm_impl
}
}
void
run
(
context
&
ctx
,
const
std
::
vector
<
argument
>&
input_args
,
int32_t
solution_idx
=
0
)
const
void
run
(
context
&
ctx
,
const
std
::
vector
<
argument
>&
input_args
,
int32_t
solution_idx
=
0
)
const
{
#ifdef MIGRAPHX_USE_ROCBLAS_FP8_API
if
(
rocblas_fp8_available
()
and
std
::
any_of
(
input_args
.
begin
(),
input_args
.
end
(),
[](
const
auto
i
)
{
return
i
.
get_shape
().
type
()
==
migraphx
::
shape
::
fp8e4m3fnuz_type
;
}))
{
if
(
strided_batched
)
{
auto
common_args
=
create_strided_batched_args_common
(
ctx
,
input_args
);
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex3
,
common_args
,
rocblas_gemm_algo_standard
,
solution_idx
,
gemm_flags
);
}
else
{
auto
common_args
=
create_gemm_ex_args_common
(
ctx
,
input_args
);
rocblas_invoke
(
&
rocblas_gemm_ex3
,
common_args
,
rocblas_gemm_algo_standard
,
solution_idx
,
gemm_flags
);
}
}
else
#endif
{
{
if
(
strided_batched
)
if
(
strided_batched
)
{
{
...
@@ -236,6 +286,7 @@ struct gemm_impl
...
@@ -236,6 +286,7 @@ struct gemm_impl
gemm_flags
);
gemm_flags
);
}
}
}
}
}
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
auto
validate
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
input_shapes
,
int32_t
solution_idx
)
const
auto
validate
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
input_shapes
,
int32_t
solution_idx
)
const
...
@@ -331,7 +382,6 @@ struct gemm_impl
...
@@ -331,7 +382,6 @@ struct gemm_impl
num_matrices
,
num_matrices
,
compute_type
);
compute_type
);
}
}
/**
/**
* Helper method to create that subset of a long rocBLAS argument list that is common
* Helper method to create that subset of a long rocBLAS argument list that is common
* to multiple "gemm_ex..." calls.
* to multiple "gemm_ex..." calls.
...
@@ -366,6 +416,7 @@ struct gemm_impl
...
@@ -366,6 +416,7 @@ struct gemm_impl
ldd
,
ldd
,
compute_type
);
compute_type
);
}
}
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
/**
/**
* Find best rocBLAS solution: Get list of solutions and try them all, returning the index
* Find best rocBLAS solution: Get list of solutions and try them all, returning the index
...
@@ -481,8 +532,8 @@ struct gemm_impl
...
@@ -481,8 +532,8 @@ struct gemm_impl
rocblas_int
b_stride
=
0
;
rocblas_int
b_stride
=
0
;
rocblas_int
c_stride
=
0
;
rocblas_int
c_stride
=
0
;
rocblas_int
d_stride
=
0
;
rocblas_int
d_stride
=
0
;
rocblas_datatype
compute_type
=
rocblas_datatype_f32_r
;
rocblas_datatype
arg_type
=
rocblas_datatype_f32_r
;
rocblas_datatype
arg_type
=
rocblas_datatype_f32_r
;
rb_compute_type
compute_type
=
rocblas_datatype_f32_r
;
rocblas_datatype
output_type
=
rocblas_datatype_f32_r
;
rocblas_datatype
output_type
=
rocblas_datatype_f32_r
;
bool
strided_batched
=
true
;
bool
strided_batched
=
true
;
bool
is_3inputs
=
true
;
bool
is_3inputs
=
true
;
...
...
src/targets/gpu/include/migraphx/gpu/rocblas.hpp
View file @
6d0b6bcf
...
@@ -40,6 +40,8 @@ struct context;
...
@@ -40,6 +40,8 @@ struct context;
MIGRAPHX_GPU_EXPORT
bool
get_compute_fp32_flag
();
MIGRAPHX_GPU_EXPORT
bool
get_compute_fp32_flag
();
MIGRAPHX_GPU_EXPORT
bool
rocblas_fp8_available
();
}
// namespace gpu
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
...
...
src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp
View file @
6d0b6bcf
...
@@ -501,9 +501,7 @@ class numeric_limits<fp8e5m2fnuz>
...
@@ -501,9 +501,7 @@ class numeric_limits<fp8e5m2fnuz>
{
{
return
fp8e5m2fnuz
(
0x7F
,
fp8e5m2fnuz
::
from_bits
());
return
fp8e5m2fnuz
(
0x7F
,
fp8e5m2fnuz
::
from_bits
());
}
}
// this is min value that is not DeNormalized(DeNorm). DeNorm min is 0x01. I am not sure if we
// this is min value that is not DeNormalized(DeNorm). DeNorm min is 0x01.
// want to make this distinction. For the floating points we would end up using lowest most of
// the times.
static
constexpr
__device__
fp8e5m2fnuz
min
()
static
constexpr
__device__
fp8e5m2fnuz
min
()
{
{
return
fp8e5m2fnuz
(
0x4
,
fp8e5m2fnuz
::
from_bits
());
return
fp8e5m2fnuz
(
0x4
,
fp8e5m2fnuz
::
from_bits
());
...
@@ -528,9 +526,7 @@ class numeric_limits<fp8e5m2>
...
@@ -528,9 +526,7 @@ class numeric_limits<fp8e5m2>
}
}
static
constexpr
__device__
fp8e5m2
max
()
{
return
fp8e5m2
(
0x7B
,
fp8e5m2
::
from_bits
());
}
static
constexpr
__device__
fp8e5m2
max
()
{
return
fp8e5m2
(
0x7B
,
fp8e5m2
::
from_bits
());
}
// this is min value that is not DeNormalized(DeNorm). DeNorm min is 0x01. I am not sure if we
// this is min value that is not DeNormalized(DeNorm). DeNorm min is 0x01.
// want to make this distinction. For the floating points we would end up using lowest most of
// the times.
static
constexpr
__device__
fp8e5m2
min
()
{
return
fp8e5m2
(
0x4
,
fp8e5m2
::
from_bits
());
}
static
constexpr
__device__
fp8e5m2
min
()
{
return
fp8e5m2
(
0x4
,
fp8e5m2
::
from_bits
());
}
static
constexpr
__device__
fp8e5m2
lowest
()
{
return
fp8e5m2
(
0xFB
,
fp8e5m2
::
from_bits
());
}
static
constexpr
__device__
fp8e5m2
lowest
()
{
return
fp8e5m2
(
0xFB
,
fp8e5m2
::
from_bits
());
}
...
...
src/targets/gpu/rocblas.cpp
View file @
6d0b6bcf
...
@@ -53,6 +53,16 @@ bool get_compute_fp32_flag()
...
@@ -53,6 +53,16 @@ bool get_compute_fp32_flag()
return
(
starts_with
(
device_name
,
"gfx9"
)
and
device_name
>=
"gfx908"
);
return
(
starts_with
(
device_name
,
"gfx9"
)
and
device_name
>=
"gfx908"
);
}
}
bool
rocblas_fp8_available
()
{
#ifndef MIGRAPHX_USE_ROCBLAS_FP8_API
return
false
;
#else
const
auto
device_name
=
trim
(
split_string
(
get_device_name
(),
':'
).
front
());
return
(
starts_with
(
device_name
,
"gfx9"
)
and
device_name
>=
"gfx940"
);
#endif
}
}
// namespace gpu
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
src/targets/gpu/target.cpp
View file @
6d0b6bcf
...
@@ -105,6 +105,11 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
...
@@ -105,6 +105,11 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
unsupported_types
.
erase
(
shape
::
type_t
::
uint8_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
uint8_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
int32_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
int32_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
tuple_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
tuple_type
);
std
::
set
<
std
::
string
>
unsupported_fp8_ops
=
{};
if
(
not
gpu
::
rocblas_fp8_available
())
{
unsupported_fp8_ops
.
insert
(
"dot"
);
}
// clang-format off
// clang-format off
return
return
{
{
...
@@ -136,6 +141,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
...
@@ -136,6 +141,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
prefuse_ops
{},
prefuse_ops
{},
dead_code_elimination
{},
dead_code_elimination
{},
auto_contiguous
{},
auto_contiguous
{},
eliminate_data_type
{{
migraphx
::
shape
::
fp8e4m3fnuz_type
},
shape
::
float_type
,
unsupported_fp8_ops
},
dead_code_elimination
{},
optimize_module
{},
optimize_module
{},
fuse_pointwise
{},
fuse_pointwise
{},
dead_code_elimination
{},
dead_code_elimination
{},
...
...
test/verify/gemm_2args_bmv.cpp
View file @
6d0b6bcf
...
@@ -27,14 +27,15 @@
...
@@ -27,14 +27,15 @@
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
struct
gemm_2args_bmv
:
verify_program
<
gemm_2args_bmv
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
gemm_2args_bmv
:
verify_program
<
gemm_2args_bmv
<
DType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
float_t
ype
,
{
2
,
3
,
3
,
5
}};
migraphx
::
shape
m1_shape
{
DT
ype
,
{
2
,
3
,
3
,
5
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
float_t
ype
,
{
5
}};
migraphx
::
shape
m2_shape
{
DT
ype
,
{
5
}};
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
l2
=
mm
->
add_parameter
(
"2"
,
m2_shape
);
auto
l2
=
mm
->
add_parameter
(
"2"
,
m2_shape
);
auto
ul2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
1
}}}),
l2
);
auto
ul2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
1
}}}),
l2
);
...
@@ -46,3 +47,7 @@ struct gemm_2args_bmv : verify_program<gemm_2args_bmv>
...
@@ -46,3 +47,7 @@ struct gemm_2args_bmv : verify_program<gemm_2args_bmv>
return
p
;
return
p
;
}
}
};
};
template
struct
gemm_2args_bmv
<
migraphx
::
shape
::
float_type
>;
template
struct
gemm_2args_bmv
<
migraphx
::
shape
::
half_type
>;
template
struct
gemm_2args_bmv
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
test/verify/gemm_2args_mm_1.cpp
View file @
6d0b6bcf
...
@@ -27,14 +27,15 @@
...
@@ -27,14 +27,15 @@
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
struct
gemm_2args_mm_1
:
verify_program
<
gemm_2args_mm_1
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
gemm_2args_mm_1
:
verify_program
<
gemm_2args_mm_1
<
DType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
float_t
ype
,
{
2
,
2
,
3
}};
migraphx
::
shape
m1_shape
{
DT
ype
,
{
2
,
2
,
3
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
float_t
ype
,
{
1
,
3
,
4
}};
migraphx
::
shape
m2_shape
{
DT
ype
,
{
1
,
3
,
4
}};
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
l2
=
mm
->
add_parameter
(
"2"
,
m2_shape
);
auto
l2
=
mm
->
add_parameter
(
"2"
,
m2_shape
);
auto
bl2
=
auto
bl2
=
...
@@ -45,3 +46,7 @@ struct gemm_2args_mm_1 : verify_program<gemm_2args_mm_1>
...
@@ -45,3 +46,7 @@ struct gemm_2args_mm_1 : verify_program<gemm_2args_mm_1>
return
p
;
return
p
;
}
}
};
};
template
struct
gemm_2args_mm_1
<
migraphx
::
shape
::
float_type
>;
template
struct
gemm_2args_mm_1
<
migraphx
::
shape
::
half_type
>;
template
struct
gemm_2args_mm_1
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
test/verify/gemm_2args_mm_2.cpp
View file @
6d0b6bcf
...
@@ -24,17 +24,19 @@
...
@@ -24,17 +24,19 @@
#include "verify_program.hpp"
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/program.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
struct
gemm_2args_mm_2
:
verify_program
<
gemm_2args_mm_2
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
gemm_2args_mm_2
:
verify_program
<
gemm_2args_mm_2
<
DType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
float_t
ype
,
{
2
,
2
,
3
}};
migraphx
::
shape
m1_shape
{
DT
ype
,
{
2
,
2
,
3
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
float_t
ype
,
{
3
,
4
}};
migraphx
::
shape
m2_shape
{
DT
ype
,
{
3
,
4
}};
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
l2
=
mm
->
add_parameter
(
"2"
,
m2_shape
);
auto
l2
=
mm
->
add_parameter
(
"2"
,
m2_shape
);
auto
bl2
=
auto
bl2
=
...
@@ -45,3 +47,7 @@ struct gemm_2args_mm_2 : verify_program<gemm_2args_mm_2>
...
@@ -45,3 +47,7 @@ struct gemm_2args_mm_2 : verify_program<gemm_2args_mm_2>
return
p
;
return
p
;
}
}
};
};
template
struct
gemm_2args_mm_2
<
migraphx
::
shape
::
float_type
>;
template
struct
gemm_2args_mm_2
<
migraphx
::
shape
::
half_type
>;
template
struct
gemm_2args_mm_2
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
test/verify/gemm_2args_mm_3.cpp
View file @
6d0b6bcf
...
@@ -24,17 +24,19 @@
...
@@ -24,17 +24,19 @@
#include "verify_program.hpp"
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/program.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
struct
gemm_2args_mm_3
:
verify_program
<
gemm_2args_mm_3
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
gemm_2args_mm_3
:
verify_program
<
gemm_2args_mm_3
<
DType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
float_t
ype
,
{
1
,
2
,
3
}};
migraphx
::
shape
m1_shape
{
DT
ype
,
{
1
,
2
,
3
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
float_t
ype
,
{
3
,
3
,
4
}};
migraphx
::
shape
m2_shape
{
DT
ype
,
{
3
,
3
,
4
}};
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
bl1
=
auto
bl1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
3
,
2
,
3
}}}),
l1
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
3
,
2
,
3
}}}),
l1
);
...
@@ -45,3 +47,7 @@ struct gemm_2args_mm_3 : verify_program<gemm_2args_mm_3>
...
@@ -45,3 +47,7 @@ struct gemm_2args_mm_3 : verify_program<gemm_2args_mm_3>
return
p
;
return
p
;
}
}
};
};
template
struct
gemm_2args_mm_3
<
migraphx
::
shape
::
float_type
>;
template
struct
gemm_2args_mm_3
<
migraphx
::
shape
::
half_type
>;
template
struct
gemm_2args_mm_3
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
test/verify/gemm_2args_mm_4.cpp
View file @
6d0b6bcf
...
@@ -23,18 +23,20 @@
...
@@ -23,18 +23,20 @@
*/
*/
#include "verify_program.hpp"
#include "verify_program.hpp"
#include <migraphx/shape.hpp>
#include <migraphx/program.hpp>
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
struct
gemm_2args_mm_4
:
verify_program
<
gemm_2args_mm_4
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
gemm_2args_mm_4
:
verify_program
<
gemm_2args_mm_4
<
DType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
float_t
ype
,
{
2
,
3
}};
migraphx
::
shape
m1_shape
{
DT
ype
,
{
2
,
3
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
float_t
ype
,
{
3
,
3
,
4
}};
migraphx
::
shape
m2_shape
{
DT
ype
,
{
3
,
3
,
4
}};
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
bl1
=
auto
bl1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
3
,
2
,
3
}}}),
l1
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
3
,
2
,
3
}}}),
l1
);
...
@@ -45,3 +47,7 @@ struct gemm_2args_mm_4 : verify_program<gemm_2args_mm_4>
...
@@ -45,3 +47,7 @@ struct gemm_2args_mm_4 : verify_program<gemm_2args_mm_4>
return
p
;
return
p
;
}
}
};
};
template
struct
gemm_2args_mm_4
<
migraphx
::
shape
::
float_type
>;
template
struct
gemm_2args_mm_4
<
migraphx
::
shape
::
half_type
>;
template
struct
gemm_2args_mm_4
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
test/verify/gemm_2args_mm_5.cpp
View file @
6d0b6bcf
...
@@ -27,14 +27,15 @@
...
@@ -27,14 +27,15 @@
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
struct
gemm_2args_mm_5
:
verify_program
<
gemm_2args_mm_5
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
gemm_2args_mm_5
:
verify_program
<
gemm_2args_mm_5
<
DType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
float_t
ype
,
{
2
,
1
,
2
,
3
}};
migraphx
::
shape
m1_shape
{
DT
ype
,
{
2
,
1
,
2
,
3
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
float_t
ype
,
{
2
,
3
,
3
,
4
}};
migraphx
::
shape
m2_shape
{
DT
ype
,
{
2
,
3
,
3
,
4
}};
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
bl1
=
mm
->
add_instruction
(
auto
bl1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
2
,
3
,
2
,
3
}}}),
l1
);
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
2
,
3
,
2
,
3
}}}),
l1
);
...
@@ -45,3 +46,7 @@ struct gemm_2args_mm_5 : verify_program<gemm_2args_mm_5>
...
@@ -45,3 +46,7 @@ struct gemm_2args_mm_5 : verify_program<gemm_2args_mm_5>
return
p
;
return
p
;
}
}
};
};
template
struct
gemm_2args_mm_5
<
migraphx
::
shape
::
float_type
>;
template
struct
gemm_2args_mm_5
<
migraphx
::
shape
::
half_type
>;
template
struct
gemm_2args_mm_5
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
test/verify/gemm_2args_mm_6.cpp
View file @
6d0b6bcf
...
@@ -27,14 +27,16 @@
...
@@ -27,14 +27,16 @@
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
struct
gemm_2args_mm_6
:
verify_program
<
gemm_2args_mm_6
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
gemm_2args_mm_6
:
verify_program
<
gemm_2args_mm_6
<
DType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
float_t
ype
,
{
2
,
1
,
2
,
3
}};
migraphx
::
shape
m1_shape
{
DT
ype
,
{
2
,
1
,
2
,
3
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
float_t
ype
,
{
1
,
3
,
3
,
4
}};
migraphx
::
shape
m2_shape
{
DT
ype
,
{
1
,
3
,
3
,
4
}};
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
bl1
=
mm
->
add_instruction
(
auto
bl1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
2
,
3
,
2
,
3
}}}),
l1
);
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
2
,
3
,
2
,
3
}}}),
l1
);
...
@@ -47,3 +49,7 @@ struct gemm_2args_mm_6 : verify_program<gemm_2args_mm_6>
...
@@ -47,3 +49,7 @@ struct gemm_2args_mm_6 : verify_program<gemm_2args_mm_6>
return
p
;
return
p
;
}
}
};
};
template
struct
gemm_2args_mm_6
<
migraphx
::
shape
::
float_type
>;
template
struct
gemm_2args_mm_6
<
migraphx
::
shape
::
half_type
>;
template
struct
gemm_2args_mm_6
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
test/verify/gemm_2args_mm_7.cpp
View file @
6d0b6bcf
...
@@ -27,14 +27,15 @@
...
@@ -27,14 +27,15 @@
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
struct
gemm_2args_mm_7
:
verify_program
<
gemm_2args_mm_7
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
gemm_2args_mm_7
:
verify_program
<
gemm_2args_mm_7
<
DType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
float_t
ype
,
{
2
,
3
}};
migraphx
::
shape
m1_shape
{
DT
ype
,
{
2
,
3
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
float_t
ype
,
{
2
,
3
,
3
,
4
}};
migraphx
::
shape
m2_shape
{
DT
ype
,
{
2
,
3
,
3
,
4
}};
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
bl1
=
mm
->
add_instruction
(
auto
bl1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
2
,
3
,
2
,
3
}}}),
l1
);
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
2
,
3
,
2
,
3
}}}),
l1
);
...
@@ -45,3 +46,7 @@ struct gemm_2args_mm_7 : verify_program<gemm_2args_mm_7>
...
@@ -45,3 +46,7 @@ struct gemm_2args_mm_7 : verify_program<gemm_2args_mm_7>
return
p
;
return
p
;
}
}
};
};
template
struct
gemm_2args_mm_7
<
migraphx
::
shape
::
float_type
>;
template
struct
gemm_2args_mm_7
<
migraphx
::
shape
::
half_type
>;
template
struct
gemm_2args_mm_7
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
test/verify/gemm_2args_mm_8.cpp
View file @
6d0b6bcf
...
@@ -27,14 +27,15 @@
...
@@ -27,14 +27,15 @@
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
struct
gemm_2args_mm_8
:
verify_program
<
gemm_2args_mm_8
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
gemm_2args_mm_8
:
verify_program
<
gemm_2args_mm_8
<
DType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
float_t
ype
,
{
2
,
128
,
32
},
{
4096
,
1
,
128
}};
migraphx
::
shape
a_shape
{
DT
ype
,
{
2
,
128
,
32
},
{
4096
,
1
,
128
}};
migraphx
::
shape
b_shape
{
migraphx
::
shape
::
float_t
ype
,
{
32
,
32
}};
migraphx
::
shape
b_shape
{
DT
ype
,
{
32
,
32
}};
auto
a
=
mm
->
add_parameter
(
"a"
,
a_shape
);
auto
a
=
mm
->
add_parameter
(
"a"
,
a_shape
);
auto
b
=
mm
->
add_parameter
(
"b"
,
b_shape
);
auto
b
=
mm
->
add_parameter
(
"b"
,
b_shape
);
auto
bb
=
mm
->
add_instruction
(
auto
bb
=
mm
->
add_instruction
(
...
@@ -45,3 +46,7 @@ struct gemm_2args_mm_8 : verify_program<gemm_2args_mm_8>
...
@@ -45,3 +46,7 @@ struct gemm_2args_mm_8 : verify_program<gemm_2args_mm_8>
return
p
;
return
p
;
}
}
};
};
template
struct
gemm_2args_mm_8
<
migraphx
::
shape
::
float_type
>;
// template struct gemm_2args_mm_8<migraphx::shape::half_type>;
template
struct
gemm_2args_mm_8
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
test/verify/gemm_2args_mv.cpp
View file @
6d0b6bcf
...
@@ -27,14 +27,15 @@
...
@@ -27,14 +27,15 @@
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
struct
gemm_2args_mv
:
verify_program
<
gemm_2args_mv
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
gemm_2args_mv
:
verify_program
<
gemm_2args_mv
<
DType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
float_t
ype
,
{
3
,
5
}};
migraphx
::
shape
m1_shape
{
DT
ype
,
{
3
,
5
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
float_t
ype
,
{
5
}};
migraphx
::
shape
m2_shape
{
DT
ype
,
{
5
}};
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
l2
=
mm
->
add_parameter
(
"2"
,
m2_shape
);
auto
l2
=
mm
->
add_parameter
(
"2"
,
m2_shape
);
auto
ul2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
1
}}}),
l2
);
auto
ul2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
1
}}}),
l2
);
...
@@ -44,3 +45,7 @@ struct gemm_2args_mv : verify_program<gemm_2args_mv>
...
@@ -44,3 +45,7 @@ struct gemm_2args_mv : verify_program<gemm_2args_mv>
return
p
;
return
p
;
}
}
};
};
template
struct
gemm_2args_mv
<
migraphx
::
shape
::
float_type
>;
template
struct
gemm_2args_mv
<
migraphx
::
shape
::
half_type
>;
template
struct
gemm_2args_mv
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
test/verify/gemm_2args_vbm.cpp
View file @
6d0b6bcf
...
@@ -27,14 +27,15 @@
...
@@ -27,14 +27,15 @@
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
struct
gemm_2args_vbm
:
verify_program
<
gemm_2args_vbm
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
gemm_2args_vbm
:
verify_program
<
gemm_2args_vbm
<
DType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
float_t
ype
,
{
5
}};
migraphx
::
shape
m1_shape
{
DT
ype
,
{
5
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
float_t
ype
,
{
2
,
2
,
5
,
4
}};
migraphx
::
shape
m2_shape
{
DT
ype
,
{
2
,
2
,
5
,
4
}};
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
ul1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
0
}}}),
l1
);
auto
ul1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
0
}}}),
l1
);
auto
bul1
=
mm
->
add_instruction
(
auto
bul1
=
mm
->
add_instruction
(
...
@@ -48,3 +49,7 @@ struct gemm_2args_vbm : verify_program<gemm_2args_vbm>
...
@@ -48,3 +49,7 @@ struct gemm_2args_vbm : verify_program<gemm_2args_vbm>
return
p
;
return
p
;
}
}
};
};
template
struct
gemm_2args_vbm
<
migraphx
::
shape
::
float_type
>;
template
struct
gemm_2args_vbm
<
migraphx
::
shape
::
half_type
>;
template
struct
gemm_2args_vbm
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
test/verify/gemm_2args_vm.cpp
View file @
6d0b6bcf
...
@@ -27,14 +27,15 @@
...
@@ -27,14 +27,15 @@
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
struct
gemm_2args_vm
:
verify_program
<
gemm_2args_vm
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
gemm_2args_vm
:
verify_program
<
gemm_2args_vm
<
DType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
float_t
ype
,
{
5
}};
migraphx
::
shape
m1_shape
{
DT
ype
,
{
5
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
float_t
ype
,
{
5
,
4
}};
migraphx
::
shape
m2_shape
{
DT
ype
,
{
5
,
4
}};
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
ul1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
0
}}}),
l1
);
auto
ul1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
0
}}}),
l1
);
auto
l2
=
mm
->
add_parameter
(
"2"
,
m2_shape
);
auto
l2
=
mm
->
add_parameter
(
"2"
,
m2_shape
);
...
@@ -45,3 +46,7 @@ struct gemm_2args_vm : verify_program<gemm_2args_vm>
...
@@ -45,3 +46,7 @@ struct gemm_2args_vm : verify_program<gemm_2args_vm>
return
p
;
return
p
;
}
}
};
};
template
struct
gemm_2args_vm
<
migraphx
::
shape
::
float_type
>;
template
struct
gemm_2args_vm
<
migraphx
::
shape
::
half_type
>;
template
struct
gemm_2args_vm
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment