Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
e7ec015d
Commit
e7ec015d
authored
Jul 21, 2022
by
Ted Themistokleous
Browse files
fixup! Make divzero a builtin instead of op
parent
a159fde1
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
29 additions
and
88 deletions
+29
-88
src/include/migraphx/builtin.hpp
src/include/migraphx/builtin.hpp
+1
-22
src/include/migraphx/module.hpp
src/include/migraphx/module.hpp
+0
-2
src/module.cpp
src/module.cpp
+0
-10
src/simplify_algebra.cpp
src/simplify_algebra.cpp
+1
-1
test/ref_ops_test.cpp
test/ref_ops_test.cpp
+2
-51
test/simplify_algebra_test.cpp
test/simplify_algebra_test.cpp
+25
-2
No files found.
src/include/migraphx/builtin.hpp
View file @
e7ec015d
...
@@ -100,28 +100,7 @@ struct returns
...
@@ -100,28 +100,7 @@ struct returns
struct
divzero
struct
divzero
{
{
std
::
string
name
()
const
{
return
"@divzero"
;
}
std
::
string
name
()
const
{
return
"@divzero"
;
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
shape
compute_shape
(
const
std
::
vector
<
shape
>&
)
const
{
return
{};
}
{
// taken from the binary.hpp. We're replacing op so don't need the check
// check_shapes{inputs, static_cast<const Derived&>(*this)}.has(2).same_type().same_dims();
auto
s0
=
inputs
.
at
(
0
);
auto
s1
=
inputs
.
at
(
1
);
if
(
s0
==
s1
and
s0
.
packed
())
{
return
s0
;
}
else
if
(
s0
.
packed
()
!=
s1
.
packed
())
{
return
s0
.
packed
()
?
s0
:
s1
;
}
else
if
(
s0
.
broadcasted
()
!=
s1
.
broadcasted
())
{
return
s0
.
broadcasted
()
?
s1
.
with_lens
(
s0
.
lens
())
:
s0
.
with_lens
(
s0
.
lens
());
}
else
{
return
{
s0
.
type
(),
s0
.
lens
()};
}
}
argument
compute
(
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
const
argument
compute
(
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
const
{
{
MIGRAPHX_THROW
(
"builtin"
);
MIGRAPHX_THROW
(
"builtin"
);
...
...
src/include/migraphx/module.hpp
View file @
e7ec015d
...
@@ -166,8 +166,6 @@ struct module
...
@@ -166,8 +166,6 @@ struct module
instruction_ref
add_divzero
(
std
::
vector
<
instruction_ref
>
args
);
instruction_ref
add_divzero
(
std
::
vector
<
instruction_ref
>
args
);
instruction_ref
replace_divzero
(
instruction_ref
ins
,
std
::
vector
<
instruction_ref
>
args
);
std
::
vector
<
std
::
string
>
get_parameter_names
()
const
;
std
::
vector
<
std
::
string
>
get_parameter_names
()
const
;
shape
get_parameter_shape
(
std
::
string
name
)
const
;
shape
get_parameter_shape
(
std
::
string
name
)
const
;
...
...
src/module.cpp
View file @
e7ec015d
...
@@ -486,16 +486,6 @@ instruction_ref module::add_divzero(std::vector<instruction_ref> args)
...
@@ -486,16 +486,6 @@ instruction_ref module::add_divzero(std::vector<instruction_ref> args)
auto
result
=
std
::
prev
(
impl
->
instructions
.
end
());
auto
result
=
std
::
prev
(
impl
->
instructions
.
end
());
instruction
::
backreference
(
result
);
instruction
::
backreference
(
result
);
assert
(
result
->
valid
(
begin
()));
assert
(
result
->
valid
(
begin
()));
return
result
;
}
instruction_ref
module
::
replace_divzero
(
instruction_ref
ins
,
std
::
vector
<
instruction_ref
>
args
)
MIGRAPHX_TIDY_CONST
{
auto
prev
=
std
::
prev
(
ins
);
shape
r
=
compute_shape
(
prev
->
get_operator
(),
args
);
auto
result
=
instruction
::
replace
(
builtin
::
divzero
{},
ins
->
get_operator
(),
r
,
std
::
move
(
args
));
return
result
;
return
result
;
}
}
...
...
src/simplify_algebra.cpp
View file @
e7ec015d
...
@@ -863,7 +863,7 @@ struct find_zero_div_const
...
@@ -863,7 +863,7 @@ struct find_zero_div_const
auto
ins
=
r
.
result
;
auto
ins
=
r
.
result
;
auto
c_ins
=
r
.
instructions
[
"c"
];
auto
c_ins
=
r
.
instructions
[
"c"
];
m
.
replace
_divzero
(
c_
ins
,
ins
->
inputs
()
);
m
.
add
_divzero
(
{
ins
,
c_ins
}
);
}
}
};
};
...
...
test/ref_ops_test.cpp
View file @
e7ec015d
...
@@ -1339,60 +1339,11 @@ TEST_CASE(div_test)
...
@@ -1339,60 +1339,11 @@ TEST_CASE(div_test)
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
gold
));
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
gold
));
}
}
TEST_CASE
(
div_zero_compile_trap_after_no_passes
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
zero
=
mm
->
add_literal
(
0
);
auto
x
=
mm
->
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
1
}});
mm
->
add_divzero
({
x
,
zero
});
bool
result
=
false
;
try
{
p
.
compile
(
migraphx
::
ref
::
target
{});
}
catch
(
const
std
::
runtime_error
&
e
)
{
(
void
)
e
;
result
=
true
;
}
EXPECT
(
result
);
}
TEST_CASE
(
div_zero_compile_trap_long_program_no_passes
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
zero
=
mm
->
add_literal
(
0.0
f
);
auto
one
=
mm
->
add_literal
(
1.0
f
);
auto
x
=
mm
->
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
1
}});
auto
y
=
mm
->
add_parameter
(
"y"
,
{
migraphx
::
shape
::
float_type
,
{
1
}});
auto
div0
=
mm
->
add_divzero
({
x
,
zero
});
std
::
cout
<<
*
mm
<<
std
::
endl
;
auto
mul
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
one
,
div0
);
auto
add
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
y
,
mul
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"sub"
),
y
,
add
);
bool
result
=
false
;
try
{
p
.
compile
(
migraphx
::
ref
::
target
{});
}
catch
(
const
std
::
runtime_error
&
e
)
{
(
void
)
e
;
result
=
true
;
}
EXPECT
(
result
);
}
TEST_CASE
(
div_zero_compile_trap_after_passes
)
TEST_CASE
(
div_zero_compile_trap_after_passes
)
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
auto
zero
=
mm
->
add_literal
(
0
);
auto
zero
=
mm
->
add_literal
(
0
.0
f
);
auto
x
=
mm
->
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
1
}});
auto
x
=
mm
->
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
1
}});
mm
->
add_instruction
(
migraphx
::
make_op
(
"div"
),
x
,
zero
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"div"
),
x
,
zero
);
run_pass
(
*
mm
);
run_pass
(
*
mm
);
...
@@ -1414,7 +1365,7 @@ TEST_CASE(div_zero_compile_trap_long_program_after_passes)
...
@@ -1414,7 +1365,7 @@ TEST_CASE(div_zero_compile_trap_long_program_after_passes)
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
auto
zero
=
mm
->
add_literal
(
0.0
);
auto
zero
=
mm
->
add_literal
(
0.0
f
);
auto
two
=
mm
->
add_literal
(
2.0
f
);
auto
two
=
mm
->
add_literal
(
2.0
f
);
auto
x
=
mm
->
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
1
}});
auto
x
=
mm
->
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
1
}});
auto
y
=
mm
->
add_parameter
(
"y"
,
{
migraphx
::
shape
::
float_type
,
{
1
}});
auto
y
=
mm
->
add_parameter
(
"y"
,
{
migraphx
::
shape
::
float_type
,
{
1
}});
...
...
test/simplify_algebra_test.cpp
View file @
e7ec015d
...
@@ -1092,7 +1092,6 @@ TEST_CASE(simplify_sub_neg_zero_const_vec)
...
@@ -1092,7 +1092,6 @@ TEST_CASE(simplify_sub_neg_zero_const_vec)
auto
x
=
m2
.
add_parameter
(
"x"
,
outer
);
auto
x
=
m2
.
add_parameter
(
"x"
,
outer
);
m2
.
add_instruction
(
migraphx
::
make_op
(
"neg"
),
x
);
m2
.
add_instruction
(
migraphx
::
make_op
(
"neg"
),
x
);
}
}
EXPECT
(
m1
==
m2
);
EXPECT
(
m1
==
m2
);
}
}
...
@@ -1110,9 +1109,33 @@ TEST_CASE(simplify_div_zero_const)
...
@@ -1110,9 +1109,33 @@ TEST_CASE(simplify_div_zero_const)
{
{
auto
x
=
m2
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
int32_type
,
{
1
}});
auto
x
=
m2
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
int32_type
,
{
1
}});
auto
zero
=
m2
.
add_literal
(
0
);
auto
zero
=
m2
.
add_literal
(
0
);
m2
.
add_divzero
({
x
,
zero
});
auto
div0
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"div"
),
x
,
zero
);
m2
.
add_divzero
({
div0
,
zero
});
}
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
simplify_div_zero_const_middle
)
{
// May looks strange but intent here is to generate a zero via
// simplify algebra passes that causes division by zero
migraphx
::
module
m1
;
{
auto
zero
=
m1
.
add_literal
(
0
);
auto
two
=
m1
.
add_literal
(
2
);
auto
mul
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"mul"
),
zero
,
two
);
auto
x
=
m1
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
int32_type
,
{
1
}});
auto
div0
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"div"
),
x
,
mul
);
m1
.
add_instruction
(
migraphx
::
make_op
(
"mul"
),
div0
,
two
);
}
}
run_pass
(
m1
);
migraphx
::
module
m2
;
{
auto
zero
=
m2
.
add_literal
(
0
);
auto
x
=
m2
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
int32_type
,
{
1
}});
auto
div0
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"div"
),
x
,
zero
);
m2
.
add_divzero
({
div0
,
zero
});
}
EXPECT
(
m1
==
m2
);
EXPECT
(
m1
==
m2
);
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment