Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
96e74d6e
Commit
96e74d6e
authored
Oct 31, 2018
by
Paul
Browse files
Formatting
parent
a0c4afbf
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
13 additions
and
14 deletions
+13
-14
src/targets/gpu/device/include/migraph/gpu/device/nary.hpp
src/targets/gpu/device/include/migraph/gpu/device/nary.hpp
+3
-3
src/targets/gpu/device/include/migraph/gpu/device/types.hpp
src/targets/gpu/device/include/migraph/gpu/device/types.hpp
+10
-11
No files found.
src/targets/gpu/device/include/migraph/gpu/device/nary.hpp
View file @
96e74d6e
...
...
@@ -38,8 +38,8 @@ auto nary_nonstandard_impl(hipStream_t stream, F f, argument result, Arguments..
const
auto
&
output_shape
=
result
.
get_shape
();
visit_all
(
result
,
args
...)([
&
](
auto
output
,
auto
...
inputs
)
{
visit_tensor_size
(
output_shape
.
lens
().
size
(),
[
&
](
auto
ndim
)
{
auto
data
=
pack
(
std
::
make_pair
(
hip_tensor_descriptor
<
ndim
>
{
inputs
.
get_shape
()},
device_cast
(
inputs
.
data
()))...);
auto
data
=
pack
(
std
::
make_pair
(
hip_tensor_descriptor
<
ndim
>
{
inputs
.
get_shape
()},
device_cast
(
inputs
.
data
()))...);
hip_tensor_descriptor
<
ndim
>
out_desc
(
output_shape
);
auto
*
outp
=
device_cast
(
output
.
data
());
gs_launch
(
stream
,
output_shape
.
elements
())([
=
](
auto
i
)
{
...
...
@@ -266,7 +266,7 @@ void nary_standard_vec_impl(hipStream_t stream, F f, argument result, Arguments.
// assert(x.get_shape().elements() == y.get_shape().elements());
const
auto
&
output_shape
=
result
.
get_shape
();
visit_all
(
result
,
args
...)([
&
](
auto
output
,
auto
...
inputs
)
{
using
type
=
device_type
<
std
::
remove_cv_t
<
typename
decltype
(
output
)
::
value_type
>>
;
using
type
=
device_type
<
std
::
remove_cv_t
<
typename
decltype
(
output
)
::
value_type
>>
;
const
std
::
size_t
vec_size
=
4
;
auto
data
=
pack_vec4
(
device_cast
(
inputs
.
data
())...);
auto
*
outp
=
as_vec4
(
device_cast
(
output
.
data
()));
...
...
src/targets/gpu/device/include/migraph/gpu/device/types.hpp
View file @
96e74d6e
...
...
@@ -17,26 +17,25 @@ namespace device {
using
gpu_half
=
__fp16
;
namespace
detail
{
template
<
class
T
>
template
<
class
T
>
struct
device_type
{
using
type
=
T
;
};
template
<
>
template
<
>
struct
device_type
<
half
>
{
using
type
=
gpu_half
;
};
template
<
class
T
>
template
<
class
T
>
struct
host_type
{
using
type
=
T
;
};
template
<
>
template
<
>
struct
device_type
<
gpu_half
>
{
using
type
=
half
;
...
...
@@ -44,31 +43,31 @@ struct device_type<gpu_half>
}
// namespace detail
template
<
class
T
>
template
<
class
T
>
using
host_type
=
typename
detail
::
host_type
<
T
>::
type
;
template
<
class
T
>
template
<
class
T
>
using
device_type
=
typename
detail
::
device_type
<
T
>::
type
;
template
<
class
T
>
template
<
class
T
>
host_type
<
T
>
host_cast
(
T
x
)
{
return
reinterpret_cast
<
host_type
<
T
>>
(
x
);
}
template
<
class
T
>
template
<
class
T
>
host_type
<
T
>*
host_cast
(
T
*
x
)
{
return
reinterpret_cast
<
host_type
<
T
>*>
(
x
);
}
template
<
class
T
>
template
<
class
T
>
device_type
<
T
>
device_cast
(
T
x
)
{
return
reinterpret_cast
<
device_type
<
T
>>
(
x
);
}
template
<
class
T
>
template
<
class
T
>
device_type
<
T
>*
device_cast
(
T
*
x
)
{
return
reinterpret_cast
<
device_type
<
T
>*>
(
x
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment