Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
d5fa82db
Commit
d5fa82db
authored
Dec 02, 2023
by
Umang Yadav
Browse files
add quant_dot support for fp8
parent
7e80f627
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
148 additions
and
56 deletions
+148
-56
src/include/migraphx/op/quant_dot.hpp
src/include/migraphx/op/quant_dot.hpp
+7
-2
src/simplify_reshapes.cpp
src/simplify_reshapes.cpp
+5
-0
src/targets/gpu/include/migraphx/gpu/gemm.hpp
src/targets/gpu/include/migraphx/gpu/gemm.hpp
+1
-1
src/targets/ref/lowering.cpp
src/targets/ref/lowering.cpp
+39
-11
test/verify/batch_quant_dot_1.cpp
test/verify/batch_quant_dot_1.cpp
+13
-5
test/verify/batch_quant_dot_2.cpp
test/verify/batch_quant_dot_2.cpp
+7
-4
test/verify/batch_quant_dot_3.cpp
test/verify/batch_quant_dot_3.cpp
+6
-3
test/verify/batch_quant_dot_4.cpp
test/verify/batch_quant_dot_4.cpp
+6
-3
test/verify/batch_quant_dot_5.cpp
test/verify/batch_quant_dot_5.cpp
+6
-3
test/verify/quant_dot_3args_1.cpp
test/verify/quant_dot_3args_1.cpp
+13
-5
test/verify/quant_dot_3args_2.cpp
test/verify/quant_dot_3args_2.cpp
+12
-5
test/verify/quant_dot_3args_3.cpp
test/verify/quant_dot_3args_3.cpp
+11
-5
test/verify/quant_dot_3args_4.cpp
test/verify/quant_dot_3args_4.cpp
+12
-5
test/verify/quant_dot_3args_5.cpp
test/verify/quant_dot_3args_5.cpp
+10
-4
No files found.
src/include/migraphx/op/quant_dot.hpp
View file @
d5fa82db
...
@@ -44,9 +44,10 @@ struct quant_dot
...
@@ -44,9 +44,10 @@ struct quant_dot
const
shape
&
a
=
inputs
.
at
(
0
);
const
shape
&
a
=
inputs
.
at
(
0
);
const
shape
&
b
=
inputs
.
at
(
1
);
const
shape
&
b
=
inputs
.
at
(
1
);
auto
t
=
a
.
type
();
auto
t
=
a
.
type
();
if
(
t
!=
shape
::
int8_type
)
std
::
set
<
migraphx
::
shape
::
type_t
>
suppported_types
=
{
shape
::
int8_type
,
shape
::
fp8e4m3fnuz_type
};
if
(
not
contains
(
suppported_types
,
t
))
{
{
MIGRAPHX_THROW
(
"QUANT_DOT: only support data type int8_t"
);
MIGRAPHX_THROW
(
"QUANT_DOT: only support data type int8_t
and fp8e4m3fnuz_type
"
);
}
}
if
(
not
std
::
all_of
(
if
(
not
std
::
all_of
(
...
@@ -73,6 +74,10 @@ struct quant_dot
...
@@ -73,6 +74,10 @@ struct quant_dot
auto
out_lens
=
a
.
lens
();
auto
out_lens
=
a
.
lens
();
out_lens
[
dim_1
]
=
b
.
lens
()[
dim_1
];
out_lens
[
dim_1
]
=
b
.
lens
()[
dim_1
];
if
(
t
==
shape
::
fp8e4m3fnuz_type
)
{
return
{
shape
::
float_type
,
out_lens
};
}
// else int8 gemm
return
{
shape
::
int32_type
,
out_lens
};
return
{
shape
::
int32_type
,
out_lens
};
}
}
};
};
...
...
src/simplify_reshapes.cpp
View file @
d5fa82db
...
@@ -183,6 +183,11 @@ struct find_nested_convert
...
@@ -183,6 +183,11 @@ struct find_nested_convert
auto
x
=
ins
->
inputs
().
front
();
auto
x
=
ins
->
inputs
().
front
();
auto
input
=
x
->
inputs
().
front
();
auto
input
=
x
->
inputs
().
front
();
while
(
input
->
name
()
==
"convert"
)
{
input
=
input
->
inputs
().
front
();
}
if
(
ins
->
get_shape
()
!=
input
->
get_shape
())
if
(
ins
->
get_shape
()
!=
input
->
get_shape
())
return
;
return
;
...
...
src/targets/gpu/include/migraphx/gpu/gemm.hpp
View file @
d5fa82db
...
@@ -112,7 +112,7 @@ struct rocblas_gemm
...
@@ -112,7 +112,7 @@ struct rocblas_gemm
argument
argument
compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
)
const
compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
)
const
{
{
if
(
this
->
name
()
==
"gpu::gemm"
)
if
(
this
->
name
()
==
"gpu::gemm"
or
output_shape
.
type
()
==
migraphx
::
shape
::
float_type
)
{
{
gemm_compute
(
ctx
,
output_shape
,
args
,
alpha
,
beta
,
compute_fp32
,
solution_idx
);
gemm_compute
(
ctx
,
output_shape
,
args
,
alpha
,
beta
,
compute_fp32
,
solution_idx
);
}
}
...
...
src/targets/ref/lowering.cpp
View file @
d5fa82db
...
@@ -24,6 +24,7 @@
...
@@ -24,6 +24,7 @@
#include <migraphx/ref/lowering.hpp>
#include <migraphx/ref/lowering.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/op/identity.hpp>
#include <migraphx/op/identity.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/convolution.hpp>
...
@@ -307,19 +308,46 @@ struct ref_quant_gemm
...
@@ -307,19 +308,46 @@ struct ref_quant_gemm
{
{
argument
result
{
output_shape
};
argument
result
{
output_shape
};
// first, convert the args[0] and args[1] from int8_t to int32_t
// first, convert the args[0] and args[1] from int8_t to int32_t
argument
arg_0
{{
shape
::
int32_type
,
{
args
.
at
(
0
).
get_shape
().
lens
()}}};
argument
arg_0
{{
output_shape
.
type
(),
{
args
.
at
(
0
).
get_shape
().
lens
()}}};
argument
arg_1
{{
shape
::
int32_type
,
{
args
.
at
(
1
).
get_shape
().
lens
()}}};
argument
arg_1
{{
output_shape
.
type
(),
{
args
.
at
(
1
).
get_shape
().
lens
()}}};
arg_0
.
visit
([
&
](
auto
output
)
{
if
(
output_shape
.
type
()
==
migraphx
::
shape
::
float_type
)
args
.
at
(
0
).
visit
(
{
[
&
](
auto
input
)
{
std
::
copy
(
input
.
begin
(),
input
.
end
(),
output
.
begin
());
});
arg_0
.
visit
([
&
](
auto
output
)
{
});
args
.
at
(
0
).
visit
([
&
](
auto
input
)
{
std
::
transform
(
input
.
begin
(),
input
.
end
(),
output
.
begin
(),
[
&
](
const
auto
x
)
{
return
static_cast
<
float
>
(
x
);
});
});
});
arg_1
.
visit
([
&
](
auto
output
)
{
arg_1
.
visit
([
&
](
auto
output
)
{
args
.
at
(
1
).
visit
(
args
.
at
(
1
).
visit
([
&
](
auto
input
)
{
[
&
](
auto
input
)
{
std
::
copy
(
input
.
begin
(),
input
.
end
(),
output
.
begin
());
});
std
::
transform
(
input
.
begin
(),
input
.
end
(),
output
.
begin
(),
[
&
](
const
auto
x
)
{
});
return
static_cast
<
float
>
(
x
);
});
});
});
migemm
(
result
,
arg_0
,
arg_1
,
1.0
f
,
0.0
f
);
}
else
if
(
output_shape
.
type
()
==
migraphx
::
shape
::
int32_type
)
{
arg_0
.
visit
([
&
](
auto
output
)
{
args
.
at
(
0
).
visit
([
&
](
auto
input
)
{
std
::
transform
(
input
.
begin
(),
input
.
end
(),
output
.
begin
(),
[
&
](
const
auto
x
)
{
return
static_cast
<
int32_t
>
(
x
);
});
});
});
migemm
(
result
,
arg_0
,
arg_1
,
int32_t
{
1
},
int32_t
{
0
});
arg_1
.
visit
([
&
](
auto
output
)
{
args
.
at
(
1
).
visit
([
&
](
auto
input
)
{
std
::
transform
(
input
.
begin
(),
input
.
end
(),
output
.
begin
(),
[
&
](
const
auto
x
)
{
return
static_cast
<
int32_t
>
(
x
);
});
});
});
migemm
(
result
,
arg_0
,
arg_1
,
int32_t
{
1
},
int32_t
{
0
});
}
return
result
;
return
result
;
}
}
...
...
test/verify/batch_quant_dot_1.cpp
View file @
d5fa82db
...
@@ -24,19 +24,23 @@
...
@@ -24,19 +24,23 @@
#include "verify_program.hpp"
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/program.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
struct
batch_quant_dot_1
:
verify_program
<
batch_quant_dot_1
>
template
<
typename
DType
,
typename
CType
>
struct
batch_quant_dot_1
:
verify_program
<
batch_quant_dot_1
<
DType
,
CType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
int8_type
,
{
3
,
2
,
8
,
2
}};
auto
dtype
=
migraphx
::
shape
::
get_type
<
DType
>
{};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
int8_type
,
{
3
,
2
,
7
,
8
}};
auto
ctype
=
migraphx
::
shape
::
get_type
<
CType
>
{};
migraphx
::
shape
m3_shape
{
migraphx
::
shape
::
int32_type
,
{
3
,
2
,
2
,
7
}};
migraphx
::
shape
m1_shape
{
dtype
,
{
3
,
2
,
8
,
2
}};
migraphx
::
shape
m2_shape
{
dtype
,
{
3
,
2
,
7
,
8
}};
migraphx
::
shape
m3_shape
{
ctype
,
{
3
,
2
,
2
,
7
}};
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
tl1
=
mm
->
add_instruction
(
auto
tl1
=
mm
->
add_instruction
(
...
@@ -45,7 +49,11 @@ struct batch_quant_dot_1 : verify_program<batch_quant_dot_1>
...
@@ -45,7 +49,11 @@ struct batch_quant_dot_1 : verify_program<batch_quant_dot_1>
auto
tl2
=
mm
->
add_instruction
(
auto
tl2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
1
,
3
,
2
}}}),
l2
);
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
1
,
3
,
2
}}}),
l2
);
auto
l3
=
mm
->
add_parameter
(
"c"
,
m3_shape
);
auto
l3
=
mm
->
add_parameter
(
"c"
,
m3_shape
);
migraphx
::
add_apply_alpha_beta
(
*
mm
,
{
tl1
,
tl2
,
l3
},
migraphx
::
make_op
(
"quant_dot"
),
3
,
2
);
migraphx
::
add_apply_alpha_beta
(
*
mm
,
{
tl1
,
tl2
,
l3
},
migraphx
::
make_op
(
"quant_dot"
),
CType
{
3
},
CType
{
2
});
return
p
;
return
p
;
}
}
};
};
template
struct
batch_quant_dot_1
<
int8_t
,
int32_t
>;
template
struct
batch_quant_dot_1
<
migraphx
::
fp8
::
fp8e4m3fnuz
,
float
>;
test/verify/batch_quant_dot_2.cpp
View file @
d5fa82db
...
@@ -28,15 +28,16 @@
...
@@ -28,15 +28,16 @@
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
struct
batch_quant_dot_2
:
verify_program
<
batch_quant_dot_2
>
template
<
migraphx
::
shape
::
type_t
DType
,
migraphx
::
shape
::
type_t
CType
>
struct
batch_quant_dot_2
:
verify_program
<
batch_quant_dot_2
<
DType
,
CType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
int8_t
ype
,
{
3
,
2
,
2
,
8
}};
migraphx
::
shape
m1_shape
{
DT
ype
,
{
3
,
2
,
2
,
8
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
int8_t
ype
,
{
3
,
2
,
8
,
7
}};
migraphx
::
shape
m2_shape
{
DT
ype
,
{
3
,
2
,
8
,
7
}};
migraphx
::
shape
m3_shape
{
migraphx
::
shape
::
int32_t
ype
,
{
3
,
2
,
2
,
7
}};
migraphx
::
shape
m3_shape
{
CT
ype
,
{
3
,
2
,
2
,
7
}};
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
...
@@ -45,3 +46,5 @@ struct batch_quant_dot_2 : verify_program<batch_quant_dot_2>
...
@@ -45,3 +46,5 @@ struct batch_quant_dot_2 : verify_program<batch_quant_dot_2>
return
p
;
return
p
;
}
}
};
};
template
struct
batch_quant_dot_2
<
migraphx
::
shape
::
int8_type
,
migraphx
::
shape
::
int32_type
>;
template
struct
batch_quant_dot_2
<
migraphx
::
shape
::
fp8e4m3fnuz_type
,
migraphx
::
shape
::
float_type
>;
test/verify/batch_quant_dot_3.cpp
View file @
d5fa82db
...
@@ -27,14 +27,15 @@
...
@@ -27,14 +27,15 @@
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
struct
batch_quant_dot_3
:
verify_program
<
batch_quant_dot_3
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
batch_quant_dot_3
:
verify_program
<
batch_quant_dot_3
<
DType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
int8_t
ype
,
{
3
,
2
,
2
,
6
}};
migraphx
::
shape
m1_shape
{
DT
ype
,
{
3
,
2
,
2
,
6
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
int8_t
ype
,
{
3
,
2
,
6
,
7
}};
migraphx
::
shape
m2_shape
{
DT
ype
,
{
3
,
2
,
6
,
7
}};
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
...
@@ -42,3 +43,5 @@ struct batch_quant_dot_3 : verify_program<batch_quant_dot_3>
...
@@ -42,3 +43,5 @@ struct batch_quant_dot_3 : verify_program<batch_quant_dot_3>
return
p
;
return
p
;
}
}
};
};
template
struct
batch_quant_dot_3
<
migraphx
::
shape
::
int8_type
>;
template
struct
batch_quant_dot_3
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
test/verify/batch_quant_dot_4.cpp
View file @
d5fa82db
...
@@ -27,14 +27,15 @@
...
@@ -27,14 +27,15 @@
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
struct
batch_quant_dot_4
:
verify_program
<
batch_quant_dot_4
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
batch_quant_dot_4
:
verify_program
<
batch_quant_dot_4
<
DType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
int8_t
ype
,
{
2
,
4
,
6
,
3
}};
migraphx
::
shape
m1_shape
{
DT
ype
,
{
2
,
4
,
6
,
3
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
int8_t
ype
,
{
7
,
2
,
6
,
3
}};
migraphx
::
shape
m2_shape
{
DT
ype
,
{
7
,
2
,
6
,
3
}};
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
...
@@ -46,3 +47,5 @@ struct batch_quant_dot_4 : verify_program<batch_quant_dot_4>
...
@@ -46,3 +47,5 @@ struct batch_quant_dot_4 : verify_program<batch_quant_dot_4>
return
p
;
return
p
;
}
}
};
};
template
struct
batch_quant_dot_4
<
migraphx
::
shape
::
int8_type
>;
template
struct
batch_quant_dot_4
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
test/verify/batch_quant_dot_5.cpp
View file @
d5fa82db
...
@@ -27,14 +27,15 @@
...
@@ -27,14 +27,15 @@
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
struct
batch_quant_dot_5
:
verify_program
<
batch_quant_dot_5
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
batch_quant_dot_5
:
verify_program
<
batch_quant_dot_5
<
DType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
int8_t
ype
,
{
3
,
2
,
7
,
2
}};
migraphx
::
shape
m1_shape
{
DT
ype
,
{
3
,
2
,
7
,
2
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
int8_t
ype
,
{
3
,
2
,
5
,
7
}};
migraphx
::
shape
m2_shape
{
DT
ype
,
{
3
,
2
,
5
,
7
}};
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
...
@@ -48,3 +49,5 @@ struct batch_quant_dot_5 : verify_program<batch_quant_dot_5>
...
@@ -48,3 +49,5 @@ struct batch_quant_dot_5 : verify_program<batch_quant_dot_5>
return
p
;
return
p
;
}
}
};
};
template
struct
batch_quant_dot_5
<
migraphx
::
shape
::
int8_type
>;
template
struct
batch_quant_dot_5
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
test/verify/quant_dot_3args_1.cpp
View file @
d5fa82db
...
@@ -25,23 +25,31 @@
...
@@ -25,23 +25,31 @@
#include "verify_program.hpp"
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/program.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
struct
quant_dot_3args_1
:
verify_program
<
quant_dot_3args_1
>
template
<
typename
DType
,
typename
CType
>
struct
quant_dot_3args_1
:
verify_program
<
quant_dot_3args_1
<
DType
,
CType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
int8_type
,
{
2
,
8
}};
auto
ctype
=
migraphx
::
shape
::
get_type
<
CType
>
();
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
int8_type
,
{
8
,
7
}};
auto
dtype
=
migraphx
::
shape
::
get_type
<
DType
>
();
migraphx
::
shape
m3_shape
{
migraphx
::
shape
::
int32_type
,
{
2
,
7
}};
migraphx
::
shape
m1_shape
{
dtype
,
{
2
,
8
}};
migraphx
::
shape
m2_shape
{
dtype
,
{
8
,
7
}};
migraphx
::
shape
m3_shape
{
ctype
,
{
2
,
7
}};
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
auto
l3
=
mm
->
add_parameter
(
"c"
,
m3_shape
);
auto
l3
=
mm
->
add_parameter
(
"c"
,
m3_shape
);
migraphx
::
add_apply_alpha_beta
(
*
mm
,
{
l1
,
l2
,
l3
},
migraphx
::
make_op
(
"quant_dot"
),
1
,
1
);
migraphx
::
add_apply_alpha_beta
(
*
mm
,
{
l1
,
l2
,
l3
},
migraphx
::
make_op
(
"quant_dot"
),
CType
{
1
},
CType
{
1
});
return
p
;
return
p
;
}
}
};
};
template
struct
quant_dot_3args_1
<
int8_t
,
int32_t
>;
template
struct
quant_dot_3args_1
<
migraphx
::
fp8
::
fp8e4m3fnuz
,
float
>;
test/verify/quant_dot_3args_2.cpp
View file @
d5fa82db
...
@@ -28,22 +28,29 @@
...
@@ -28,22 +28,29 @@
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
struct
quant_dot_3args_2
:
verify_program
<
quant_dot_3args_2
>
template
<
typename
DType
,
typename
CType
>
struct
quant_dot_3args_2
:
verify_program
<
quant_dot_3args_2
<
DType
,
CType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
int8_type
,
{
8
,
2
}};
auto
ctype
=
migraphx
::
shape
::
get_type
<
CType
>
();
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
int8_type
,
{
8
,
7
}};
auto
dtype
=
migraphx
::
shape
::
get_type
<
DType
>
();
migraphx
::
shape
m3_shape
{
migraphx
::
shape
::
int32_type
,
{
2
,
7
}};
migraphx
::
shape
m1_shape
{
dtype
,
{
8
,
2
}};
migraphx
::
shape
m2_shape
{
dtype
,
{
8
,
7
}};
migraphx
::
shape
m3_shape
{
ctype
,
{
2
,
7
}};
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
tl1
=
auto
tl1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
l1
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
l1
);
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
auto
l3
=
mm
->
add_parameter
(
"c"
,
m3_shape
);
auto
l3
=
mm
->
add_parameter
(
"c"
,
m3_shape
);
migraphx
::
add_apply_alpha_beta
(
*
mm
,
{
tl1
,
l2
,
l3
},
migraphx
::
make_op
(
"quant_dot"
),
1
,
3
);
migraphx
::
add_apply_alpha_beta
(
*
mm
,
{
tl1
,
l2
,
l3
},
migraphx
::
make_op
(
"quant_dot"
),
CType
{
1
},
CType
{
3
});
return
p
;
return
p
;
}
}
};
};
template
struct
quant_dot_3args_2
<
int8_t
,
int32_t
>;
template
struct
quant_dot_3args_2
<
migraphx
::
fp8
::
fp8e4m3fnuz
,
float
>;
test/verify/quant_dot_3args_3.cpp
View file @
d5fa82db
...
@@ -28,22 +28,28 @@
...
@@ -28,22 +28,28 @@
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
struct
quant_dot_3args_3
:
verify_program
<
quant_dot_3args_3
>
template
<
typename
DType
,
typename
CType
>
struct
quant_dot_3args_3
:
verify_program
<
quant_dot_3args_3
<
DType
,
CType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
int8_type
,
{
2
,
8
}};
auto
ctype
=
migraphx
::
shape
::
get_type
<
CType
>
();
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
int8_type
,
{
7
,
8
}};
auto
dtype
=
migraphx
::
shape
::
get_type
<
DType
>
();
migraphx
::
shape
m3_shape
{
migraphx
::
shape
::
int32_type
,
{
2
,
7
}};
migraphx
::
shape
m1_shape
{
dtype
,
{
2
,
8
}};
migraphx
::
shape
m2_shape
{
dtype
,
{
7
,
8
}};
migraphx
::
shape
m3_shape
{
ctype
,
{
2
,
7
}};
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
auto
tl2
=
auto
tl2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
l2
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
l2
);
auto
l3
=
mm
->
add_parameter
(
"c"
,
m3_shape
);
auto
l3
=
mm
->
add_parameter
(
"c"
,
m3_shape
);
migraphx
::
add_apply_alpha_beta
(
*
mm
,
{
l1
,
tl2
,
l3
},
migraphx
::
make_op
(
"quant_dot"
),
2
,
3
);
migraphx
::
add_apply_alpha_beta
(
*
mm
,
{
l1
,
tl2
,
l3
},
migraphx
::
make_op
(
"quant_dot"
),
CType
{
2
},
CType
{
3
});
return
p
;
return
p
;
}
}
};
};
template
struct
quant_dot_3args_3
<
int8_t
,
int32_t
>;
template
struct
quant_dot_3args_3
<
migraphx
::
fp8
::
fp8e4m3fnuz
,
float
>;
test/verify/quant_dot_3args_4.cpp
View file @
d5fa82db
...
@@ -28,15 +28,18 @@
...
@@ -28,15 +28,18 @@
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
struct
quant_dot_3args_4
:
verify_program
<
quant_dot_3args_4
>
template
<
typename
DType
,
typename
CType
>
struct
quant_dot_3args_4
:
verify_program
<
quant_dot_3args_4
<
DType
,
CType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
int8_type
,
{
8
,
2
}};
auto
ctype
=
migraphx
::
shape
::
get_type
<
CType
>
();
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
int8_type
,
{
7
,
8
}};
auto
dtype
=
migraphx
::
shape
::
get_type
<
DType
>
();
migraphx
::
shape
m3_shape
{
migraphx
::
shape
::
int32_type
,
{
2
,
7
}};
migraphx
::
shape
m1_shape
{
dtype
,
{
8
,
2
}};
migraphx
::
shape
m2_shape
{
dtype
,
{
7
,
8
}};
migraphx
::
shape
m3_shape
{
ctype
,
{
2
,
7
}};
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
tl1
=
auto
tl1
=
...
@@ -45,7 +48,11 @@ struct quant_dot_3args_4 : verify_program<quant_dot_3args_4>
...
@@ -45,7 +48,11 @@ struct quant_dot_3args_4 : verify_program<quant_dot_3args_4>
auto
tl2
=
auto
tl2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
l2
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
l2
);
auto
l3
=
mm
->
add_parameter
(
"c"
,
m3_shape
);
auto
l3
=
mm
->
add_parameter
(
"c"
,
m3_shape
);
migraphx
::
add_apply_alpha_beta
(
*
mm
,
{
tl1
,
tl2
,
l3
},
migraphx
::
make_op
(
"quant_dot"
),
3
,
2
);
migraphx
::
add_apply_alpha_beta
(
*
mm
,
{
tl1
,
tl2
,
l3
},
migraphx
::
make_op
(
"quant_dot"
),
CType
{
3
},
CType
{
2
});
return
p
;
return
p
;
}
}
};
};
template
struct
quant_dot_3args_4
<
int8_t
,
int32_t
>;
template
struct
quant_dot_3args_4
<
migraphx
::
fp8
::
fp8e4m3fnuz
,
float
>;
test/verify/quant_dot_3args_5.cpp
View file @
d5fa82db
...
@@ -28,14 +28,17 @@
...
@@ -28,14 +28,17 @@
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
struct
quant_dot_3args_5
:
verify_program
<
quant_dot_3args_5
>
template
<
typename
DType
,
typename
CType
>
struct
quant_dot_3args_5
:
verify_program
<
quant_dot_3args_5
<
DType
,
CType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
int8_type
,
{
6
,
2
}};
auto
dtype
=
migraphx
::
shape
::
get_type
<
DType
>
();
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
int8_type
,
{
7
,
6
}};
migraphx
::
shape
m1_shape
{
dtype
,
{
6
,
2
}};
migraphx
::
shape
m2_shape
{
dtype
,
{
7
,
6
}};
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
tl1
=
auto
tl1
=
...
@@ -43,7 +46,10 @@ struct quant_dot_3args_5 : verify_program<quant_dot_3args_5>
...
@@ -43,7 +46,10 @@ struct quant_dot_3args_5 : verify_program<quant_dot_3args_5>
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
auto
tl2
=
auto
tl2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
l2
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
l2
);
migraphx
::
add_apply_alpha_beta
(
*
mm
,
{
tl1
,
tl2
},
migraphx
::
make_op
(
"quant_dot"
),
3
);
migraphx
::
add_apply_alpha_beta
(
*
mm
,
{
tl1
,
tl2
},
migraphx
::
make_op
(
"quant_dot"
),
CType
{
3
}
);
return
p
;
return
p
;
}
}
};
};
template
struct
quant_dot_3args_5
<
int8_t
,
int32_t
>;
template
struct
quant_dot_3args_5
<
migraphx
::
fp8
::
fp8e4m3fnuz
,
float
>;
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment