Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
53793762
Commit
53793762
authored
Jun 18, 2019
by
Paul
Browse files
Formatting
parent
55422f0e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
34 additions
and
32 deletions
+34
-32
src/targets/gpu/device/include/migraphx/gpu/device/nary.hpp
src/targets/gpu/device/include/migraphx/gpu/device/nary.hpp
+34
-31
src/targets/gpu/device/include/migraphx/gpu/device/types.hpp
src/targets/gpu/device/include/migraphx/gpu/device/types.hpp
+0
-1
No files found.
src/targets/gpu/device/include/migraphx/gpu/device/nary.hpp
View file @
53793762
...
...
@@ -241,7 +241,8 @@ void binary_broadcast_impl(
}
template
<
class
F
,
class
...
Arguments
>
void
nary_broadcast_vec_impl
(
hipStream_t
stream
,
F
f
,
argument
result
,
argument
barg
,
Arguments
...
args
)
void
nary_broadcast_vec_impl
(
hipStream_t
stream
,
F
f
,
argument
result
,
argument
barg
,
Arguments
...
args
)
{
const
auto
&
output_shape
=
result
.
get_shape
();
const
auto
&
b_shape
=
barg
.
get_shape
();
...
...
@@ -258,35 +259,36 @@ void nary_broadcast_vec_impl(hipStream_t stream, F f, argument result, argument
const
std
::
size_t
nlocal
=
1024
;
const
std
::
size_t
nglobal
=
256
*
nlocal
;
const
std
::
size_t
bdim_vec_len
=
bdim_len
/
vec_size
;
hip_vec_visit_all
<
vec_size
>
(
result
,
barg
,
args
...)([
&
](
auto
output
,
auto
binput
,
auto
...
inputs
)
{
using
type
=
typename
decltype
(
output
)
::
value_type
;
const
std
::
size_t
nelements
=
output
.
size
()
/
vec_size
;
launch
(
stream
,
nglobal
,
nlocal
)([
=
](
auto
idx
)
__device__
{
MIGRAPHX_DEVICE_SHARED
type
buffer
[
2048
/
vec_size
];
// Load bias into LDS
for
(
size_t
i
=
idx
.
local
;
i
<
bdim_vec_len
;
i
+=
nlocal
)
{
buffer
[
i
]
=
binput
.
data
()[
i
];
}
__syncthreads
();
auto
*
bp
=
as_pointer
(
buffer
);
// Process the data
for
(
size_t
i
=
idx
.
global
;
i
<
nelements
;
i
+=
nglobal
)
{
auto
bidx
=
((
i
*
vec_size
)
%
bdim_next_stride
)
/
bdim_stride
;
auto
b
=
bp
[
bidx
];
auto
out
=
output
.
data
()[
i
];
pack
(
inputs
.
data
()[
i
]...)([
&
](
auto
...
xs
)
__device__
{
for
(
std
::
size_t
j
=
0
;
j
<
vec_size
;
j
++
)
{
output
.
data
()[
i
][
j
]
=
f
(
xs
[
j
]...,
b
);
}
});
output
.
data
()[
i
]
=
out
;
}
hip_vec_visit_all
<
vec_size
>
(
result
,
barg
,
args
...)(
[
&
](
auto
output
,
auto
binput
,
auto
...
inputs
)
{
using
type
=
typename
decltype
(
output
)
::
value_type
;
const
std
::
size_t
nelements
=
output
.
size
()
/
vec_size
;
launch
(
stream
,
nglobal
,
nlocal
)([
=
](
auto
idx
)
__device__
{
MIGRAPHX_DEVICE_SHARED
type
buffer
[
2048
/
vec_size
];
// Load bias into LDS
for
(
size_t
i
=
idx
.
local
;
i
<
bdim_vec_len
;
i
+=
nlocal
)
{
buffer
[
i
]
=
binput
.
data
()[
i
];
}
__syncthreads
();
auto
*
bp
=
as_pointer
(
buffer
);
// Process the data
for
(
size_t
i
=
idx
.
global
;
i
<
nelements
;
i
+=
nglobal
)
{
auto
bidx
=
((
i
*
vec_size
)
%
bdim_next_stride
)
/
bdim_stride
;
auto
b
=
bp
[
bidx
];
auto
out
=
output
.
data
()[
i
];
pack
(
inputs
.
data
()[
i
]...)([
&
](
auto
...
xs
)
__device__
{
for
(
std
::
size_t
j
=
0
;
j
<
vec_size
;
j
++
)
{
output
.
data
()[
i
][
j
]
=
f
(
xs
[
j
]...,
b
);
}
});
output
.
data
()[
i
]
=
out
;
}
});
});
});
}
template
<
class
F
,
class
...
Arguments
>
...
...
@@ -417,8 +419,9 @@ auto nary(hipStream_t stream, argument result, Arguments... args)
if
(
b_len
<=
2048
and
std
::
none_of
(
std
::
next
(
b_it
),
strides
.
end
(),
not_zero
))
{
const
bool
divisible_by_4
=
(
b_len
%
4
==
0
)
and
(
b_stride
%
4
==
0
)
and
(
front_args
(
args
...).
get_shape
().
elements
()
%
4
==
0
);
const
bool
divisible_by_4
=
(
b_len
%
4
==
0
)
and
(
b_stride
%
4
==
0
)
and
(
front_args
(
args
...).
get_shape
().
elements
()
%
4
==
0
);
if
(
divisible_by_4
)
nary_broadcast_vec_impl
(
stream
,
f
,
result
,
barg
,
args2
...);
else
...
...
src/targets/gpu/device/include/migraphx/gpu/device/types.hpp
View file @
53793762
...
...
@@ -58,7 +58,6 @@ struct device_type<half>
using
type
=
gpu_half
;
};
template
<
class
T
>
struct
host_type
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment