Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
050184cb
Commit
050184cb
authored
Dec 03, 2023
by
Umang Yadav
Browse files
revert some changes
parent
3f213325
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
21 additions
and
57 deletions
+21
-57
src/include/migraphx/op/quant_dot.hpp
src/include/migraphx/op/quant_dot.hpp
+2
-7
src/simplify_reshapes.cpp
src/simplify_reshapes.cpp
+0
-5
src/targets/gpu/fuse_ck.cpp
src/targets/gpu/fuse_ck.cpp
+1
-2
src/targets/gpu/gemm_impl.cpp
src/targets/gpu/gemm_impl.cpp
+6
-3
src/targets/gpu/include/migraphx/gpu/gemm.hpp
src/targets/gpu/include/migraphx/gpu/gemm.hpp
+1
-1
src/targets/ref/lowering.cpp
src/targets/ref/lowering.cpp
+11
-39
No files found.
src/include/migraphx/op/quant_dot.hpp
View file @
050184cb
...
...
@@ -44,10 +44,9 @@ struct quant_dot
const
shape
&
a
=
inputs
.
at
(
0
);
const
shape
&
b
=
inputs
.
at
(
1
);
auto
t
=
a
.
type
();
std
::
set
<
migraphx
::
shape
::
type_t
>
suppported_types
=
{
shape
::
int8_type
,
shape
::
fp8e4m3fnuz_type
};
if
(
not
contains
(
suppported_types
,
t
))
if
(
t
!=
shape
::
int8_type
)
{
MIGRAPHX_THROW
(
"QUANT_DOT: only support data type int8_t
and fp8e4m3fnuz_type
"
);
MIGRAPHX_THROW
(
"QUANT_DOT: only support data type int8_t"
);
}
if
(
not
std
::
all_of
(
...
...
@@ -74,10 +73,6 @@ struct quant_dot
auto
out_lens
=
a
.
lens
();
out_lens
[
dim_1
]
=
b
.
lens
()[
dim_1
];
if
(
t
==
shape
::
fp8e4m3fnuz_type
)
{
return
{
shape
::
float_type
,
out_lens
};
}
// else int8 gemm
return
{
shape
::
int32_type
,
out_lens
};
}
};
...
...
src/simplify_reshapes.cpp
View file @
050184cb
...
...
@@ -183,11 +183,6 @@ struct find_nested_convert
auto
x
=
ins
->
inputs
().
front
();
auto
input
=
x
->
inputs
().
front
();
while
(
input
->
name
()
==
"convert"
)
{
input
=
input
->
inputs
().
front
();
}
if
(
ins
->
get_shape
()
!=
input
->
get_shape
())
return
;
...
...
src/targets/gpu/fuse_ck.cpp
View file @
050184cb
...
...
@@ -69,8 +69,7 @@ struct ck_gemm
static
bool
is_ck_supported_type
(
shape
::
type_t
t
)
{
return
contains
(
{
shape
::
half_type
,
shape
::
int8_type
,
shape
::
int32_type
,
shape
::
fp8e4m3fnuz_type
},
t
);
return
contains
({
shape
::
half_type
,
shape
::
int8_type
,
shape
::
int32_type
},
t
);
}
};
MIGRAPHX_REGISTER_OP
(
ck_gemm
);
...
...
src/targets/gpu/gemm_impl.cpp
View file @
050184cb
...
...
@@ -180,9 +180,12 @@ struct gemm_impl
ldd
=
is_3inputs
?
input_shapes
[
3
].
strides
()[
dim_0
]
:
ldc
;
arg_type
=
get_type
(
input_shapes
[
0
].
type
());
output_type
=
get_type
(
input_shapes
[
2
].
type
());
compute_type
=
output_type
;
// not valid for ex3 BETA APIs. it has different type and set differently.
output_type
=
arg_type
;
if
(
output_type
==
rocblas_datatype_i8_r
)
{
output_type
=
rocblas_datatype_i32_r
;
}
compute_type
=
output_type
;
if
(
compute_fp32
)
{
if
(
arg_type
==
rocblas_datatype_f16_r
)
...
...
src/targets/gpu/include/migraphx/gpu/gemm.hpp
View file @
050184cb
...
...
@@ -112,7 +112,7 @@ struct rocblas_gemm
argument
compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
)
const
{
if
(
this
->
name
()
==
"gpu::gemm"
or
output_shape
.
type
()
==
migraphx
::
shape
::
float_type
)
if
(
this
->
name
()
==
"gpu::gemm"
)
{
gemm_compute
(
ctx
,
output_shape
,
args
,
alpha
,
beta
,
compute_fp32
,
solution_idx
);
}
...
...
src/targets/ref/lowering.cpp
View file @
050184cb
...
...
@@ -24,7 +24,6 @@
#include <migraphx/ref/lowering.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/op/identity.hpp>
#include <migraphx/op/convolution.hpp>
...
...
@@ -308,46 +307,19 @@ struct ref_quant_gemm
{
argument
result
{
output_shape
};
// first, convert the args[0] and args[1] from int8_t to int32_t
argument
arg_0
{{
output_shape
.
type
(),
{
args
.
at
(
0
).
get_shape
().
lens
()}}};
argument
arg_1
{{
output_shape
.
type
(),
{
args
.
at
(
1
).
get_shape
().
lens
()}}};
if
(
output_shape
.
type
()
==
migraphx
::
shape
::
float_type
)
{
argument
arg_0
{{
shape
::
int32_type
,
{
args
.
at
(
0
).
get_shape
().
lens
()}}};
argument
arg_1
{{
shape
::
int32_type
,
{
args
.
at
(
1
).
get_shape
().
lens
()}}};
arg_0
.
visit
([
&
](
auto
output
)
{
args
.
at
(
0
).
visit
([
&
](
auto
input
)
{
std
::
transform
(
input
.
begin
(),
input
.
end
(),
output
.
begin
(),
[
&
](
const
auto
x
)
{
return
static_cast
<
float
>
(
x
);
});
});
args
.
at
(
0
).
visit
(
[
&
](
auto
input
)
{
std
::
copy
(
input
.
begin
(),
input
.
end
(),
output
.
begin
());
});
});
arg_1
.
visit
([
&
](
auto
output
)
{
args
.
at
(
1
).
visit
([
&
](
auto
input
)
{
std
::
transform
(
input
.
begin
(),
input
.
end
(),
output
.
begin
(),
[
&
](
const
auto
x
)
{
return
static_cast
<
float
>
(
x
);
});
});
});
migemm
(
result
,
arg_0
,
arg_1
,
1.0
f
,
0.0
f
);
}
else
if
(
output_shape
.
type
()
==
migraphx
::
shape
::
int32_type
)
{
arg_0
.
visit
([
&
](
auto
output
)
{
args
.
at
(
0
).
visit
([
&
](
auto
input
)
{
std
::
transform
(
input
.
begin
(),
input
.
end
(),
output
.
begin
(),
[
&
](
const
auto
x
)
{
return
static_cast
<
int32_t
>
(
x
);
});
});
args
.
at
(
1
).
visit
(
[
&
](
auto
input
)
{
std
::
copy
(
input
.
begin
(),
input
.
end
(),
output
.
begin
());
});
});
arg_1
.
visit
([
&
](
auto
output
)
{
args
.
at
(
1
).
visit
([
&
](
auto
input
)
{
std
::
transform
(
input
.
begin
(),
input
.
end
(),
output
.
begin
(),
[
&
](
const
auto
x
)
{
return
static_cast
<
int32_t
>
(
x
);
});
});
});
migemm
(
result
,
arg_0
,
arg_1
,
int32_t
{
1
},
int32_t
{
0
});
}
return
result
;
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment