Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
ee29e116
Commit
ee29e116
authored
Jun 18, 2019
by
Paul
Browse files
Do generic nary broadcast
parent
c1d244d9
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
220 additions
and
72 deletions
+220
-72
src/include/migraphx/array.hpp
src/include/migraphx/array.hpp
+77
-0
src/include/migraphx/functional.hpp
src/include/migraphx/functional.hpp
+47
-2
src/include/migraphx/ranges.hpp
src/include/migraphx/ranges.hpp
+20
-0
src/targets/gpu/device/include/migraphx/gpu/device/nary.hpp
src/targets/gpu/device/include/migraphx/gpu/device/nary.hpp
+75
-63
src/targets/gpu/fuse_ops.cpp
src/targets/gpu/fuse_ops.cpp
+1
-7
No files found.
src/include/migraphx/array.hpp
0 → 100644
View file @
ee29e116
#ifndef MIGRAPHX_GUARD_RTGLIB_ARRAY_HPP
#define MIGRAPHX_GUARD_RTGLIB_ARRAY_HPP
#include <migraphx/config.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/requires.hpp>
#include <type_traits>
#include <array>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
detail
{
template
<
class
R
,
class
...
>
struct
array_type
{
using
type
=
R
;
};
template
<
class
...
Ts
>
struct
array_type
<
void
,
Ts
...
>
:
std
::
common_type
<
Ts
...
>
{};
template
<
class
R
,
class
...
Ts
>
using
array_type_t
=
typename
array_type
<
R
,
Ts
...
>::
type
;
template
<
class
T
,
std
::
size_t
N
,
std
::
size_t
...
I
>
constexpr
std
::
array
<
std
::
remove_cv_t
<
T
>
,
N
>
to_array_impl
(
T
(
&
a
)[
N
],
seq
<
I
...
>
)
{
return
{
{
a
[
I
]...}
};
}
}
// namespace detail
template
<
class
Result
=
void
,
class
...
Ts
,
MIGRAPHX_REQUIRES
((
sizeof
...(
Ts
)
>
0
))
>
constexpr
std
::
array
<
detail
::
array_type_t
<
Result
,
Ts
...
>
,
sizeof
...(
Ts
)
>
make_array
(
Ts
&&
...
xs
)
{
return
{
static_cast
<
detail
::
array_type_t
<
Result
,
Ts
...
>>
(
std
::
forward
<
Ts
>
(
xs
))...
};
}
constexpr
std
::
array
<
int
,
0
>
make_array
()
{
return
{};
}
template
<
class
T
,
std
::
size_t
N
>
constexpr
auto
to_array
(
T
(
&
a
)[
N
])
{
return
detail
::
to_array_impl
(
a
,
detail
::
gens
<
N
>
{});
}
namespace
detail
{
template
<
std
::
size_t
Offset
=
0
,
class
Array
,
std
::
size_t
...
I
>
constexpr
auto
rearray_impl
(
Array
a
,
seq
<
I
...
>
)
{
return
make_array
(
a
[
I
+
Offset
]...);
}
}
// namespace detail
template
<
class
T
,
std
::
size_t
N
>
constexpr
auto
pop_front
(
std
::
array
<
T
,
N
>
a
)
{
return
detail
::
rearray_impl
(
a
,
detail
::
gens
<
N
-
1
>
{});
}
template
<
class
T
,
std
::
size_t
N
>
constexpr
auto
pop_back
(
std
::
array
<
T
,
N
>
a
)
{
return
detail
::
rearray_impl
<
1
>
(
a
,
detail
::
gens
<
N
-
1
>
{});
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/functional.hpp
View file @
ee29e116
...
@@ -15,6 +15,12 @@ struct swallow
...
@@ -15,6 +15,12 @@ struct swallow
}
}
};
};
template
<
class
T
>
auto
tuple_size
(
const
T
&
)
{
return
typename
std
::
tuple_size
<
T
>::
type
{};
}
namespace
detail
{
namespace
detail
{
template
<
class
R
,
class
F
>
template
<
class
R
,
class
F
>
...
@@ -83,6 +89,12 @@ constexpr auto sequence_c(F&& f)
...
@@ -83,6 +89,12 @@ constexpr auto sequence_c(F&& f)
return
detail
::
sequence_c_impl
(
f
,
detail
::
gens
<
N
>
{});
return
detail
::
sequence_c_impl
(
f
,
detail
::
gens
<
N
>
{});
}
}
template
<
class
IntegerConstant
,
class
F
>
constexpr
auto
sequence
(
IntegerConstant
ic
,
F
&&
f
)
{
return
sequence_c
<
ic
>
(
f
);
}
template
<
class
F
,
class
...
Ts
>
template
<
class
F
,
class
...
Ts
>
constexpr
void
each_args
(
F
f
,
Ts
&&
...
xs
)
constexpr
void
each_args
(
F
f
,
Ts
&&
...
xs
)
{
{
...
@@ -95,9 +107,9 @@ constexpr void each_args(F)
...
@@ -95,9 +107,9 @@ constexpr void each_args(F)
}
}
template
<
class
F
,
class
T
>
template
<
class
F
,
class
T
>
auto
unpack
(
F
f
,
T
&
x
)
auto
unpack
(
F
f
,
T
&
&
x
)
{
{
return
sequence
_c
<
std
::
tuple_size
<
T
>
{}
>
(
[
&
](
auto
...
is
)
{
f
(
std
::
get
<
is
>
(
x
)...);
});
return
sequence
(
tuple_size
(
x
),
[
&
](
auto
...
is
)
{
f
(
std
::
get
<
is
>
(
static_cast
<
T
&&>
(
x
)
)...);
});
}
}
/// Implements a fix-point combinator
/// Implements a fix-point combinator
...
@@ -149,6 +161,39 @@ auto index_of(T& x)
...
@@ -149,6 +161,39 @@ auto index_of(T& x)
return
[
&
](
auto
&&
y
)
{
return
x
[
y
];
};
return
[
&
](
auto
&&
y
)
{
return
x
[
y
];
};
}
}
template
<
class
T
,
class
...
Ts
>
decltype
(
auto
)
front_args
(
T
&&
x
,
Ts
&&
...)
{
return
static_cast
<
T
&&>
(
x
);
}
template
<
class
...
Ts
>
decltype
(
auto
)
back_args
(
Ts
&&
...
xs
)
{
return
std
::
get
<
sizeof
...(
Ts
)
-
1
>
(
std
::
tuple
<
Ts
&&
...
>
(
static_cast
<
Ts
&&>
(
xs
)...));
}
template
<
class
T
,
class
...
Ts
>
auto
pop_front_args
(
T
&&
,
Ts
&&
...
xs
)
{
return
[
&
](
auto
f
)
{
f
(
static_cast
<
Ts
&&>
(
xs
)...);
};
}
template
<
class
...
Ts
>
auto
pop_back_args
(
Ts
&&
...
xs
)
{
return
[
&
](
auto
f
)
{
using
tuple_type
=
std
::
tuple
<
Ts
&&
...
>
;
auto
t
=
tuple_type
(
static_cast
<
Ts
&&>
(
xs
)...);
sequence_c
<
sizeof
...(
Ts
)
-
1
>
([
&
](
auto
...
is
)
{
f
(
std
::
get
<
is
>
(
static_cast
<
tuple_type
&&>
(
t
))...);
});
};
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
...
...
src/include/migraphx/ranges.hpp
View file @
ee29e116
...
@@ -33,6 +33,8 @@ auto generic_find_impl(rank<0>, C&& c, const T& x)
...
@@ -33,6 +33,8 @@ auto generic_find_impl(rank<0>, C&& c, const T& x)
return
std
::
find
(
c
.
begin
(),
c
.
end
(),
x
);
return
std
::
find
(
c
.
begin
(),
c
.
end
(),
x
);
}
}
struct
empty
{};
}
// namespace detail
}
// namespace detail
template
<
class
C
,
class
T
>
template
<
class
C
,
class
T
>
...
@@ -71,6 +73,12 @@ bool all_of(const std::initializer_list<T>& c, const Predicate& p)
...
@@ -71,6 +73,12 @@ bool all_of(const std::initializer_list<T>& c, const Predicate& p)
return
std
::
all_of
(
c
.
begin
(),
c
.
end
(),
p
);
return
std
::
all_of
(
c
.
begin
(),
c
.
end
(),
p
);
}
}
template
<
class
Predicate
>
bool
all_of
(
detail
::
empty
,
const
Predicate
&
)
{
return
true
;
}
template
<
class
C
,
class
Predicate
>
template
<
class
C
,
class
Predicate
>
bool
any_of
(
const
C
&
c
,
const
Predicate
&
p
)
bool
any_of
(
const
C
&
c
,
const
Predicate
&
p
)
{
{
...
@@ -83,6 +91,12 @@ bool any_of(const std::initializer_list<T>& c, const Predicate& p)
...
@@ -83,6 +91,12 @@ bool any_of(const std::initializer_list<T>& c, const Predicate& p)
return
std
::
any_of
(
c
.
begin
(),
c
.
end
(),
p
);
return
std
::
any_of
(
c
.
begin
(),
c
.
end
(),
p
);
}
}
template
<
class
Predicate
>
bool
any_of
(
detail
::
empty
,
const
Predicate
&
)
{
return
false
;
}
template
<
class
C
,
class
Predicate
>
template
<
class
C
,
class
Predicate
>
bool
none_of
(
const
C
&
c
,
const
Predicate
&
p
)
bool
none_of
(
const
C
&
c
,
const
Predicate
&
p
)
{
{
...
@@ -95,6 +109,12 @@ bool none_of(const std::initializer_list<T>& c, const Predicate& p)
...
@@ -95,6 +109,12 @@ bool none_of(const std::initializer_list<T>& c, const Predicate& p)
return
std
::
none_of
(
c
.
begin
(),
c
.
end
(),
p
);
return
std
::
none_of
(
c
.
begin
(),
c
.
end
(),
p
);
}
}
template
<
class
Predicate
>
bool
none_of
(
detail
::
empty
,
const
Predicate
&
)
{
return
true
;
}
template
<
class
Range
,
class
Iterator
>
template
<
class
Range
,
class
Iterator
>
void
copy
(
Range
&&
r
,
Iterator
it
)
void
copy
(
Range
&&
r
,
Iterator
it
)
{
{
...
...
src/targets/gpu/device/include/migraphx/gpu/device/nary.hpp
View file @
ee29e116
...
@@ -6,6 +6,7 @@
...
@@ -6,6 +6,7 @@
#include <migraphx/gpu/device/types.hpp>
#include <migraphx/gpu/device/types.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/array.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
namespace
migraphx
{
...
@@ -257,6 +258,45 @@ void binary_broadcast_impl(
...
@@ -257,6 +258,45 @@ void binary_broadcast_impl(
});
});
}
}
template
<
class
F
,
class
...
Arguments
>
void
nary_broadcast_impl
(
hipStream_t
stream
,
F
f
,
argument
result
,
argument
barg
,
Arguments
...
args
)
{
const
auto
&
output_shape
=
result
.
get_shape
();
const
auto
&
b_shape
=
barg
.
get_shape
();
auto
bdim
=
std
::
distance
(
b_shape
.
strides
().
begin
(),
std
::
find_if
(
b_shape
.
strides
().
begin
(),
b_shape
.
strides
().
end
(),
[](
auto
x
)
{
return
x
!=
0
;
}));
auto
bdim_len
=
output_shape
.
lens
()[
bdim
];
auto
bdim_stride
=
output_shape
.
strides
()[
bdim
];
auto
bdim_next_stride
=
bdim_stride
*
bdim_len
;
const
std
::
size_t
nlocal
=
1024
;
const
std
::
size_t
nglobal
=
256
*
nlocal
;
std
::
size_t
nelements
=
result
.
get_shape
().
elements
();
hip_visit_all
(
result
,
barg
,
args
...)([
&
](
auto
output
,
auto
binput
,
auto
...
inputs
)
{
using
type
=
typename
decltype
(
output
)
::
value_type
;
launch
(
stream
,
nglobal
,
nlocal
)([
=
](
auto
idx
)
__device__
{
MIGRAPHX_DEVICE_SHARED
type
buffer
[
2048
];
// Load bias into LDS
for
(
size_t
i
=
idx
.
local
;
i
<
bdim_len
;
i
+=
nlocal
)
{
buffer
[
i
]
=
binput
.
data
()[
i
];
}
__syncthreads
();
// Process the data
for
(
size_t
i
=
idx
.
global
;
i
<
nelements
;
i
+=
nglobal
)
{
auto
bidx
=
(
i
%
bdim_next_stride
)
/
bdim_stride
;
auto
b
=
buffer
[
bidx
];
output
.
data
()[
i
]
=
f
(
inputs
.
data
()[
i
]...,
b
);
}
});
});
}
template
<
class
F
,
class
...
Arguments
>
template
<
class
F
,
class
...
Arguments
>
void
nary_standard_vec_impl
(
hipStream_t
stream
,
F
f
,
argument
result
,
Arguments
...
args
)
void
nary_standard_vec_impl
(
hipStream_t
stream
,
F
f
,
argument
result
,
Arguments
...
args
)
{
{
...
@@ -304,12 +344,6 @@ void nary_impl(hipStream_t stream, F f, argument result, Arguments... args)
...
@@ -304,12 +344,6 @@ void nary_impl(hipStream_t stream, F f, argument result, Arguments... args)
nary_nonstandard_impl
(
stream
,
f
,
result
,
args
...);
nary_nonstandard_impl
(
stream
,
f
,
result
,
args
...);
}
}
template
<
class
F
>
void
nary_impl
(
hipStream_t
stream
,
F
f
,
argument
result
)
{
nary_standard_impl
(
stream
,
f
,
result
);
}
template
<
class
...
Arguments
>
template
<
class
...
Arguments
>
auto
nary_nonstandard
(
hipStream_t
stream
,
argument
result
,
Arguments
...
args
)
auto
nary_nonstandard
(
hipStream_t
stream
,
argument
result
,
Arguments
...
args
)
{
{
...
@@ -323,71 +357,49 @@ auto nary_standard(hipStream_t stream, argument result, Arguments... args)
...
@@ -323,71 +357,49 @@ auto nary_standard(hipStream_t stream, argument result, Arguments... args)
}
}
template
<
class
...
Arguments
>
template
<
class
...
Arguments
>
auto
nary
(
hipStream_t
stream
,
argument
result
,
Arguments
...
args
)
auto
nary
(
hipStream_t
stream
,
argument
result
)
{
{
return
[
=
](
auto
f
)
{
nary_impl
(
stream
,
f
,
result
,
args
...
);
};
return
[
=
](
auto
f
)
{
nary_
standard_
impl
(
stream
,
f
,
result
);
};
}
}
inline
auto
template
<
class
...
Arguments
>
nary
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg2
)
auto
nary
(
hipStream_t
stream
,
argument
result
,
Arguments
...
args
)
{
{
return
[
=
](
auto
f
)
{
// TODO: Check result and arg1 shape is the same
if
(
arg1
.
get_shape
().
standard
()
and
arg2
.
get_shape
().
broadcasted
()
and
not
arg2
.
get_shape
().
scalar
())
{
auto
not_zero
=
[](
auto
x
)
{
return
x
!=
0
;
};
const
auto
&
strides
=
arg2
.
get_shape
().
strides
();
auto
b_it
=
std
::
find_if
(
strides
.
begin
(),
strides
.
end
(),
not_zero
);
auto
b_idx
=
std
::
distance
(
strides
.
begin
(),
b_it
);
auto
b_len
=
result
.
get_shape
().
lens
()[
b_idx
];
auto
b_stride
=
result
.
get_shape
().
strides
()[
b_idx
];
assert
(
arg2
.
get_shape
().
lens
()[
b_idx
]
==
b_len
);
if
(
b_len
<=
2048
and
std
::
none_of
(
std
::
next
(
b_it
),
strides
.
end
(),
not_zero
))
{
const
bool
divisible_by_4
=
(
b_len
%
4
==
0
)
and
(
b_stride
%
4
==
0
)
and
(
arg1
.
get_shape
().
elements
()
%
4
==
0
);
if
(
divisible_by_4
)
binary_broadcast_vec_impl
(
stream
,
f
,
result
,
arg1
,
arg2
);
else
binary_broadcast_impl
(
stream
,
f
,
result
,
arg1
,
arg2
);
return
;
}
}
nary_impl
(
stream
,
f
,
result
,
arg1
,
arg2
);
};
}
inline
auto
nary
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg2
,
const
argument
&
arg3
)
{
return
[
=
](
auto
f
)
{
return
[
=
](
auto
f
)
{
// TODO: Check result and arg1 shape is the same
auto
barg
=
back_args
(
args
...);
if
(
arg1
.
get_shape
().
standard
()
and
arg2
.
get_shape
().
standard
()
and
pop_back_args
(
args
...)([
&
](
auto
&&
...
args2
)
{
arg3
.
get_shape
().
broadcasted
())
auto
bshape
=
barg
.
get_shape
();
{
const
bool
standard
=
all_of
({
args2
.
get_shape
()...},
[](
const
shape
&
s
)
{
return
s
.
standard
();
});
auto
not_zero
=
[](
auto
x
)
{
return
x
!=
0
;
};
const
bool
same_shapes
=
const
auto
&
strides
=
arg3
.
get_shape
().
strides
();
all_of
({
args2
.
get_shape
()...},
[
&
](
const
shape
&
s
)
{
return
s
==
result
.
get_shape
();
});
auto
b_it
=
std
::
find_if
(
strides
.
begin
(),
strides
.
end
(),
not_zero
);
// TODO: Check result and args shape is the same
auto
b_idx
=
std
::
distance
(
strides
.
begin
(),
b_it
);
if
(
standard
and
same_shapes
and
bshape
.
broadcasted
()
and
auto
b_len
=
result
.
get_shape
().
lens
()[
b_idx
];
not
bshape
.
scalar
())
auto
b_stride
=
result
.
get_shape
().
strides
()[
b_idx
];
assert
(
arg3
.
get_shape
().
lens
()[
b_idx
]
==
b_len
);
if
(
b_len
<=
2048
and
std
::
none_of
(
std
::
next
(
b_it
),
strides
.
end
(),
not_zero
))
{
{
const
bool
divisible_by_4
=
(
b_len
%
4
==
0
)
and
(
b_stride
%
4
==
0
)
and
auto
not_zero
=
[](
auto
x
)
{
return
x
!=
0
;
};
(
arg1
.
get_shape
().
elements
()
%
4
==
0
);
const
auto
&
strides
=
bshape
.
strides
();
if
(
divisible_by_4
)
auto
b_it
=
std
::
find_if
(
strides
.
begin
(),
strides
.
end
(),
not_zero
);
trinary_broadcast_vec_impl
(
stream
,
f
,
result
,
arg1
,
arg2
,
arg3
);
auto
b_idx
=
std
::
distance
(
strides
.
begin
(),
b_it
);
else
auto
b_len
=
result
.
get_shape
().
lens
()[
b_idx
];
trinary_broadcast_impl
(
stream
,
f
,
result
,
arg1
,
arg2
,
arg3
);
auto
b_stride
=
result
.
get_shape
().
strides
()[
b_idx
];
return
;
assert
(
bshape
.
lens
()[
b_idx
]
==
b_len
);
if
(
b_len
<=
2048
and
std
::
none_of
(
std
::
next
(
b_it
),
strides
.
end
(),
not_zero
))
{
nary_broadcast_impl
(
stream
,
f
,
result
,
barg
,
args2
...);
// const bool divisible_by_4 = (b_len % 4 == 0) and (b_stride % 4 == 0) and
// (arg1.get_shape().elements() % 4 == 0);
// if(divisible_by_4)
// binary_broadcast_vec_impl(stream, f, result, arg1, arg);
// else
// binary_broadcast_impl(stream, f, result, arg1, arg);
// return;
}
}
}
}
}
);
nary_impl
(
stream
,
f
,
result
,
arg
1
,
arg2
,
arg3
);
nary_impl
(
stream
,
f
,
result
,
arg
s
...
);
};
};
}
}
...
...
src/targets/gpu/fuse_ops.cpp
View file @
ee29e116
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
#include <migraphx/gpu/device/add_relu.hpp>
#include <migraphx/gpu/device/add_relu.hpp>
#include <migraphx/gpu/device/add.hpp>
#include <migraphx/gpu/device/add.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/array.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -122,13 +123,6 @@ MIGRAPHX_PRED_MATCHER(bias_shape, instruction_ref ins)
...
@@ -122,13 +123,6 @@ MIGRAPHX_PRED_MATCHER(bias_shape, instruction_ref ins)
s
.
strides
()[
1
]
!=
0
and
s
.
strides
()[
2
]
==
0
and
s
.
strides
()[
3
]
==
0
;
s
.
strides
()[
1
]
!=
0
and
s
.
strides
()[
2
]
==
0
and
s
.
strides
()[
3
]
==
0
;
}
}
// TODO: Move to another header
template
<
class
T
,
class
...
Ts
>
std
::
array
<
T
,
sizeof
...(
Ts
)
+
1
>
make_array
(
T
x
,
Ts
...
xs
)
{
return
{
std
::
move
(
x
),
std
::
move
(
static_cast
<
T
>
(
xs
))...};
}
MIGRAPHX_PRED_MATCHER
(
fusable_conv
,
instruction_ref
ins
)
MIGRAPHX_PRED_MATCHER
(
fusable_conv
,
instruction_ref
ins
)
{
{
if
(
ins
->
name
()
!=
"gpu::convolution"
)
if
(
ins
->
name
()
!=
"gpu::convolution"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment