Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
33212f8f
Commit
33212f8f
authored
Aug 15, 2018
by
Paul
Browse files
Add onnx updates from resnet branch
parent
4d0fdcd5
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
72 additions
and
3 deletions
+72
-3
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+72
-3
No files found.
src/onnx/onnx.cpp
View file @
33212f8f
...
...
@@ -29,7 +29,7 @@ struct unknown
}
argument
compute
(
context
&
,
shape
,
std
::
vector
<
argument
>
)
const
{
MIGRAPH_THROW
(
name
()
+
":
not computable"
);
MIGRAPH_THROW
(
"
not computable"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
unknown
&
x
)
{
...
...
@@ -60,8 +60,11 @@ struct onnx_parser
add_mem_op
(
"Constant"
,
&
onnx_parser
::
parse_constant
);
add_mem_op
(
"Conv"
,
&
onnx_parser
::
parse_conv
);
add_mem_op
(
"MaxPool"
,
&
onnx_parser
::
parse_pooling
);
add_mem_op
(
"MaxPool"
,
&
onnx_parser
::
parse_max_pooling
);
add_mem_op
(
"AveragePool"
,
&
onnx_parser
::
parse_average_pooling
);
add_mem_op
(
"Reshape"
,
&
onnx_parser
::
parse_reshape
);
add_mem_op
(
"Flatten"
,
&
onnx_parser
::
parse_flatten
);
add_mem_op
(
"Gemm"
,
&
onnx_parser
::
parse_gemm
);
add_mem_op
(
"BatchNormalization"
,
&
onnx_parser
::
parse_batchnorm
);
}
...
...
@@ -126,7 +129,7 @@ struct onnx_parser
}
instruction_ref
parse_pooling
(
std
::
string
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
parse_
max_
pooling
(
std
::
string
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
pooling
op
{
"max"
};
if
(
contains
(
attributes
,
"pads"
))
...
...
@@ -144,6 +147,25 @@ struct onnx_parser
return
prog
.
add_instruction
(
op
,
args
);
}
instruction_ref
parse_average_pooling
(
std
::
string
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
pooling
op
{
"average"
};
if
(
contains
(
attributes
,
"pads"
))
{
copy
(
attributes
[
"pads"
].
ints
(),
op
.
padding
.
begin
());
}
if
(
contains
(
attributes
,
"strides"
))
{
copy
(
attributes
[
"strides"
].
ints
(),
op
.
stride
.
begin
());
}
if
(
contains
(
attributes
,
"kernel_shape"
))
{
copy
(
attributes
[
"kernel_shape"
].
ints
(),
op
.
lengths
.
begin
());
}
return
prog
.
add_instruction
(
op
,
args
);
}
instruction_ref
parse_reshape
(
std
::
string
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
...
...
@@ -161,6 +183,17 @@ struct onnx_parser
return
prog
.
add_instruction
(
op
,
args
[
0
]);
}
instruction_ref
parse_flatten
(
std
::
string
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
uint64_t
axis
=
0
;
// if(contains(attributes, "axis"))
// {
// axis = parse_value(attributes.at("axis")).at<int>();
// }
return
prog
.
add_instruction
(
flatten
{
axis
},
args
[
0
]);
}
instruction_ref
parse_constant
(
std
::
string
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
)
{
...
...
@@ -168,6 +201,42 @@ struct onnx_parser
return
prog
.
add_literal
(
v
);
}
instruction_ref
parse_gemm
(
std
::
string
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
float
alpha
=
1.0
f
;
float
beta
=
0.0
f
;
bool
transa
=
false
;
bool
transb
=
false
;
if
(
contains
(
attributes
,
"alpha"
))
{
alpha
=
parse_value
(
attributes
.
at
(
"alpha"
)).
at
<
float
>
();
}
if
(
contains
(
attributes
,
"beta"
))
{
alpha
=
parse_value
(
attributes
.
at
(
"beta"
)).
at
<
float
>
();
}
if
(
contains
(
attributes
,
"transA"
))
{
transa
=
parse_value
(
attributes
.
at
(
"transA"
)).
at
<
bool
>
();
}
if
(
contains
(
attributes
,
"transB"
))
{
transb
=
parse_value
(
attributes
.
at
(
"transB"
)).
at
<
bool
>
();
}
std
::
vector
<
int64_t
>
perm
=
{
1
,
0
};
auto
l1
=
(
transa
)
?
prog
.
add_instruction
(
transpose
{
perm
},
args
[
0
])
:
args
[
0
];
auto
l2
=
(
transb
)
?
prog
.
add_instruction
(
transpose
{
perm
},
args
[
1
])
:
args
[
1
];
if
(
args
.
size
()
==
3
)
{
uint64_t
axis
=
1
;
auto
l3
=
prog
.
add_instruction
(
gemm
{
alpha
,
beta
},
l1
,
l2
);
auto
l4
=
prog
.
add_instruction
(
broadcast
{
axis
},
l3
,
args
[
2
]);
return
prog
.
add_instruction
(
add
{},
l3
,
l4
);
}
return
prog
.
add_instruction
(
gemm
{
alpha
,
beta
},
l1
,
l2
);
}
instruction_ref
parse_batchnorm
(
std
::
string
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment