Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
84ecee26
Commit
84ecee26
authored
Jun 14, 2019
by
Shucai Xiao
Browse files
clang format
parent
249f5024
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
20 additions
and
16 deletions
+20
-16
src/py/migraphx_py.cpp
src/py/migraphx_py.cpp
+2
-2
src/rewrite_rnn.cpp
src/rewrite_rnn.cpp
+18
-14
No files found.
src/py/migraphx_py.cpp
View file @
84ecee26
...
...
@@ -104,7 +104,7 @@ migraphx::shape to_shape(const py::buffer_info& info)
}
});
if
(
n
==
0
)
if
(
n
==
0
)
{
MIGRAPHX_THROW
(
"MIGRAPHX PYTHON: Unsupported data type"
+
info
.
format
);
}
...
...
@@ -140,7 +140,7 @@ PYBIND11_MODULE(migraphx, m)
.
def
(
"__init__"
,
[](
migraphx
::
argument
&
x
,
py
::
buffer
b
)
{
py
::
buffer_info
info
=
b
.
request
();
auto
s
=
to_shape
(
info
);
auto
s
=
to_shape
(
info
);
new
(
&
x
)
migraphx
::
argument
(
to_shape
(
info
),
info
.
ptr
);
})
.
def
(
"get_shape"
,
&
migraphx
::
argument
::
get_shape
)
...
...
src/rewrite_rnn.cpp
View file @
84ecee26
...
...
@@ -903,17 +903,17 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
migraphx
::
shape
r_shape
=
r
->
get_shape
();
long
seq_len
=
static_cast
<
long
>
(
seq_shape
.
lens
()[
0
]);
long
hs
=
static_cast
<
long
>
(
r_shape
.
lens
()[
2
]);
auto
bs
=
ih
->
get_shape
().
lens
()[
1
];
auto
bs
=
ih
->
get_shape
().
lens
()[
1
];
std
::
vector
<
int64_t
>
perm
{
1
,
0
};
// w matrix, squeeze and transpose
auto
sw
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
w
);
auto
sw
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
w
);
auto
tsw
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
sw
);
// r matrix, squeeze and transpose
auto
sr
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
r
);
auto
sr
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
r
);
auto
tsr
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
sr
);
// initial hidden state
auto
sih
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ih
);
...
...
@@ -931,8 +931,10 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
auto
ub_wb
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
4
*
hs
}},
sbias
);
auto
ub_rb
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
4
*
hs
},
{
8
*
hs
}},
sbias
);
wb
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
{
bs
,
4
*
static_cast
<
size_t
>
(
hs
)}},
ub_wb
);
rb
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
{
bs
,
4
*
static_cast
<
size_t
>
(
hs
)}},
ub_rb
);
wb
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
{
bs
,
4
*
static_cast
<
size_t
>
(
hs
)}},
ub_wb
);
rb
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
{
bs
,
4
*
static_cast
<
size_t
>
(
hs
)}},
ub_rb
);
}
// peep hole
...
...
@@ -959,23 +961,25 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
xt
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
xt
);
instruction_ref
xt_sih
{};
if
(
bias
!=
prog
.
end
())
if
(
bias
!=
prog
.
end
())
{
auto
xt_tsw
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tsw
,
wb
);
auto
xt_tsw
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tsw
,
wb
);
auto
sih_tsr
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tsr
,
rb
);
xt_sih
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_tsw
,
sih_tsr
);
xt_sih
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_tsw
,
sih_tsr
);
}
else
{
auto
xt_tsw
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tsw
);
auto
xt_tsw
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tsw
);
auto
sih_tsr
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tsr
);
xt_sih
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_tsw
,
sih_tsr
);
xt_sih
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_tsw
,
sih_tsr
);
}
auto
it_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
1
},
{
0
},
{
hs
}},
xt_sih
);
auto
ot_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
1
},
{
hs
},
{
2
*
hs
}},
xt_sih
);
auto
ft_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
1
},
{
2
*
hs
},
{
3
*
hs
}},
xt_sih
);
auto
ct_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
1
},
{
3
*
hs
},
{
4
*
hs
}},
xt_sih
);
auto
ft_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
1
},
{
2
*
hs
},
{
3
*
hs
}},
xt_sih
);
auto
ct_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
1
},
{
3
*
hs
},
{
4
*
hs
}},
xt_sih
);
if
(
pph
!=
prog
.
end
())
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment