Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
16864eef
Commit
16864eef
authored
Jul 23, 2019
by
Paul
Browse files
Formatting
parent
928cb435
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
38 additions
and
39 deletions
+38
-39
src/targets/gpu/device/include/migraphx/gpu/device/nary.hpp
src/targets/gpu/device/include/migraphx/gpu/device/nary.hpp
+38
-39
No files found.
src/targets/gpu/device/include/migraphx/gpu/device/nary.hpp
View file @
16864eef
...
@@ -119,7 +119,8 @@ void nary_broadcast_impl(hipStream_t stream, F f, argument result, argument barg
...
@@ -119,7 +119,8 @@ void nary_broadcast_impl(hipStream_t stream, F f, argument result, argument barg
}
}
template
<
class
F
,
class
...
Arguments
>
template
<
class
F
,
class
...
Arguments
>
void
nary_double_broadcast_impl
(
hipStream_t
stream
,
F
f
,
argument
result
,
argument
barg1
,
argument
barg2
,
Arguments
...
args
)
void
nary_double_broadcast_impl
(
hipStream_t
stream
,
F
f
,
argument
result
,
argument
barg1
,
argument
barg2
,
Arguments
...
args
)
{
{
const
auto
&
output_shape
=
result
.
get_shape
();
const
auto
&
output_shape
=
result
.
get_shape
();
const
auto
&
b_shape
=
barg1
.
get_shape
();
const
auto
&
b_shape
=
barg1
.
get_shape
();
...
@@ -135,7 +136,8 @@ void nary_double_broadcast_impl(hipStream_t stream, F f, argument result, argume
...
@@ -135,7 +136,8 @@ void nary_double_broadcast_impl(hipStream_t stream, F f, argument result, argume
const
std
::
size_t
nlocal
=
1024
;
const
std
::
size_t
nlocal
=
1024
;
const
std
::
size_t
nglobal
=
256
*
nlocal
;
const
std
::
size_t
nglobal
=
256
*
nlocal
;
std
::
size_t
nelements
=
result
.
get_shape
().
elements
();
std
::
size_t
nelements
=
result
.
get_shape
().
elements
();
hip_visit_all
(
result
,
barg1
,
barg2
,
args
...)([
&
](
auto
output
,
auto
binput1
,
auto
binput2
,
auto
...
inputs
)
{
hip_visit_all
(
result
,
barg1
,
barg2
,
args
...)(
[
&
](
auto
output
,
auto
binput1
,
auto
binput2
,
auto
...
inputs
)
{
using
type
=
typename
decltype
(
output
)
::
value_type
;
using
type
=
typename
decltype
(
output
)
::
value_type
;
launch
(
stream
,
nglobal
,
nlocal
)([
=
](
auto
idx
)
__device__
{
launch
(
stream
,
nglobal
,
nlocal
)([
=
](
auto
idx
)
__device__
{
MIGRAPHX_DEVICE_SHARED
type
buffer
[
2048
];
MIGRAPHX_DEVICE_SHARED
type
buffer
[
2048
];
...
@@ -146,7 +148,7 @@ void nary_double_broadcast_impl(hipStream_t stream, F f, argument result, argume
...
@@ -146,7 +148,7 @@ void nary_double_broadcast_impl(hipStream_t stream, F f, argument result, argume
}
}
for
(
size_t
i
=
idx
.
local
;
i
<
bdim_len
;
i
+=
nlocal
)
for
(
size_t
i
=
idx
.
local
;
i
<
bdim_len
;
i
+=
nlocal
)
{
{
buffer
[
i
+
bdim_len
]
=
binput2
.
data
()[
i
+
bdim_len
];
buffer
[
i
+
bdim_len
]
=
binput2
.
data
()[
i
+
bdim_len
];
}
}
__syncthreads
();
__syncthreads
();
// Process the data
// Process the data
...
@@ -154,7 +156,7 @@ void nary_double_broadcast_impl(hipStream_t stream, F f, argument result, argume
...
@@ -154,7 +156,7 @@ void nary_double_broadcast_impl(hipStream_t stream, F f, argument result, argume
{
{
auto
bidx
=
(
i
%
bdim_next_stride
)
/
bdim_stride
;
auto
bidx
=
(
i
%
bdim_next_stride
)
/
bdim_stride
;
auto
b1
=
buffer
[
bidx
];
auto
b1
=
buffer
[
bidx
];
auto
b2
=
buffer
[
bidx
+
bdim_len
];
auto
b2
=
buffer
[
bidx
+
bdim_len
];
output
.
data
()[
i
]
=
f
(
inputs
.
data
()[
i
]...,
b1
,
b2
);
output
.
data
()[
i
]
=
f
(
inputs
.
data
()[
i
]...,
b1
,
b2
);
}
}
});
});
...
@@ -219,15 +221,15 @@ auto nary_standard(hipStream_t stream, argument result, Arguments... args)
...
@@ -219,15 +221,15 @@ auto nary_standard(hipStream_t stream, argument result, Arguments... args)
return
[
=
](
auto
f
)
{
nary_standard_impl
(
stream
,
f
,
result
,
args
...);
};
return
[
=
](
auto
f
)
{
nary_standard_impl
(
stream
,
f
,
result
,
args
...);
};
}
}
template
<
class
...
Arguments
>
template
<
class
...
Arguments
>
bool
broadcastable
(
bool
&
divisible_by_4
,
argument
result
,
argument
barg
,
Arguments
...
args
)
bool
broadcastable
(
bool
&
divisible_by_4
,
argument
result
,
argument
barg
,
Arguments
...
args
)
{
{
divisible_by_4
=
false
;
divisible_by_4
=
false
;
auto
bshape
=
barg
.
get_shape
();
auto
bshape
=
barg
.
get_shape
();
const
bool
standard
=
const
bool
standard
=
all_of
({
args
.
get_shape
()...},
[](
const
shape
&
s
)
{
return
s
.
standard
();
});
all_of
({
args
.
get_shape
()...},
[](
const
shape
&
s
)
{
return
s
.
standard
();
});
const
bool
same_shapes
=
all_of
(
const
bool
same_shapes
=
{
args
.
get_shape
()...},
[
&
](
const
shape
&
s
)
{
return
s
==
result
.
get_shape
();
});
all_of
(
{
args
.
get_shape
()...},
[
&
](
const
shape
&
s
)
{
return
s
==
result
.
get_shape
();
});
// TODO: Check result and args shape is the same
// TODO: Check result and args shape is the same
if
(
standard
and
same_shapes
and
bshape
.
broadcasted
()
and
not
bshape
.
scalar
())
if
(
standard
and
same_shapes
and
bshape
.
broadcasted
()
and
not
bshape
.
scalar
())
{
{
...
@@ -241,8 +243,7 @@ bool broadcastable(bool &divisible_by_4, argument result, argument barg, Argumen
...
@@ -241,8 +243,7 @@ bool broadcastable(bool &divisible_by_4, argument result, argument barg, Argumen
if
(
b_len
<=
2048
and
std
::
none_of
(
std
::
next
(
b_it
),
strides
.
end
(),
not_zero
))
if
(
b_len
<=
2048
and
std
::
none_of
(
std
::
next
(
b_it
),
strides
.
end
(),
not_zero
))
{
{
divisible_by_4
=
divisible_by_4
=
(
b_len
%
4
==
0
)
and
(
b_stride
%
4
==
0
)
and
(
b_len
%
4
==
0
)
and
(
b_stride
%
4
==
0
)
and
(
front_args
(
args
...).
get_shape
().
elements
()
%
4
==
0
);
(
front_args
(
args
...).
get_shape
().
elements
()
%
4
==
0
);
return
true
;
return
true
;
}
}
...
@@ -250,7 +251,7 @@ bool broadcastable(bool &divisible_by_4, argument result, argument barg, Argumen
...
@@ -250,7 +251,7 @@ bool broadcastable(bool &divisible_by_4, argument result, argument barg, Argumen
return
false
;
return
false
;
}
}
inline
bool
broadcastable
(
bool
&
divisible_by_4
,
argument
,
argument
)
inline
bool
broadcastable
(
bool
&
divisible_by_4
,
argument
,
argument
)
{
{
divisible_by_4
=
false
;
divisible_by_4
=
false
;
return
false
;
return
false
;
...
@@ -265,9 +266,7 @@ inline auto nary(hipStream_t stream, argument result)
...
@@ -265,9 +266,7 @@ inline auto nary(hipStream_t stream, argument result)
// Unary
// Unary
inline
auto
nary
(
hipStream_t
stream
,
argument
result
,
argument
arg
)
inline
auto
nary
(
hipStream_t
stream
,
argument
result
,
argument
arg
)
{
{
return
[
=
](
auto
f
)
{
return
[
=
](
auto
f
)
{
nary_impl
(
stream
,
f
,
result
,
arg
);
};
nary_impl
(
stream
,
f
,
result
,
arg
);
};
}
}
// Binary
// Binary
...
@@ -275,7 +274,7 @@ inline auto nary(hipStream_t stream, argument result, argument arg, argument bar
...
@@ -275,7 +274,7 @@ inline auto nary(hipStream_t stream, argument result, argument arg, argument bar
{
{
return
[
=
](
auto
f
)
{
return
[
=
](
auto
f
)
{
bool
divisible_by_4
=
false
;
bool
divisible_by_4
=
false
;
if
(
broadcastable
(
divisible_by_4
,
result
,
barg
,
arg
))
if
(
broadcastable
(
divisible_by_4
,
result
,
barg
,
arg
))
{
{
if
(
divisible_by_4
)
if
(
divisible_by_4
)
nary_broadcast_vec_impl
(
stream
,
f
,
result
,
barg
,
arg
);
nary_broadcast_vec_impl
(
stream
,
f
,
result
,
barg
,
arg
);
...
@@ -296,7 +295,7 @@ auto nary(hipStream_t stream, argument result, Arguments... args)
...
@@ -296,7 +295,7 @@ auto nary(hipStream_t stream, argument result, Arguments... args)
auto
barg1
=
back_args
(
args
...);
auto
barg1
=
back_args
(
args
...);
bool
fallback
=
pop_back_args
(
args
...)([
&
](
auto
&&
...
args2
)
{
bool
fallback
=
pop_back_args
(
args
...)([
&
](
auto
&&
...
args2
)
{
bool
divisible_by_4
=
false
;
bool
divisible_by_4
=
false
;
if
(
broadcastable
(
divisible_by_4
,
result
,
barg1
,
args2
...))
if
(
broadcastable
(
divisible_by_4
,
result
,
barg1
,
args2
...))
{
{
if
(
divisible_by_4
)
if
(
divisible_by_4
)
nary_broadcast_vec_impl
(
stream
,
f
,
result
,
barg1
,
args2
...);
nary_broadcast_vec_impl
(
stream
,
f
,
result
,
barg1
,
args2
...);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment