Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
a919e88a
Commit
a919e88a
authored
Aug 24, 2018
by
Paul
Browse files
Add nary to handle any number of arguments
parent
032d0650
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
99 additions
and
11 deletions
+99
-11
src/include/migraph/functional.hpp
src/include/migraph/functional.hpp
+8
-0
src/include/migraph/ranges.hpp
src/include/migraph/ranges.hpp
+12
-0
src/targets/gpu/device/add_relu.cpp
src/targets/gpu/device/add_relu.cpp
+4
-4
src/targets/gpu/device/contiguous.cpp
src/targets/gpu/device/contiguous.cpp
+3
-3
src/targets/gpu/device/include/migraph/gpu/device/nary.hpp
src/targets/gpu/device/include/migraph/gpu/device/nary.hpp
+68
-0
src/targets/gpu/fuse_ops.cpp
src/targets/gpu/fuse_ops.cpp
+1
-1
src/targets/gpu/include/migraph/gpu/device/add_relu.hpp
src/targets/gpu/include/migraph/gpu/device/add_relu.hpp
+1
-1
src/targets/gpu/include/migraph/gpu/device/contiguous.hpp
src/targets/gpu/include/migraph/gpu/device/contiguous.hpp
+1
-1
src/targets/gpu/lowering.cpp
src/targets/gpu/lowering.cpp
+1
-1
No files found.
src/include/migraph/functional.hpp
View file @
a919e88a
...
...
@@ -34,6 +34,14 @@ auto fix(F f)
return
fix
<
void
>
(
f
);
}
template
<
class
...
Ts
>
auto
make_sequence
(
Ts
...
xs
)
{
return
[
=
](
auto
f
)
{
return
f
(
xs
...);
};
}
}
// namespace migraph
#endif
src/include/migraph/ranges.hpp
View file @
a919e88a
...
...
@@ -60,6 +60,18 @@ bool contains(const std::initializer_list<T>& c, const U& x)
return
generic_find
(
c
,
x
)
!=
c
.
end
();
}
template
<
class
C
,
class
Predicate
>
bool
all_of
(
const
C
&
c
,
const
Predicate
&
p
)
{
return
std
::
all_of
(
c
.
begin
(),
c
.
end
(),
p
);
}
template
<
class
T
,
class
Predicate
>
bool
all_of
(
const
std
::
initializer_list
<
T
>&
c
,
const
Predicate
&
p
)
{
return
std
::
all_of
(
c
.
begin
(),
c
.
end
(),
p
);
}
template
<
class
Range
,
class
Iterator
>
void
copy
(
Range
&&
r
,
Iterator
it
)
{
...
...
src/targets/gpu/device/add_relu.cpp
View file @
a919e88a
#include <migraph/gpu/device/
contiguous
.hpp>
#include <migraph/gpu/device/
bi
nary.hpp>
#include <migraph/gpu/device/
add_relu
.hpp>
#include <migraph/gpu/device/nary.hpp>
namespace
migraph
{
namespace
gpu
{
namespace
device
{
void
add_relu
(
argument
arg1
,
argument
arg2
,
argument
result
)
void
add_relu
(
argument
result
,
argument
arg1
,
argument
arg2
)
{
bi
nary_standard
(
arg1
,
arg2
,
result
,
[](
auto
x
,
auto
y
)
{
return
max
(
0
,
x
+
y
);
});
nary_standard
(
result
,
arg1
,
arg2
)(
[](
auto
x
,
auto
y
)
{
return
max
(
0
,
x
+
y
);
});
}
}
// namespace device
...
...
src/targets/gpu/device/contiguous.cpp
View file @
a919e88a
#include <migraph/gpu/device/contiguous.hpp>
#include <migraph/gpu/device/
u
nary.hpp>
#include <migraph/gpu/device/nary.hpp>
namespace
migraph
{
namespace
gpu
{
namespace
device
{
void
contiguous
(
argument
arg
,
argument
result
)
void
contiguous
(
argument
result
,
argument
arg
)
{
u
nary_nonstandard
(
arg
,
result
,
[](
auto
x
)
{
return
x
;
});
nary_nonstandard
(
result
,
arg
)(
[](
auto
x
)
{
return
x
;
});
}
}
// namespace device
...
...
src/targets/gpu/device/include/migraph/gpu/device/nary.hpp
0 → 100644
View file @
a919e88a
#ifndef MIGRAPH_GUARD_RTGLIB_DEVICE_NARY_HPP
#define MIGRAPH_GUARD_RTGLIB_DEVICE_NARY_HPP
#include <migraph/gpu/device/tensor.hpp>
#include <migraph/gpu/device/launch.hpp>
#include <migraph/functional.hpp>
#include <migraph/ranges.hpp>
namespace
migraph
{
namespace
gpu
{
namespace
device
{
template
<
class
...
Arguments
>
auto
nary
(
argument
result
,
Arguments
...
args
)
{
return
[
=
](
auto
f
)
{
if
(
all_of
({
args
...},
[](
const
shape
&
s
)
{
return
s
.
standard
();
}))
nary_standard
(
result
,
args
...)(
f
);
else
nary_nonstandard
(
result
,
args
...)(
f
);
};
}
template
<
class
...
Arguments
>
auto
nary_nonstandard
(
argument
result
,
Arguments
...
args
)
{
return
[
=
](
auto
f
)
{
auto
output_shape
=
result
.
get_shape
();
visit_all
(
result
,
args
...)([
&
](
auto
output
,
auto
...
inputs
)
{
visit_tensor_size
(
output_shape
.
lens
().
size
(),
[
&
](
auto
ndim
)
{
auto
data
=
make_sequence
(
std
::
make_pair
(
hip_tensor_descriptor
<
ndim
>
{
inputs
.
get_shape
().
lens
(),
inputs
.
get_shape
().
strides
()},
inputs
.
data
())...);
hip_tensor_descriptor
<
ndim
>
out_desc
(
output_shape
.
lens
(),
output_shape
.
strides
());
auto
*
outp
=
output
.
data
();
gs_launch
(
output_shape
.
elements
())([
=
](
auto
i
)
{
data
([
&
](
auto
...
ps
)
{
auto
outidx
=
out_desc
.
multi
(
i
);
outp
[
i
]
=
f
(
ps
.
second
[
ps
.
first
.
linear
(
outidx
)]...);
});
});
});
});
};
}
template
<
class
...
Arguments
>
auto
nary_standard
(
argument
result
,
Arguments
...
args
)
{
return
[
=
](
auto
f
)
{
// assert(x.get_shape().elements() == y.get_shape().elements());
auto
output_shape
=
result
.
get_shape
();
visit_all
(
result
,
args
...)([
&
](
auto
output
,
auto
...
inputs
)
{
auto
data
=
make_sequence
(
inputs
.
data
()...);
auto
*
outp
=
output
.
data
();
gs_launch
(
output_shape
.
elements
())([
=
](
auto
i
)
{
data
([
&
](
auto
...
xps
)
{
outp
[
i
]
=
f
(
xps
[
i
]...);
});
});
});
};
}
}
// namespace device
}
// namespace gpu
}
// namespace migraph
#endif
src/targets/gpu/fuse_ops.cpp
View file @
a919e88a
...
...
@@ -17,7 +17,7 @@ struct hip_add_relu
}
argument
compute
(
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
args
)
const
{
device
::
add_relu
(
args
.
at
(
0
),
args
.
at
(
1
)
,
args
.
at
(
2
)
);
device
::
add_relu
(
args
.
at
(
2
),
args
.
at
(
0
),
args
.
at
(
1
));
return
args
.
at
(
2
);
}
};
...
...
src/targets/gpu/include/migraph/gpu/device/add_relu.hpp
View file @
a919e88a
...
...
@@ -8,7 +8,7 @@ namespace migraph {
namespace
gpu
{
namespace
device
{
void
add_relu
(
argument
arg1
,
argument
arg2
,
argument
result
);
void
add_relu
(
argument
result
,
argument
arg1
,
argument
arg2
);
}
// namespace device
}
// namespace gpu
...
...
src/targets/gpu/include/migraph/gpu/device/contiguous.hpp
View file @
a919e88a
...
...
@@ -7,7 +7,7 @@ namespace migraph {
namespace
gpu
{
namespace
device
{
void
contiguous
(
argument
arg
,
argument
result
);
void
contiguous
(
argument
result
,
argument
arg
);
}
// namespace device
}
// namespace gpu
...
...
src/targets/gpu/lowering.cpp
View file @
a919e88a
...
...
@@ -253,7 +253,7 @@ struct miopen_contiguous
assert
(
output_shape
==
args
[
1
].
get_shape
());
assert
(
output_shape
.
standard
());
(
void
)
output_shape
;
device
::
contiguous
(
args
.
at
(
0
),
args
.
at
(
1
));
device
::
contiguous
(
args
.
at
(
1
),
args
.
at
(
0
));
return
args
.
at
(
1
);
}
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment