Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
f919cb7e
Commit
f919cb7e
authored
Jul 18, 2022
by
Ted Themistokleous
Browse files
Work in progress. Adding in divzero instruction related things
Conversion works, just issues with predicate right now.
parent
558ca0fe
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
72 additions
and
8 deletions
+72
-8
src/CMakeLists.txt
src/CMakeLists.txt
+1
-0
src/module.cpp
src/module.cpp
+11
-3
src/program.cpp
src/program.cpp
+2
-0
src/simplify_algebra.cpp
src/simplify_algebra.cpp
+1
-1
test/ref_ops_test.cpp
test/ref_ops_test.cpp
+52
-0
test/simplify_algebra_test.cpp
test/simplify_algebra_test.cpp
+5
-4
No files found.
src/CMakeLists.txt
View file @
f919cb7e
...
@@ -126,6 +126,7 @@ register_migraphx_ops(
...
@@ -126,6 +126,7 @@ register_migraphx_ops(
deconvolution
deconvolution
dequantizelinear
dequantizelinear
div
div
divzero
dot
dot
elu
elu
equal
equal
...
...
src/module.cpp
View file @
f919cb7e
...
@@ -630,24 +630,32 @@ instruction_ref module::find_dangling_reference() const
...
@@ -630,24 +630,32 @@ instruction_ref module::find_dangling_reference() const
bool
is_div_zero
(
instruction_ref
ins
)
bool
is_div_zero
(
instruction_ref
ins
)
{
{
const
auto
&
op
=
instruction
::
get_output_alias
(
ins
)
->
name
();
const
auto
&
op
=
instruction
::
get_output_alias
(
ins
)
->
get_operator
();
return
op
==
"@divzero"
;
std
::
cout
<<
op
.
name
()
<<
std
::
endl
;
return
op
.
name
().
find
(
"divzero"
)
!=
std
::
string
::
npos
;
}
}
instruction_ref
module
::
find_division_by_zero
()
const
instruction_ref
module
::
find_division_by_zero
()
const
{
{
std
::
cout
<<
"start search"
<<
std
::
endl
;
auto
last
=
std
::
prev
(
end
());
auto
last
=
std
::
prev
(
end
());
if
(
last
->
name
()
==
"
@
divzero"
)
if
(
last
->
name
()
==
"divzero"
)
{
{
std
::
cout
<<
"search"
<<
std
::
endl
;
auto
div_zero
=
std
::
find_if
(
auto
div_zero
=
std
::
find_if
(
last
->
inputs
().
begin
(),
last
->
inputs
().
end
(),
[](
auto
x
)
{
return
is_div_zero
(
x
);
});
last
->
inputs
().
begin
(),
last
->
inputs
().
end
(),
[](
auto
x
)
{
return
is_div_zero
(
x
);
});
if
(
div_zero
!=
last
->
inputs
().
end
())
if
(
div_zero
!=
last
->
inputs
().
end
())
{
std
::
cout
<<
"found divzero"
<<
std
::
endl
;
return
*
div_zero
;
return
*
div_zero
;
}
}
}
else
if
(
is_div_zero
(
last
))
else
if
(
is_div_zero
(
last
))
{
{
std
::
cout
<<
"check last ref"
<<
std
::
endl
;
return
last
;
return
last
;
}
}
std
::
cout
<<
"End ref"
<<
std
::
endl
;
return
end
();
return
end
();
}
}
...
...
src/program.cpp
View file @
f919cb7e
...
@@ -195,6 +195,8 @@ void program::compile(const target& t, compile_options options)
...
@@ -195,6 +195,8 @@ void program::compile(const target& t, compile_options options)
std
::
to_string
(
index
));
std
::
to_string
(
index
));
}
}
std
::
cout
<<
"find div by zero"
<<
std
::
endl
;
std
::
cout
<<
*
mod
<<
std
::
endl
;
auto
divide_by_zero
=
mod
->
find_division_by_zero
();
auto
divide_by_zero
=
mod
->
find_division_by_zero
();
if
(
divide_by_zero
!=
mod
->
end
())
if
(
divide_by_zero
!=
mod
->
end
())
{
{
...
...
src/simplify_algebra.cpp
View file @
f919cb7e
...
@@ -862,7 +862,7 @@ struct find_zero_div_const
...
@@ -862,7 +862,7 @@ struct find_zero_div_const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
{
auto
ins
=
r
.
result
;
auto
ins
=
r
.
result
;
m
.
replace_instruction
(
ins
,
make_op
(
"divzero"
));
m
.
replace_instruction
(
ins
,
make_op
(
"divzero"
)
,
ins
->
inputs
()
);
}
}
};
};
...
...
test/ref_ops_test.cpp
View file @
f919cb7e
...
@@ -37,6 +37,10 @@
...
@@ -37,6 +37,10 @@
#include <migraphx/onnx.hpp>
#include <migraphx/onnx.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/serialize.hpp>
#include "test.hpp"
#include "test.hpp"
...
@@ -47,6 +51,11 @@ float sigmoid(float x) { return 1 / (1 + expf(-x)); }
...
@@ -47,6 +51,11 @@ float sigmoid(float x) { return 1 / (1 + expf(-x)); }
float
elu
(
float
a
,
float
x
)
{
return
x
>
0
?
x
:
a
*
std
::
expm1
(
x
);
}
float
elu
(
float
a
,
float
x
)
{
return
x
>
0
?
x
:
a
*
std
::
expm1
(
x
);
}
void
run_pass
(
migraphx
::
module
&
m
)
{
migraphx
::
run_passes
(
m
,
{
migraphx
::
simplify_algebra
{},
migraphx
::
dead_code_elimination
{}});
}
TEST_CASE
(
abs_test
)
TEST_CASE
(
abs_test
)
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
...
@@ -1330,6 +1339,49 @@ TEST_CASE(div_test)
...
@@ -1330,6 +1339,49 @@ TEST_CASE(div_test)
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
gold
));
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
gold
));
}
}
TEST_CASE
(
div_zero_compile_trap_after_no_passes
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
zero
=
mm
->
add_literal
(
0
);
auto
x
=
mm
->
add_parameter
(
"x"
,
{
migraphx
::
shape
::
int32_type
,
{
1
}});
mm
->
add_instruction
(
migraphx
::
make_op
(
"divzero"
),
x
,
zero
);
bool
result
=
false
;
try
{
p
.
compile
(
migraphx
::
ref
::
target
{});
}
catch
(
const
std
::
runtime_error
&
e
)
{
(
void
)
e
;
result
=
true
;
}
EXPECT
(
result
);
}
TEST_CASE
(
div_zero_compile_trap_after_passes
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
zero
=
mm
->
add_literal
(
0
);
auto
x
=
mm
->
add_parameter
(
"x"
,
{
migraphx
::
shape
::
int32_type
,
{
1
}});
mm
->
add_instruction
(
migraphx
::
make_op
(
"div"
),
x
,
zero
);
run_pass
(
*
mm
);
bool
result
=
false
;
try
{
p
.
compile
(
migraphx
::
ref
::
target
{});
}
catch
(
const
std
::
runtime_error
&
e
)
{
(
void
)
e
;
result
=
true
;
}
EXPECT
(
result
);
}
TEST_CASE
(
elu_test
)
TEST_CASE
(
elu_test
)
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
...
...
test/simplify_algebra_test.cpp
View file @
f919cb7e
...
@@ -1101,15 +1101,16 @@ TEST_CASE(simplify_div_zero_const)
...
@@ -1101,15 +1101,16 @@ TEST_CASE(simplify_div_zero_const)
migraphx
::
module
m1
;
migraphx
::
module
m1
;
{
{
auto
x
=
m1
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
int32_type
,
{
1
}});
auto
x
=
m1
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
int32_type
,
{
1
}});
auto
unit
=
m1
.
add_literal
(
0
);
auto
zero
=
m1
.
add_literal
(
0
);
m1
.
add_instruction
(
migraphx
::
make_op
(
"div"
),
x
,
unit
);
m1
.
add_instruction
(
migraphx
::
make_op
(
"div"
),
x
,
zero
);
}
}
run_pass
(
m1
);
migraphx
::
module
m2
;
migraphx
::
module
m2
;
{
{
auto
x
=
m2
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
int32_type
,
{
1
}});
auto
x
=
m2
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
int32_type
,
{
1
}});
auto
unit
=
m
1
.
add_literal
(
0
);
auto
zero
=
m
2
.
add_literal
(
0
);
m
1
.
add_instruction
(
migraphx
::
make_op
(
"divzero"
),
x
,
unit
);
m
2
.
add_instruction
(
migraphx
::
make_op
(
"divzero"
),
x
,
zero
);
}
}
EXPECT
(
m1
==
m2
);
EXPECT
(
m1
==
m2
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment