Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
9979fada
"git@developer.sourcefind.cn:zhaoyu6/sglang.git" did not exist on "65e6f48ce43fe21b0fb282ee8ece3421e1cd6e52"
Commit
9979fada
authored
Oct 26, 2018
by
Paul
Browse files
Add stream to device functions
parent
3f84bc67
Changes
14
Show whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
71 additions
and
67 deletions
+71
-67
src/targets/gpu/add.cpp
src/targets/gpu/add.cpp
+2
-2
src/targets/gpu/concat.cpp
src/targets/gpu/concat.cpp
+2
-2
src/targets/gpu/contiguous.cpp
src/targets/gpu/contiguous.cpp
+2
-2
src/targets/gpu/device/add.cpp
src/targets/gpu/device/add.cpp
+4
-4
src/targets/gpu/device/add_relu.cpp
src/targets/gpu/device/add_relu.cpp
+4
-4
src/targets/gpu/device/concat.cpp
src/targets/gpu/device/concat.cpp
+2
-4
src/targets/gpu/device/contiguous.cpp
src/targets/gpu/device/contiguous.cpp
+2
-2
src/targets/gpu/device/include/migraph/gpu/device/launch.hpp
src/targets/gpu/device/include/migraph/gpu/device/launch.hpp
+4
-4
src/targets/gpu/device/include/migraph/gpu/device/nary.hpp
src/targets/gpu/device/include/migraph/gpu/device/nary.hpp
+31
-31
src/targets/gpu/fuse_ops.cpp
src/targets/gpu/fuse_ops.cpp
+6
-6
src/targets/gpu/include/migraph/gpu/device/add.hpp
src/targets/gpu/include/migraph/gpu/device/add.hpp
+3
-2
src/targets/gpu/include/migraph/gpu/device/add_relu.hpp
src/targets/gpu/include/migraph/gpu/device/add_relu.hpp
+3
-2
src/targets/gpu/include/migraph/gpu/device/concat.hpp
src/targets/gpu/include/migraph/gpu/device/concat.hpp
+4
-1
src/targets/gpu/include/migraph/gpu/device/contiguous.hpp
src/targets/gpu/include/migraph/gpu/device/contiguous.hpp
+2
-1
No files found.
src/targets/gpu/add.cpp
View file @
9979fada
...
@@ -14,9 +14,9 @@ shape hip_add::compute_shape(const std::vector<shape>& inputs) const
...
@@ -14,9 +14,9 @@ shape hip_add::compute_shape(const std::vector<shape>& inputs) const
return
inputs
.
at
(
0
);
return
inputs
.
at
(
0
);
}
}
argument
hip_add
::
compute
(
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
args
)
const
argument
hip_add
::
compute
(
context
&
ctx
,
const
shape
&
,
const
std
::
vector
<
argument
>&
args
)
const
{
{
device
::
add
(
args
[
2
],
args
[
0
],
args
[
1
]);
device
::
add
(
ctx
.
get_stream
().
get
(),
args
[
2
],
args
[
0
],
args
[
1
]);
return
args
[
2
];
return
args
[
2
];
}
}
...
...
src/targets/gpu/concat.cpp
View file @
9979fada
...
@@ -15,10 +15,10 @@ shape hip_concat::compute_shape(std::vector<shape> inputs) const
...
@@ -15,10 +15,10 @@ shape hip_concat::compute_shape(std::vector<shape> inputs) const
}
}
argument
argument
hip_concat
::
compute
(
context
&
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
)
const
hip_concat
::
compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
)
const
{
{
std
::
vector
<
std
::
size_t
>
offsets
=
op
.
compute_offsets
(
output_shape
,
args
);
std
::
vector
<
std
::
size_t
>
offsets
=
op
.
compute_offsets
(
output_shape
,
args
);
return
device
::
concat
(
output_shape
,
args
,
offsets
);
return
device
::
concat
(
ctx
.
get_stream
().
get
(),
output_shape
,
args
,
offsets
);
}
}
}
// namespace gpu
}
// namespace gpu
...
...
src/targets/gpu/contiguous.cpp
View file @
9979fada
...
@@ -13,12 +13,12 @@ shape miopen_contiguous::compute_shape(const std::vector<shape>& inputs) const
...
@@ -13,12 +13,12 @@ shape miopen_contiguous::compute_shape(const std::vector<shape>& inputs) const
return
op
.
compute_shape
({
inputs
.
at
(
0
)});
return
op
.
compute_shape
({
inputs
.
at
(
0
)});
}
}
argument
argument
miopen_contiguous
::
compute
(
context
&
,
shape
output_shape
,
const
std
::
vector
<
argument
>&
args
)
const
miopen_contiguous
::
compute
(
context
&
ctx
,
shape
output_shape
,
const
std
::
vector
<
argument
>&
args
)
const
{
{
assert
(
output_shape
==
args
[
1
].
get_shape
());
assert
(
output_shape
==
args
[
1
].
get_shape
());
assert
(
output_shape
.
standard
());
assert
(
output_shape
.
standard
());
(
void
)
output_shape
;
(
void
)
output_shape
;
device
::
contiguous
(
args
.
at
(
1
),
args
.
at
(
0
));
device
::
contiguous
(
ctx
.
get_stream
().
get
(),
args
.
at
(
1
),
args
.
at
(
0
));
return
args
.
at
(
1
);
return
args
.
at
(
1
);
}
}
...
...
src/targets/gpu/device/add.cpp
View file @
9979fada
...
@@ -5,14 +5,14 @@ namespace migraph {
...
@@ -5,14 +5,14 @@ namespace migraph {
namespace
gpu
{
namespace
gpu
{
namespace
device
{
namespace
device
{
void
add
(
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg2
)
void
add
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg2
)
{
{
nary
(
result
,
arg1
,
arg2
)([](
auto
x
,
auto
y
)
{
return
x
+
y
;
});
nary
(
stream
,
result
,
arg1
,
arg2
)([](
auto
x
,
auto
y
)
{
return
x
+
y
;
});
}
}
void
add
(
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg2
,
const
argument
&
arg3
)
void
add
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg2
,
const
argument
&
arg3
)
{
{
nary
(
result
,
arg1
,
arg2
,
arg3
)([](
auto
x
,
auto
y
,
auto
z
)
{
return
x
+
y
+
z
;
});
nary
(
stream
,
result
,
arg1
,
arg2
,
arg3
)([](
auto
x
,
auto
y
,
auto
z
)
{
return
x
+
y
+
z
;
});
}
}
}
// namespace device
}
// namespace device
...
...
src/targets/gpu/device/add_relu.cpp
View file @
9979fada
...
@@ -5,17 +5,17 @@ namespace migraph {
...
@@ -5,17 +5,17 @@ namespace migraph {
namespace
gpu
{
namespace
gpu
{
namespace
device
{
namespace
device
{
void
add_relu
(
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg2
)
void
add_relu
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg2
)
{
{
nary
(
result
,
arg1
,
arg2
)([](
auto
x
,
auto
y
)
{
return
std
::
max
<
decltype
(
x
+
y
)
>
(
0
,
x
+
y
);
});
nary
(
stream
,
result
,
arg1
,
arg2
)([](
auto
x
,
auto
y
)
{
return
std
::
max
<
decltype
(
x
+
y
)
>
(
0
,
x
+
y
);
});
}
}
void
add_relu
(
const
argument
&
result
,
void
add_relu
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg1
,
const
argument
&
arg2
,
const
argument
&
arg2
,
const
argument
&
arg3
)
const
argument
&
arg3
)
{
{
nary
(
result
,
arg1
,
arg2
,
arg3
)(
nary
(
stream
,
result
,
arg1
,
arg2
,
arg3
)(
[](
auto
x
,
auto
y
,
auto
z
)
{
return
std
::
max
<
decltype
(
x
+
y
+
z
)
>
(
0
,
x
+
y
+
z
);
});
[](
auto
x
,
auto
y
,
auto
z
)
{
return
std
::
max
<
decltype
(
x
+
y
+
z
)
>
(
0
,
x
+
y
+
z
);
});
}
}
...
...
src/targets/gpu/device/concat.cpp
View file @
9979fada
...
@@ -8,11 +8,10 @@ namespace migraph {
...
@@ -8,11 +8,10 @@ namespace migraph {
namespace
gpu
{
namespace
gpu
{
namespace
device
{
namespace
device
{
argument
concat
(
const
migraph
::
shape
&
output_shape
,
argument
concat
(
hipStream_t
stream
,
const
migraph
::
shape
&
output_shape
,
std
::
vector
<
migraph
::
argument
>
args
,
std
::
vector
<
migraph
::
argument
>
args
,
std
::
vector
<
std
::
size_t
>
offsets
)
std
::
vector
<
std
::
size_t
>
offsets
)
{
{
// migraph::argument& result = args.back();
for
(
std
::
size_t
l
=
0
;
l
<
args
.
size
()
-
1
;
l
++
)
for
(
std
::
size_t
l
=
0
;
l
<
args
.
size
()
-
1
;
l
++
)
{
{
auto
argl
=
args
[
l
];
auto
argl
=
args
[
l
];
...
@@ -23,12 +22,11 @@ argument concat(const migraph::shape& output_shape,
...
@@ -23,12 +22,11 @@ argument concat(const migraph::shape& output_shape,
const
auto
*
inptr
=
input
.
data
();
const
auto
*
inptr
=
input
.
data
();
hip_tensor_descriptor
<
ndim
>
desc_input
(
input
.
get_shape
());
hip_tensor_descriptor
<
ndim
>
desc_input
(
input
.
get_shape
());
hip_tensor_descriptor
<
ndim
>
desc_output
(
output
.
get_shape
());
hip_tensor_descriptor
<
ndim
>
desc_output
(
output
.
get_shape
());
gs_launch
(
nelements
)(
gs_launch
(
stream
,
nelements
)(
[
=
](
auto
i
)
{
outptr
[
desc_output
.
linear
(
desc_input
.
multi
(
i
))]
=
inptr
[
i
];
});
[
=
](
auto
i
)
{
outptr
[
desc_output
.
linear
(
desc_input
.
multi
(
i
))]
=
inptr
[
i
];
});
});
});
});
});
}
}
// return result;
return
args
.
back
();
return
args
.
back
();
}
}
...
...
src/targets/gpu/device/contiguous.cpp
View file @
9979fada
...
@@ -6,9 +6,9 @@ namespace migraph {
...
@@ -6,9 +6,9 @@ namespace migraph {
namespace
gpu
{
namespace
gpu
{
namespace
device
{
namespace
device
{
void
contiguous
(
argument
result
,
argument
arg
)
void
contiguous
(
hipStream_t
stream
,
argument
result
,
argument
arg
)
{
{
nary_nonstandard
(
std
::
move
(
result
),
std
::
move
(
arg
))([](
auto
x
)
{
return
x
;
});
nary_nonstandard
(
stream
,
std
::
move
(
result
),
std
::
move
(
arg
))([](
auto
x
)
{
return
x
;
});
}
}
}
// namespace device
}
// namespace device
...
...
src/targets/gpu/device/include/migraph/gpu/device/launch.hpp
View file @
9979fada
...
@@ -21,7 +21,7 @@ __global__ void launcher(F f)
...
@@ -21,7 +21,7 @@ __global__ void launcher(F f)
f
(
idx
);
f
(
idx
);
}
}
inline
auto
launch
(
std
::
size_t
global
,
std
::
size_t
local
)
inline
auto
launch
(
hipStream_t
stream
,
std
::
size_t
global
,
std
::
size_t
local
)
{
{
return
[
=
](
auto
f
)
{
return
[
=
](
auto
f
)
{
assert
(
local
>
0
);
assert
(
local
>
0
);
...
@@ -29,17 +29,17 @@ inline auto launch(std::size_t global, std::size_t local)
...
@@ -29,17 +29,17 @@ inline auto launch(std::size_t global, std::size_t local)
using
f_type
=
decltype
(
f
);
using
f_type
=
decltype
(
f
);
dim3
nblocks
(
global
/
local
);
dim3
nblocks
(
global
/
local
);
dim3
nthreads
(
local
);
dim3
nthreads
(
local
);
hipLaunchKernelGGL
((
launcher
<
f_type
>
),
nblocks
,
nthreads
,
0
,
nullptr
,
f
);
hipLaunchKernelGGL
((
launcher
<
f_type
>
),
nblocks
,
nthreads
,
0
,
stream
,
f
);
};
};
}
}
inline
auto
gs_launch
(
std
::
size_t
n
,
std
::
size_t
local
=
1024
)
inline
auto
gs_launch
(
hipStream_t
stream
,
std
::
size_t
n
,
std
::
size_t
local
=
1024
)
{
{
std
::
size_t
groups
=
1
+
n
/
local
;
std
::
size_t
groups
=
1
+
n
/
local
;
std
::
size_t
nglobal
=
std
::
min
<
std
::
size_t
>
(
256
,
groups
)
*
local
;
std
::
size_t
nglobal
=
std
::
min
<
std
::
size_t
>
(
256
,
groups
)
*
local
;
return
[
=
](
auto
f
)
{
return
[
=
](
auto
f
)
{
launch
(
nglobal
,
local
)([
=
](
auto
idx
)
{
launch
(
stream
,
nglobal
,
local
)([
=
](
auto
idx
)
{
for
(
size_t
i
=
idx
.
global
;
i
<
n
;
i
+=
nglobal
)
for
(
size_t
i
=
idx
.
global
;
i
<
n
;
i
+=
nglobal
)
{
{
f
(
i
);
f
(
i
);
...
...
src/targets/gpu/device/include/migraph/gpu/device/nary.hpp
View file @
9979fada
...
@@ -32,7 +32,7 @@ auto pack_vec4(Ts... xs)
...
@@ -32,7 +32,7 @@ auto pack_vec4(Ts... xs)
}
}
template
<
class
F
,
class
...
Arguments
>
template
<
class
F
,
class
...
Arguments
>
auto
nary_nonstandard_impl
(
F
f
,
argument
result
,
Arguments
...
args
)
auto
nary_nonstandard_impl
(
hipStream_t
stream
,
F
f
,
argument
result
,
Arguments
...
args
)
{
{
const
auto
&
output_shape
=
result
.
get_shape
();
const
auto
&
output_shape
=
result
.
get_shape
();
visit_all
(
result
,
args
...)([
&
](
auto
output
,
auto
...
inputs
)
{
visit_all
(
result
,
args
...)([
&
](
auto
output
,
auto
...
inputs
)
{
...
@@ -41,7 +41,7 @@ auto nary_nonstandard_impl(F f, argument result, Arguments... args)
...
@@ -41,7 +41,7 @@ auto nary_nonstandard_impl(F f, argument result, Arguments... args)
std
::
make_pair
(
hip_tensor_descriptor
<
ndim
>
{
inputs
.
get_shape
()},
inputs
.
data
())...);
std
::
make_pair
(
hip_tensor_descriptor
<
ndim
>
{
inputs
.
get_shape
()},
inputs
.
data
())...);
hip_tensor_descriptor
<
ndim
>
out_desc
(
output_shape
);
hip_tensor_descriptor
<
ndim
>
out_desc
(
output_shape
);
auto
*
outp
=
output
.
data
();
auto
*
outp
=
output
.
data
();
gs_launch
(
output_shape
.
elements
())([
=
](
auto
i
)
{
gs_launch
(
stream
,
output_shape
.
elements
())([
=
](
auto
i
)
{
data
([
&
](
auto
&&
...
ps
)
{
data
([
&
](
auto
&&
...
ps
)
{
auto
outidx
=
out_desc
.
multi
(
i
);
auto
outidx
=
out_desc
.
multi
(
i
);
outp
[
i
]
=
f
(
ps
.
second
[
ps
.
first
.
linear
(
outidx
)]...);
outp
[
i
]
=
f
(
ps
.
second
[
ps
.
first
.
linear
(
outidx
)]...);
...
@@ -53,7 +53,7 @@ auto nary_nonstandard_impl(F f, argument result, Arguments... args)
...
@@ -53,7 +53,7 @@ auto nary_nonstandard_impl(F f, argument result, Arguments... args)
template
<
class
F
>
template
<
class
F
>
void
trinary_broadcast_vec_impl
(
void
trinary_broadcast_vec_impl
(
F
f
,
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg2
,
const
argument
&
arg3
)
hipStream_t
stream
,
F
f
,
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg2
,
const
argument
&
arg3
)
{
{
const
auto
&
output_shape
=
result
.
get_shape
();
const
auto
&
output_shape
=
result
.
get_shape
();
const
auto
&
b_shape
=
arg3
.
get_shape
();
const
auto
&
b_shape
=
arg3
.
get_shape
();
...
@@ -79,7 +79,7 @@ void trinary_broadcast_vec_impl(
...
@@ -79,7 +79,7 @@ void trinary_broadcast_vec_impl(
const
std
::
size_t
n
=
output
.
size
()
/
vec_size
;
const
std
::
size_t
n
=
output
.
size
()
/
vec_size
;
const
std
::
size_t
bdim_vec_len
=
bdim_len
/
vec_size
;
const
std
::
size_t
bdim_vec_len
=
bdim_len
/
vec_size
;
launch
(
nglobal
,
nlocal
)([
=
](
auto
idx
)
__device__
{
launch
(
stream
,
nglobal
,
nlocal
)([
=
](
auto
idx
)
__device__
{
MIGRAPH_DEVICE_SHARED
vec4
<
type
>
buffer
[
2048
/
vec_size
];
MIGRAPH_DEVICE_SHARED
vec4
<
type
>
buffer
[
2048
/
vec_size
];
// Load bias into LDS
// Load bias into LDS
for
(
size_t
i
=
idx
.
local
;
i
<
bdim_vec_len
;
i
+=
nlocal
)
for
(
size_t
i
=
idx
.
local
;
i
<
bdim_vec_len
;
i
+=
nlocal
)
...
@@ -108,7 +108,7 @@ void trinary_broadcast_vec_impl(
...
@@ -108,7 +108,7 @@ void trinary_broadcast_vec_impl(
template
<
class
F
>
template
<
class
F
>
void
trinary_broadcast_impl
(
void
trinary_broadcast_impl
(
F
f
,
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg2
,
const
argument
&
arg3
)
hipStream_t
stream
,
F
f
,
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg2
,
const
argument
&
arg3
)
{
{
const
auto
&
output_shape
=
result
.
get_shape
();
const
auto
&
output_shape
=
result
.
get_shape
();
const
auto
&
b_shape
=
arg3
.
get_shape
();
const
auto
&
b_shape
=
arg3
.
get_shape
();
...
@@ -132,7 +132,7 @@ void trinary_broadcast_impl(
...
@@ -132,7 +132,7 @@ void trinary_broadcast_impl(
const
std
::
size_t
nglobal
=
256
*
nlocal
;
const
std
::
size_t
nglobal
=
256
*
nlocal
;
const
std
::
size_t
n
=
output
.
size
();
const
std
::
size_t
n
=
output
.
size
();
launch
(
nglobal
,
nlocal
)([
=
](
auto
idx
)
__device__
{
launch
(
stream
,
nglobal
,
nlocal
)([
=
](
auto
idx
)
__device__
{
MIGRAPH_DEVICE_SHARED
type
buffer
[
2048
];
MIGRAPH_DEVICE_SHARED
type
buffer
[
2048
];
// Load bias into LDS
// Load bias into LDS
for
(
size_t
i
=
idx
.
local
;
i
<
bdim_len
;
i
+=
nlocal
)
for
(
size_t
i
=
idx
.
local
;
i
<
bdim_len
;
i
+=
nlocal
)
...
@@ -154,7 +154,7 @@ void trinary_broadcast_impl(
...
@@ -154,7 +154,7 @@ void trinary_broadcast_impl(
}
}
template
<
class
F
>
template
<
class
F
>
void
binary_broadcast_vec_impl
(
F
f
,
void
binary_broadcast_vec_impl
(
hipStream_t
stream
,
F
f
,
const
argument
&
result
,
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg1
,
const
argument
&
arg2
)
const
argument
&
arg2
)
...
@@ -182,7 +182,7 @@ void binary_broadcast_vec_impl(F f,
...
@@ -182,7 +182,7 @@ void binary_broadcast_vec_impl(F f,
const
std
::
size_t
n
=
output
.
size
()
/
vec_size
;
const
std
::
size_t
n
=
output
.
size
()
/
vec_size
;
const
std
::
size_t
bdim_vec_len
=
bdim_len
/
vec_size
;
const
std
::
size_t
bdim_vec_len
=
bdim_len
/
vec_size
;
launch
(
nglobal
,
nlocal
)([
=
](
auto
idx
)
__device__
{
launch
(
stream
,
nglobal
,
nlocal
)([
=
](
auto
idx
)
__device__
{
MIGRAPH_DEVICE_SHARED
vec4
<
type
>
buffer
[
2048
/
vec_size
];
MIGRAPH_DEVICE_SHARED
vec4
<
type
>
buffer
[
2048
/
vec_size
];
// Load bias into LDS
// Load bias into LDS
for
(
size_t
i
=
idx
.
local
;
i
<
bdim_vec_len
;
i
+=
nlocal
)
for
(
size_t
i
=
idx
.
local
;
i
<
bdim_vec_len
;
i
+=
nlocal
)
...
@@ -209,7 +209,7 @@ void binary_broadcast_vec_impl(F f,
...
@@ -209,7 +209,7 @@ void binary_broadcast_vec_impl(F f,
}
}
template
<
class
F
>
template
<
class
F
>
void
binary_broadcast_impl
(
F
f
,
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg2
)
void
binary_broadcast_impl
(
hipStream_t
stream
,
F
f
,
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg2
)
{
{
const
auto
&
output_shape
=
result
.
get_shape
();
const
auto
&
output_shape
=
result
.
get_shape
();
const
auto
&
b_shape
=
arg2
.
get_shape
();
const
auto
&
b_shape
=
arg2
.
get_shape
();
...
@@ -232,7 +232,7 @@ void binary_broadcast_impl(F f, const argument& result, const argument& arg1, co
...
@@ -232,7 +232,7 @@ void binary_broadcast_impl(F f, const argument& result, const argument& arg1, co
const
std
::
size_t
nglobal
=
256
*
nlocal
;
const
std
::
size_t
nglobal
=
256
*
nlocal
;
const
std
::
size_t
n
=
output
.
size
();
const
std
::
size_t
n
=
output
.
size
();
launch
(
nglobal
,
nlocal
)([
=
](
auto
idx
)
__device__
{
launch
(
stream
,
nglobal
,
nlocal
)([
=
](
auto
idx
)
__device__
{
MIGRAPH_DEVICE_SHARED
type
buffer
[
2048
];
MIGRAPH_DEVICE_SHARED
type
buffer
[
2048
];
// Load bias into LDS
// Load bias into LDS
for
(
size_t
i
=
idx
.
local
;
i
<
bdim_len
;
i
+=
nlocal
)
for
(
size_t
i
=
idx
.
local
;
i
<
bdim_len
;
i
+=
nlocal
)
...
@@ -253,7 +253,7 @@ void binary_broadcast_impl(F f, const argument& result, const argument& arg1, co
...
@@ -253,7 +253,7 @@ void binary_broadcast_impl(F f, const argument& result, const argument& arg1, co
}
}
template
<
class
F
,
class
...
Arguments
>
template
<
class
F
,
class
...
Arguments
>
void
nary_standard_vec_impl
(
F
f
,
argument
result
,
Arguments
...
args
)
void
nary_standard_vec_impl
(
hipStream_t
stream
,
F
f
,
argument
result
,
Arguments
...
args
)
{
{
// assert(x.get_shape().elements() == y.get_shape().elements());
// assert(x.get_shape().elements() == y.get_shape().elements());
const
auto
&
output_shape
=
result
.
get_shape
();
const
auto
&
output_shape
=
result
.
get_shape
();
...
@@ -262,7 +262,7 @@ void nary_standard_vec_impl(F f, argument result, Arguments... args)
...
@@ -262,7 +262,7 @@ void nary_standard_vec_impl(F f, argument result, Arguments... args)
const
std
::
size_t
vec_size
=
4
;
const
std
::
size_t
vec_size
=
4
;
auto
data
=
pack_vec4
(
inputs
.
data
()...);
auto
data
=
pack_vec4
(
inputs
.
data
()...);
auto
*
outp
=
as_vec4
(
output
.
data
());
auto
*
outp
=
as_vec4
(
output
.
data
());
gs_launch
(
output_shape
.
elements
()
/
vec_size
)([
=
](
auto
i
)
{
gs_launch
(
stream
,
output_shape
.
elements
()
/
vec_size
)([
=
](
auto
i
)
{
vec4
<
type
>
out
=
outp
[
i
];
vec4
<
type
>
out
=
outp
[
i
];
data
(
data
(
[
&
](
auto
...
xs
)
{
[
&
](
auto
...
xs
)
{
...
@@ -278,50 +278,50 @@ void nary_standard_vec_impl(F f, argument result, Arguments... args)
...
@@ -278,50 +278,50 @@ void nary_standard_vec_impl(F f, argument result, Arguments... args)
}
}
template
<
class
F
,
class
...
Arguments
>
template
<
class
F
,
class
...
Arguments
>
void
nary_standard_impl
(
F
f
,
argument
result
,
Arguments
...
args
)
void
nary_standard_impl
(
hipStream_t
stream
,
F
f
,
argument
result
,
Arguments
...
args
)
{
{
// assert(x.get_shape().elements() == y.get_shape().elements());
// assert(x.get_shape().elements() == y.get_shape().elements());
const
auto
&
output_shape
=
result
.
get_shape
();
const
auto
&
output_shape
=
result
.
get_shape
();
visit_all
(
result
,
args
...)([
&
](
auto
output
,
auto
...
inputs
)
{
visit_all
(
result
,
args
...)([
&
](
auto
output
,
auto
...
inputs
)
{
auto
data
=
pack
(
inputs
.
data
()...);
auto
data
=
pack
(
inputs
.
data
()...);
auto
*
outp
=
output
.
data
();
auto
*
outp
=
output
.
data
();
gs_launch
(
output_shape
.
elements
())(
gs_launch
(
stream
,
output_shape
.
elements
())(
[
=
](
auto
i
)
{
data
([
&
](
auto
...
xps
)
{
outp
[
i
]
=
f
(
xps
[
i
]...);
});
});
[
=
](
auto
i
)
{
data
([
&
](
auto
...
xps
)
{
outp
[
i
]
=
f
(
xps
[
i
]...);
});
});
});
});
}
}
template
<
class
F
,
class
...
Arguments
>
template
<
class
F
,
class
...
Arguments
>
void
nary_impl
(
F
f
,
argument
result
,
Arguments
...
args
)
void
nary_impl
(
hipStream_t
stream
,
F
f
,
argument
result
,
Arguments
...
args
)
{
{
bool
standard
=
all_of
({
args
.
get_shape
()...},
[](
const
shape
&
s
)
{
return
s
.
standard
();
});
bool
standard
=
all_of
({
args
.
get_shape
()...},
[](
const
shape
&
s
)
{
return
s
.
standard
();
});
bool
packed
=
all_of
({
args
.
get_shape
()...},
[](
const
shape
&
s
)
{
return
s
.
packed
();
});
bool
packed
=
all_of
({
args
.
get_shape
()...},
[](
const
shape
&
s
)
{
return
s
.
packed
();
});
bool
same_shapes
=
bool
same_shapes
=
all_of
({
args
.
get_shape
()...},
[
&
](
const
shape
&
s
)
{
return
s
==
result
.
get_shape
();
});
all_of
({
args
.
get_shape
()...},
[
&
](
const
shape
&
s
)
{
return
s
==
result
.
get_shape
();
});
if
(
standard
or
(
packed
and
same_shapes
))
if
(
standard
or
(
packed
and
same_shapes
))
nary_standard_impl
(
f
,
result
,
args
...);
nary_standard_impl
(
stream
,
f
,
result
,
args
...);
else
else
nary_nonstandard_impl
(
f
,
result
,
args
...);
nary_nonstandard_impl
(
stream
,
f
,
result
,
args
...);
}
}
template
<
class
...
Arguments
>
template
<
class
...
Arguments
>
auto
nary_nonstandard
(
argument
result
,
Arguments
...
args
)
auto
nary_nonstandard
(
hipStream_t
stream
,
argument
result
,
Arguments
...
args
)
{
{
return
[
=
](
auto
f
)
{
nary_nonstandard_impl
(
f
,
result
,
args
...);
};
return
[
=
](
auto
f
)
{
nary_nonstandard_impl
(
stream
,
f
,
result
,
args
...);
};
}
}
template
<
class
...
Arguments
>
template
<
class
...
Arguments
>
auto
nary_standard
(
argument
result
,
Arguments
...
args
)
auto
nary_standard
(
hipStream_t
stream
,
argument
result
,
Arguments
...
args
)
{
{
return
[
=
](
auto
f
)
{
nary_standard_impl
(
f
,
result
,
args
...);
};
return
[
=
](
auto
f
)
{
nary_standard_impl
(
stream
,
f
,
result
,
args
...);
};
}
}
template
<
class
...
Arguments
>
template
<
class
...
Arguments
>
auto
nary
(
argument
result
,
Arguments
...
args
)
auto
nary
(
hipStream_t
stream
,
argument
result
,
Arguments
...
args
)
{
{
return
[
=
](
auto
f
)
{
nary_impl
(
f
,
result
,
args
...);
};
return
[
=
](
auto
f
)
{
nary_impl
(
stream
,
f
,
result
,
args
...);
};
}
}
inline
auto
nary
(
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg2
)
inline
auto
nary
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg2
)
{
{
return
[
=
](
auto
f
)
{
return
[
=
](
auto
f
)
{
// TODO: Check result and arg1 shape is the same
// TODO: Check result and arg1 shape is the same
...
@@ -339,18 +339,18 @@ inline auto nary(const argument& result, const argument& arg1, const argument& a
...
@@ -339,18 +339,18 @@ inline auto nary(const argument& result, const argument& arg1, const argument& a
const
bool
divisible_by_4
=
(
b_len
%
4
==
0
)
and
(
b_stride
%
4
==
0
)
and
const
bool
divisible_by_4
=
(
b_len
%
4
==
0
)
and
(
b_stride
%
4
==
0
)
and
(
arg1
.
get_shape
().
elements
()
%
4
==
0
);
(
arg1
.
get_shape
().
elements
()
%
4
==
0
);
if
(
divisible_by_4
)
if
(
divisible_by_4
)
binary_broadcast_vec_impl
(
f
,
result
,
arg1
,
arg2
);
binary_broadcast_vec_impl
(
stream
,
f
,
result
,
arg1
,
arg2
);
else
else
binary_broadcast_impl
(
f
,
result
,
arg1
,
arg2
);
binary_broadcast_impl
(
stream
,
f
,
result
,
arg1
,
arg2
);
return
;
return
;
}
}
}
}
nary_impl
(
f
,
result
,
arg1
,
arg2
);
nary_impl
(
stream
,
f
,
result
,
arg1
,
arg2
);
};
};
}
}
inline
auto
inline
auto
nary
(
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg2
,
const
argument
&
arg3
)
nary
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg2
,
const
argument
&
arg3
)
{
{
return
[
=
](
auto
f
)
{
return
[
=
](
auto
f
)
{
// TODO: Check result and arg1 shape is the same
// TODO: Check result and arg1 shape is the same
...
@@ -369,13 +369,13 @@ nary(const argument& result, const argument& arg1, const argument& arg2, const a
...
@@ -369,13 +369,13 @@ nary(const argument& result, const argument& arg1, const argument& arg2, const a
const
bool
divisible_by_4
=
(
b_len
%
4
==
0
)
and
(
b_stride
%
4
==
0
)
and
const
bool
divisible_by_4
=
(
b_len
%
4
==
0
)
and
(
b_stride
%
4
==
0
)
and
(
arg1
.
get_shape
().
elements
()
%
4
==
0
);
(
arg1
.
get_shape
().
elements
()
%
4
==
0
);
if
(
divisible_by_4
)
if
(
divisible_by_4
)
trinary_broadcast_vec_impl
(
f
,
result
,
arg1
,
arg2
,
arg3
);
trinary_broadcast_vec_impl
(
stream
,
f
,
result
,
arg1
,
arg2
,
arg3
);
else
else
trinary_broadcast_impl
(
f
,
result
,
arg1
,
arg2
,
arg3
);
trinary_broadcast_impl
(
stream
,
f
,
result
,
arg1
,
arg2
,
arg3
);
return
;
return
;
}
}
}
}
nary_impl
(
f
,
result
,
arg1
,
arg2
,
arg3
);
nary_impl
(
stream
,
f
,
result
,
arg1
,
arg2
,
arg3
);
};
};
}
}
...
...
src/targets/gpu/fuse_ops.cpp
View file @
9979fada
...
@@ -153,9 +153,9 @@ struct hip_triadd
...
@@ -153,9 +153,9 @@ struct hip_triadd
check_shapes
{
inputs
,
*
this
}.
has
(
4
);
check_shapes
{
inputs
,
*
this
}.
has
(
4
);
return
inputs
.
front
();
return
inputs
.
front
();
}
}
argument
compute
(
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
args
)
const
argument
compute
(
context
&
ctx
,
const
shape
&
,
const
std
::
vector
<
argument
>&
args
)
const
{
{
device
::
add
(
args
.
at
(
3
),
args
.
at
(
0
),
args
.
at
(
1
),
args
.
at
(
2
));
device
::
add
(
ctx
.
get_stream
().
get
(),
args
.
at
(
3
),
args
.
at
(
0
),
args
.
at
(
1
),
args
.
at
(
2
));
return
args
.
at
(
3
);
return
args
.
at
(
3
);
}
}
};
};
...
@@ -168,9 +168,9 @@ struct hip_triadd_relu
...
@@ -168,9 +168,9 @@ struct hip_triadd_relu
check_shapes
{
inputs
,
*
this
}.
has
(
4
);
check_shapes
{
inputs
,
*
this
}.
has
(
4
);
return
inputs
.
front
();
return
inputs
.
front
();
}
}
argument
compute
(
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
args
)
const
argument
compute
(
context
&
ctx
,
const
shape
&
,
const
std
::
vector
<
argument
>&
args
)
const
{
{
device
::
add_relu
(
args
.
at
(
3
),
args
.
at
(
0
),
args
.
at
(
1
),
args
.
at
(
2
));
device
::
add_relu
(
ctx
.
get_stream
().
get
(),
args
.
at
(
3
),
args
.
at
(
0
),
args
.
at
(
1
),
args
.
at
(
2
));
return
args
.
at
(
3
);
return
args
.
at
(
3
);
}
}
};
};
...
@@ -183,9 +183,9 @@ struct hip_add_relu
...
@@ -183,9 +183,9 @@ struct hip_add_relu
check_shapes
{
inputs
,
*
this
}.
has
(
3
);
check_shapes
{
inputs
,
*
this
}.
has
(
3
);
return
inputs
.
front
();
return
inputs
.
front
();
}
}
argument
compute
(
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
args
)
const
argument
compute
(
context
&
ctx
,
const
shape
&
,
const
std
::
vector
<
argument
>&
args
)
const
{
{
device
::
add_relu
(
args
.
at
(
2
),
args
.
at
(
0
),
args
.
at
(
1
));
device
::
add_relu
(
ctx
.
get_stream
().
get
(),
args
.
at
(
2
),
args
.
at
(
0
),
args
.
at
(
1
));
return
args
.
at
(
2
);
return
args
.
at
(
2
);
}
}
};
};
...
...
src/targets/gpu/include/migraph/gpu/device/add.hpp
View file @
9979fada
...
@@ -3,14 +3,15 @@
...
@@ -3,14 +3,15 @@
#define MIGRAPH_GUARD_RTGLIB_DEVICE_ADD_HPP
#define MIGRAPH_GUARD_RTGLIB_DEVICE_ADD_HPP
#include <migraph/argument.hpp>
#include <migraph/argument.hpp>
#include <hip/hip_runtime_api.h>
namespace
migraph
{
namespace
migraph
{
namespace
gpu
{
namespace
gpu
{
namespace
device
{
namespace
device
{
void
add
(
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg2
);
void
add
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg2
);
void
add
(
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg2
,
const
argument
&
arg3
);
void
add
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg2
,
const
argument
&
arg3
);
}
// namespace device
}
// namespace device
}
// namespace gpu
}
// namespace gpu
...
...
src/targets/gpu/include/migraph/gpu/device/add_relu.hpp
View file @
9979fada
...
@@ -3,14 +3,15 @@
...
@@ -3,14 +3,15 @@
#define MIGRAPH_GUARD_RTGLIB_DEVICE_ADD_RELU_HPP
#define MIGRAPH_GUARD_RTGLIB_DEVICE_ADD_RELU_HPP
#include <migraph/argument.hpp>
#include <migraph/argument.hpp>
#include <hip/hip_runtime_api.h>
namespace
migraph
{
namespace
migraph
{
namespace
gpu
{
namespace
gpu
{
namespace
device
{
namespace
device
{
void
add_relu
(
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg2
);
void
add_relu
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg2
);
void
add_relu
(
const
argument
&
result
,
void
add_relu
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg1
,
const
argument
&
arg2
,
const
argument
&
arg2
,
const
argument
&
arg3
);
const
argument
&
arg3
);
...
...
src/targets/gpu/include/migraph/gpu/device/concat.hpp
View file @
9979fada
#ifndef MIGRAPH_GUARD_RTGLIB_DEVICE_CONCAT_HPP
#ifndef MIGRAPH_GUARD_RTGLIB_DEVICE_CONCAT_HPP
#define MIGRAPH_GUARD_RTGLIB_DEVICE_CONCAT_HPP
#define MIGRAPH_GUARD_RTGLIB_DEVICE_CONCAT_HPP
#include <migraph/argument.hpp>
#include <hip/hip_runtime_api.h>
namespace
migraph
{
namespace
migraph
{
namespace
gpu
{
namespace
gpu
{
namespace
device
{
namespace
device
{
argument
argument
concat
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
,
std
::
vector
<
std
::
size_t
>
offsets
);
concat
(
hipStream_t
stream
,
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
,
std
::
vector
<
std
::
size_t
>
offsets
);
}
// namespace device
}
// namespace device
}
// namespace gpu
}
// namespace gpu
...
...
src/targets/gpu/include/migraph/gpu/device/contiguous.hpp
View file @
9979fada
...
@@ -2,12 +2,13 @@
...
@@ -2,12 +2,13 @@
#define MIGRAPH_GUARD_MIGRAPHLIB_KERNELS_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_KERNELS_HPP
#include <migraph/argument.hpp>
#include <migraph/argument.hpp>
#include <hip/hip_runtime_api.h>
namespace
migraph
{
namespace
migraph
{
namespace
gpu
{
namespace
gpu
{
namespace
device
{
namespace
device
{
void
contiguous
(
argument
result
,
argument
arg
);
void
contiguous
(
hipStream_t
stream
,
argument
result
,
argument
arg
);
}
// namespace device
}
// namespace device
}
// namespace gpu
}
// namespace gpu
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment