Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
c52d1f69
Commit
c52d1f69
authored
Dec 06, 2023
by
Umang Yadav
Browse files
revert changes for nested converts
parent
af2ffd63
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
54 additions
and
44 deletions
+54
-44
src/eliminate_data_type.cpp
src/eliminate_data_type.cpp
+0
-16
src/targets/ref/lowering.cpp
src/targets/ref/lowering.cpp
+38
-12
src/targets/ref/target.cpp
src/targets/ref/target.cpp
+11
-15
test/verify/main.cpp
test/verify/main.cpp
+5
-1
No files found.
src/eliminate_data_type.cpp
View file @
c52d1f69
...
...
@@ -120,22 +120,6 @@ void eliminate_data_type::apply(module& m) const
if
(
contains
(
unsupported_ops
,
"all"
)
or
contains
(
unsupported_ops
,
ins
->
name
()))
insert_convert_to_supported_type
(
m
,
ins
,
target_type
,
unsupported_types
);
}
// remove nested converts
for
(
auto
ins
:
iterator_for
(
m
))
{
if
(
ins
->
name
()
==
"convert"
)
{
auto
convert_input
=
ins
->
inputs
().
front
();
while
(
convert_input
->
name
()
==
"convert"
)
{
convert_input
=
convert_input
->
inputs
().
front
();
}
if
(
convert_input
->
get_shape
()
==
ins
->
get_shape
())
{
m
.
replace_instruction
(
ins
,
convert_input
);
}
}
}
}
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/targets/ref/lowering.cpp
View file @
c52d1f69
...
...
@@ -307,20 +307,46 @@ struct ref_quant_gemm
{
argument
result
{
output_shape
};
// first, convert the args[0] and args[1] from int8_t to int32_t
argument
arg_0
{{
shape
::
int32_
type
,
{
args
.
at
(
0
).
get_shape
().
lens
()}}};
argument
arg_1
{{
shape
::
int32_
type
,
{
args
.
at
(
1
).
get_shape
().
lens
()}}};
arg_0
.
visit
([
&
](
auto
output
)
{
args
.
at
(
0
).
visit
(
[
&
](
auto
input
)
{
std
::
copy
(
input
.
begin
(),
input
.
end
(),
output
.
begin
());
});
});
arg_1
.
visit
([
&
](
auto
output
)
{
args
.
at
(
1
).
visit
(
[
&
](
auto
input
)
{
std
::
copy
(
input
.
begin
(),
input
.
end
(),
output
.
begin
());
});
});
argument
arg_0
{{
output_shape
.
type
()
,
{
args
.
at
(
0
).
get_shape
().
lens
()}}};
argument
arg_1
{{
output_shape
.
type
()
,
{
args
.
at
(
1
).
get_shape
().
lens
()}}};
if
(
output_shape
.
type
()
==
migraphx
::
shape
::
float_type
)
{
arg_0
.
visit
([
&
](
auto
output
)
{
args
.
at
(
0
).
visit
([
&
](
auto
input
)
{
std
::
transform
(
input
.
begin
(),
input
.
end
(),
output
.
begin
(),
[
&
](
const
auto
x
)
{
return
static_cast
<
float
>
(
x
);
});
});
});
migemm
(
result
,
arg_0
,
arg_1
,
int32_t
{
1
},
int32_t
{
0
});
arg_1
.
visit
([
&
](
auto
output
)
{
args
.
at
(
1
).
visit
([
&
](
auto
input
)
{
std
::
transform
(
input
.
begin
(),
input
.
end
(),
output
.
begin
(),
[
&
](
const
auto
x
)
{
return
static_cast
<
float
>
(
x
);
});
});
});
migemm
(
result
,
arg_0
,
arg_1
,
1.0
f
,
0.0
f
);
}
else
if
(
output_shape
.
type
()
==
migraphx
::
shape
::
int32_type
)
{
arg_0
.
visit
([
&
](
auto
output
)
{
args
.
at
(
0
).
visit
([
&
](
auto
input
)
{
std
::
transform
(
input
.
begin
(),
input
.
end
(),
output
.
begin
(),
[
&
](
const
auto
x
)
{
return
static_cast
<
int32_t
>
(
x
);
});
});
});
arg_1
.
visit
([
&
](
auto
output
)
{
args
.
at
(
1
).
visit
([
&
](
auto
input
)
{
std
::
transform
(
input
.
begin
(),
input
.
end
(),
output
.
begin
(),
[
&
](
const
auto
x
)
{
return
static_cast
<
int32_t
>
(
x
);
});
});
});
migemm
(
result
,
arg_0
,
arg_1
,
int32_t
{
1
},
int32_t
{
0
});
}
return
result
;
}
};
...
...
src/targets/ref/target.cpp
View file @
c52d1f69
...
...
@@ -24,7 +24,6 @@
#include <migraphx/ref/target.hpp>
#include <migraphx/ref/lowering.hpp>
#include <migraphx/eliminate_data_type.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/pass.hpp>
#include <migraphx/auto_contiguous.hpp>
...
...
@@ -43,20 +42,17 @@ std::string target::name() const { return "ref"; }
std
::
vector
<
pass
>
target
::
get_passes
(
migraphx
::
context
&
,
const
compile_options
&
)
const
{
return
{
normalize_ops
{},
eliminate_pad
{},
dead_code_elimination
{},
eliminate_data_type
{{
migraphx
::
shape
::
fp8e4m3fnuz_type
},
shape
::
float_type
,
{
"quant_dot"
}},
dead_code_elimination
{},
insert_pad
{},
dead_code_elimination
{},
rewrite_rnn
{},
dead_code_elimination
{},
auto_contiguous
{},
dead_code_elimination
{},
lowering
{},
dead_code_elimination
{}};
return
{
normalize_ops
{},
eliminate_pad
{},
dead_code_elimination
{},
insert_pad
{},
dead_code_elimination
{},
rewrite_rnn
{},
dead_code_elimination
{},
auto_contiguous
{},
dead_code_elimination
{},
lowering
{},
dead_code_elimination
{}};
}
argument
target
::
allocate
(
const
shape
&
s
)
const
{
return
fill_argument
(
s
,
0
);
}
...
...
test/verify/main.cpp
View file @
c52d1f69
...
...
@@ -76,7 +76,11 @@ int main(int argc, const char* argv[])
"test_select_module_conv"
,
"test_split_single_dyn_dim"
,
"test_instancenorm_large_3d<migraphx::shape::float_type>"
,
"test_instancenorm_large_3d<migraphx::shape::half_type>"
});
"test_instancenorm_large_3d<migraphx::shape::half_type>"
,
// these tests are disabled due issue of lossy downcast, see issue#2517
"batch_quant_dot_1<migraphx::fp8::fp8e4m3fnuz, float>"
,
"quant_dot_3args_4<migraphx::fp8::fp8e4m3fnuz, float>"
,
"quant_dot_3args_5<migraphx::fp8::fp8e4m3fnuz, float>"
});
rv
.
disable_test_for
(
"gpu"
,
{
"test_conv_bn_add"
});
rv
.
run
(
argc
,
argv
);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment