Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
395ec9c8
Commit
395ec9c8
authored
Mar 05, 2019
by
Shucai Xiao
Browse files
merge latest develop to branch
parents
23d0deb0
c1fec2c4
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
60 additions
and
5 deletions
+60
-5
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+1
-0
src/rewrite_rnn.cpp
src/rewrite_rnn.cpp
+0
-3
src/targets/cpu/target.cpp
src/targets/cpu/target.cpp
+4
-2
test/op_shape_test.cpp
test/op_shape_test.cpp
+55
-0
No files found.
src/onnx/onnx.cpp
View file @
395ec9c8
...
@@ -510,6 +510,7 @@ struct onnx_parser
...
@@ -510,6 +510,7 @@ struct onnx_parser
return
add_broadcastable_binary_op
(
l3
,
l4
,
op
::
add
{});
return
add_broadcastable_binary_op
(
l3
,
l4
,
op
::
add
{});
}
}
}
}
return
prog
.
add_instruction
(
op
::
dot
{
alpha
,
beta
},
l1
,
l2
);
return
prog
.
add_instruction
(
op
::
dot
{
alpha
,
beta
},
l1
,
l2
);
}
}
...
...
src/rewrite_rnn.cpp
View file @
395ec9c8
...
@@ -987,15 +987,12 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
...
@@ -987,15 +987,12 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
auto
spph
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
pph
);
auto
spph
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
pph
);
auto
pphi
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
spph
);
auto
pphi
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
spph
);
pphi_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_shape
},
pphi
);
pphi_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_shape
},
pphi
);
pphi_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
contiguous
{},
pphi_brcst
);
auto
ppho
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
spph
);
auto
ppho
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
spph
);
ppho_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_shape
},
ppho
);
ppho_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_shape
},
ppho
);
ppho_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
contiguous
{},
ppho_brcst
);
auto
pphf
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
2
*
hs
},
{
3
*
hs
}},
spph
);
auto
pphf
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
2
*
hs
},
{
3
*
hs
}},
spph
);
pphf_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_shape
},
pphf
);
pphf_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_shape
},
pphf
);
pphf_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
contiguous
{},
pphf_brcst
);
}
}
for
(
long
i
=
0
;
i
<
seq_len
;
++
i
)
for
(
long
i
=
0
;
i
<
seq_len
;
++
i
)
...
...
src/targets/cpu/target.cpp
View file @
395ec9c8
...
@@ -14,8 +14,10 @@ std::string target::name() const { return "cpu"; }
...
@@ -14,8 +14,10 @@ std::string target::name() const { return "cpu"; }
std
::
vector
<
pass
>
target
::
get_passes
(
migraphx
::
context
&
)
const
std
::
vector
<
pass
>
target
::
get_passes
(
migraphx
::
context
&
)
const
{
{
return
{
auto_contiguous
{},
return
{
rewrite_rnn
{},
rewrite_rnn
{},
dead_code_elimination
{},
auto_contiguous
{},
dead_code_elimination
{},
dead_code_elimination
{},
lowering
{},
lowering
{},
dead_code_elimination
{}};
dead_code_elimination
{}};
...
...
test/op_shape_test.cpp
View file @
395ec9c8
...
@@ -316,6 +316,61 @@ TEST_CASE(gather)
...
@@ -316,6 +316,61 @@ TEST_CASE(gather)
}
}
}
}
TEST_CASE
(
logsoftmax
)
{
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}};
int
axis
=
0
;
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}},
migraphx
::
op
::
logsoftmax
{
axis
},
input
);
}
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}};
int
axis
=
1
;
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}},
migraphx
::
op
::
logsoftmax
{
axis
},
input
);
}
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}};
int
axis
=
2
;
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}},
migraphx
::
op
::
logsoftmax
{
axis
},
input
);
}
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}};
int
axis
=
3
;
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}},
migraphx
::
op
::
logsoftmax
{
axis
},
input
);
}
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}};
int
axis
=
4
;
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}},
migraphx
::
op
::
logsoftmax
{
axis
},
input
);
}
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}};
int
axis
=
5
;
throws_shape
(
migraphx
::
op
::
logsoftmax
{
axis
},
input
);
}
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}};
int
axis
=
-
1
;
throws_shape
(
migraphx
::
op
::
logsoftmax
{
axis
},
input
);
}
}
TEST_CASE
(
dot
)
TEST_CASE
(
dot
)
{
{
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment