Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
5b37c53c
Unverified
Commit
5b37c53c
authored
Mar 09, 2022
by
Charlie Lin
Committed by
GitHub
Mar 09, 2022
Browse files
Celu ONNX parser and tests (#1114)
Add Celu ONNX operator
parent
4467c158
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
234 additions
and
0 deletions
+234
-0
src/onnx/parse_celu.cpp
src/onnx/parse_celu.cpp
+57
-0
test/onnx/celu_alpha_test.onnx
test/onnx/celu_alpha_test.onnx
+12
-0
test/onnx/celu_default_test.onnx
test/onnx/celu_default_test.onnx
+11
-0
test/onnx/celu_verify_test.onnx
test/onnx/celu_verify_test.onnx
+0
-0
test/onnx/celu_wrong_type_test.onnx
test/onnx/celu_wrong_type_test.onnx
+13
-0
test/onnx/celu_zero_alpha_test.onnx
test/onnx/celu_zero_alpha_test.onnx
+0
-0
test/onnx/gen_onnx.py
test/onnx/gen_onnx.py
+59
-0
test/onnx/onnx_test.cpp
test/onnx/onnx_test.cpp
+59
-0
test/onnx/verify_onnx.cpp
test/onnx/verify_onnx.cpp
+22
-0
test/py/onnx_backend_test.py
test/py/onnx_backend_test.py
+1
-0
No files found.
src/onnx/parse_celu.cpp
0 → 100644
View file @
5b37c53c
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
onnx
{
struct
parse_celu
:
op_parser
<
parse_celu
>
{
std
::
vector
<
op_desc
>
operators
()
const
{
return
{{
"Celu"
}};
}
instruction_ref
parse
(
const
op_desc
&
,
const
onnx_parser
&
,
const
onnx_parser
::
node_info
&
info
,
std
::
vector
<
instruction_ref
>
args
)
const
{
float
alpha
=
1.0
;
if
(
contains
(
info
.
attributes
,
"alpha"
))
{
alpha
=
info
.
attributes
.
at
(
"alpha"
).
f
();
}
if
(
float_equal
(
alpha
,
0.0
f
))
{
MIGRAPHX_THROW
(
"CELU: alpha is zero (division by zero)"
);
}
auto
input_lens
=
args
[
0
]
->
get_shape
().
lens
();
auto
input_type
=
args
[
0
]
->
get_shape
().
type
();
if
(
input_type
!=
migraphx
::
shape
::
float_type
)
{
MIGRAPHX_THROW
(
"CELU: input tensor not float type"
);
}
auto
zero_lit
=
info
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
input_lens
}}),
info
.
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
input_type
},
{
0.
}}));
auto
one_lit
=
info
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
input_lens
}}),
info
.
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
input_type
},
{
1.
}}));
auto
alpha_lit
=
info
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
input_lens
}}),
info
.
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
input_type
},
{
alpha
}}));
auto
linear_part
=
info
.
add_instruction
(
migraphx
::
make_op
(
"max"
),
zero_lit
,
args
[
0
]);
auto
divi
=
info
.
add_instruction
(
migraphx
::
make_op
(
"div"
),
args
[
0
],
alpha_lit
);
auto
expo
=
info
.
add_instruction
(
migraphx
::
make_op
(
"exp"
),
divi
);
auto
sub
=
info
.
add_instruction
(
migraphx
::
make_op
(
"sub"
),
expo
,
one_lit
);
auto
mul
=
info
.
add_instruction
(
migraphx
::
make_op
(
"mul"
),
alpha_lit
,
sub
);
auto
exp_part
=
info
.
add_instruction
(
migraphx
::
make_op
(
"min"
),
zero_lit
,
mul
);
return
info
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
linear_part
,
exp_part
);
}
};
}
// namespace onnx
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
test/onnx/celu_alpha_test.onnx
0 → 100644
View file @
5b37c53c
celu_alpha_test:R
xy"Celu*
alphaL?celu_alpha_testZ
x
b
y
B
\ No newline at end of file
test/onnx/celu_default_test.onnx
0 → 100644
View file @
5b37c53c
celu_default_test:K
xy"Celucelu_default_testZ
x
b
y
B
\ No newline at end of file
test/onnx/celu_verify_test.onnx
0 → 100644
View file @
5b37c53c
File added
test/onnx/celu_wrong_type_test.onnx
0 → 100644
View file @
5b37c53c
celu_wrong_type_test:N
xy"Celucelu_wrong_type_testZ
x
b
y
B
\ No newline at end of file
test/onnx/celu_zero_alpha_test.onnx
0 → 100644
View file @
5b37c53c
File added
test/onnx/gen_onnx.py
View file @
5b37c53c
...
...
@@ -351,6 +351,65 @@ def ceil_test():
return
([
node
],
[
x
],
[
y
])
@
onnx_test
def
celu_alpha_test
():
x
=
helper
.
make_tensor_value_info
(
'x'
,
TensorProto
.
FLOAT
,
[
3
])
y
=
helper
.
make_tensor_value_info
(
'y'
,
TensorProto
.
FLOAT
,
[
3
])
node
=
onnx
.
helper
.
make_node
(
'Celu'
,
inputs
=
[
'x'
],
outputs
=
[
'y'
],
alpha
=
0.8
)
return
([
node
],
[
x
],
[
y
])
@
onnx_test
def
celu_default_test
():
x
=
helper
.
make_tensor_value_info
(
'x'
,
TensorProto
.
FLOAT
,
[
2
,
3
])
y
=
helper
.
make_tensor_value_info
(
'y'
,
TensorProto
.
FLOAT
,
[
2
,
3
])
node
=
onnx
.
helper
.
make_node
(
'Celu'
,
inputs
=
[
'x'
],
outputs
=
[
'y'
])
return
([
node
],
[
x
],
[
y
])
@
onnx_test
def
celu_verify_test
():
x
=
helper
.
make_tensor_value_info
(
'x'
,
TensorProto
.
FLOAT
,
[
2
,
3
])
y
=
helper
.
make_tensor_value_info
(
'y'
,
TensorProto
.
FLOAT
,
[
2
,
3
])
node
=
onnx
.
helper
.
make_node
(
'Celu'
,
inputs
=
[
'x'
],
outputs
=
[
'y'
],
alpha
=
0.5
)
return
([
node
],
[
x
],
[
y
])
@
onnx_test
def
celu_wrong_type_test
():
x
=
helper
.
make_tensor_value_info
(
'x'
,
TensorProto
.
FLOAT16
,
[
2
,
3
])
y
=
helper
.
make_tensor_value_info
(
'y'
,
TensorProto
.
FLOAT16
,
[
2
,
3
])
node
=
onnx
.
helper
.
make_node
(
'Celu'
,
inputs
=
[
'x'
],
outputs
=
[
'y'
])
return
([
node
],
[
x
],
[
y
])
@
onnx_test
def
celu_zero_alpha_test
():
x
=
helper
.
make_tensor_value_info
(
'x'
,
TensorProto
.
FLOAT
,
[
2
,
3
])
y
=
helper
.
make_tensor_value_info
(
'y'
,
TensorProto
.
FLOAT
,
[
2
,
3
])
node
=
onnx
.
helper
.
make_node
(
'Celu'
,
inputs
=
[
'x'
],
outputs
=
[
'y'
],
alpha
=
0.0
)
return
([
node
],
[
x
],
[
y
])
@
onnx_test
def
clip_test
():
x
=
helper
.
make_tensor_value_info
(
'0'
,
TensorProto
.
FLOAT
,
[
3
])
...
...
test/onnx/onnx_test.cpp
View file @
5b37c53c
...
...
@@ -46,6 +46,29 @@ migraphx::program optimize_onnx(const std::string& name, bool run_passes = false
return
prog
;
}
void
add_celu_instruction
(
migraphx
::
module
*
mm
,
const
migraphx
::
shape
&
s
,
float
alpha
)
{
auto
x
=
mm
->
add_parameter
(
"x"
,
s
);
const
auto
&
input_lens
=
s
.
lens
();
const
auto
&
input_type
=
s
.
type
();
auto
zero_lit
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
input_lens
}}),
mm
->
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
input_type
},
{
0.
}}));
auto
one_lit
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
input_lens
}}),
mm
->
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
input_type
},
{
1.
}}));
auto
alpha_lit
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
input_lens
}}),
mm
->
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
input_type
},
{
alpha
}}));
auto
linear_part
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"max"
),
zero_lit
,
x
);
auto
divi
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"div"
),
x
,
alpha_lit
);
auto
expo
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"exp"
),
divi
);
auto
sub
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"sub"
),
expo
,
one_lit
);
auto
mul
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
alpha_lit
,
sub
);
auto
exp_part
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"min"
),
zero_lit
,
mul
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
linear_part
,
exp_part
);
}
static
std
::
vector
<
double
>
make_r_eyelike
(
size_t
num_rows
,
size_t
num_cols
,
size_t
k
)
{
std
::
vector
<
double
>
eyelike_mat
(
num_rows
*
num_cols
,
0
);
...
...
@@ -380,6 +403,42 @@ TEST_CASE(ceil_test)
EXPECT
(
p
==
prog
);
}
TEST_CASE
(
celu_alpha_test
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
std
::
vector
<
std
::
size_t
>
input_lens
=
{
3
};
auto
input_type
=
migraphx
::
shape
::
float_type
;
migraphx
::
shape
s
{
input_type
,
input_lens
};
float
alpha
=
0.8
;
add_celu_instruction
(
mm
,
s
,
alpha
);
auto
prog
=
optimize_onnx
(
"celu_alpha_test.onnx"
);
EXPECT
(
p
==
prog
);
}
TEST_CASE
(
celu_default_test
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
std
::
vector
<
std
::
size_t
>
input_lens
=
{
2
,
3
};
auto
input_type
=
migraphx
::
shape
::
float_type
;
migraphx
::
shape
s
{
input_type
,
input_lens
};
float
alpha
=
1.0
;
add_celu_instruction
(
mm
,
s
,
alpha
);
auto
prog
=
optimize_onnx
(
"celu_default_test.onnx"
);
EXPECT
(
p
==
prog
);
}
TEST_CASE
(
celu_wrong_type_test
)
{
EXPECT
(
test
::
throws
([
&
]
{
migraphx
::
parse_onnx
(
"celu_wrong_type_test.onnx"
);
}));
}
TEST_CASE
(
celu_zero_alpha_test
)
{
EXPECT
(
test
::
throws
([
&
]
{
migraphx
::
parse_onnx
(
"celu_zero_alpha_test.onnx"
);
}));
}
TEST_CASE
(
clip_test
)
{
migraphx
::
program
p
;
...
...
test/onnx/verify_onnx.cpp
View file @
5b37c53c
...
...
@@ -45,6 +45,28 @@ TEST_CASE(averagepool_nt_cip_test)
EXPECT
(
migraphx
::
verify_range
(
result_vector
,
gold
));
}
TEST_CASE
(
celu_verify_test
)
{
migraphx
::
program
p
=
migraphx
::
parse_onnx
(
"celu_verify_test.onnx"
);
p
.
compile
(
migraphx
::
ref
::
target
{});
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
2
,
3
}};
std
::
vector
<
float
>
data
=
{
-
5.5
,
2.0
,
100.
,
7.0
,
0.
,
-
1.
};
migraphx
::
parameter_map
pp
;
pp
[
"x"
]
=
migraphx
::
argument
(
s
,
data
.
data
());
auto
result
=
p
.
eval
(
pp
).
back
();
std
::
vector
<
float
>
result_vector
;
result
.
visit
([
&
](
auto
output
)
{
result_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
correct
(
6
);
float
alpha
=
0.5
;
std
::
transform
(
data
.
begin
(),
data
.
end
(),
correct
.
begin
(),
[
&
](
auto
x
)
{
return
std
::
max
(
0.0
f
,
x
)
+
std
::
min
(
0.0
f
,
alpha
*
std
::
expm1
(
x
/
alpha
));
});
EXPECT
(
migraphx
::
verify_range
(
result_vector
,
correct
));
}
TEST_CASE
(
clip_args_type_mismatch
)
{
auto
p
=
migraphx
::
parse_onnx
(
"clip_test_args_type_mismatch.onnx"
);
...
...
test/py/onnx_backend_test.py
View file @
5b37c53c
...
...
@@ -96,6 +96,7 @@ def create_backend_test(testname=None, target_device=None):
backend_test
.
include
(
r
'.*test_AvgPool.*'
)
backend_test
.
include
(
r
'.*test_BatchNorm.*eval.*'
)
backend_test
.
include
(
r
'.*test_ceil.*'
)
backend_test
.
include
(
r
'.*test_celu.*'
)
backend_test
.
include
(
r
'.*test_clip.*'
)
backend_test
.
include
(
r
'.*test_concat.*'
)
backend_test
.
include
(
r
'.*test_constant.*'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment