Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
3c7b6d27
Commit
3c7b6d27
authored
Feb 07, 2019
by
Shucai Xiao
Browse files
merge rnn operator rewritting into one file, so only one pass is needed
parent
857df64e
Changes
7
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
551 additions
and
584 deletions
+551
-584
src/CMakeLists.txt
src/CMakeLists.txt
+0
-1
src/include/migraphx/rewrite_gru.hpp
src/include/migraphx/rewrite_gru.hpp
+0
-38
src/include/migraphx/rewrite_rnn.hpp
src/include/migraphx/rewrite_rnn.hpp
+14
-1
src/rewrite_gru.cpp
src/rewrite_gru.cpp
+0
-380
src/rewrite_rnn.cpp
src/rewrite_rnn.cpp
+537
-158
src/targets/cpu/target.cpp
src/targets/cpu/target.cpp
+0
-3
src/targets/gpu/target.cpp
src/targets/gpu/target.cpp
+0
-3
No files found.
src/CMakeLists.txt
View file @
3c7b6d27
...
...
@@ -12,7 +12,6 @@ add_library(migraphx
eliminate_concat.cpp
fwd_conv_batchnorm_rewrite.cpp
rewrite_rnn.cpp
rewrite_gru.cpp
env.cpp
generate.cpp
instruction.cpp
...
...
src/include/migraphx/rewrite_gru.hpp
deleted
100644 → 0
View file @
857df64e
#ifndef MIGRAPHX_GUARD_RTGLIB_REWRITE_GRU_HPP
#define MIGRAPHX_GUARD_RTGLIB_REWRITE_GRU_HPP
#include <string>
#include <vector>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
program
;
/**
* Rewrite gru to gemm, mul, and add.
*/
struct
rewrite_gru
{
std
::
string
name
()
const
{
return
"rewrite_gru"
;
}
void
apply
(
program
&
prog
)
const
;
private:
std
::
vector
<
instruction_ref
>
gru_cell
(
bool
is_forward
,
program
&
prog
,
instruction_ref
ins
,
std
::
vector
<
instruction_ref
>
inputs
,
int
linear_before_reset
,
const
operation
&
actv_func1
,
const
operation
&
actv_func2
)
const
;
std
::
vector
<
operation
>
compute_actv_funcs
(
instruction_ref
ins
)
const
;
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/rewrite_rnn.hpp
View file @
3c7b6d27
...
...
@@ -21,6 +21,8 @@ struct rewrite_rnn
void
apply
(
program
&
prog
)
const
;
private:
// for vallina rnn operators
void
apply_vallina_rnn
(
program
&
prog
,
instruction_ref
ins
)
const
;
std
::
vector
<
instruction_ref
>
rnn_cell
(
bool
is_forward
,
program
&
prog
,
instruction_ref
ins
,
...
...
@@ -30,8 +32,19 @@ struct rewrite_rnn
instruction_ref
bias
,
instruction_ref
ih
,
operation
&
actv_func
)
const
;
std
::
vector
<
operation
>
rnn_actv_funcs
(
instruction_ref
ins
)
const
;
std
::
vector
<
operation
>
compute_actv_funcs
(
instruction_ref
ins
)
const
;
// for gru operators
void
apply_gru
(
program
&
prog
,
instruction_ref
ins
)
const
;
std
::
vector
<
instruction_ref
>
gru_cell
(
bool
is_forward
,
program
&
prog
,
instruction_ref
ins
,
std
::
vector
<
instruction_ref
>
inputs
,
int
linear_before_reset
,
const
operation
&
actv_func1
,
const
operation
&
actv_func2
)
const
;
std
::
vector
<
operation
>
gru_actv_funcs
(
instruction_ref
ins
)
const
;
};
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/rewrite_gru.cpp
deleted
100644 → 0
View file @
857df64e
#include <migraphx/rewrite_gru.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/dfor.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
void
rewrite_gru
::
apply
(
program
&
prog
)
const
{
for
(
auto
ins
:
iterator_for
(
prog
))
{
if
(
ins
->
name
()
==
"gru"
)
{
const
auto
actv_funcs
=
compute_actv_funcs
(
ins
);
// could be 3 to 5 inputs (though onnx::rnn has 6 inputs,
// the 5th one is undefined and ignored by protobuf. so
// we need to process up to 5 inputs
auto
args
=
ins
->
inputs
();
shape
seq_shape
=
args
[
0
]
->
get_shape
();
std
::
size_t
hidden_size
=
args
[
2
]
->
get_shape
().
lens
()[
2
];
std
::
size_t
batch_size
=
seq_shape
.
lens
()[
1
];
shape
::
type_t
type
=
seq_shape
.
type
();
migraphx
::
shape
ih_shape
{
type
,
{
1
,
batch_size
,
hidden_size
}};
std
::
vector
<
float
>
data
(
ih_shape
.
elements
(),
0.0
);
auto
gru_op
=
any_cast
<
op
::
gru
>
(
ins
->
get_operator
());
op
::
gru
::
gru_direction_t
dicrt
=
gru_op
.
direction
;
instruction_ref
last_output
{};
if
(
dicrt
==
op
::
gru
::
bidirectional
)
{
// w weight matrix
auto
w_forward
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
1
}},
args
[
1
]);
auto
w_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
1
},
{
2
}},
args
[
1
]);
// r weight matrix
auto
r_forward
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
1
}},
args
[
2
]);
auto
r_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
1
},
{
2
}},
args
[
2
]);
// bias
instruction_ref
bias_forward
=
prog
.
end
();
instruction_ref
bias_reverse
=
prog
.
end
();
if
(
args
.
size
()
>=
4
&&
args
[
3
]
->
get_operator
().
name
()
!=
"undefined"
)
{
bias_forward
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
1
}},
args
[
3
]);
bias_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
1
},
{
2
}},
args
[
3
]);
}
// intial hidden state
instruction_ref
ih_forward
{};
instruction_ref
ih_reverse
{};
if
(
args
.
size
()
==
6
&&
args
[
5
]
->
get_operator
().
name
()
!=
"undefined"
)
{
ih_forward
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
1
}},
args
[
5
]);
ih_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
1
},
{
2
}},
args
[
5
]);
}
else
{
ih_forward
=
prog
.
add_literal
(
migraphx
::
literal
{
ih_shape
,
data
});
ih_reverse
=
prog
.
add_literal
(
migraphx
::
literal
{
ih_shape
,
data
});
}
auto
ret_forward
=
gru_cell
(
true
,
prog
,
ins
,
{
args
[
0
],
w_forward
,
r_forward
,
bias_forward
,
ih_forward
},
gru_op
.
linear_before_reset
,
actv_funcs
.
at
(
0
),
actv_funcs
.
at
(
1
));
auto
ret_reverse
=
gru_cell
(
false
,
prog
,
ins
,
{
args
[
0
],
w_reverse
,
r_reverse
,
bias_reverse
,
ih_reverse
},
gru_op
.
linear_before_reset
,
actv_funcs
.
at
(
2
),
actv_funcs
.
at
(
3
));
auto
concat_output
=
prog
.
insert_instruction
(
ins
,
op
::
concat
{
1
},
ret_forward
[
1
],
ret_reverse
[
1
]);
last_output
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
concat_output
);
// The following logic is to ensure the last instruction rewritten
// from gru operator is a concat
instruction_ref
hidden_state
{};
if
(
ret_forward
[
0
]
==
prog
.
end
())
{
hidden_state
=
prog
.
replace_instruction
(
ins
,
op
::
concat
{
1
},
ret_forward
[
1
],
ret_reverse
[
1
]);
}
else
{
ret_forward
[
0
]
=
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
ret_forward
[
0
],
ret_forward
[
1
]);
ret_reverse
[
0
]
=
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
ret_reverse
[
1
],
ret_reverse
[
0
]);
hidden_state
=
prog
.
replace_instruction
(
ins
,
op
::
concat
{
1
},
{
ret_forward
[
0
],
ret_reverse
[
0
]});
}
}
else
{
bool
is_forward
=
(
dicrt
==
op
::
gru
::
forward
);
// weight matrix
auto
w
=
args
[
1
];
auto
r
=
args
[
2
];
// bias
instruction_ref
bias
=
prog
.
end
();
if
(
args
.
size
()
>=
4
&&
args
[
3
]
->
get_operator
().
name
()
!=
"undefined"
)
{
bias
=
args
[
3
];
}
// intial hidden state
instruction_ref
ih
{};
if
(
args
.
size
()
==
6
&&
args
[
5
]
->
get_operator
().
name
()
!=
"undefined"
)
{
ih
=
args
[
5
];
}
else
{
ih
=
prog
.
add_literal
(
migraphx
::
literal
{
ih_shape
,
data
});
}
auto
ret
=
gru_cell
(
is_forward
,
prog
,
ins
,
{
args
[
0
],
w
,
r
,
bias
,
ih
},
gru_op
.
linear_before_reset
,
actv_funcs
.
at
(
0
),
actv_funcs
.
at
(
1
));
last_output
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ret
[
1
]);
instruction_ref
hidden_state
{};
if
(
ret
[
0
]
==
prog
.
end
())
{
hidden_state
=
prog
.
replace_instruction
(
ins
,
op
::
concat
{
0
},
ret
[
1
]);
}
else
{
auto
concat_arg0
=
is_forward
?
ret
[
0
]
:
ret
[
1
];
auto
concat_arg1
=
is_forward
?
ret
[
1
]
:
ret
[
0
];
hidden_state
=
prog
.
replace_instruction
(
ins
,
op
::
concat
{
0
},
concat_arg0
,
concat_arg1
);
}
}
// replace the corresponding gru_last_output instruction
// with the last_output, if gru_last_output exists
// while loop to handle case of multiple gru_last_output operators
auto
last_output_it
=
ins
->
outputs
().
begin
();
while
(
last_output_it
!=
ins
->
outputs
().
end
())
{
last_output_it
=
std
::
find_if
(
last_output_it
,
ins
->
outputs
().
end
(),
[](
auto
i
)
{
return
i
->
name
()
==
"gru_last_output"
;
});
if
(
last_output_it
!=
ins
->
outputs
().
end
())
{
prog
.
replace_instruction
(
*
last_output_it
,
last_output
);
last_output_it
++
;
}
}
}
}
}
std
::
vector
<
instruction_ref
>
rewrite_gru
::
gru_cell
(
bool
is_forward
,
program
&
prog
,
instruction_ref
ins
,
std
::
vector
<
instruction_ref
>
inputs
,
int
linear_before_reset
,
const
operation
&
actv_func1
,
const
operation
&
actv_func2
)
const
{
assert
(
inputs
.
size
()
==
5
);
auto
seq
=
inputs
.
at
(
0
);
auto
w
=
inputs
.
at
(
1
);
auto
r
=
inputs
.
at
(
2
);
auto
bias
=
inputs
.
at
(
3
);
auto
ih
=
inputs
.
at
(
4
);
instruction_ref
hidden_states
=
prog
.
end
(),
last_output
;
long
seq_len
=
static_cast
<
long
>
(
seq
->
get_shape
().
lens
()[
0
]);
long
hs
=
static_cast
<
long
>
(
r
->
get_shape
().
lens
()[
2
]);
migraphx
::
shape
s
(
seq
->
get_shape
().
type
(),
{
seq
->
get_shape
().
lens
()[
1
],
static_cast
<
std
::
size_t
>
(
hs
)});
std
::
vector
<
int
>
data
(
s
.
elements
(),
1
);
auto
l1
=
prog
.
add_literal
(
migraphx
::
literal
{
s
,
data
});
// weight matrix
std
::
vector
<
int64_t
>
perm
{
1
,
0
};
auto
sw
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
w
);
auto
wz
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
sw
);
auto
tran_wz
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
wz
);
auto
wr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
sw
);
auto
tran_wr
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
wr
);
auto
wh
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
2
*
hs
},
{
3
*
hs
}},
sw
);
auto
tran_wh
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
wh
);
auto
sr
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
r
);
auto
rz
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
sr
);
auto
tran_rz
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
rz
);
auto
rr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
sr
);
auto
tran_rr
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
rr
);
auto
rh
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
2
*
hs
},
{
3
*
hs
}},
sr
);
auto
tran_rh
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
rh
);
// initial states
auto
sih
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ih
);
// bias
instruction_ref
brcst_bz
{};
instruction_ref
brcst_br
{};
instruction_ref
brcst_wbh
{};
instruction_ref
brcst_rbh
{};
instruction_ref
brcst_bh
{};
if
(
bias
!=
prog
.
end
())
{
auto
sbias
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
bias
);
auto
wbz
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
sbias
);
auto
wbr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
sbias
);
auto
wbh
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
2
*
hs
},
{
3
*
hs
}},
sbias
);
brcst_wbh
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
sih
->
get_shape
()},
wbh
);
auto
rbz
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
3
*
hs
},
{
4
*
hs
}},
sbias
);
auto
rbr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
4
*
hs
},
{
5
*
hs
}},
sbias
);
auto
rbh
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
5
*
hs
},
{
6
*
hs
}},
sbias
);
brcst_rbh
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
sih
->
get_shape
()},
rbh
);
auto
bz
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
wbz
,
rbz
);
brcst_bz
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
sih
->
get_shape
()},
bz
);
auto
br
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
wbr
,
rbr
);
brcst_br
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
sih
->
get_shape
()},
br
);
auto
bh
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
wbh
,
rbh
);
brcst_bh
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
sih
->
get_shape
()},
bh
);
}
for
(
long
i
=
0
;
i
<
seq_len
;
i
++
)
{
long
seq_index
=
is_forward
?
i
:
(
seq_len
-
1
-
i
);
auto
xt
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
seq_index
},
{
seq_index
+
1
}},
seq
);
xt
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
xt
);
// equation f(xt*(Wz^T) + Ht-1 * (Rz^T) + Wbz + Rbz)
auto
xt_wz
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wz
);
auto
ht_rz
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_rz
);
auto
xht_z
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wz
,
ht_rz
);
if
(
bias
!=
prog
.
end
())
{
xht_z
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xht_z
,
brcst_bz
);
}
auto
zt
=
prog
.
insert_instruction
(
ins
,
actv_func1
,
xht_z
);
// equation f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr)
auto
xt_wr
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wr
);
auto
ht_rr
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_rr
);
auto
xht_r
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wr
,
ht_rr
);
if
(
bias
!=
prog
.
end
())
{
xht_r
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xht_r
,
brcst_br
);
}
auto
rt
=
prog
.
insert_instruction
(
ins
,
actv_func1
,
xht_r
);
instruction_ref
xht_h
;
if
(
linear_before_reset
==
0
)
{
// equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh)
auto
xt_wh
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wh
);
auto
rt_ht1
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
rt
,
sih
);
auto
rt_rh
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
rt_ht1
,
tran_rh
);
xht_h
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wh
,
rt_rh
);
if
(
bias
!=
prog
.
end
())
{
xht_h
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xht_h
,
brcst_bh
);
}
}
else
{
// equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh)
auto
xt_wh
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wh
);
auto
ht1_rh
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_rh
);
if
(
bias
!=
prog
.
end
())
{
ht1_rh
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
ht1_rh
,
brcst_rbh
);
}
auto
rt_rh
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
rt
,
ht1_rh
);
xht_h
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wh
,
rt_rh
);
if
(
bias
!=
prog
.
end
())
{
xht_h
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xht_h
,
brcst_wbh
);
}
}
auto
ht
=
prog
.
insert_instruction
(
ins
,
actv_func2
,
xht_h
);
// equation Ht = (1 - zt) (.) ht + zt (.) Ht-1
auto
one_minus_zt
=
prog
.
insert_instruction
(
ins
,
op
::
sub
{},
l1
,
zt
);
auto
one_minus_zt_ht
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
one_minus_zt
,
ht
);
auto
zt_ht1
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
zt
,
sih
);
sih
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
one_minus_zt_ht
,
zt_ht1
);
last_output
=
prog
.
insert_instruction
(
ins
,
op
::
unsqueeze
{{
0
,
1
}},
sih
);
if
(
i
<
seq_len
-
1
)
{
if
(
is_forward
)
{
hidden_states
=
(
seq_index
==
0
)
?
last_output
:
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
hidden_states
,
last_output
);
}
else
{
hidden_states
=
(
seq_index
==
seq_len
-
1
)
?
last_output
:
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
last_output
,
hidden_states
);
}
}
}
return
{
hidden_states
,
last_output
};
}
std
::
vector
<
operation
>
rewrite_gru
::
compute_actv_funcs
(
instruction_ref
ins
)
const
{
auto
gru_op
=
any_cast
<
op
::
gru
>
(
ins
->
get_operator
());
// before rewrite the gru operator, need to ensure
// we have 4 actv funcs, even though a user does not
// specifiy any actv func. If less than 4, use the
// algorithm in parse_gru to make 4 actv functions
if
(
gru_op
.
direction
==
op
::
gru
::
bidirectional
)
{
if
(
gru_op
.
actv_funcs
.
empty
())
return
{
op
::
sigmoid
{},
op
::
tanh
{},
op
::
sigmoid
{},
op
::
tanh
{}};
else
if
(
gru_op
.
actv_funcs
.
size
()
==
1
)
return
{
gru_op
.
actv_funcs
.
at
(
0
),
gru_op
.
actv_funcs
.
at
(
0
),
gru_op
.
actv_funcs
.
at
(
0
),
gru_op
.
actv_funcs
.
at
(
0
)};
else
if
(
gru_op
.
actv_funcs
.
size
()
==
2
)
return
{
gru_op
.
actv_funcs
.
at
(
0
),
gru_op
.
actv_funcs
.
at
(
1
),
gru_op
.
actv_funcs
.
at
(
0
),
gru_op
.
actv_funcs
.
at
(
1
)};
else
if
(
gru_op
.
actv_funcs
.
size
()
==
3
)
return
{
gru_op
.
actv_funcs
.
at
(
0
),
gru_op
.
actv_funcs
.
at
(
1
),
gru_op
.
actv_funcs
.
at
(
2
),
gru_op
.
actv_funcs
.
at
(
0
)};
else
return
gru_op
.
actv_funcs
;
}
else
{
if
(
gru_op
.
actv_funcs
.
empty
())
return
{
op
::
sigmoid
{},
op
::
tanh
{}};
else
if
(
gru_op
.
actv_funcs
.
size
()
==
1
)
return
{
gru_op
.
actv_funcs
.
at
(
0
),
gru_op
.
actv_funcs
.
at
(
0
)};
else
return
gru_op
.
actv_funcs
;
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/rewrite_rnn.cpp
View file @
3c7b6d27
This diff is collapsed.
Click to expand it.
src/targets/cpu/target.cpp
View file @
3c7b6d27
...
...
@@ -3,7 +3,6 @@
#include <migraphx/cpu/lowering.hpp>
#include <migraphx/auto_contiguous.hpp>
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/rewrite_gru.hpp>
#include <migraphx/dead_code_elimination.hpp>
namespace
migraphx
{
...
...
@@ -17,8 +16,6 @@ std::vector<pass> target::get_passes(migraphx::context&) const
return
{
auto_contiguous
{},
rewrite_rnn
{},
dead_code_elimination
{},
rewrite_gru
{},
dead_code_elimination
{},
lowering
{},
dead_code_elimination
{}};
}
...
...
src/targets/gpu/target.cpp
View file @
3c7b6d27
...
...
@@ -16,7 +16,6 @@
#include <migraphx/common_subexpression_elimination.hpp>
#include <migraphx/fwd_conv_batchnorm_rewrite.hpp>
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/rewrite_gru.hpp>
#include <migraphx/eliminate_concat.hpp>
#include <migraphx/gpu/concat_gpu_opt.hpp>
...
...
@@ -35,8 +34,6 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
dead_code_elimination
{},
rewrite_rnn
{},
dead_code_elimination
{},
rewrite_gru
{},
dead_code_elimination
{},
//common_subexpression_elimination{},
//dead_code_elimination{},
simplify_algebra
{},
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment