Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
a7a3f867
Commit
a7a3f867
authored
Feb 01, 2019
by
Shucai Xiao
Browse files
fix comments.
parent
343a5774
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
26 additions
and
31 deletions
+26
-31
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+1
-5
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+6
-5
src/program.cpp
src/program.cpp
+1
-6
src/rewrite_rnn.cpp
src/rewrite_rnn.cpp
+18
-15
No files found.
src/include/migraphx/operators.hpp
View file @
a7a3f867
...
@@ -632,11 +632,7 @@ struct reshape
...
@@ -632,11 +632,7 @@ struct reshape
rdims
[
i
]
=
missing_dim
;
rdims
[
i
]
=
missing_dim
;
}
}
}
}
// if(dims.back() == -1)
//{
// rdims.pop_back();
// std::copy(idims.begin() + rdims.size(), idims.end(), std::back_inserter(rdims));
//}
shape
s
{
inputs
.
front
().
type
(),
rdims
};
shape
s
{
inputs
.
front
().
type
(),
rdims
};
if
(
s
.
elements
()
!=
inputs
.
front
().
elements
())
if
(
s
.
elements
()
!=
inputs
.
front
().
elements
())
MIGRAPHX_THROW
(
"Wrong number of elements for reshape"
);
MIGRAPHX_THROW
(
"Wrong number of elements for reshape"
);
...
...
src/onnx/onnx.cpp
View file @
a7a3f867
...
@@ -739,8 +739,9 @@ struct onnx_parser
...
@@ -739,8 +739,9 @@ struct onnx_parser
}
}
});
});
// bidirectional should have two activation functions
// bidirectional case should have two activation functions.
// if only one actv function is provides, we use it in both
// one is for forward, and the other is for reverse.
// if only one actv function is provided, we use it in both
// forward and reverse direction
// forward and reverse direction
if
(
dirct
==
op
::
rnn
::
bidirectional
)
if
(
dirct
==
op
::
rnn
::
bidirectional
)
{
{
...
@@ -750,9 +751,9 @@ struct onnx_parser
...
@@ -750,9 +751,9 @@ struct onnx_parser
}
}
}
}
std
::
vector
<
operation
>
vec_actv_funcs
;
std
::
vector
<
operation
>
vec_actv_funcs
(
vec_names
.
size
())
;
for_each
(
vec_names
.
begin
(),
vec_names
.
end
(),
[
&
](
auto
&
fn
)
{
std
::
transform
(
vec_names
.
begin
(),
vec_names
.
end
(),
vec_actv_funcs
.
begin
(),
[
&
](
auto
&
fn
)
{
vec_actv_funcs
.
push_back
(
map_actv_funcs
[
fn
]
)
;
return
map_actv_funcs
[
fn
];
});
});
// To be added later
// To be added later
...
...
src/program.cpp
View file @
a7a3f867
...
@@ -138,12 +138,7 @@ instruction_ref program::replace_instruction(instruction_ref ins, instruction_re
...
@@ -138,12 +138,7 @@ instruction_ref program::replace_instruction(instruction_ref ins, instruction_re
if
(
ins
==
std
::
prev
(
this
->
end
()))
if
(
ins
==
std
::
prev
(
this
->
end
()))
{
{
// additional check to ensure the ins to be replaced is either
return
replace_instruction
(
ins
,
op
::
identity
{},
rep
);
// the rnn_last_output, gru_last_output, or lstm_last_output
if
(
ins
->
name
()
==
"rnn_last_output"
)
{
return
replace_instruction
(
ins
,
op
::
identity
{},
rep
);
}
}
}
// TODO: Should it be an error if the output is empty?
// TODO: Should it be an error if the output is empty?
...
...
src/rewrite_rnn.cpp
View file @
a7a3f867
...
@@ -10,7 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -10,7 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS {
void
rewrite_rnn
::
apply
(
program
&
prog
)
const
void
rewrite_rnn
::
apply
(
program
&
prog
)
const
{
{
instruction_ref
last_output
=
prog
.
end
()
;
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
map_last_output
;
for
(
auto
ins
:
iterator_for
(
prog
))
for
(
auto
ins
:
iterator_for
(
prog
))
{
{
// rewrite rnn operator
// rewrite rnn operator
...
@@ -87,14 +87,15 @@ void rewrite_rnn::apply(program& prog) const
...
@@ -87,14 +87,15 @@ void rewrite_rnn::apply(program& prog) const
auto
concat_output
=
auto
concat_output
=
prog
.
insert_instruction
(
ins
,
op
::
concat
{
1
},
ret_forward
[
1
],
ret_reverse
[
1
]);
prog
.
insert_instruction
(
ins
,
op
::
concat
{
1
},
ret_forward
[
1
],
ret_reverse
[
1
]);
last_output
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
concat_output
);
auto
last_output
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
concat_output
);
// The following logic is to ensure the last instruction rewritten from
// The following logic is to ensure the last instruction rewritten from
// rnn operator is a concat instruction
// rnn operator is a concat instruction
// sequence len is 1
// sequence len is 1
instruction_ref
hidden_output
{};
if
(
ret_forward
[
0
]
==
prog
.
end
())
if
(
ret_forward
[
0
]
==
prog
.
end
())
{
{
prog
.
replace_instruction
(
ins
,
op
::
concat
{
1
},
ret_forward
[
1
],
ret_reverse
[
1
]);
hidden_output
=
prog
.
replace_instruction
(
ins
,
op
::
concat
{
1
},
ret_forward
[
1
],
ret_reverse
[
1
]);
}
}
else
else
{
{
...
@@ -102,8 +103,9 @@ void rewrite_rnn::apply(program& prog) const
...
@@ -102,8 +103,9 @@ void rewrite_rnn::apply(program& prog) const
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
ret_forward
[
0
],
ret_forward
[
1
]);
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
ret_forward
[
0
],
ret_forward
[
1
]);
ret_reverse
[
0
]
=
ret_reverse
[
0
]
=
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
ret_reverse
[
1
],
ret_reverse
[
0
]);
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
ret_reverse
[
1
],
ret_reverse
[
0
]);
prog
.
replace_instruction
(
ins
,
op
::
concat
{
1
},
{
ret_forward
[
0
],
ret_reverse
[
0
]});
hidden_output
=
prog
.
replace_instruction
(
ins
,
op
::
concat
{
1
},
{
ret_forward
[
0
],
ret_reverse
[
0
]});
}
}
map_last_output
[
hidden_output
]
=
last_output
;
}
}
else
else
{
{
...
@@ -135,21 +137,23 @@ void rewrite_rnn::apply(program& prog) const
...
@@ -135,21 +137,23 @@ void rewrite_rnn::apply(program& prog) const
auto
ret
=
rnn_cell
(
auto
ret
=
rnn_cell
(
is_forward
,
prog
,
ins
,
args
[
0
],
w
,
r
,
bias
,
ih
,
rnn_op
.
actv_funcs
.
at
(
0
));
is_forward
,
prog
,
ins
,
args
[
0
],
w
,
r
,
bias
,
ih
,
rnn_op
.
actv_funcs
.
at
(
0
));
last_output
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ret
[
1
]);
auto
last_output
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ret
[
1
]);
// following logic is to ensure the last instruction is a
// following logic is to ensure the last instruction is a
// concat instruction
// concat instruction
// sequence len is 1
// sequence len is 1
instruction_ref
hidden_output
{};
if
(
ret
[
0
]
==
prog
.
end
())
if
(
ret
[
0
]
==
prog
.
end
())
{
{
prog
.
replace_instruction
(
ins
,
op
::
concat
{
0
},
ret
[
1
]);
hidden_output
=
prog
.
replace_instruction
(
ins
,
op
::
concat
{
0
},
ret
[
1
]);
}
}
else
else
{
{
auto
concat_arg0
=
is_forward
?
ret
[
0
]
:
ret
[
1
];
auto
concat_arg0
=
is_forward
?
ret
[
0
]
:
ret
[
1
];
auto
concat_arg1
=
is_forward
?
ret
[
1
]
:
ret
[
0
];
auto
concat_arg1
=
is_forward
?
ret
[
1
]
:
ret
[
0
];
prog
.
replace_instruction
(
ins
,
op
::
concat
{
0
},
concat_arg0
,
concat_arg1
);
hidden_output
=
prog
.
replace_instruction
(
ins
,
op
::
concat
{
0
},
concat_arg0
,
concat_arg1
);
}
}
map_last_output
[
hidden_output
]
=
last_output
;
}
}
}
}
...
@@ -159,16 +163,15 @@ void rewrite_rnn::apply(program& prog) const
...
@@ -159,16 +163,15 @@ void rewrite_rnn::apply(program& prog) const
// so we can just use it as the output here
// so we can just use it as the output here
if
(
ins
->
name
()
==
"rnn_last_output"
)
if
(
ins
->
name
()
==
"rnn_last_output"
)
{
{
// if rnn operator is executed, the last_output != prog.end()
auto
inputs
=
ins
->
inputs
();
if
(
last_output
!=
prog
.
end
())
assert
(
inputs
.
size
()
==
1
);
auto
arg
=
inputs
[
0
];
if
(
map_last_output
.
count
(
arg
)
==
0
)
{
{
prog
.
replace_instruction
(
ins
,
last_output
);
MIGRAPHX_THROW
(
"RNN_LAST_OUTPUT: no related rnn operator as its input"
);
last_output
=
prog
.
end
();
}
else
{
MIGRAPHX_THROW
(
"RNN_LAST_OUTPUT: must put after rnn operator"
);
}
}
prog
.
replace_instruction
(
ins
,
map_last_output
[
arg
]);
}
}
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment