Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
a87890be
Commit
a87890be
authored
Feb 02, 2019
by
Shucai Xiao
Browse files
commit gru changes.
parent
6b1e5e63
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
47 additions
and
46 deletions
+47
-46
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+14
-14
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+13
-10
src/rewrite_gru.cpp
src/rewrite_gru.cpp
+20
-22
No files found.
src/include/migraphx/operators.hpp
View file @
a87890be
...
...
@@ -1173,6 +1173,20 @@ struct rnn
}
};
struct
rnn_last_output
{
std
::
string
name
()
const
{
return
"rnn_last_output"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
auto
dims
=
inputs
[
0
].
lens
();
// remove the first dimension, remaing are output shape
dims
.
erase
(
dims
.
begin
());
return
{
inputs
[
0
].
type
(),
dims
};
}
};
struct
gru
{
enum
gru_direction_t
...
...
@@ -1217,20 +1231,6 @@ struct gru
}
};
struct
rnn_last_output
{
std
::
string
name
()
const
{
return
"rnn_last_output"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
auto
dims
=
inputs
[
0
].
lens
();
// remove the first dimension, remaing are output shape
dims
.
erase
(
dims
.
begin
());
return
{
inputs
[
0
].
type
(),
dims
};
}
};
struct
gru_last_output
{
std
::
string
name
()
const
{
return
"gru_last_output"
;
}
...
...
src/onnx/onnx.cpp
View file @
a87890be
...
...
@@ -832,8 +832,8 @@ struct onnx_parser
{
// 4 activation functions are used in the bidirectional
// scenario. No spec is provided in onnx::operator. we
// use the algorithm that: if 1 actv function is provide
s
,
// repeat 1 four times. If 2 actv functins are provide
s
,
// use the algorithm that: if 1 actv function is provide
d
,
// repeat 1 four times. If 2 actv functins are provide
d
,
// assume forward and reverse use the same pair of actv
// functions. For the case of 3 actv functions provided,
// assume the 3rd one is repeated once and used by the
...
...
@@ -869,12 +869,11 @@ struct onnx_parser
}
});
std
::
vector
<
operation
>
vec_actv_funcs
;
for_each
(
vec_names
.
begin
(),
vec_names
.
end
(),
[
&
](
auto
&
name
)
{
vec_actv_funcs
.
push_back
(
map_actv_funcs
[
name
]
)
;
std
::
vector
<
operation
>
vec_actv_funcs
(
vec_names
.
size
())
;
std
::
transform
(
vec_names
.
begin
(),
vec_names
.
end
(),
vec_actv_funcs
.
begin
(),
[
&
](
auto
&
name
)
{
return
map_actv_funcs
[
name
];
});
// To be added later
float
clip
=
0.0
;
if
(
contains
(
attributes
,
"clip"
))
{
...
...
@@ -887,18 +886,22 @@ struct onnx_parser
linear_before_reset
=
parse_value
(
attributes
.
at
(
"linear_before_reset"
)).
at
<
int
>
();
}
std
::
vector
<
instruction_ref
>
result
;
// append undefined opeator to make 6 arguments
if
(
args
.
size
()
<
6
)
{
auto
ins
=
prog
.
add_instruction
(
op
::
undefined
{});
args
.
insert
(
args
.
end
(),
6
-
args
.
size
(),
ins
);
}
// first output for concatenation of hidden states
auto
hidden_states
=
prog
.
add_instruction
(
op
::
gru
{
hidden_size
,
vec_actv_funcs
,
dirct
,
clip
,
linear_before_reset
},
std
::
move
(
args
));
result
.
push_back
(
hidden_states
);
// second output for last gru output
auto
last_output
=
prog
.
add_instruction
(
op
::
gru_last_output
{},
hidden_states
);
result
.
push_back
(
last_output
);
return
result
;
return
{
hidden_states
,
last_output
}
;
}
void
parse_from
(
std
::
istream
&
is
)
...
...
src/rewrite_gru.cpp
View file @
a87890be
...
...
@@ -10,7 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS {
void
rewrite_gru
::
apply
(
program
&
prog
)
const
{
instruction_ref
last_output
=
prog
.
end
()
;
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
map_last_output
;
for
(
auto
ins
:
iterator_for
(
prog
))
{
if
(
ins
->
name
()
==
"gru"
)
...
...
@@ -22,9 +22,9 @@ void rewrite_gru::apply(program& prog) const
shape
seq_shape
=
args
[
0
]
->
get_shape
();
std
::
size_t
hidden_size
=
args
[
2
]
->
get_shape
().
lens
()[
2
];
std
::
size_t
batch
s
=
seq_shape
.
lens
()[
1
];
std
::
size_t
batch
_size
=
seq_shape
.
lens
()[
1
];
shape
::
type_t
type
=
seq_shape
.
type
();
migraphx
::
shape
ih_shape
{
type
,
{
1
,
batch
s
,
hidden_size
}};
migraphx
::
shape
ih_shape
{
type
,
{
1
,
batch
_size
,
hidden_size
}};
std
::
vector
<
char
>
data
(
ih_shape
.
bytes
(),
0
);
auto
gru_op
=
any_cast
<
op
::
gru
>
(
ins
->
get_operator
());
...
...
@@ -42,7 +42,7 @@ void rewrite_gru::apply(program& prog) const
// bias
instruction_ref
bias_forward
,
bias_reverse
;
bias_forward
=
bias_reverse
=
prog
.
end
();
if
(
args
.
size
()
>=
4
)
if
(
args
.
size
()
>=
4
&&
args
[
3
]
->
get_operator
().
name
()
!=
"undefined"
)
{
bias_forward
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
1
}},
args
[
3
]);
bias_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
1
},
{
2
}},
args
[
3
]);
...
...
@@ -50,12 +50,10 @@ void rewrite_gru::apply(program& prog) const
// intial hidden state
instruction_ref
ih_forward
,
ih_reverse
;
if
(
args
.
size
()
==
6
||
(
args
.
size
()
==
5
&&
args
[
4
]
->
get_shape
().
lens
().
size
()
==
3
))
if
(
args
.
size
()
==
6
&&
args
[
5
]
->
get_operator
().
name
()
!=
"undefined"
)
{
auto
arg_ih
=
(
args
.
size
()
==
6
)
?
args
[
5
]
:
args
[
4
];
ih_forward
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
1
}},
arg_ih
);
ih_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
1
},
{
2
}},
arg_ih
);
ih_forward
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
1
}},
args
[
5
]);
ih_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
1
},
{
2
}},
args
[
5
]);
}
else
{
...
...
@@ -87,7 +85,7 @@ void rewrite_gru::apply(program& prog) const
gru_op
.
actv_funcs
.
at
(
2
),
gru_op
.
actv_funcs
.
at
(
3
));
last_output
=
auto
last_output
=
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
ret_forward
[
1
],
ret_reverse
[
1
]);
// add the dimension of num_direction
...
...
@@ -95,7 +93,8 @@ void rewrite_gru::apply(program& prog) const
ret_reverse
[
0
]
=
prog
.
insert_instruction
(
ins
,
op
::
unsqueeze
{{
1
}},
ret_reverse
[
0
]);
// concat the forward and reverse output
prog
.
replace_instruction
(
ins
,
op
::
concat
{
1
},
{
ret_forward
[
0
],
ret_reverse
[
0
]});
auto
hidden_state
=
prog
.
replace_instruction
(
ins
,
op
::
concat
{
1
},
{
ret_forward
[
0
],
ret_reverse
[
0
]});
map_last_output
[
hidden_state
]
=
last_output
;
}
else
{
...
...
@@ -106,17 +105,16 @@ void rewrite_gru::apply(program& prog) const
// bias
instruction_ref
bias
=
prog
.
end
();
if
(
args
.
size
()
>=
4
)
if
(
args
.
size
()
>=
4
&&
args
[
3
]
->
get_operator
().
name
()
!=
"undefined"
)
{
bias
=
args
[
3
];
}
// intial hidden state
instruction_ref
ih
;
if
(
args
.
size
()
==
6
||
(
args
.
size
()
==
5
&&
args
[
4
]
->
get_shape
().
lens
().
size
()
==
3
))
if
(
args
.
size
()
==
6
&&
args
[
5
]
->
get_operator
().
name
()
!=
"undefined"
)
{
ih
=
args
.
size
()
==
6
?
args
[
5
]
:
args
[
4
];
ih
=
args
[
5
];
}
else
{
...
...
@@ -135,10 +133,11 @@ void rewrite_gru::apply(program& prog) const
gru_op
.
actv_funcs
.
at
(
0
),
gru_op
.
actv_funcs
.
at
(
1
));
last_output
=
ret
[
1
];
auto
last_output
=
ret
[
1
];
// add the dimension of num_direction
prog
.
replace_instruction
(
ins
,
op
::
unsqueeze
{{
1
}},
ret
[
0
]);
auto
hidden_state
=
prog
.
replace_instruction
(
ins
,
op
::
unsqueeze
{{
1
}},
ret
[
0
]);
map_last_output
[
hidden_state
]
=
last_output
;
}
}
...
...
@@ -148,11 +147,10 @@ void rewrite_gru::apply(program& prog) const
// so we can just use it as the output here
if
(
ins
->
name
()
==
"gru_last_output"
)
{
if
(
last_output
!=
prog
.
end
())
{
prog
.
replace_instruction
(
ins
,
op
::
identity
{},
last_output
);
last_output
=
prog
.
end
();
}
auto
inputs
=
ins
->
inputs
();
assert
(
inputs
.
size
()
==
1
);
assert
(
map_last_output
.
count
(
inputs
[
0
])
>
0
);
prog
.
replace_instruction
(
ins
,
map_last_output
[
inputs
[
0
]]);
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment