Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
3809fcb4
Commit
3809fcb4
authored
Jun 25, 2019
by
Paul
Browse files
Reduce block size for reductions
parent
1e2ef8fa
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
56 additions
and
11 deletions
+56
-11
src/targets/gpu/device/include/migraphx/gpu/device/array.hpp
src/targets/gpu/device/include/migraphx/gpu/device/array.hpp
+8
-0
src/targets/gpu/device/include/migraphx/gpu/device/launch.hpp
...targets/gpu/device/include/migraphx/gpu/device/launch.hpp
+32
-3
src/targets/gpu/device/reduce_sum.cpp
src/targets/gpu/device/reduce_sum.cpp
+16
-8
No files found.
src/targets/gpu/device/include/migraphx/gpu/device/array.hpp
View file @
3809fcb4
...
@@ -50,6 +50,14 @@ struct hip_array
...
@@ -50,6 +50,14 @@ struct hip_array
result
[
i
]
=
x
[
i
]
*
y
[
i
];
result
[
i
]
=
x
[
i
]
*
y
[
i
];
return
result
;
return
result
;
}
}
friend
MIGRAPHX_DEVICE_CONSTEXPR
hip_array
operator
+
(
const
hip_array
&
x
,
const
hip_array
&
y
)
{
hip_array
result
;
for
(
std
::
size_t
i
=
0
;
i
<
N
;
i
++
)
result
[
i
]
=
x
[
i
]
+
y
[
i
];
return
result
;
}
};
};
}
// namespace device
}
// namespace device
...
...
src/targets/gpu/device/include/migraphx/gpu/device/launch.hpp
View file @
3809fcb4
...
@@ -14,6 +14,36 @@ struct index
...
@@ -14,6 +14,36 @@ struct index
std
::
size_t
global
;
std
::
size_t
global
;
std
::
size_t
local
;
std
::
size_t
local
;
std
::
size_t
group
;
std
::
size_t
group
;
__device__
std
::
size_t
nglobal
()
const
{
return
blockDim
.
x
*
gridDim
.
x
;
}
__device__
std
::
size_t
nlocal
()
const
{
return
blockDim
.
x
;
}
template
<
class
F
>
__device__
void
global_stride
(
std
::
size_t
n
,
F
f
)
const
{
const
auto
stride
=
nglobal
();
for
(
std
::
size_t
i
=
global
;
i
<
n
;
i
+=
stride
)
{
f
(
i
);
}
}
template
<
class
F
>
__device__
void
local_stride
(
std
::
size_t
n
,
F
f
)
const
{
const
auto
stride
=
nlocal
();
for
(
std
::
size_t
i
=
local
;
i
<
n
;
i
+=
stride
)
{
f
(
i
);
}
}
};
};
template
<
class
F
>
template
<
class
F
>
...
@@ -54,10 +84,9 @@ inline auto gs_launch(hipStream_t stream, std::size_t n, std::size_t local = 102
...
@@ -54,10 +84,9 @@ inline auto gs_launch(hipStream_t stream, std::size_t n, std::size_t local = 102
return
[
=
](
auto
f
)
{
return
[
=
](
auto
f
)
{
launch
(
stream
,
nglobal
,
local
)([
=
](
auto
idx
)
{
launch
(
stream
,
nglobal
,
local
)([
=
](
auto
idx
)
{
for
(
size_t
i
=
idx
.
global
;
i
<
n
;
i
+=
nglobal
)
idx
.
global_stride
(
n
,
[
&
](
auto
i
)
{
{
gs_invoke
(
f
,
i
,
idx
);
gs_invoke
(
f
,
i
,
idx
);
}
}
);
});
});
};
};
}
}
...
...
src/targets/gpu/device/reduce_sum.cpp
View file @
3809fcb4
...
@@ -22,17 +22,16 @@ __device__ auto block_reduce(index idx, Op op, T init, std::size_t n, F f)
...
@@ -22,17 +22,16 @@ __device__ auto block_reduce(index idx, Op op, T init, std::size_t n, F f)
using
type
=
decltype
(
f
(
idx
.
local
));
using
type
=
decltype
(
f
(
idx
.
local
));
MIGRAPHX_DEVICE_SHARED
type
buffer
[
N
];
MIGRAPHX_DEVICE_SHARED
type
buffer
[
N
];
type
x
=
init
;
type
x
=
init
;
for
(
size_t
i
=
idx
.
local
;
i
<
n
;
i
+=
N
)
idx
.
local_stride
(
n
,
[
&
](
auto
i
)
{
{
x
=
op
(
x
,
f
(
i
));
x
=
op
(
x
,
f
(
i
));
}
}
);
buffer
[
idx
.
local
]
=
x
;
buffer
[
idx
.
local
]
=
x
;
__syncthreads
();
__syncthreads
();
for
(
std
::
size_t
s
=
1
;
s
<
N
;
s
*=
2
)
for
(
std
::
size_t
s
=
1
;
s
<
idx
.
nlocal
()
;
s
*=
2
)
{
{
const
std
::
size_t
index
=
2
*
s
*
idx
.
local
;
const
std
::
size_t
index
=
2
*
s
*
idx
.
local
;
if
(
index
<
N
)
if
(
index
<
idx
.
nlocal
()
)
{
{
buffer
[
index
]
=
op
(
buffer
[
index
],
buffer
[
index
+
s
]);
buffer
[
index
]
=
op
(
buffer
[
index
],
buffer
[
index
+
s
]);
}
}
...
@@ -41,6 +40,14 @@ __device__ auto block_reduce(index idx, Op op, T init, std::size_t n, F f)
...
@@ -41,6 +40,14 @@ __device__ auto block_reduce(index idx, Op op, T init, std::size_t n, F f)
return
buffer
[
0
];
return
buffer
[
0
];
}
}
constexpr
std
::
size_t
compute_block_size
(
std
::
size_t
n
,
std
::
size_t
max_block_size
)
{
size_t
block_size
=
1
;
while
(
block_size
<
max_block_size
and
block_size
<
n
)
block_size
*=
2
;
return
block_size
;
}
void
reduce_sum
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
)
void
reduce_sum
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
)
{
{
auto
&&
output_shape
=
result
.
get_shape
();
auto
&&
output_shape
=
result
.
get_shape
();
...
@@ -61,11 +68,12 @@ void reduce_sum(hipStream_t stream, const argument& result, const argument& arg)
...
@@ -61,11 +68,12 @@ void reduce_sum(hipStream_t stream, const argument& result, const argument& arg)
auto
nelements
=
result
.
get_shape
().
elements
();
auto
nelements
=
result
.
get_shape
().
elements
();
auto
relements
=
reduce_slice
.
elements
();
auto
relements
=
reduce_slice
.
elements
();
const
std
::
size_t
block_size
=
1024
;
const
std
::
size_t
max_block_size
=
1024
;
const
std
::
size_t
block_size
=
compute_block_size
(
relements
,
max_block_size
);
gs_launch
(
stream
,
nelements
*
block_size
,
block_size
)([
=
](
auto
i
,
auto
idx
)
__device__
{
gs_launch
(
stream
,
nelements
*
block_size
,
block_size
)([
=
](
auto
i
,
auto
idx
)
__device__
{
auto
base_idx
=
output
.
get_shape
().
multi
(
i
/
block_size
);
auto
base_idx
=
output
.
get_shape
().
multi
(
i
/
block_size
);
auto
offset
=
input
.
get_shape
().
index
(
base_idx
);
auto
offset
=
input
.
get_shape
().
index
(
base_idx
);
auto
r
=
block_reduce
<
block_size
>
(
idx
,
sum
{},
0
,
relements
,
[
&
](
auto
j
)
__device__
{
auto
r
=
block_reduce
<
max_
block_size
>
(
idx
,
sum
{},
0
,
relements
,
[
&
](
auto
j
)
__device__
{
return
input
.
data
()[
reduce_shape
.
index
(
j
)
+
offset
];
return
input
.
data
()[
reduce_shape
.
index
(
j
)
+
offset
];
});
});
if
(
idx
.
local
==
0
)
if
(
idx
.
local
==
0
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment