Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
a919e88a
".github/vscode:/vscode.git/clone" did not exist on "bad334fa5b9dd9d8efb98c2cd1b1bfe533434322"
Commit
a919e88a
authored
Aug 24, 2018
by
Paul
Browse files
Add nary to handle any number of arguments
parent
032d0650
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
99 additions
and
11 deletions
+99
-11
src/include/migraph/functional.hpp
src/include/migraph/functional.hpp
+8
-0
src/include/migraph/ranges.hpp
src/include/migraph/ranges.hpp
+12
-0
src/targets/gpu/device/add_relu.cpp
src/targets/gpu/device/add_relu.cpp
+4
-4
src/targets/gpu/device/contiguous.cpp
src/targets/gpu/device/contiguous.cpp
+3
-3
src/targets/gpu/device/include/migraph/gpu/device/nary.hpp
src/targets/gpu/device/include/migraph/gpu/device/nary.hpp
+68
-0
src/targets/gpu/fuse_ops.cpp
src/targets/gpu/fuse_ops.cpp
+1
-1
src/targets/gpu/include/migraph/gpu/device/add_relu.hpp
src/targets/gpu/include/migraph/gpu/device/add_relu.hpp
+1
-1
src/targets/gpu/include/migraph/gpu/device/contiguous.hpp
src/targets/gpu/include/migraph/gpu/device/contiguous.hpp
+1
-1
src/targets/gpu/lowering.cpp
src/targets/gpu/lowering.cpp
+1
-1
No files found.
src/include/migraph/functional.hpp
View file @
a919e88a
...
@@ -34,6 +34,14 @@ auto fix(F f)
...
@@ -34,6 +34,14 @@ auto fix(F f)
return
fix
<
void
>
(
f
);
return
fix
<
void
>
(
f
);
}
}
template
<
class
...
Ts
>
auto
make_sequence
(
Ts
...
xs
)
{
return
[
=
](
auto
f
)
{
return
f
(
xs
...);
};
}
}
// namespace migraph
}
// namespace migraph
#endif
#endif
src/include/migraph/ranges.hpp
View file @
a919e88a
...
@@ -60,6 +60,18 @@ bool contains(const std::initializer_list<T>& c, const U& x)
...
@@ -60,6 +60,18 @@ bool contains(const std::initializer_list<T>& c, const U& x)
return
generic_find
(
c
,
x
)
!=
c
.
end
();
return
generic_find
(
c
,
x
)
!=
c
.
end
();
}
}
template
<
class
C
,
class
Predicate
>
bool
all_of
(
const
C
&
c
,
const
Predicate
&
p
)
{
return
std
::
all_of
(
c
.
begin
(),
c
.
end
(),
p
);
}
template
<
class
T
,
class
Predicate
>
bool
all_of
(
const
std
::
initializer_list
<
T
>&
c
,
const
Predicate
&
p
)
{
return
std
::
all_of
(
c
.
begin
(),
c
.
end
(),
p
);
}
template
<
class
Range
,
class
Iterator
>
template
<
class
Range
,
class
Iterator
>
void
copy
(
Range
&&
r
,
Iterator
it
)
void
copy
(
Range
&&
r
,
Iterator
it
)
{
{
...
...
src/targets/gpu/device/add_relu.cpp
View file @
a919e88a
#include <migraph/gpu/device/
contiguous
.hpp>
#include <migraph/gpu/device/
add_relu
.hpp>
#include <migraph/gpu/device/
bi
nary.hpp>
#include <migraph/gpu/device/nary.hpp>
namespace
migraph
{
namespace
migraph
{
namespace
gpu
{
namespace
gpu
{
namespace
device
{
namespace
device
{
void
add_relu
(
argument
arg1
,
argument
arg2
,
argument
result
)
void
add_relu
(
argument
result
,
argument
arg1
,
argument
arg2
)
{
{
bi
nary_standard
(
arg1
,
arg2
,
result
,
[](
auto
x
,
auto
y
)
{
return
max
(
0
,
x
+
y
);
});
nary_standard
(
result
,
arg1
,
arg2
)(
[](
auto
x
,
auto
y
)
{
return
max
(
0
,
x
+
y
);
});
}
}
}
// namespace device
}
// namespace device
...
...
src/targets/gpu/device/contiguous.cpp
View file @
a919e88a
#include <migraph/gpu/device/contiguous.hpp>
#include <migraph/gpu/device/contiguous.hpp>
#include <migraph/gpu/device/
u
nary.hpp>
#include <migraph/gpu/device/nary.hpp>
namespace
migraph
{
namespace
migraph
{
namespace
gpu
{
namespace
gpu
{
namespace
device
{
namespace
device
{
void
contiguous
(
argument
arg
,
argument
result
)
void
contiguous
(
argument
result
,
argument
arg
)
{
{
u
nary_nonstandard
(
arg
,
result
,
[](
auto
x
)
{
return
x
;
});
nary_nonstandard
(
result
,
arg
)(
[](
auto
x
)
{
return
x
;
});
}
}
}
// namespace device
}
// namespace device
...
...
src/targets/gpu/device/include/migraph/gpu/device/nary.hpp
0 → 100644
View file @
a919e88a
#ifndef MIGRAPH_GUARD_RTGLIB_DEVICE_NARY_HPP
#define MIGRAPH_GUARD_RTGLIB_DEVICE_NARY_HPP
#include <migraph/gpu/device/tensor.hpp>
#include <migraph/gpu/device/launch.hpp>
#include <migraph/functional.hpp>
#include <migraph/ranges.hpp>
namespace
migraph
{
namespace
gpu
{
namespace
device
{
template
<
class
...
Arguments
>
auto
nary
(
argument
result
,
Arguments
...
args
)
{
return
[
=
](
auto
f
)
{
if
(
all_of
({
args
...},
[](
const
shape
&
s
)
{
return
s
.
standard
();
}))
nary_standard
(
result
,
args
...)(
f
);
else
nary_nonstandard
(
result
,
args
...)(
f
);
};
}
template
<
class
...
Arguments
>
auto
nary_nonstandard
(
argument
result
,
Arguments
...
args
)
{
return
[
=
](
auto
f
)
{
auto
output_shape
=
result
.
get_shape
();
visit_all
(
result
,
args
...)([
&
](
auto
output
,
auto
...
inputs
)
{
visit_tensor_size
(
output_shape
.
lens
().
size
(),
[
&
](
auto
ndim
)
{
auto
data
=
make_sequence
(
std
::
make_pair
(
hip_tensor_descriptor
<
ndim
>
{
inputs
.
get_shape
().
lens
(),
inputs
.
get_shape
().
strides
()},
inputs
.
data
())...);
hip_tensor_descriptor
<
ndim
>
out_desc
(
output_shape
.
lens
(),
output_shape
.
strides
());
auto
*
outp
=
output
.
data
();
gs_launch
(
output_shape
.
elements
())([
=
](
auto
i
)
{
data
([
&
](
auto
...
ps
)
{
auto
outidx
=
out_desc
.
multi
(
i
);
outp
[
i
]
=
f
(
ps
.
second
[
ps
.
first
.
linear
(
outidx
)]...);
});
});
});
});
};
}
template
<
class
...
Arguments
>
auto
nary_standard
(
argument
result
,
Arguments
...
args
)
{
return
[
=
](
auto
f
)
{
// assert(x.get_shape().elements() == y.get_shape().elements());
auto
output_shape
=
result
.
get_shape
();
visit_all
(
result
,
args
...)([
&
](
auto
output
,
auto
...
inputs
)
{
auto
data
=
make_sequence
(
inputs
.
data
()...);
auto
*
outp
=
output
.
data
();
gs_launch
(
output_shape
.
elements
())([
=
](
auto
i
)
{
data
([
&
](
auto
...
xps
)
{
outp
[
i
]
=
f
(
xps
[
i
]...);
});
});
});
};
}
}
// namespace device
}
// namespace gpu
}
// namespace migraph
#endif
src/targets/gpu/fuse_ops.cpp
View file @
a919e88a
...
@@ -17,7 +17,7 @@ struct hip_add_relu
...
@@ -17,7 +17,7 @@ struct hip_add_relu
}
}
argument
compute
(
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
args
)
const
argument
compute
(
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
args
)
const
{
{
device
::
add_relu
(
args
.
at
(
0
),
args
.
at
(
1
)
,
args
.
at
(
2
)
);
device
::
add_relu
(
args
.
at
(
2
),
args
.
at
(
0
),
args
.
at
(
1
));
return
args
.
at
(
2
);
return
args
.
at
(
2
);
}
}
};
};
...
...
src/targets/gpu/include/migraph/gpu/device/add_relu.hpp
View file @
a919e88a
...
@@ -8,7 +8,7 @@ namespace migraph {
...
@@ -8,7 +8,7 @@ namespace migraph {
namespace
gpu
{
namespace
gpu
{
namespace
device
{
namespace
device
{
void
add_relu
(
argument
arg1
,
argument
arg2
,
argument
result
);
void
add_relu
(
argument
result
,
argument
arg1
,
argument
arg2
);
}
// namespace device
}
// namespace device
}
// namespace gpu
}
// namespace gpu
...
...
src/targets/gpu/include/migraph/gpu/device/contiguous.hpp
View file @
a919e88a
...
@@ -7,7 +7,7 @@ namespace migraph {
...
@@ -7,7 +7,7 @@ namespace migraph {
namespace
gpu
{
namespace
gpu
{
namespace
device
{
namespace
device
{
void
contiguous
(
argument
arg
,
argument
result
);
void
contiguous
(
argument
result
,
argument
arg
);
}
// namespace device
}
// namespace device
}
// namespace gpu
}
// namespace gpu
...
...
src/targets/gpu/lowering.cpp
View file @
a919e88a
...
@@ -253,7 +253,7 @@ struct miopen_contiguous
...
@@ -253,7 +253,7 @@ struct miopen_contiguous
assert
(
output_shape
==
args
[
1
].
get_shape
());
assert
(
output_shape
==
args
[
1
].
get_shape
());
assert
(
output_shape
.
standard
());
assert
(
output_shape
.
standard
());
(
void
)
output_shape
;
(
void
)
output_shape
;
device
::
contiguous
(
args
.
at
(
0
),
args
.
at
(
1
));
device
::
contiguous
(
args
.
at
(
1
),
args
.
at
(
0
));
return
args
.
at
(
1
);
return
args
.
at
(
1
);
}
}
};
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment