Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
33212f8f
Commit
33212f8f
authored
Aug 15, 2018
by
Paul
Browse files
Add onnx updates from resnet branch
parent
4d0fdcd5
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
72 additions
and
3 deletions
+72
-3
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+72
-3
No files found.
src/onnx/onnx.cpp
View file @
33212f8f
...
...
@@ -29,7 +29,7 @@ struct unknown
}
argument
compute
(
context
&
,
shape
,
std
::
vector
<
argument
>
)
const
{
MIGRAPH_THROW
(
name
()
+
":
not computable"
);
MIGRAPH_THROW
(
"
not computable"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
unknown
&
x
)
{
...
...
@@ -60,8 +60,11 @@ struct onnx_parser
add_mem_op
(
"Constant"
,
&
onnx_parser
::
parse_constant
);
add_mem_op
(
"Conv"
,
&
onnx_parser
::
parse_conv
);
add_mem_op
(
"MaxPool"
,
&
onnx_parser
::
parse_pooling
);
add_mem_op
(
"MaxPool"
,
&
onnx_parser
::
parse_max_pooling
);
add_mem_op
(
"AveragePool"
,
&
onnx_parser
::
parse_average_pooling
);
add_mem_op
(
"Reshape"
,
&
onnx_parser
::
parse_reshape
);
add_mem_op
(
"Flatten"
,
&
onnx_parser
::
parse_flatten
);
add_mem_op
(
"Gemm"
,
&
onnx_parser
::
parse_gemm
);
add_mem_op
(
"BatchNormalization"
,
&
onnx_parser
::
parse_batchnorm
);
}
...
...
@@ -126,7 +129,7 @@ struct onnx_parser
}
instruction_ref
parse_pooling
(
std
::
string
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
parse_
max_
pooling
(
std
::
string
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
pooling
op
{
"max"
};
if
(
contains
(
attributes
,
"pads"
))
...
...
@@ -144,6 +147,25 @@ struct onnx_parser
return
prog
.
add_instruction
(
op
,
args
);
}
instruction_ref
parse_average_pooling
(
std
::
string
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
pooling
op
{
"average"
};
if
(
contains
(
attributes
,
"pads"
))
{
copy
(
attributes
[
"pads"
].
ints
(),
op
.
padding
.
begin
());
}
if
(
contains
(
attributes
,
"strides"
))
{
copy
(
attributes
[
"strides"
].
ints
(),
op
.
stride
.
begin
());
}
if
(
contains
(
attributes
,
"kernel_shape"
))
{
copy
(
attributes
[
"kernel_shape"
].
ints
(),
op
.
lengths
.
begin
());
}
return
prog
.
add_instruction
(
op
,
args
);
}
instruction_ref
parse_reshape
(
std
::
string
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
...
...
@@ -161,6 +183,17 @@ struct onnx_parser
return
prog
.
add_instruction
(
op
,
args
[
0
]);
}
instruction_ref
parse_flatten
(
std
::
string
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
uint64_t
axis
=
0
;
// if(contains(attributes, "axis"))
// {
// axis = parse_value(attributes.at("axis")).at<int>();
// }
return
prog
.
add_instruction
(
flatten
{
axis
},
args
[
0
]);
}
instruction_ref
parse_constant
(
std
::
string
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
)
{
...
...
@@ -168,6 +201,42 @@ struct onnx_parser
return
prog
.
add_literal
(
v
);
}
instruction_ref
parse_gemm
(
std
::
string
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
float
alpha
=
1.0
f
;
float
beta
=
0.0
f
;
bool
transa
=
false
;
bool
transb
=
false
;
if
(
contains
(
attributes
,
"alpha"
))
{
alpha
=
parse_value
(
attributes
.
at
(
"alpha"
)).
at
<
float
>
();
}
if
(
contains
(
attributes
,
"beta"
))
{
alpha
=
parse_value
(
attributes
.
at
(
"beta"
)).
at
<
float
>
();
}
if
(
contains
(
attributes
,
"transA"
))
{
transa
=
parse_value
(
attributes
.
at
(
"transA"
)).
at
<
bool
>
();
}
if
(
contains
(
attributes
,
"transB"
))
{
transb
=
parse_value
(
attributes
.
at
(
"transB"
)).
at
<
bool
>
();
}
std
::
vector
<
int64_t
>
perm
=
{
1
,
0
};
auto
l1
=
(
transa
)
?
prog
.
add_instruction
(
transpose
{
perm
},
args
[
0
])
:
args
[
0
];
auto
l2
=
(
transb
)
?
prog
.
add_instruction
(
transpose
{
perm
},
args
[
1
])
:
args
[
1
];
if
(
args
.
size
()
==
3
)
{
uint64_t
axis
=
1
;
auto
l3
=
prog
.
add_instruction
(
gemm
{
alpha
,
beta
},
l1
,
l2
);
auto
l4
=
prog
.
add_instruction
(
broadcast
{
axis
},
l3
,
args
[
2
]);
return
prog
.
add_instruction
(
add
{},
l3
,
l4
);
}
return
prog
.
add_instruction
(
gemm
{
alpha
,
beta
},
l1
,
l2
);
}
instruction_ref
parse_batchnorm
(
std
::
string
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment