Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
8ec57ece
Commit
8ec57ece
authored
Jun 18, 2019
by
Paul
Browse files
Formatting
parent
ee29e116
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
37 additions
and
43 deletions
+37
-43
src/include/migraphx/array.hpp
src/include/migraphx/array.hpp
+14
-16
src/include/migraphx/functional.hpp
src/include/migraphx/functional.hpp
+9
-13
src/include/migraphx/ranges.hpp
src/include/migraphx/ranges.hpp
+3
-1
src/targets/gpu/device/include/migraphx/gpu/device/nary.hpp
src/targets/gpu/device/include/migraphx/gpu/device/nary.hpp
+11
-13
No files found.
src/include/migraphx/array.hpp
View file @
8ec57ece
...
...
@@ -13,34 +13,33 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace
detail
{
template
<
class
R
,
class
...
>
struct
array_type
{
using
type
=
R
;
};
struct
array_type
{
using
type
=
R
;
};
template
<
class
...
Ts
>
struct
array_type
<
void
,
Ts
...
>
:
std
::
common_type
<
Ts
...
>
{};
struct
array_type
<
void
,
Ts
...
>
:
std
::
common_type
<
Ts
...
>
{
};
template
<
class
R
,
class
...
Ts
>
using
array_type_t
=
typename
array_type
<
R
,
Ts
...
>::
type
;
template
<
class
T
,
std
::
size_t
N
,
std
::
size_t
...
I
>
constexpr
std
::
array
<
std
::
remove_cv_t
<
T
>
,
N
>
to_array_impl
(
T
(
&
a
)[
N
],
seq
<
I
...
>
)
constexpr
std
::
array
<
std
::
remove_cv_t
<
T
>
,
N
>
to_array_impl
(
T
(
&
a
)[
N
],
seq
<
I
...
>
)
{
return
{
{
a
[
I
]...}
};
return
{{
a
[
I
]...}};
}
}
// namespace detail
template
<
class
Result
=
void
,
class
...
Ts
,
MIGRAPHX_REQUIRES
((
sizeof
...(
Ts
)
>
0
))
>
constexpr
std
::
array
<
detail
::
array_type_t
<
Result
,
Ts
...
>
,
sizeof
...(
Ts
)
>
make_array
(
Ts
&&
...
xs
)
constexpr
std
::
array
<
detail
::
array_type_t
<
Result
,
Ts
...
>
,
sizeof
...(
Ts
)
>
make_array
(
Ts
&&
...
xs
)
{
return
{
static_cast
<
detail
::
array_type_t
<
Result
,
Ts
...
>>
(
std
::
forward
<
Ts
>
(
xs
))...
};
return
{
static_cast
<
detail
::
array_type_t
<
Result
,
Ts
...
>>
(
std
::
forward
<
Ts
>
(
xs
))...};
}
constexpr
std
::
array
<
int
,
0
>
make_array
()
{
return
{};
}
constexpr
std
::
array
<
int
,
0
>
make_array
()
{
return
{};
}
template
<
class
T
,
std
::
size_t
N
>
constexpr
auto
to_array
(
T
(
&
a
)[
N
])
...
...
@@ -51,10 +50,9 @@ constexpr auto to_array(T (&a)[N])
namespace
detail
{
template
<
std
::
size_t
Offset
=
0
,
class
Array
,
std
::
size_t
...
I
>
constexpr
auto
rearray_impl
(
Array
a
,
seq
<
I
...
>
)
constexpr
auto
rearray_impl
(
Array
a
,
seq
<
I
...
>
)
{
return
make_array
(
a
[
I
+
Offset
]...);
return
make_array
(
a
[
I
+
Offset
]...);
}
}
// namespace detail
...
...
src/include/migraphx/functional.hpp
View file @
8ec57ece
...
...
@@ -15,7 +15,7 @@ struct swallow
}
};
template
<
class
T
>
template
<
class
T
>
auto
tuple_size
(
const
T
&
)
{
return
typename
std
::
tuple_size
<
T
>::
type
{};
...
...
@@ -161,39 +161,35 @@ auto index_of(T& x)
return
[
&
](
auto
&&
y
)
{
return
x
[
y
];
};
}
template
<
class
T
,
class
...
Ts
>
template
<
class
T
,
class
...
Ts
>
decltype
(
auto
)
front_args
(
T
&&
x
,
Ts
&&
...)
{
return
static_cast
<
T
&&>
(
x
);
}
template
<
class
...
Ts
>
template
<
class
...
Ts
>
decltype
(
auto
)
back_args
(
Ts
&&
...
xs
)
{
return
std
::
get
<
sizeof
...(
Ts
)
-
1
>
(
std
::
tuple
<
Ts
&&
...
>
(
static_cast
<
Ts
&&>
(
xs
)...));
}
template
<
class
T
,
class
...
Ts
>
template
<
class
T
,
class
...
Ts
>
auto
pop_front_args
(
T
&&
,
Ts
&&
...
xs
)
{
return
[
&
](
auto
f
)
{
f
(
static_cast
<
Ts
&&>
(
xs
)...);
};
return
[
&
](
auto
f
)
{
f
(
static_cast
<
Ts
&&>
(
xs
)...);
};
}
template
<
class
...
Ts
>
template
<
class
...
Ts
>
auto
pop_back_args
(
Ts
&&
...
xs
)
{
return
[
&
](
auto
f
)
{
using
tuple_type
=
std
::
tuple
<
Ts
&&
...
>
;
auto
t
=
tuple_type
(
static_cast
<
Ts
&&>
(
xs
)...);
sequence_c
<
sizeof
...(
Ts
)
-
1
>
([
&
](
auto
...
is
)
{
f
(
std
::
get
<
is
>
(
static_cast
<
tuple_type
&&>
(
t
))...);
});
auto
t
=
tuple_type
(
static_cast
<
Ts
&&>
(
xs
)...);
sequence_c
<
sizeof
...(
Ts
)
-
1
>
(
[
&
](
auto
...
is
)
{
f
(
std
::
get
<
is
>
(
static_cast
<
tuple_type
&&>
(
t
))...);
});
};
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
...
...
src/include/migraphx/ranges.hpp
View file @
8ec57ece
...
...
@@ -33,7 +33,9 @@ auto generic_find_impl(rank<0>, C&& c, const T& x)
return
std
::
find
(
c
.
begin
(),
c
.
end
(),
x
);
}
struct
empty
{};
struct
empty
{
};
}
// namespace detail
...
...
src/targets/gpu/device/include/migraphx/gpu/device/nary.hpp
View file @
8ec57ece
...
...
@@ -259,8 +259,7 @@ void binary_broadcast_impl(
}
template
<
class
F
,
class
...
Arguments
>
void
nary_broadcast_impl
(
hipStream_t
stream
,
F
f
,
argument
result
,
argument
barg
,
Arguments
...
args
)
void
nary_broadcast_impl
(
hipStream_t
stream
,
F
f
,
argument
result
,
argument
barg
,
Arguments
...
args
)
{
const
auto
&
output_shape
=
result
.
get_shape
();
const
auto
&
b_shape
=
barg
.
get_shape
();
...
...
@@ -275,7 +274,7 @@ void nary_broadcast_impl(
const
std
::
size_t
nlocal
=
1024
;
const
std
::
size_t
nglobal
=
256
*
nlocal
;
std
::
size_t
nelements
=
result
.
get_shape
().
elements
();
std
::
size_t
nelements
=
result
.
get_shape
().
elements
();
hip_visit_all
(
result
,
barg
,
args
...)([
&
](
auto
output
,
auto
binput
,
auto
...
inputs
)
{
using
type
=
typename
decltype
(
output
)
::
value_type
;
launch
(
stream
,
nglobal
,
nlocal
)([
=
](
auto
idx
)
__device__
{
...
...
@@ -289,9 +288,9 @@ void nary_broadcast_impl(
// Process the data
for
(
size_t
i
=
idx
.
global
;
i
<
nelements
;
i
+=
nglobal
)
{
auto
bidx
=
(
i
%
bdim_next_stride
)
/
bdim_stride
;
auto
b
=
buffer
[
bidx
];
output
.
data
()[
i
]
=
f
(
inputs
.
data
()[
i
]...,
b
);
auto
bidx
=
(
i
%
bdim_next_stride
)
/
bdim_stride
;
auto
b
=
buffer
[
bidx
];
output
.
data
()[
i
]
=
f
(
inputs
.
data
()[
i
]...,
b
);
}
});
});
...
...
@@ -363,20 +362,19 @@ auto nary(hipStream_t stream, argument result)
}
template
<
class
...
Arguments
>
auto
nary
(
hipStream_t
stream
,
argument
result
,
Arguments
...
args
)
auto
nary
(
hipStream_t
stream
,
argument
result
,
Arguments
...
args
)
{
return
[
=
](
auto
f
)
{
auto
barg
=
back_args
(
args
...);
pop_back_args
(
args
...)([
&
](
auto
&&
...
args2
)
{
auto
bshape
=
barg
.
get_shape
();
const
bool
standard
=
all_of
({
args2
.
get_shape
()...},
[](
const
shape
&
s
)
{
return
s
.
standard
();
});
const
bool
same_shapes
=
all_of
({
args2
.
get_shape
()...},
[
&
](
const
shape
&
s
)
{
return
s
==
result
.
get_shape
();
});
const
bool
standard
=
all_of
({
args2
.
get_shape
()...},
[](
const
shape
&
s
)
{
return
s
.
standard
();
});
const
bool
same_shapes
=
all_of
(
{
args2
.
get_shape
()...},
[
&
](
const
shape
&
s
)
{
return
s
==
result
.
get_shape
();
});
// TODO: Check result and args shape is the same
if
(
standard
and
same_shapes
and
bshape
.
broadcasted
()
and
not
bshape
.
scalar
())
if
(
standard
and
same_shapes
and
bshape
.
broadcasted
()
and
not
bshape
.
scalar
())
{
auto
not_zero
=
[](
auto
x
)
{
return
x
!=
0
;
};
const
auto
&
strides
=
bshape
.
strides
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment