Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
6d0b6bcf
Unverified
Commit
6d0b6bcf
authored
Dec 05, 2023
by
Umang Yadav
Committed by
GitHub
Dec 05, 2023
Browse files
Add FP8 rocblas gemm support (#2473)
parent
e3e00547
Changes
48
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
272 additions
and
80 deletions
+272
-80
src/eliminate_data_type.cpp
src/eliminate_data_type.cpp
+72
-20
src/include/migraphx/eliminate_data_type.hpp
src/include/migraphx/eliminate_data_type.hpp
+2
-1
src/targets/gpu/CMakeLists.txt
src/targets/gpu/CMakeLists.txt
+9
-0
src/targets/gpu/gemm_impl.cpp
src/targets/gpu/gemm_impl.cpp
+68
-17
src/targets/gpu/include/migraphx/gpu/rocblas.hpp
src/targets/gpu/include/migraphx/gpu/rocblas.hpp
+2
-0
src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp
src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp
+2
-6
src/targets/gpu/rocblas.cpp
src/targets/gpu/rocblas.cpp
+10
-0
src/targets/gpu/target.cpp
src/targets/gpu/target.cpp
+7
-0
test/verify/gemm_2args_bmv.cpp
test/verify/gemm_2args_bmv.cpp
+8
-3
test/verify/gemm_2args_mm_1.cpp
test/verify/gemm_2args_mm_1.cpp
+8
-3
test/verify/gemm_2args_mm_2.cpp
test/verify/gemm_2args_mm_2.cpp
+9
-3
test/verify/gemm_2args_mm_3.cpp
test/verify/gemm_2args_mm_3.cpp
+9
-3
test/verify/gemm_2args_mm_4.cpp
test/verify/gemm_2args_mm_4.cpp
+9
-3
test/verify/gemm_2args_mm_5.cpp
test/verify/gemm_2args_mm_5.cpp
+8
-3
test/verify/gemm_2args_mm_6.cpp
test/verify/gemm_2args_mm_6.cpp
+9
-3
test/verify/gemm_2args_mm_7.cpp
test/verify/gemm_2args_mm_7.cpp
+8
-3
test/verify/gemm_2args_mm_8.cpp
test/verify/gemm_2args_mm_8.cpp
+8
-3
test/verify/gemm_2args_mv.cpp
test/verify/gemm_2args_mv.cpp
+8
-3
test/verify/gemm_2args_vbm.cpp
test/verify/gemm_2args_vbm.cpp
+8
-3
test/verify/gemm_2args_vm.cpp
test/verify/gemm_2args_vm.cpp
+8
-3
No files found.
src/eliminate_data_type.cpp
View file @
6d0b6bcf
...
...
@@ -31,6 +31,72 @@
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
void
insert_convert_to_supported_type
(
module
&
m
,
instruction_ref
ins
,
migraphx
::
shape
::
type_t
target_type
,
std
::
set
<
migraphx
::
shape
::
type_t
>
unsupported_types
)
{
migraphx
::
shape
::
type_t
orig_type
=
ins
->
get_shape
().
type
();
std
::
vector
<
instruction_ref
>
inputs
=
ins
->
inputs
();
std
::
transform
(
inputs
.
begin
(),
inputs
.
end
(),
inputs
.
begin
(),
[
&
](
const
auto
&
i
)
{
if
(
contains
(
unsupported_types
,
i
->
get_shape
().
type
()))
{
return
m
.
insert_instruction
(
ins
,
migraphx
::
make_op
(
"convert"
,
{{
"target_type"
,
migraphx
::
to_value
(
target_type
)}}),
i
);
}
else
{
return
i
;
}
});
// if no change
if
(
inputs
==
ins
->
inputs
())
return
;
auto
op
=
ins
->
get_operator
();
auto
attributes
=
op
.
attributes
();
if
(
attributes
.
contains
(
"general_data_type"
))
{
op
=
make_op
(
attributes
[
"general_data_type"
].
to
<
std
::
string
>
(),
op
.
to_value
());
}
auto
new_ins
=
m
.
insert_instruction
(
ins
,
op
,
inputs
);
if
(
orig_type
==
shape
::
tuple_type
)
{
auto
orig_outs
=
ins
->
outputs
();
if
(
not
std
::
all_of
(
orig_outs
.
begin
(),
orig_outs
.
end
(),
[
&
](
const
auto
out_ins
)
{
return
out_ins
->
name
()
==
"get_tuple_elem"
;
}))
MIGRAPHX_THROW
(
"eliminate_data_type: Instruction with tuple output doesn't have all its "
"usages as get_tuple_elem instruction"
);
std
::
transform
(
orig_outs
.
begin
(),
orig_outs
.
end
(),
orig_outs
.
begin
(),
[
&
](
const
auto
out_ins
)
{
auto
gte_ins
=
m
.
insert_instruction
(
ins
,
out_ins
->
get_operator
(),
new_ins
);
auto
orig_out_type
=
out_ins
->
get_shape
().
type
();
if
(
contains
(
unsupported_types
,
orig_out_type
))
{
auto
gte_convert
=
m
.
insert_instruction
(
ins
,
make_op
(
"convert"
,
{{
"target_type"
,
orig_out_type
}}),
gte_ins
);
return
m
.
replace_instruction
(
out_ins
,
gte_convert
);
}
else
{
return
m
.
replace_instruction
(
out_ins
,
gte_ins
);
}
});
}
else
{
auto
convert_back_ins
=
m
.
insert_instruction
(
ins
,
migraphx
::
make_op
(
"convert"
,
{{
"target_type"
,
migraphx
::
to_value
(
orig_type
)}}),
new_ins
);
m
.
replace_instruction
(
ins
,
convert_back_ins
);
}
}
void
eliminate_data_type
::
apply
(
module
&
m
)
const
{
static
const
std
::
vector
<
std
::
string
>
skip_op_names
=
{
"convert"
,
...
...
@@ -42,31 +108,17 @@ void eliminate_data_type::apply(module& m) const
"scatternd_add"
,
"scatternd_mul"
,
"scatternd_none"
};
if
(
unsupported_types
.
empty
())
return
;
for
(
auto
ins
:
iterator_for
(
m
))
{
if
(
ins
->
name
()[
0
]
==
'@'
)
continue
;
if
(
contains
(
skip_op_names
,
ins
->
name
()))
continue
;
auto
inputs
=
ins
->
inputs
();
std
::
transform
(
inputs
.
begin
(),
inputs
.
end
(),
inputs
.
begin
(),
[
&
](
auto
i
)
{
if
(
types
.
count
(
i
->
get_shape
().
type
())
==
0
)
return
i
;
return
m
.
insert_instruction
(
ins
,
make_op
(
"convert"
,
{{
"target_type"
,
target_type
}}),
i
);
});
if
(
inputs
==
ins
->
inputs
())
if
(
contains
(
skip_op_names
,
ins
->
name
())
and
not
contains
(
unsupported_ops
,
ins
->
name
()))
continue
;
auto
op
=
ins
->
get_operator
();
auto
attributes
=
op
.
attributes
();
if
(
attributes
.
contains
(
"general_data_type"
))
{
op
=
make_op
(
attributes
[
"general_data_type"
].
to
<
std
::
string
>
(),
op
.
to_value
());
}
auto
old_type
=
ins
->
get_shape
().
type
();
auto
out
=
m
.
insert_instruction
(
ins
,
op
,
inputs
);
auto
convert
=
m
.
insert_instruction
(
ins
,
make_op
(
"convert"
,
{{
"target_type"
,
old_type
}}),
out
);
m
.
replace_instruction
(
ins
,
convert
);
if
(
contains
(
unsupported_ops
,
"all"
)
or
contains
(
unsupported_ops
,
ins
->
name
()))
insert_convert_to_supported_type
(
m
,
ins
,
target_type
,
unsupported_types
);
}
}
...
...
src/include/migraphx/eliminate_data_type.hpp
View file @
6d0b6bcf
...
...
@@ -40,8 +40,9 @@ struct module;
*/
struct
MIGRAPHX_EXPORT
eliminate_data_type
{
std
::
set
<
shape
::
type_t
>
types
;
std
::
set
<
shape
::
type_t
>
unsupported_
types
;
shape
::
type_t
target_type
;
std
::
set
<
std
::
string
>
unsupported_ops
=
{
"all"
};
std
::
string
name
()
const
{
return
"eliminate_data_type"
;
}
void
apply
(
module
&
m
)
const
;
};
...
...
src/targets/gpu/CMakeLists.txt
View file @
6d0b6bcf
...
...
@@ -259,6 +259,8 @@ check_library_exists(MIOpen "miopenHiddenSetConvolutionFindMode" "${MIOPEN_LOCAT
check_library_exists
(
MIOpen
"miopenFindSolutions"
"
${
MIOPEN_LOCATION
}
"
HAS_FIND_2_API
)
# Beta API for automated GEMM tuning
check_library_exists
(
roc::rocblas
"rocblas_gemm_ex_get_solutions"
"
${
ROCBLAS_LOCATION
}
"
HAS_ROCBLAS_TUNING_BETA_FEATURE_API
)
# rocblas FP8 API
check_library_exists
(
roc::rocblas
"rocblas_gemm_strided_batched_ex3"
"
${
ROCBLAS_LOCATION
}
"
HAS_ROCBLAS_FP8_BETA_API
)
set
(
MIGRAPHX_USE_FIND_2_API
"
${
HAS_FIND_2_API
}
"
CACHE BOOL
""
)
...
...
@@ -288,6 +290,13 @@ else()
message
(
STATUS
"rocBLAS does not have User Tuning Beta API"
)
endif
()
if
(
HAS_ROCBLAS_FP8_BETA_API
)
target_compile_definitions
(
migraphx_gpu PUBLIC -DMIGRAPHX_USE_ROCBLAS_FP8_API -DROCBLAS_BETA_FEATURES_API -DROCBLAS_NO_DEPRECATED_WARNINGS
)
message
(
STATUS
"MIGraphX is using Beta API of rocBLAS for FP8 computations"
)
else
()
message
(
STATUS
"rocBLAS does not have Fp8 Beta API"
)
endif
()
target_link_libraries
(
migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas
)
target_link_libraries
(
migraphx_gpu PRIVATE migraphx_device migraphx_kernels
)
if
(
MIGRAPHX_USE_COMPOSABLEKERNEL
)
...
...
src/targets/gpu/gemm_impl.cpp
View file @
6d0b6bcf
...
...
@@ -22,11 +22,14 @@
* THE SOFTWARE.
*/
#include <rocblas/internal/rocblas-types.h>
#include <rocblas/rocblas.h>
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/gemm_impl.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/time.hpp>
#include <type_traits>
using
microseconds
=
std
::
chrono
::
duration
<
double
,
std
::
micro
>
;
...
...
@@ -34,6 +37,20 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
/*
Regular rocBLAS API takes compute_type as `rocblas_datatype` enum value v/s "ex3" BETA API takes it
as `rocblas_computetype` enum value. `rb_compute_type` is faciliator to implictly cast integer enum
value to required type that can be used inside `common_args` generator.
*/
struct
rb_compute_type
{
int
type
=
0
;
rb_compute_type
(
rocblas_datatype
t
)
:
type
(
static_cast
<
int
>
(
t
))
{}
rb_compute_type
(
rocblas_computetype
t
)
:
type
(
static_cast
<
int
>
(
t
))
{}
operator
rocblas_datatype
()
const
{
return
static_cast
<
rocblas_datatype
>
(
type
);
}
operator
rocblas_computetype
()
const
{
return
static_cast
<
rocblas_computetype
>
(
type
);
}
};
// Convert rocBLAS datatypes to equivalent Migraphx data types
rocblas_datatype
get_type
(
shape
::
type_t
type
)
{
...
...
@@ -46,7 +63,7 @@ rocblas_datatype get_type(shape::type_t type)
case
shape
::
uint8_type
:
return
rocblas_datatype_u8_r
;
case
shape
::
int32_type
:
return
rocblas_datatype_i32_r
;
case
shape
::
uint32_type
:
return
rocblas_datatype_u32_r
;
case
shape
::
fp8e4m3fnuz_type
:
case
shape
::
fp8e4m3fnuz_type
:
return
rocblas_datatype_f8_r
;
case
shape
::
tuple_type
:
case
shape
::
bool_type
:
case
shape
::
uint16_type
:
...
...
@@ -183,12 +200,17 @@ struct gemm_impl
{
output_type
=
rocblas_datatype_i32_r
;
}
compute_type
=
output_type
;
compute_type
=
rb_compute_type
{
output_type
}
;
if
(
compute_fp32
)
{
if
(
arg_type
==
rocblas_datatype_f16_r
)
compute_type
=
rocblas_datatype_f32_r
;
}
if
(
arg_type
==
rocblas_datatype_f8_r
)
{
assert
(
get_type
(
input_shapes
[
1
].
type
())
==
rocblas_datatype_f8_r
);
compute_type
=
rocblas_compute_type_f32
;
}
auto
a_lens
=
input_shapes
[
0
].
lens
();
auto
b_lens
=
input_shapes
[
1
].
lens
();
...
...
@@ -217,23 +239,52 @@ struct gemm_impl
void
run
(
context
&
ctx
,
const
std
::
vector
<
argument
>&
input_args
,
int32_t
solution_idx
=
0
)
const
{
if
(
strided_batched
)
#ifdef MIGRAPHX_USE_ROCBLAS_FP8_API
if
(
rocblas_fp8_available
()
and
std
::
any_of
(
input_args
.
begin
(),
input_args
.
end
(),
[](
const
auto
i
)
{
return
i
.
get_shape
().
type
()
==
migraphx
::
shape
::
fp8e4m3fnuz_type
;
}))
{
auto
common_args
=
create_strided_batched_args_common
(
ctx
,
input_args
);
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex
,
common_args
,
rocblas_gemm_algo_solution_index
,
solution_idx
,
gemm_flags
);
if
(
strided_batched
)
{
auto
common_args
=
create_strided_batched_args_common
(
ctx
,
input_args
);
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex3
,
common_args
,
rocblas_gemm_algo_standard
,
solution_idx
,
gemm_flags
);
}
else
{
auto
common_args
=
create_gemm_ex_args_common
(
ctx
,
input_args
);
rocblas_invoke
(
&
rocblas_gemm_ex3
,
common_args
,
rocblas_gemm_algo_standard
,
solution_idx
,
gemm_flags
);
}
}
else
#endif
{
auto
common_args
=
create_gemm_ex_args_common
(
ctx
,
input_args
);
rocblas_invoke
(
&
rocblas_gemm_ex
,
common_args
,
rocblas_gemm_algo_solution_index
,
solution_idx
,
gemm_flags
);
if
(
strided_batched
)
{
auto
common_args
=
create_strided_batched_args_common
(
ctx
,
input_args
);
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex
,
common_args
,
rocblas_gemm_algo_solution_index
,
solution_idx
,
gemm_flags
);
}
else
{
auto
common_args
=
create_gemm_ex_args_common
(
ctx
,
input_args
);
rocblas_invoke
(
&
rocblas_gemm_ex
,
common_args
,
rocblas_gemm_algo_solution_index
,
solution_idx
,
gemm_flags
);
}
}
}
...
...
@@ -331,7 +382,6 @@ struct gemm_impl
num_matrices
,
compute_type
);
}
/**
* Helper method to create that subset of a long rocBLAS argument list that is common
* to multiple "gemm_ex..." calls.
...
...
@@ -366,6 +416,7 @@ struct gemm_impl
ldd
,
compute_type
);
}
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
/**
* Find best rocBLAS solution: Get list of solutions and try them all, returning the index
...
...
@@ -481,8 +532,8 @@ struct gemm_impl
rocblas_int
b_stride
=
0
;
rocblas_int
c_stride
=
0
;
rocblas_int
d_stride
=
0
;
rocblas_datatype
compute_type
=
rocblas_datatype_f32_r
;
rocblas_datatype
arg_type
=
rocblas_datatype_f32_r
;
rb_compute_type
compute_type
=
rocblas_datatype_f32_r
;
rocblas_datatype
output_type
=
rocblas_datatype_f32_r
;
bool
strided_batched
=
true
;
bool
is_3inputs
=
true
;
...
...
src/targets/gpu/include/migraphx/gpu/rocblas.hpp
View file @
6d0b6bcf
...
...
@@ -40,6 +40,8 @@ struct context;
MIGRAPHX_GPU_EXPORT
bool
get_compute_fp32_flag
();
MIGRAPHX_GPU_EXPORT
bool
rocblas_fp8_available
();
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
...
...
src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp
View file @
6d0b6bcf
...
...
@@ -501,9 +501,7 @@ class numeric_limits<fp8e5m2fnuz>
{
return
fp8e5m2fnuz
(
0x7F
,
fp8e5m2fnuz
::
from_bits
());
}
// this is min value that is not DeNormalized(DeNorm). DeNorm min is 0x01. I am not sure if we
// want to make this distinction. For the floating points we would end up using lowest most of
// the times.
// this is min value that is not DeNormalized(DeNorm). DeNorm min is 0x01.
static
constexpr
__device__
fp8e5m2fnuz
min
()
{
return
fp8e5m2fnuz
(
0x4
,
fp8e5m2fnuz
::
from_bits
());
...
...
@@ -528,9 +526,7 @@ class numeric_limits<fp8e5m2>
}
static
constexpr
__device__
fp8e5m2
max
()
{
return
fp8e5m2
(
0x7B
,
fp8e5m2
::
from_bits
());
}
// this is min value that is not DeNormalized(DeNorm). DeNorm min is 0x01. I am not sure if we
// want to make this distinction. For the floating points we would end up using lowest most of
// the times.
// this is min value that is not DeNormalized(DeNorm). DeNorm min is 0x01.
static
constexpr
__device__
fp8e5m2
min
()
{
return
fp8e5m2
(
0x4
,
fp8e5m2
::
from_bits
());
}
static
constexpr
__device__
fp8e5m2
lowest
()
{
return
fp8e5m2
(
0xFB
,
fp8e5m2
::
from_bits
());
}
...
...
src/targets/gpu/rocblas.cpp
View file @
6d0b6bcf
...
...
@@ -53,6 +53,16 @@ bool get_compute_fp32_flag()
return
(
starts_with
(
device_name
,
"gfx9"
)
and
device_name
>=
"gfx908"
);
}
bool
rocblas_fp8_available
()
{
#ifndef MIGRAPHX_USE_ROCBLAS_FP8_API
return
false
;
#else
const
auto
device_name
=
trim
(
split_string
(
get_device_name
(),
':'
).
front
());
return
(
starts_with
(
device_name
,
"gfx9"
)
and
device_name
>=
"gfx940"
);
#endif
}
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/target.cpp
View file @
6d0b6bcf
...
...
@@ -105,6 +105,11 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
unsupported_types
.
erase
(
shape
::
type_t
::
uint8_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
int32_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
tuple_type
);
std
::
set
<
std
::
string
>
unsupported_fp8_ops
=
{};
if
(
not
gpu
::
rocblas_fp8_available
())
{
unsupported_fp8_ops
.
insert
(
"dot"
);
}
// clang-format off
return
{
...
...
@@ -136,6 +141,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
prefuse_ops
{},
dead_code_elimination
{},
auto_contiguous
{},
eliminate_data_type
{{
migraphx
::
shape
::
fp8e4m3fnuz_type
},
shape
::
float_type
,
unsupported_fp8_ops
},
dead_code_elimination
{},
optimize_module
{},
fuse_pointwise
{},
dead_code_elimination
{},
...
...
test/verify/gemm_2args_bmv.cpp
View file @
6d0b6bcf
...
...
@@ -27,14 +27,15 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
gemm_2args_bmv
:
verify_program
<
gemm_2args_bmv
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
gemm_2args_bmv
:
verify_program
<
gemm_2args_bmv
<
DType
>>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
float_t
ype
,
{
2
,
3
,
3
,
5
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
float_t
ype
,
{
5
}};
migraphx
::
shape
m1_shape
{
DT
ype
,
{
2
,
3
,
3
,
5
}};
migraphx
::
shape
m2_shape
{
DT
ype
,
{
5
}};
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
l2
=
mm
->
add_parameter
(
"2"
,
m2_shape
);
auto
ul2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
1
}}}),
l2
);
...
...
@@ -46,3 +47,7 @@ struct gemm_2args_bmv : verify_program<gemm_2args_bmv>
return
p
;
}
};
template
struct
gemm_2args_bmv
<
migraphx
::
shape
::
float_type
>;
template
struct
gemm_2args_bmv
<
migraphx
::
shape
::
half_type
>;
template
struct
gemm_2args_bmv
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
test/verify/gemm_2args_mm_1.cpp
View file @
6d0b6bcf
...
...
@@ -27,14 +27,15 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
gemm_2args_mm_1
:
verify_program
<
gemm_2args_mm_1
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
gemm_2args_mm_1
:
verify_program
<
gemm_2args_mm_1
<
DType
>>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
float_t
ype
,
{
2
,
2
,
3
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
float_t
ype
,
{
1
,
3
,
4
}};
migraphx
::
shape
m1_shape
{
DT
ype
,
{
2
,
2
,
3
}};
migraphx
::
shape
m2_shape
{
DT
ype
,
{
1
,
3
,
4
}};
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
l2
=
mm
->
add_parameter
(
"2"
,
m2_shape
);
auto
bl2
=
...
...
@@ -45,3 +46,7 @@ struct gemm_2args_mm_1 : verify_program<gemm_2args_mm_1>
return
p
;
}
};
template
struct
gemm_2args_mm_1
<
migraphx
::
shape
::
float_type
>;
template
struct
gemm_2args_mm_1
<
migraphx
::
shape
::
half_type
>;
template
struct
gemm_2args_mm_1
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
test/verify/gemm_2args_mm_2.cpp
View file @
6d0b6bcf
...
...
@@ -24,17 +24,19 @@
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
gemm_2args_mm_2
:
verify_program
<
gemm_2args_mm_2
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
gemm_2args_mm_2
:
verify_program
<
gemm_2args_mm_2
<
DType
>>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
float_t
ype
,
{
2
,
2
,
3
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
float_t
ype
,
{
3
,
4
}};
migraphx
::
shape
m1_shape
{
DT
ype
,
{
2
,
2
,
3
}};
migraphx
::
shape
m2_shape
{
DT
ype
,
{
3
,
4
}};
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
l2
=
mm
->
add_parameter
(
"2"
,
m2_shape
);
auto
bl2
=
...
...
@@ -45,3 +47,7 @@ struct gemm_2args_mm_2 : verify_program<gemm_2args_mm_2>
return
p
;
}
};
template
struct
gemm_2args_mm_2
<
migraphx
::
shape
::
float_type
>;
template
struct
gemm_2args_mm_2
<
migraphx
::
shape
::
half_type
>;
template
struct
gemm_2args_mm_2
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
test/verify/gemm_2args_mm_3.cpp
View file @
6d0b6bcf
...
...
@@ -24,17 +24,19 @@
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
gemm_2args_mm_3
:
verify_program
<
gemm_2args_mm_3
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
gemm_2args_mm_3
:
verify_program
<
gemm_2args_mm_3
<
DType
>>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
float_t
ype
,
{
1
,
2
,
3
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
float_t
ype
,
{
3
,
3
,
4
}};
migraphx
::
shape
m1_shape
{
DT
ype
,
{
1
,
2
,
3
}};
migraphx
::
shape
m2_shape
{
DT
ype
,
{
3
,
3
,
4
}};
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
bl1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
3
,
2
,
3
}}}),
l1
);
...
...
@@ -45,3 +47,7 @@ struct gemm_2args_mm_3 : verify_program<gemm_2args_mm_3>
return
p
;
}
};
template
struct
gemm_2args_mm_3
<
migraphx
::
shape
::
float_type
>;
template
struct
gemm_2args_mm_3
<
migraphx
::
shape
::
half_type
>;
template
struct
gemm_2args_mm_3
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
test/verify/gemm_2args_mm_4.cpp
View file @
6d0b6bcf
...
...
@@ -23,18 +23,20 @@
*/
#include "verify_program.hpp"
#include <migraphx/shape.hpp>
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
gemm_2args_mm_4
:
verify_program
<
gemm_2args_mm_4
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
gemm_2args_mm_4
:
verify_program
<
gemm_2args_mm_4
<
DType
>>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
float_t
ype
,
{
2
,
3
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
float_t
ype
,
{
3
,
3
,
4
}};
migraphx
::
shape
m1_shape
{
DT
ype
,
{
2
,
3
}};
migraphx
::
shape
m2_shape
{
DT
ype
,
{
3
,
3
,
4
}};
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
bl1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
3
,
2
,
3
}}}),
l1
);
...
...
@@ -45,3 +47,7 @@ struct gemm_2args_mm_4 : verify_program<gemm_2args_mm_4>
return
p
;
}
};
template
struct
gemm_2args_mm_4
<
migraphx
::
shape
::
float_type
>;
template
struct
gemm_2args_mm_4
<
migraphx
::
shape
::
half_type
>;
template
struct
gemm_2args_mm_4
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
test/verify/gemm_2args_mm_5.cpp
View file @
6d0b6bcf
...
...
@@ -27,14 +27,15 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
gemm_2args_mm_5
:
verify_program
<
gemm_2args_mm_5
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
gemm_2args_mm_5
:
verify_program
<
gemm_2args_mm_5
<
DType
>>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
float_t
ype
,
{
2
,
1
,
2
,
3
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
float_t
ype
,
{
2
,
3
,
3
,
4
}};
migraphx
::
shape
m1_shape
{
DT
ype
,
{
2
,
1
,
2
,
3
}};
migraphx
::
shape
m2_shape
{
DT
ype
,
{
2
,
3
,
3
,
4
}};
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
bl1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
2
,
3
,
2
,
3
}}}),
l1
);
...
...
@@ -45,3 +46,7 @@ struct gemm_2args_mm_5 : verify_program<gemm_2args_mm_5>
return
p
;
}
};
template
struct
gemm_2args_mm_5
<
migraphx
::
shape
::
float_type
>;
template
struct
gemm_2args_mm_5
<
migraphx
::
shape
::
half_type
>;
template
struct
gemm_2args_mm_5
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
test/verify/gemm_2args_mm_6.cpp
View file @
6d0b6bcf
...
...
@@ -27,14 +27,16 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
gemm_2args_mm_6
:
verify_program
<
gemm_2args_mm_6
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
gemm_2args_mm_6
:
verify_program
<
gemm_2args_mm_6
<
DType
>>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
float_t
ype
,
{
2
,
1
,
2
,
3
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
float_t
ype
,
{
1
,
3
,
3
,
4
}};
migraphx
::
shape
m1_shape
{
DT
ype
,
{
2
,
1
,
2
,
3
}};
migraphx
::
shape
m2_shape
{
DT
ype
,
{
1
,
3
,
3
,
4
}};
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
bl1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
2
,
3
,
2
,
3
}}}),
l1
);
...
...
@@ -47,3 +49,7 @@ struct gemm_2args_mm_6 : verify_program<gemm_2args_mm_6>
return
p
;
}
};
template
struct
gemm_2args_mm_6
<
migraphx
::
shape
::
float_type
>;
template
struct
gemm_2args_mm_6
<
migraphx
::
shape
::
half_type
>;
template
struct
gemm_2args_mm_6
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
test/verify/gemm_2args_mm_7.cpp
View file @
6d0b6bcf
...
...
@@ -27,14 +27,15 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
gemm_2args_mm_7
:
verify_program
<
gemm_2args_mm_7
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
gemm_2args_mm_7
:
verify_program
<
gemm_2args_mm_7
<
DType
>>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
float_t
ype
,
{
2
,
3
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
float_t
ype
,
{
2
,
3
,
3
,
4
}};
migraphx
::
shape
m1_shape
{
DT
ype
,
{
2
,
3
}};
migraphx
::
shape
m2_shape
{
DT
ype
,
{
2
,
3
,
3
,
4
}};
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
bl1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
2
,
3
,
2
,
3
}}}),
l1
);
...
...
@@ -45,3 +46,7 @@ struct gemm_2args_mm_7 : verify_program<gemm_2args_mm_7>
return
p
;
}
};
template
struct
gemm_2args_mm_7
<
migraphx
::
shape
::
float_type
>;
template
struct
gemm_2args_mm_7
<
migraphx
::
shape
::
half_type
>;
template
struct
gemm_2args_mm_7
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
test/verify/gemm_2args_mm_8.cpp
View file @
6d0b6bcf
...
...
@@ -27,14 +27,15 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
gemm_2args_mm_8
:
verify_program
<
gemm_2args_mm_8
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
gemm_2args_mm_8
:
verify_program
<
gemm_2args_mm_8
<
DType
>>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
float_t
ype
,
{
2
,
128
,
32
},
{
4096
,
1
,
128
}};
migraphx
::
shape
b_shape
{
migraphx
::
shape
::
float_t
ype
,
{
32
,
32
}};
migraphx
::
shape
a_shape
{
DT
ype
,
{
2
,
128
,
32
},
{
4096
,
1
,
128
}};
migraphx
::
shape
b_shape
{
DT
ype
,
{
32
,
32
}};
auto
a
=
mm
->
add_parameter
(
"a"
,
a_shape
);
auto
b
=
mm
->
add_parameter
(
"b"
,
b_shape
);
auto
bb
=
mm
->
add_instruction
(
...
...
@@ -45,3 +46,7 @@ struct gemm_2args_mm_8 : verify_program<gemm_2args_mm_8>
return
p
;
}
};
template
struct
gemm_2args_mm_8
<
migraphx
::
shape
::
float_type
>;
// template struct gemm_2args_mm_8<migraphx::shape::half_type>;
template
struct
gemm_2args_mm_8
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
test/verify/gemm_2args_mv.cpp
View file @
6d0b6bcf
...
...
@@ -27,14 +27,15 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
gemm_2args_mv
:
verify_program
<
gemm_2args_mv
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
gemm_2args_mv
:
verify_program
<
gemm_2args_mv
<
DType
>>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
float_t
ype
,
{
3
,
5
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
float_t
ype
,
{
5
}};
migraphx
::
shape
m1_shape
{
DT
ype
,
{
3
,
5
}};
migraphx
::
shape
m2_shape
{
DT
ype
,
{
5
}};
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
l2
=
mm
->
add_parameter
(
"2"
,
m2_shape
);
auto
ul2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
1
}}}),
l2
);
...
...
@@ -44,3 +45,7 @@ struct gemm_2args_mv : verify_program<gemm_2args_mv>
return
p
;
}
};
template
struct
gemm_2args_mv
<
migraphx
::
shape
::
float_type
>;
template
struct
gemm_2args_mv
<
migraphx
::
shape
::
half_type
>;
template
struct
gemm_2args_mv
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
test/verify/gemm_2args_vbm.cpp
View file @
6d0b6bcf
...
...
@@ -27,14 +27,15 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
gemm_2args_vbm
:
verify_program
<
gemm_2args_vbm
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
gemm_2args_vbm
:
verify_program
<
gemm_2args_vbm
<
DType
>>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
float_t
ype
,
{
5
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
float_t
ype
,
{
2
,
2
,
5
,
4
}};
migraphx
::
shape
m1_shape
{
DT
ype
,
{
5
}};
migraphx
::
shape
m2_shape
{
DT
ype
,
{
2
,
2
,
5
,
4
}};
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
ul1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
0
}}}),
l1
);
auto
bul1
=
mm
->
add_instruction
(
...
...
@@ -48,3 +49,7 @@ struct gemm_2args_vbm : verify_program<gemm_2args_vbm>
return
p
;
}
};
template
struct
gemm_2args_vbm
<
migraphx
::
shape
::
float_type
>;
template
struct
gemm_2args_vbm
<
migraphx
::
shape
::
half_type
>;
template
struct
gemm_2args_vbm
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
test/verify/gemm_2args_vm.cpp
View file @
6d0b6bcf
...
...
@@ -27,14 +27,15 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
gemm_2args_vm
:
verify_program
<
gemm_2args_vm
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
gemm_2args_vm
:
verify_program
<
gemm_2args_vm
<
DType
>>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
float_t
ype
,
{
5
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
float_t
ype
,
{
5
,
4
}};
migraphx
::
shape
m1_shape
{
DT
ype
,
{
5
}};
migraphx
::
shape
m2_shape
{
DT
ype
,
{
5
,
4
}};
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
ul1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
0
}}}),
l1
);
auto
l2
=
mm
->
add_parameter
(
"2"
,
m2_shape
);
...
...
@@ -45,3 +46,7 @@ struct gemm_2args_vm : verify_program<gemm_2args_vm>
return
p
;
}
};
template
struct
gemm_2args_vm
<
migraphx
::
shape
::
float_type
>;
template
struct
gemm_2args_vm
<
migraphx
::
shape
::
half_type
>;
template
struct
gemm_2args_vm
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment