Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
6d0b6bcf
"vscode:/vscode.git/clone" did not exist on "9b29da00731817be09408633d005edffd592e9ec"
Unverified
Commit
6d0b6bcf
authored
Dec 05, 2023
by
Umang Yadav
Committed by
GitHub
Dec 05, 2023
Browse files
Add FP8 rocblas gemm support (#2473)
parent
e3e00547
Changes
48
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
70 additions
and
29 deletions
+70
-29
test/verify/test_gemm_transposea_ex.cpp
test/verify/test_gemm_transposea_ex.cpp
+8
-3
test/verify/test_gemm_transposeab.cpp
test/verify/test_gemm_transposeab.cpp
+8
-3
test/verify/test_gemm_transposeb.cpp
test/verify/test_gemm_transposeb.cpp
+8
-3
test/verify/test_gemm_transposeb_ex.cpp
test/verify/test_gemm_transposeb_ex.cpp
+8
-3
test/verify/test_mul_dot_a.cpp
test/verify/test_mul_dot_a.cpp
+9
-5
test/verify/test_mul_dot_b.cpp
test/verify/test_mul_dot_b.cpp
+10
-5
test/verify/test_unbatched_gemm_1.cpp
test/verify/test_unbatched_gemm_1.cpp
+10
-4
test/verify/test_unbatched_gemm_2.cpp
test/verify/test_unbatched_gemm_2.cpp
+9
-3
No files found.
test/verify/test_gemm_transposea_ex.cpp
View file @
6d0b6bcf
...
@@ -27,17 +27,22 @@
...
@@ -27,17 +27,22 @@
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
struct
test_gemm_transposea_ex
:
verify_program
<
test_gemm_transposea_ex
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
test_gemm_transposea_ex
:
verify_program
<
test_gemm_transposea_ex
<
DType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
auto
a
=
mm
->
add_parameter
(
"a"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_t
ype
,
{
1
,
1
,
5
,
4
}});
auto
a
=
mm
->
add_parameter
(
"a"
,
migraphx
::
shape
{
DT
ype
,
{
1
,
1
,
5
,
4
}});
auto
b
=
mm
->
add_parameter
(
"b"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_t
ype
,
{
1
,
1
,
5
,
3
}});
auto
b
=
mm
->
add_parameter
(
"b"
,
migraphx
::
shape
{
DT
ype
,
{
1
,
1
,
5
,
3
}});
auto
at
=
auto
at
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
1
,
3
,
2
}}}),
a
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
1
,
3
,
2
}}}),
a
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
at
,
b
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
at
,
b
);
return
p
;
return
p
;
}
}
};
};
template
struct
test_gemm_transposea_ex
<
migraphx
::
shape
::
float_type
>;
template
struct
test_gemm_transposea_ex
<
migraphx
::
shape
::
half_type
>;
template
struct
test_gemm_transposea_ex
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
test/verify/test_gemm_transposeab.cpp
View file @
6d0b6bcf
...
@@ -27,17 +27,22 @@
...
@@ -27,17 +27,22 @@
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
struct
test_gemm_transposeab
:
verify_program
<
test_gemm_transposeab
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
test_gemm_transposeab
:
verify_program
<
test_gemm_transposeab
<
DType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
auto
a
=
mm
->
add_parameter
(
"a"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_t
ype
,
{
5
,
4
}});
auto
a
=
mm
->
add_parameter
(
"a"
,
migraphx
::
shape
{
DT
ype
,
{
5
,
4
}});
auto
b
=
mm
->
add_parameter
(
"b"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_t
ype
,
{
3
,
5
}});
auto
b
=
mm
->
add_parameter
(
"b"
,
migraphx
::
shape
{
DT
ype
,
{
3
,
5
}});
auto
at
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
a
);
auto
at
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
a
);
auto
bt
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
b
);
auto
bt
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
b
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
at
,
bt
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
at
,
bt
);
return
p
;
return
p
;
}
}
};
};
template
struct
test_gemm_transposeab
<
migraphx
::
shape
::
float_type
>;
template
struct
test_gemm_transposeab
<
migraphx
::
shape
::
half_type
>;
template
struct
test_gemm_transposeab
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
test/verify/test_gemm_transposeb.cpp
View file @
6d0b6bcf
...
@@ -27,16 +27,21 @@
...
@@ -27,16 +27,21 @@
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
struct
test_gemm_transposeb
:
verify_program
<
test_gemm_transposeb
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
test_gemm_transposeb
:
verify_program
<
test_gemm_transposeb
<
DType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
auto
a
=
mm
->
add_parameter
(
"a"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_t
ype
,
{
4
,
5
}});
auto
a
=
mm
->
add_parameter
(
"a"
,
migraphx
::
shape
{
DT
ype
,
{
4
,
5
}});
auto
b
=
mm
->
add_parameter
(
"b"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_t
ype
,
{
3
,
5
}});
auto
b
=
mm
->
add_parameter
(
"b"
,
migraphx
::
shape
{
DT
ype
,
{
3
,
5
}});
auto
bt
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
b
);
auto
bt
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
b
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
a
,
bt
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
a
,
bt
);
return
p
;
return
p
;
}
}
};
};
template
struct
test_gemm_transposeb
<
migraphx
::
shape
::
float_type
>;
template
struct
test_gemm_transposeb
<
migraphx
::
shape
::
half_type
>;
template
struct
test_gemm_transposeb
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
test/verify/test_gemm_transposeb_ex.cpp
View file @
6d0b6bcf
...
@@ -27,17 +27,22 @@
...
@@ -27,17 +27,22 @@
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
struct
test_gemm_transposeb_ex
:
verify_program
<
test_gemm_transposeb_ex
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
test_gemm_transposeb_ex
:
verify_program
<
test_gemm_transposeb_ex
<
DType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
auto
a
=
mm
->
add_parameter
(
"a"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_t
ype
,
{
1
,
4
,
5
}});
auto
a
=
mm
->
add_parameter
(
"a"
,
migraphx
::
shape
{
DT
ype
,
{
1
,
4
,
5
}});
auto
b
=
mm
->
add_parameter
(
"b"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_t
ype
,
{
1
,
3
,
5
}});
auto
b
=
mm
->
add_parameter
(
"b"
,
migraphx
::
shape
{
DT
ype
,
{
1
,
3
,
5
}});
auto
bt
=
auto
bt
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
2
,
1
}}}),
b
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
2
,
1
}}}),
b
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
a
,
bt
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
a
,
bt
);
return
p
;
return
p
;
}
}
};
};
template
struct
test_gemm_transposeb_ex
<
migraphx
::
shape
::
float_type
>;
template
struct
test_gemm_transposeb_ex
<
migraphx
::
shape
::
half_type
>;
template
struct
test_gemm_transposeb_ex
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
test/verify/test_mul_dot_a.cpp
View file @
6d0b6bcf
...
@@ -27,17 +27,17 @@
...
@@ -27,17 +27,17 @@
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
struct
test_mul_dot_a
:
verify_program
<
test_mul_dot_a
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
test_mul_dot_a
:
verify_program
<
test_mul_dot_a
<
DType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
as
{
migraphx
::
shape
::
float_t
ype
,
{
2
,
256
,
32
}};
migraphx
::
shape
as
{
DT
ype
,
{
2
,
256
,
32
}};
migraphx
::
shape
bs
{
migraphx
::
shape
::
float_t
ype
,
{
2
,
32
,
128
}};
migraphx
::
shape
bs
{
DT
ype
,
{
2
,
32
,
128
}};
auto
a
=
mm
->
add_parameter
(
"input"
,
as
);
auto
a
=
mm
->
add_parameter
(
"input"
,
as
);
auto
lit
=
auto
lit
=
mm
->
add_literal
(
migraphx
::
generate_literal
({
DType
,
{
1
,
1
,
32
}}));
mm
->
add_literal
(
migraphx
::
generate_literal
({
migraphx
::
shape
::
float_type
,
{
1
,
1
,
32
}}));
auto
litb
=
mm
->
add_instruction
(
auto
litb
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
as
.
lens
()}}),
lit
);
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
as
.
lens
()}}),
lit
);
auto
mul
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
a
,
litb
);
auto
mul
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
a
,
litb
);
...
@@ -47,3 +47,7 @@ struct test_mul_dot_a : verify_program<test_mul_dot_a>
...
@@ -47,3 +47,7 @@ struct test_mul_dot_a : verify_program<test_mul_dot_a>
return
p
;
return
p
;
}
}
};
};
template
struct
test_mul_dot_a
<
migraphx
::
shape
::
float_type
>;
template
struct
test_mul_dot_a
<
migraphx
::
shape
::
half_type
>;
template
struct
test_mul_dot_a
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
test/verify/test_mul_dot_b.cpp
View file @
6d0b6bcf
...
@@ -27,17 +27,18 @@
...
@@ -27,17 +27,18 @@
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
struct
test_mul_dot_b
:
verify_program
<
test_mul_dot_b
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
test_mul_dot_b
:
verify_program
<
test_mul_dot_b
<
DType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
as
{
migraphx
::
shape
::
float_t
ype
,
{
2
,
256
,
32
}};
migraphx
::
shape
as
{
DT
ype
,
{
2
,
256
,
32
}};
migraphx
::
shape
bs
{
migraphx
::
shape
::
float_t
ype
,
{
2
,
32
,
128
}};
migraphx
::
shape
bs
{
DT
ype
,
{
2
,
32
,
128
}};
auto
b
=
mm
->
add_parameter
(
"input"
,
bs
);
auto
b
=
mm
->
add_parameter
(
"input"
,
bs
);
auto
lit
=
auto
lit
=
mm
->
add_literal
(
migraphx
::
generate_literal
({
DType
,
{
1
,
32
,
1
}}));
mm
->
add_literal
(
migraphx
::
generate_literal
({
migraphx
::
shape
::
float_type
,
{
1
,
32
,
1
}}));
auto
litb
=
mm
->
add_instruction
(
auto
litb
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
bs
.
lens
()}}),
lit
);
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
bs
.
lens
()}}),
lit
);
auto
mul
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
b
,
litb
);
auto
mul
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
b
,
litb
);
...
@@ -47,3 +48,7 @@ struct test_mul_dot_b : verify_program<test_mul_dot_b>
...
@@ -47,3 +48,7 @@ struct test_mul_dot_b : verify_program<test_mul_dot_b>
return
p
;
return
p
;
}
}
};
};
template
struct
test_mul_dot_b
<
migraphx
::
shape
::
float_type
>;
template
struct
test_mul_dot_b
<
migraphx
::
shape
::
half_type
>;
template
struct
test_mul_dot_b
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
test/verify/test_unbatched_gemm_1.cpp
View file @
6d0b6bcf
...
@@ -27,15 +27,17 @@
...
@@ -27,15 +27,17 @@
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/apply_alpha_beta.hpp>
struct
test_unbatched_gemm_1
:
verify_program
<
test_unbatched_gemm_1
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
test_unbatched_gemm_1
:
verify_program
<
test_unbatched_gemm_1
<
DType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
float_t
ype
,
{
2
,
32
,
64
}};
migraphx
::
shape
m1_shape
{
DT
ype
,
{
2
,
32
,
64
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
float_t
ype
,
{
64
,
64
}};
migraphx
::
shape
m2_shape
{
DT
ype
,
{
64
,
64
}};
migraphx
::
shape
m3_shape
{
migraphx
::
shape
::
float_t
ype
,
{
2
,
32
,
192
}};
migraphx
::
shape
m3_shape
{
DT
ype
,
{
2
,
32
,
192
}};
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
l2
=
mm
->
add_literal
(
migraphx
::
generate_literal
(
m2_shape
));
auto
l2
=
mm
->
add_literal
(
migraphx
::
generate_literal
(
m2_shape
));
l2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
2
,
64
,
64
}}}),
l2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
2
,
64
,
64
}}}),
...
@@ -56,3 +58,7 @@ struct test_unbatched_gemm_1 : verify_program<test_unbatched_gemm_1>
...
@@ -56,3 +58,7 @@ struct test_unbatched_gemm_1 : verify_program<test_unbatched_gemm_1>
return
p
;
return
p
;
}
}
};
};
template
struct
test_unbatched_gemm_1
<
migraphx
::
shape
::
float_type
>;
template
struct
test_unbatched_gemm_1
<
migraphx
::
shape
::
half_type
>;
template
struct
test_unbatched_gemm_1
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
test/verify/test_unbatched_gemm_2.cpp
View file @
6d0b6bcf
...
@@ -27,14 +27,16 @@
...
@@ -27,14 +27,16 @@
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/apply_alpha_beta.hpp>
struct
test_unbatched_gemm_2
:
verify_program
<
test_unbatched_gemm_2
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
test_unbatched_gemm_2
:
verify_program
<
test_unbatched_gemm_2
<
DType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
float_t
ype
,
{
4
,
32
,
64
}};
migraphx
::
shape
m1_shape
{
DT
ype
,
{
4
,
32
,
64
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
float_t
ype
,
{
64
,
64
}};
migraphx
::
shape
m2_shape
{
DT
ype
,
{
64
,
64
}};
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
l2
=
mm
->
add_literal
(
migraphx
::
generate_literal
(
m2_shape
));
auto
l2
=
mm
->
add_literal
(
migraphx
::
generate_literal
(
m2_shape
));
l2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
4
,
64
,
64
}}}),
l2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
4
,
64
,
64
}}}),
...
@@ -44,3 +46,7 @@ struct test_unbatched_gemm_2 : verify_program<test_unbatched_gemm_2>
...
@@ -44,3 +46,7 @@ struct test_unbatched_gemm_2 : verify_program<test_unbatched_gemm_2>
return
p
;
return
p
;
}
}
};
};
template
struct
test_unbatched_gemm_2
<
migraphx
::
shape
::
float_type
>;
template
struct
test_unbatched_gemm_2
<
migraphx
::
shape
::
half_type
>;
template
struct
test_unbatched_gemm_2
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment