Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
7f05ac8a
Commit
7f05ac8a
authored
Aug 25, 2018
by
Paul
Browse files
Load memory into lds
parent
20bdf794
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
20 additions
and
15 deletions
+20
-15
src/targets/gpu/device/include/migraph/gpu/device/nary.hpp
src/targets/gpu/device/include/migraph/gpu/device/nary.hpp
+20
-15
No files found.
src/targets/gpu/device/include/migraph/gpu/device/nary.hpp
View file @
7f05ac8a
...
@@ -39,33 +39,37 @@ auto nary_nonstandard(argument result, Arguments... args)
...
@@ -39,33 +39,37 @@ auto nary_nonstandard(argument result, Arguments... args)
inline
auto
binary_broadcast
(
argument
result
,
argument
arg1
,
argument
arg2
)
inline
auto
binary_broadcast
(
argument
result
,
argument
arg1
,
argument
arg2
)
{
{
return
[
=
](
auto
f
)
{
return
[
=
](
auto
f
)
{
const
auto
&
output_shape
=
result
.
get_shape
();
const
auto
&
b_shape
=
arg2
.
get_shape
();
const
auto
&
b_shape
=
arg2
.
get_shape
();
auto
bdim
=
std
::
distance
(
b_shape
.
strides
().
begin
(),
auto
bdim
=
std
::
distance
(
b_shape
.
strides
().
begin
(),
std
::
find_if
(
b_shape
.
strides
().
begin
(),
std
::
find_if
(
b_shape
.
strides
().
begin
(),
b_shape
.
strides
().
end
(),
b_shape
.
strides
().
end
(),
[](
auto
x
)
{
return
x
!=
0
;
}));
[](
auto
x
)
{
return
x
!=
0
;
}));
auto
bdim_len
=
b_shape
.
lens
()[
bdim
];
auto
bdim_len
=
b_shape
.
lens
()[
bdim
];
auto
outer_size
=
std
::
accumulate
(
output_shape
.
lens
().
begin
(),
output_shape
.
lens
().
begin
()
+
bdim
+
1
,
std
::
size_t
{
1
},
std
::
multiplies
<>
{});
auto
inner_size
=
std
::
accumulate
(
output_shape
.
lens
().
begin
()
+
bdim
+
1
,
output_shape
.
lens
().
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<>
{});
visit_all
(
result
,
arg1
,
arg2
)([
&
](
auto
output
,
auto
input1
,
auto
input2
)
{
visit_all
(
result
,
arg1
,
arg2
)([
&
](
auto
output
,
auto
input1
,
auto
input2
)
{
using
type
=
std
::
remove_cv_t
<
typename
decltype
(
output
)
::
value_type
>
;
auto
*
xp
=
input1
.
data
();
auto
*
xp
=
input1
.
data
();
auto
*
yp
=
input2
.
data
();
auto
*
yp
=
input2
.
data
();
auto
*
outp
=
output
.
data
();
auto
*
outp
=
output
.
data
();
gs_launch
(
outer_size
)([
=
](
auto
i
)
{
auto
*
outp2
=
outp
+
i
;
const
std
::
size_t
nlocal
=
256
;
auto
*
xp2
=
xp
+
i
;
const
std
::
size_t
nglobal
=
256
*
nlocal
;
auto
b
=
yp
[
i
%
bdim_len
];
const
std
::
size_t
n
=
output
.
size
();
for
(
std
::
size_t
j
=
0
;
j
<
inner_size
;
j
++
)
launch
(
nglobal
,
nlocal
)([
=
](
auto
idx
)
__device__
{
__shared__
type
buffer
[
2048
];
for
(
size_t
i
=
idx
.
local
;
i
<
bdim_len
;
i
+=
nlocal
)
{
buffer
[
i
]
=
yp
[
i
];
}
__syncthreads
();
for
(
size_t
i
=
idx
.
local
;
i
<
bdim_len
;
i
+=
nlocal
)
{
{
outp2
[
j
]
=
f
(
xp2
[
j
],
b
);
auto
b
=
buffer
[
i
];
for
(
size_t
j
=
idx
.
global
;
j
<
n
;
j
+=
nglobal
)
{
outp
[
j
]
=
f
(
xp
[
j
],
b
);
}
}
}
});
});
});
});
...
@@ -114,6 +118,7 @@ inline auto nary(argument result, argument arg1, argument arg2)
...
@@ -114,6 +118,7 @@ inline auto nary(argument result, argument arg1, argument arg2)
return
[
=
](
auto
f
)
{
return
[
=
](
auto
f
)
{
// TODO: Check for one broadcast stride
// TODO: Check for one broadcast stride
// TODO: Check result and arg1 shape is the same
// TODO: Check result and arg1 shape is the same
// TODO: CHeck that broadcast shape doesnt have more than 2048 elements
if
(
arg1
.
get_shape
().
standard
()
and
arg2
.
get_shape
().
broadcasted
()
and
if
(
arg1
.
get_shape
().
standard
()
and
arg2
.
get_shape
().
broadcasted
()
and
std
::
count_if
(
arg2
.
get_shape
().
strides
().
begin
(),
std
::
count_if
(
arg2
.
get_shape
().
strides
().
begin
(),
arg2
.
get_shape
().
strides
().
end
(),
arg2
.
get_shape
().
strides
().
end
(),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment