Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
9b32ae0c
Commit
9b32ae0c
authored
Aug 12, 2019
by
Paul
Browse files
Reduce dims in pooling
parent
d1e0225d
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
79 additions
and
18 deletions
+79
-18
src/include/migraphx/op/reshape.hpp
src/include/migraphx/op/reshape.hpp
+1
-1
src/rewrite_pooling.cpp
src/rewrite_pooling.cpp
+6
-1
src/targets/gpu/device/include/migraphx/gpu/device/array.hpp
src/targets/gpu/device/include/migraphx/gpu/device/array.hpp
+6
-0
src/targets/gpu/device/include/migraphx/gpu/device/reduce.hpp
...targets/gpu/device/include/migraphx/gpu/device/reduce.hpp
+66
-16
No files found.
src/include/migraphx/op/reshape.hpp
View file @
9b32ae0c
...
@@ -59,7 +59,7 @@ struct reshape
...
@@ -59,7 +59,7 @@ struct reshape
shape
s
{
inputs
.
front
().
type
(),
rdims
};
shape
s
{
inputs
.
front
().
type
(),
rdims
};
if
(
s
.
elements
()
!=
inputs
.
front
().
elements
())
if
(
s
.
elements
()
!=
inputs
.
front
().
elements
())
MIGRAPHX_THROW
(
"Wrong number of elements for reshape
"
);
MIGRAPHX_THROW
(
"Wrong number of elements for reshape
: reshape has "
+
std
::
to_string
(
s
.
elements
())
+
" elements whereas the input has "
+
std
::
to_string
(
inputs
.
front
().
elements
())
);
return
s
;
return
s
;
}
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
...
...
src/rewrite_pooling.cpp
View file @
9b32ae0c
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/reshape.hpp>
#include <migraphx/op/reduce_mean.hpp>
#include <migraphx/op/reduce_mean.hpp>
#include <migraphx/program.hpp>
#include <migraphx/program.hpp>
...
@@ -28,7 +29,11 @@ void rewrite_pooling::apply(program& prog) const
...
@@ -28,7 +29,11 @@ void rewrite_pooling::apply(program& prog) const
continue
;
continue
;
if
(
s
.
lens
()[
2
]
!=
op
.
lengths
[
0
]
and
s
.
lens
()[
3
]
!=
op
.
lengths
[
1
])
if
(
s
.
lens
()[
2
]
!=
op
.
lengths
[
0
]
and
s
.
lens
()[
3
]
!=
op
.
lengths
[
1
])
continue
;
continue
;
prog
.
replace_instruction
(
ins
,
op
::
reduce_mean
{{
2
,
3
}},
ins
->
inputs
().
front
());
std
::
int64_t
n
=
s
.
lens
()[
0
];
std
::
int64_t
c
=
s
.
lens
()[
1
];
auto
reshape
=
prog
.
insert_instruction
(
ins
,
op
::
reshape
{{
n
*
c
,
-
1
}},
ins
->
inputs
().
front
());
auto
pooling
=
prog
.
insert_instruction
(
ins
,
op
::
reduce_mean
{{
1
}},
reshape
);
prog
.
replace_instruction
(
ins
,
op
::
reshape
{{
n
,
c
,
1
,
1
}},
pooling
);
}
}
}
}
...
...
src/targets/gpu/device/include/migraphx/gpu/device/array.hpp
View file @
9b32ae0c
...
@@ -16,6 +16,12 @@ struct hip_array
...
@@ -16,6 +16,12 @@ struct hip_array
MIGRAPHX_DEVICE_CONSTEXPR
T
&
operator
[](
std
::
size_t
i
)
{
return
d
[
i
];
}
MIGRAPHX_DEVICE_CONSTEXPR
T
&
operator
[](
std
::
size_t
i
)
{
return
d
[
i
];
}
MIGRAPHX_DEVICE_CONSTEXPR
const
T
&
operator
[](
std
::
size_t
i
)
const
{
return
d
[
i
];
}
MIGRAPHX_DEVICE_CONSTEXPR
const
T
&
operator
[](
std
::
size_t
i
)
const
{
return
d
[
i
];
}
MIGRAPHX_DEVICE_CONSTEXPR
T
&
front
()
{
return
d
[
0
];
}
MIGRAPHX_DEVICE_CONSTEXPR
const
T
&
front
()
const
{
return
d
[
0
];
}
MIGRAPHX_DEVICE_CONSTEXPR
T
&
back
()
{
return
d
[
N
-
1
];
}
MIGRAPHX_DEVICE_CONSTEXPR
const
T
&
back
()
const
{
return
d
[
N
-
1
];
}
MIGRAPHX_DEVICE_CONSTEXPR
T
*
data
()
{
return
d
;
}
MIGRAPHX_DEVICE_CONSTEXPR
T
*
data
()
{
return
d
;
}
MIGRAPHX_DEVICE_CONSTEXPR
const
T
*
data
()
const
{
return
d
;
}
MIGRAPHX_DEVICE_CONSTEXPR
const
T
*
data
()
const
{
return
d
;
}
...
...
src/targets/gpu/device/include/migraphx/gpu/device/reduce.hpp
View file @
9b32ae0c
...
@@ -209,28 +209,15 @@ constexpr std::size_t compute_block_size(std::size_t n, std::size_t max_block_si
...
@@ -209,28 +209,15 @@ constexpr std::size_t compute_block_size(std::size_t n, std::size_t max_block_si
}
}
template
<
class
Op
,
class
T
,
class
Input
,
class
Output
>
template
<
class
Op
,
class
T
,
class
Input
,
class
Output
>
void
reduce
(
hipStream_t
stream
,
void
reduce
_multi_impl
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
result
,
const
argument
&
arg
,
const
argument
&
arg
,
Op
op
,
Op
op
,
T
init
,
T
init
,
Input
read_input
,
Input
read_input
,
Output
read_output
)
Output
read_output
,
const
shape
&
reduce_slice
)
{
{
auto
&&
output_shape
=
result
.
get_shape
();
auto
&&
input_shape
=
arg
.
get_shape
();
std
::
vector
<
std
::
size_t
>
reduce_lens
;
std
::
transform
(
output_shape
.
lens
().
begin
(),
output_shape
.
lens
().
end
(),
input_shape
.
lens
().
begin
(),
std
::
back_inserter
(
reduce_lens
),
[](
auto
x
,
auto
y
)
->
std
::
size_t
{
if
(
x
==
y
)
return
1
;
else
return
y
;
});
shape
reduce_slice
{
output_shape
.
type
(),
reduce_lens
};
hip_visit_all
(
result
,
arg
,
reduce_slice
)([
&
](
auto
output
,
auto
input
,
auto
reduce_shape
)
{
hip_visit_all
(
result
,
arg
,
reduce_slice
)([
&
](
auto
output
,
auto
input
,
auto
reduce_shape
)
{
auto
nelements
=
result
.
get_shape
().
elements
();
auto
nelements
=
result
.
get_shape
().
elements
();
auto
relements
=
reduce_slice
.
elements
();
auto
relements
=
reduce_slice
.
elements
();
...
@@ -250,6 +237,69 @@ void reduce(hipStream_t stream,
...
@@ -250,6 +237,69 @@ void reduce(hipStream_t stream,
});
});
}
}
template
<
class
Op
,
class
T
,
class
Input
,
class
Output
>
void
reduce_standard_impl
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
Op
op
,
T
init
,
Input
read_input
,
Output
read_output
,
std
::
size_t
relements
,
std
::
size_t
stride
)
{
hip_visit_all
(
result
,
arg
)([
&
](
auto
output
,
auto
input
)
{
auto
nelements
=
result
.
get_shape
().
elements
();
const
std
::
size_t
max_block_size
=
256
;
const
std
::
size_t
block_size
=
compute_block_size
(
relements
,
max_block_size
);
gs_launch
(
stream
,
nelements
*
block_size
,
block_size
)([
=
](
auto
i
,
auto
idx
)
__device__
{
const
auto
out_idx
=
i
/
block_size
;
const
auto
base_idx
=
out_idx
*
stride
;
auto
r
=
block_reduce
<
max_block_size
>
(
idx
,
op
,
init
,
relements
,
[
&
](
auto
j
)
__device__
{
return
read_input
(
input
.
data
()[
base_idx
+
j
]);
});
if
(
idx
.
local
==
0
)
output
.
data
()[
out_idx
]
=
read_output
(
r
);
});
});
}
template
<
class
Op
,
class
T
,
class
Input
,
class
Output
>
void
reduce
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
Op
op
,
T
init
,
Input
read_input
,
Output
read_output
)
{
auto
&&
output_shape
=
result
.
get_shape
();
auto
&&
input_shape
=
arg
.
get_shape
();
if
(
input_shape
.
standard
()
and
output_shape
.
standard
()
and
output_shape
.
lens
().
back
()
!=
input_shape
.
lens
().
back
()
and
std
::
equal
(
output_shape
.
lens
().
begin
(),
std
::
prev
(
output_shape
.
lens
().
end
()),
input_shape
.
lens
().
begin
()))
{
std
::
size_t
stride
=
std
::
accumulate
(
input_shape
.
strides
().
begin
(),
input_shape
.
strides
().
end
(),
1
,
std
::
multiplies
<
size_t
>
());
reduce_standard_impl
(
stream
,
result
,
arg
,
op
,
init
,
read_input
,
read_output
,
input_shape
.
lens
().
back
(),
stride
);
}
else
{
std
::
vector
<
std
::
size_t
>
reduce_lens
;
std
::
transform
(
output_shape
.
lens
().
begin
(),
output_shape
.
lens
().
end
(),
input_shape
.
lens
().
begin
(),
std
::
back_inserter
(
reduce_lens
),
[](
auto
x
,
auto
y
)
->
std
::
size_t
{
if
(
x
==
y
)
return
1
;
else
return
y
;
});
shape
reduce_slice
{
output_shape
.
type
(),
reduce_lens
};
reduce_multi_impl
(
stream
,
result
,
arg
,
op
,
init
,
read_input
,
read_output
,
reduce_slice
);
}
}
}
// namespace device
}
// namespace device
}
// namespace gpu
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment