Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
ec48e189
Commit
ec48e189
authored
Aug 16, 2018
by
Paul
Browse files
Use floor for pooling for now
parent
4886f3e8
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
26 additions
and
5 deletions
+26
-5
src/include/migraph/operators.hpp
src/include/migraph/operators.hpp
+26
-5
No files found.
src/include/migraph/operators.hpp
View file @
ec48e189
...
@@ -145,8 +145,8 @@ struct pooling
...
@@ -145,8 +145,8 @@ struct pooling
const
shape
&
input
=
inputs
.
at
(
0
);
const
shape
&
input
=
inputs
.
at
(
0
);
auto
t
=
input
.
type
();
auto
t
=
input
.
type
();
assert
(
lengths
[
0
]
<
(
input
.
lens
()[
2
]
+
2
*
padding
[
0
]));
assert
(
lengths
[
0
]
<
=
(
input
.
lens
()[
2
]
+
2
*
padding
[
0
]));
assert
(
lengths
[
1
]
<
(
input
.
lens
()[
3
]
+
2
*
padding
[
1
]));
assert
(
lengths
[
1
]
<
=
(
input
.
lens
()[
3
]
+
2
*
padding
[
1
]));
return
{
t
,
return
{
t
,
{
{
...
@@ -154,12 +154,12 @@ struct pooling
...
@@ -154,12 +154,12 @@ struct pooling
input
.
lens
()[
1
],
input
.
lens
()[
1
],
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
1
,
1
,
std
::
ptrdiff_t
(
std
::
ceil
((
input
.
lens
()[
2
]
+
2
*
padding
[
0
]
-
lengths
[
0
])
/
std
::
ptrdiff_t
(
std
::
floor
((
input
.
lens
()[
2
]
+
2
*
padding
[
0
]
-
lengths
[
0
])
/
static_cast
<
float
>
(
stride
[
0
])))
+
static_cast
<
float
>
(
stride
[
0
])))
+
1
)),
1
)),
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
1
,
1
,
std
::
ptrdiff_t
(
std
::
ceil
((
input
.
lens
()[
3
]
+
2
*
padding
[
1
]
-
lengths
[
1
])
/
std
::
ptrdiff_t
(
std
::
floor
((
input
.
lens
()[
3
]
+
2
*
padding
[
1
]
-
lengths
[
1
])
/
static_cast
<
float
>
(
stride
[
1
])))
+
static_cast
<
float
>
(
stride
[
1
])))
+
1
)),
1
)),
}};
}};
...
@@ -236,6 +236,13 @@ struct transpose
...
@@ -236,6 +236,13 @@ struct transpose
{
{
return
{
output_shape
,
std
::
move
(
args
.
front
().
data
)};
return
{
output_shape
,
std
::
move
(
args
.
front
().
data
)};
}
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
transpose
&
op
)
{
os
<<
op
.
name
()
<<
"["
;
os
<<
"dims={"
<<
stream_range
(
op
.
dims
)
<<
"}"
;
os
<<
"]"
;
return
os
;
}
};
};
struct
contiguous
struct
contiguous
...
@@ -305,7 +312,7 @@ struct reshape
...
@@ -305,7 +312,7 @@ struct reshape
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
reshape
&
op
)
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
reshape
&
op
)
{
{
os
<<
op
.
name
()
<<
"["
;
os
<<
op
.
name
()
<<
"["
;
os
<<
"dims={"
<<
stream_range
(
op
.
dims
)
<<
"}
,
"
;
os
<<
"dims={"
<<
stream_range
(
op
.
dims
)
<<
"}"
;
os
<<
"]"
;
os
<<
"]"
;
return
os
;
return
os
;
}
}
...
@@ -443,6 +450,13 @@ struct flatten
...
@@ -443,6 +450,13 @@ struct flatten
{
{
return
{
output_shape
,
std
::
move
(
args
.
front
().
data
)};
return
{
output_shape
,
std
::
move
(
args
.
front
().
data
)};
}
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
flatten
&
op
)
{
os
<<
op
.
name
()
<<
"["
;
os
<<
"axis="
<<
op
.
axis
;
os
<<
"]"
;
return
os
;
}
};
};
struct
broadcast
struct
broadcast
{
{
...
@@ -476,6 +490,13 @@ struct broadcast
...
@@ -476,6 +490,13 @@ struct broadcast
{
{
return
{
output_shape
,
std
::
move
(
args
.
at
(
1
).
data
)};
return
{
output_shape
,
std
::
move
(
args
.
at
(
1
).
data
)};
}
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
broadcast
&
op
)
{
os
<<
op
.
name
()
<<
"["
;
os
<<
"axis="
<<
op
.
axis
;
os
<<
"]"
;
return
os
;
}
};
};
struct
binary
struct
binary
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment