Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
0fc52912
Commit
0fc52912
authored
Jul 30, 2018
by
Paul
Browse files
Refactor operators in onnx parser
parent
02e0dd2a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
59 additions
and
81 deletions
+59
-81
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+59
-81
No files found.
src/onnx/onnx.cpp
View file @
0fc52912
...
@@ -51,7 +51,55 @@ struct onnx_parser
...
@@ -51,7 +51,55 @@ struct onnx_parser
onnx_parser
()
onnx_parser
()
{
{
add_op
(
"Conv"
,
[
this
](
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
add_generic_op
(
"Add"
,
add
{});
add_generic_op
(
"Div"
,
div
{});
add_generic_op
(
"MatMul"
,
gemm
{});
add_generic_op
(
"Mul"
,
mul
{});
add_generic_op
(
"Relu"
,
activation
{
"relu"
});
add_generic_op
(
"Sub"
,
sub
{});
add_mem_op
(
"Constant"
,
&
onnx_parser
::
parse_constant
);
add_mem_op
(
"Conv"
,
&
onnx_parser
::
parse_conv
);
add_mem_op
(
"MaxPool"
,
&
onnx_parser
::
parse_pooling
);
add_mem_op
(
"Reshape"
,
&
onnx_parser
::
parse_reshape
);
}
template
<
class
F
>
void
add_op
(
std
::
string
name
,
F
f
)
{
ops
.
emplace
(
name
,
f
);
}
template
<
class
F
>
void
add_mem_op
(
std
::
string
name
,
F
f
)
{
ops
.
emplace
(
name
,
[
=
](
auto
&&
...
xs
)
{
return
std
::
mem_fn
(
f
)(
*
this
,
name
,
std
::
forward
<
decltype
(
xs
)
>
(
xs
)...);
});
}
template
<
class
T
>
void
add_generic_op
(
std
::
string
name
,
T
x
)
{
ops
.
emplace
(
name
,
[
this
,
x
](
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
if
(
args
.
size
()
==
2
and
contains
(
attributes
,
"broadcast"
))
{
uint64_t
broadcasted
=
parse_value
(
attributes
.
at
(
"broadcast"
)).
at
<
uint64_t
>
();
if
(
broadcasted
!=
0
)
{
uint64_t
axis
=
(
contains
(
attributes
,
"axis"
))
?
parse_value
(
attributes
.
at
(
"axis"
)).
at
<
uint64_t
>
()
:
0
;
auto
l
=
prog
.
add_instruction
(
broadcast
{
axis
},
args
);
return
prog
.
add_instruction
(
x
,
args
[
0
],
l
);
}
}
return
prog
.
add_instruction
(
x
,
args
);
});
}
instruction_ref
parse_conv
(
std
::
string
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
convolution
op
;
convolution
op
;
if
(
contains
(
attributes
,
"pads"
))
if
(
contains
(
attributes
,
"pads"
))
{
{
...
@@ -73,11 +121,9 @@ struct onnx_parser
...
@@ -73,11 +121,9 @@ struct onnx_parser
return
prog
.
add_instruction
(
add
{},
l1
,
l2
);
return
prog
.
add_instruction
(
add
{},
l1
,
l2
);
}
}
return
prog
.
add_instruction
(
op
,
args
);
return
prog
.
add_instruction
(
op
,
args
);
});
}
add_op
(
"MatMul"
,
[
this
](
attribute_map
,
std
::
vector
<
instruction_ref
>
args
)
{
return
prog
.
add_instruction
(
gemm
{},
args
);
instruction_ref
parse_pooling
(
std
::
string
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
});
add_op
(
"MaxPool"
,
[
this
](
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
pooling
op
{
"max"
};
pooling
op
{
"max"
};
// for(auto&& p:attributes) std::cout << p.first << std::endl;
// for(auto&& p:attributes) std::cout << p.first << std::endl;
if
(
contains
(
attributes
,
"pads"
))
if
(
contains
(
attributes
,
"pads"
))
...
@@ -93,11 +139,9 @@ struct onnx_parser
...
@@ -93,11 +139,9 @@ struct onnx_parser
copy
(
attributes
[
"kernel_shape"
].
ints
(),
op
.
lengths
.
begin
());
copy
(
attributes
[
"kernel_shape"
].
ints
(),
op
.
lengths
.
begin
());
}
}
return
prog
.
add_instruction
(
op
,
args
);
return
prog
.
add_instruction
(
op
,
args
);
});
}
add_op
(
"Relu"
,
[
this
](
attribute_map
,
std
::
vector
<
instruction_ref
>
args
)
{
return
prog
.
add_instruction
(
activation
{
"relu"
},
args
);
instruction_ref
parse_reshape
(
std
::
string
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
});
add_op
(
"Reshape"
,
[
this
](
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
reshape
op
;
reshape
op
;
if
(
args
.
size
()
==
1
)
if
(
args
.
size
()
==
1
)
{
{
...
@@ -110,78 +154,12 @@ struct onnx_parser
...
@@ -110,78 +154,12 @@ struct onnx_parser
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
op
.
dims
));
});
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
op
.
dims
));
});
}
}
return
prog
.
add_instruction
(
op
,
args
[
0
]);
return
prog
.
add_instruction
(
op
,
args
[
0
]);
});
}
add_op
(
"Constant"
,
[
this
](
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
)
{
instruction_ref
parse_constant
(
std
::
string
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
)
{
literal
v
=
parse_value
(
attributes
.
at
(
"value"
));
literal
v
=
parse_value
(
attributes
.
at
(
"value"
));
return
prog
.
add_literal
(
v
);
return
prog
.
add_literal
(
v
);
});
}
add_op
(
"Add"
,
[
this
](
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
if
(
contains
(
attributes
,
"broadcast"
))
{
uint64_t
broadcasted
=
parse_value
(
attributes
.
at
(
"broadcast"
)).
at
<
uint64_t
>
();
if
(
broadcasted
!=
0
)
{
uint64_t
axis
=
(
contains
(
attributes
,
"axis"
))
?
parse_value
(
attributes
.
at
(
"axis"
)).
at
<
uint64_t
>
()
:
0
;
auto
l
=
prog
.
add_instruction
(
broadcast
{
axis
},
args
);
return
prog
.
add_instruction
(
add
{},
args
[
0
],
l
);
}
}
return
prog
.
add_instruction
(
add
{},
args
);
});
add_op
(
"Sub"
,
[
this
](
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
if
(
contains
(
attributes
,
"broadcast"
))
{
uint64_t
broadcasted
=
parse_value
(
attributes
.
at
(
"broadcast"
)).
at
<
uint64_t
>
();
if
(
broadcasted
!=
0
)
{
uint64_t
axis
=
(
contains
(
attributes
,
"axis"
))
?
parse_value
(
attributes
.
at
(
"axis"
)).
at
<
uint64_t
>
()
:
0
;
auto
l
=
prog
.
add_instruction
(
broadcast
{
axis
},
args
);
return
prog
.
add_instruction
(
sub
{},
args
[
0
],
l
);
}
}
return
prog
.
add_instruction
(
sub
{},
args
);
});
add_op
(
"Mul"
,
[
this
](
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
if
(
contains
(
attributes
,
"broadcast"
))
{
uint64_t
broadcasted
=
parse_value
(
attributes
.
at
(
"broadcast"
)).
at
<
uint64_t
>
();
if
(
broadcasted
!=
0
)
{
uint64_t
axis
=
(
contains
(
attributes
,
"axis"
))
?
parse_value
(
attributes
.
at
(
"axis"
)).
at
<
uint64_t
>
()
:
0
;
auto
l
=
prog
.
add_instruction
(
broadcast
{
axis
},
args
);
return
prog
.
add_instruction
(
mul
{},
args
[
0
],
l
);
}
}
return
prog
.
add_instruction
(
mul
{},
args
);
});
add_op
(
"Div"
,
[
this
](
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
if
(
contains
(
attributes
,
"broadcast"
))
{
uint64_t
broadcasted
=
parse_value
(
attributes
.
at
(
"broadcast"
)).
at
<
uint64_t
>
();
if
(
broadcasted
!=
0
)
{
uint64_t
axis
=
(
contains
(
attributes
,
"axis"
))
?
parse_value
(
attributes
.
at
(
"axis"
)).
at
<
uint64_t
>
()
:
0
;
auto
l
=
prog
.
add_instruction
(
broadcast
{
axis
},
args
);
return
prog
.
add_instruction
(
div
{},
args
[
0
],
l
);
}
}
return
prog
.
add_instruction
(
div
{},
args
);
});
}
template
<
class
F
>
void
add_op
(
std
::
string
name
,
F
f
)
{
ops
.
emplace
(
name
,
f
);
}
void
parse_from
(
std
::
istream
&
is
)
void
parse_from
(
std
::
istream
&
is
)
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment