Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
5d240fa4
Commit
5d240fa4
authored
Dec 05, 2023
by
Umang Yadav
Browse files
quant dot
parent
7772a428
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
13 additions
and
8 deletions
+13
-8
src/targets/gpu/gemm_impl.cpp
src/targets/gpu/gemm_impl.cpp
+1
-1
test/verify/batch_quant_dot_2.cpp
test/verify/batch_quant_dot_2.cpp
+12
-7
No files found.
src/targets/gpu/gemm_impl.cpp
View file @
5d240fa4
...
@@ -195,7 +195,7 @@ struct gemm_impl
...
@@ -195,7 +195,7 @@ struct gemm_impl
ldd
=
is_3inputs
?
input_shapes
[
3
].
strides
()[
dim_0
]
:
ldc
;
ldd
=
is_3inputs
?
input_shapes
[
3
].
strides
()[
dim_0
]
:
ldc
;
arg_type
=
get_type
(
input_shapes
[
0
].
type
());
arg_type
=
get_type
(
input_shapes
[
0
].
type
());
output_type
=
arg
_type
;
output_type
=
get
_type
(
input_shapes
[
2
].
type
())
;
if
(
output_type
==
rocblas_datatype_i8_r
)
if
(
output_type
==
rocblas_datatype_i8_r
)
{
{
output_type
=
rocblas_datatype_i32_r
;
output_type
=
rocblas_datatype_i32_r
;
...
...
test/verify/batch_quant_dot_2.cpp
View file @
5d240fa4
...
@@ -25,26 +25,31 @@
...
@@ -25,26 +25,31 @@
#include "verify_program.hpp"
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/program.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
template
<
migraphx
::
shape
::
type_t
DType
,
migraphx
::
shape
::
type_t
CType
>
template
<
typename
DType
,
typename
CType
>
struct
batch_quant_dot_2
:
verify_program
<
batch_quant_dot_2
<
DType
,
CType
>>
struct
batch_quant_dot_2
:
verify_program
<
batch_quant_dot_2
<
DType
,
CType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
DType
,
{
3
,
2
,
2
,
8
}};
auto
dtype
=
migraphx
::
shape
::
get_type
<
DType
>
{};
migraphx
::
shape
m2_shape
{
DType
,
{
3
,
2
,
8
,
7
}};
auto
ctype
=
migraphx
::
shape
::
get_type
<
CType
>
{};
migraphx
::
shape
m3_shape
{
CType
,
{
3
,
2
,
2
,
7
}};
migraphx
::
shape
m1_shape
{
dtype
,
{
3
,
2
,
2
,
8
}};
migraphx
::
shape
m2_shape
{
dtype
,
{
3
,
2
,
8
,
7
}};
migraphx
::
shape
m3_shape
{
ctype
,
{
3
,
2
,
2
,
7
}};
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
auto
l3
=
mm
->
add_parameter
(
"c"
,
m3_shape
);
auto
l3
=
mm
->
add_parameter
(
"c"
,
m3_shape
);
migraphx
::
add_apply_alpha_beta
(
*
mm
,
{
l1
,
l2
,
l3
},
migraphx
::
make_op
(
"quant_dot"
),
1
,
3
);
migraphx
::
add_apply_alpha_beta
(
*
mm
,
{
l1
,
l2
,
l3
},
migraphx
::
make_op
(
"quant_dot"
),
CType
{
1
},
CType
{
3
});
return
p
;
return
p
;
}
}
};
};
template
struct
batch_quant_dot_2
<
migraphx
::
shape
::
int8_type
,
migraphx
::
shape
::
int32_t
ype
>;
template
struct
batch_quant_dot_2
<
int8_t
,
int32_t
>;
template
struct
batch_quant_dot_2
<
migraphx
::
shape
::
fp8e4m3fnuz
_type
,
migraphx
::
shape
::
float_type
>;
template
struct
batch_quant_dot_2
<
migraphx
::
fp8
::
fp8e4m3fnuz
,
float
>;
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment