Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
0fc52912
Commit
0fc52912
authored
Jul 30, 2018
by
Paul
Browse files
Refactor operators in onnx parser
parent
02e0dd2a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
59 additions
and
81 deletions
+59
-81
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+59
-81
No files found.
src/onnx/onnx.cpp
View file @
0fc52912
...
@@ -51,7 +51,55 @@ struct onnx_parser
...
@@ -51,7 +51,55 @@ struct onnx_parser
onnx_parser
()
onnx_parser
()
{
{
add_op
(
"Conv"
,
[
this
](
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
add_generic_op
(
"Add"
,
add
{});
add_generic_op
(
"Div"
,
div
{});
add_generic_op
(
"MatMul"
,
gemm
{});
add_generic_op
(
"Mul"
,
mul
{});
add_generic_op
(
"Relu"
,
activation
{
"relu"
});
add_generic_op
(
"Sub"
,
sub
{});
add_mem_op
(
"Constant"
,
&
onnx_parser
::
parse_constant
);
add_mem_op
(
"Conv"
,
&
onnx_parser
::
parse_conv
);
add_mem_op
(
"MaxPool"
,
&
onnx_parser
::
parse_pooling
);
add_mem_op
(
"Reshape"
,
&
onnx_parser
::
parse_reshape
);
}
template
<
class
F
>
void
add_op
(
std
::
string
name
,
F
f
)
{
ops
.
emplace
(
name
,
f
);
}
template
<
class
F
>
void
add_mem_op
(
std
::
string
name
,
F
f
)
{
ops
.
emplace
(
name
,
[
=
](
auto
&&
...
xs
)
{
return
std
::
mem_fn
(
f
)(
*
this
,
name
,
std
::
forward
<
decltype
(
xs
)
>
(
xs
)...);
});
}
template
<
class
T
>
void
add_generic_op
(
std
::
string
name
,
T
x
)
{
ops
.
emplace
(
name
,
[
this
,
x
](
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
if
(
args
.
size
()
==
2
and
contains
(
attributes
,
"broadcast"
))
{
uint64_t
broadcasted
=
parse_value
(
attributes
.
at
(
"broadcast"
)).
at
<
uint64_t
>
();
if
(
broadcasted
!=
0
)
{
uint64_t
axis
=
(
contains
(
attributes
,
"axis"
))
?
parse_value
(
attributes
.
at
(
"axis"
)).
at
<
uint64_t
>
()
:
0
;
auto
l
=
prog
.
add_instruction
(
broadcast
{
axis
},
args
);
return
prog
.
add_instruction
(
x
,
args
[
0
],
l
);
}
}
return
prog
.
add_instruction
(
x
,
args
);
});
}
instruction_ref
parse_conv
(
std
::
string
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
convolution
op
;
convolution
op
;
if
(
contains
(
attributes
,
"pads"
))
if
(
contains
(
attributes
,
"pads"
))
{
{
...
@@ -73,11 +121,9 @@ struct onnx_parser
...
@@ -73,11 +121,9 @@ struct onnx_parser
return
prog
.
add_instruction
(
add
{},
l1
,
l2
);
return
prog
.
add_instruction
(
add
{},
l1
,
l2
);
}
}
return
prog
.
add_instruction
(
op
,
args
);
return
prog
.
add_instruction
(
op
,
args
);
});
}
add_op
(
"MatMul"
,
[
this
](
attribute_map
,
std
::
vector
<
instruction_ref
>
args
)
{
return
prog
.
add_instruction
(
gemm
{},
args
);
instruction_ref
parse_pooling
(
std
::
string
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
});
add_op
(
"MaxPool"
,
[
this
](
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
pooling
op
{
"max"
};
pooling
op
{
"max"
};
// for(auto&& p:attributes) std::cout << p.first << std::endl;
// for(auto&& p:attributes) std::cout << p.first << std::endl;
if
(
contains
(
attributes
,
"pads"
))
if
(
contains
(
attributes
,
"pads"
))
...
@@ -93,11 +139,9 @@ struct onnx_parser
...
@@ -93,11 +139,9 @@ struct onnx_parser
copy
(
attributes
[
"kernel_shape"
].
ints
(),
op
.
lengths
.
begin
());
copy
(
attributes
[
"kernel_shape"
].
ints
(),
op
.
lengths
.
begin
());
}
}
return
prog
.
add_instruction
(
op
,
args
);
return
prog
.
add_instruction
(
op
,
args
);
});
}
add_op
(
"Relu"
,
[
this
](
attribute_map
,
std
::
vector
<
instruction_ref
>
args
)
{
return
prog
.
add_instruction
(
activation
{
"relu"
},
args
);
instruction_ref
parse_reshape
(
std
::
string
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
});
add_op
(
"Reshape"
,
[
this
](
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
reshape
op
;
reshape
op
;
if
(
args
.
size
()
==
1
)
if
(
args
.
size
()
==
1
)
{
{
...
@@ -110,78 +154,12 @@ struct onnx_parser
...
@@ -110,78 +154,12 @@ struct onnx_parser
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
op
.
dims
));
});
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
op
.
dims
));
});
}
}
return
prog
.
add_instruction
(
op
,
args
[
0
]);
return
prog
.
add_instruction
(
op
,
args
[
0
]);
});
}
add_op
(
"Constant"
,
[
this
](
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
)
{
instruction_ref
parse_constant
(
std
::
string
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
)
{
literal
v
=
parse_value
(
attributes
.
at
(
"value"
));
literal
v
=
parse_value
(
attributes
.
at
(
"value"
));
return
prog
.
add_literal
(
v
);
return
prog
.
add_literal
(
v
);
});
}
add_op
(
"Add"
,
[
this
](
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
if
(
contains
(
attributes
,
"broadcast"
))
{
uint64_t
broadcasted
=
parse_value
(
attributes
.
at
(
"broadcast"
)).
at
<
uint64_t
>
();
if
(
broadcasted
!=
0
)
{
uint64_t
axis
=
(
contains
(
attributes
,
"axis"
))
?
parse_value
(
attributes
.
at
(
"axis"
)).
at
<
uint64_t
>
()
:
0
;
auto
l
=
prog
.
add_instruction
(
broadcast
{
axis
},
args
);
return
prog
.
add_instruction
(
add
{},
args
[
0
],
l
);
}
}
return
prog
.
add_instruction
(
add
{},
args
);
});
add_op
(
"Sub"
,
[
this
](
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
if
(
contains
(
attributes
,
"broadcast"
))
{
uint64_t
broadcasted
=
parse_value
(
attributes
.
at
(
"broadcast"
)).
at
<
uint64_t
>
();
if
(
broadcasted
!=
0
)
{
uint64_t
axis
=
(
contains
(
attributes
,
"axis"
))
?
parse_value
(
attributes
.
at
(
"axis"
)).
at
<
uint64_t
>
()
:
0
;
auto
l
=
prog
.
add_instruction
(
broadcast
{
axis
},
args
);
return
prog
.
add_instruction
(
sub
{},
args
[
0
],
l
);
}
}
return
prog
.
add_instruction
(
sub
{},
args
);
});
add_op
(
"Mul"
,
[
this
](
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
if
(
contains
(
attributes
,
"broadcast"
))
{
uint64_t
broadcasted
=
parse_value
(
attributes
.
at
(
"broadcast"
)).
at
<
uint64_t
>
();
if
(
broadcasted
!=
0
)
{
uint64_t
axis
=
(
contains
(
attributes
,
"axis"
))
?
parse_value
(
attributes
.
at
(
"axis"
)).
at
<
uint64_t
>
()
:
0
;
auto
l
=
prog
.
add_instruction
(
broadcast
{
axis
},
args
);
return
prog
.
add_instruction
(
mul
{},
args
[
0
],
l
);
}
}
return
prog
.
add_instruction
(
mul
{},
args
);
});
add_op
(
"Div"
,
[
this
](
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
if
(
contains
(
attributes
,
"broadcast"
))
{
uint64_t
broadcasted
=
parse_value
(
attributes
.
at
(
"broadcast"
)).
at
<
uint64_t
>
();
if
(
broadcasted
!=
0
)
{
uint64_t
axis
=
(
contains
(
attributes
,
"axis"
))
?
parse_value
(
attributes
.
at
(
"axis"
)).
at
<
uint64_t
>
()
:
0
;
auto
l
=
prog
.
add_instruction
(
broadcast
{
axis
},
args
);
return
prog
.
add_instruction
(
div
{},
args
[
0
],
l
);
}
}
return
prog
.
add_instruction
(
div
{},
args
);
});
}
template
<
class
F
>
void
add_op
(
std
::
string
name
,
F
f
)
{
ops
.
emplace
(
name
,
f
);
}
void
parse_from
(
std
::
istream
&
is
)
void
parse_from
(
std
::
istream
&
is
)
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment