Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
a0921a37
Commit
a0921a37
authored
Aug 25, 2022
by
Paul
Browse files
Format
parent
be27f5cb
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
11 additions
and
16 deletions
+11
-16
src/targets/gpu/jit/concat.cpp
src/targets/gpu/jit/concat.cpp
+4
-7
src/targets/gpu/kernels/include/migraphx/kernels/concat.hpp
src/targets/gpu/kernels/include/migraphx/kernels/concat.hpp
+7
-9
No files found.
src/targets/gpu/jit/concat.cpp
View file @
a0921a37
...
@@ -63,9 +63,8 @@ struct concat_compiler : compiler<concat_compiler>
...
@@ -63,9 +63,8 @@ struct concat_compiler : compiler<concat_compiler>
static
std
::
size_t
get_min_elements
(
const
std
::
vector
<
shape
>&
inputs
)
static
std
::
size_t
get_min_elements
(
const
std
::
vector
<
shape
>&
inputs
)
{
{
auto
it
=
std
::
min_element
(
inputs
.
begin
(),
inputs
.
end
(),
by
(
std
::
less
<>
{},
[](
auto
s
)
{
auto
it
=
std
::
min_element
(
return
s
.
elements
();
inputs
.
begin
(),
inputs
.
end
(),
by
(
std
::
less
<>
{},
[](
auto
s
)
{
return
s
.
elements
();
}));
}));
return
it
->
elements
();
return
it
->
elements
();
}
}
...
@@ -80,9 +79,7 @@ struct concat_compiler : compiler<concat_compiler>
...
@@ -80,9 +79,7 @@ struct concat_compiler : compiler<concat_compiler>
auto
vec
=
vectorize
::
elements
(
axis
,
options
.
virtual_inputs
);
auto
vec
=
vectorize
::
elements
(
axis
,
options
.
virtual_inputs
);
options
.
kernel_name
=
v
.
get
(
"kernel"
,
"concat_kernel"
);
options
.
kernel_name
=
v
.
get
(
"kernel"
,
"concat_kernel"
);
options
.
set_launch_params
(
options
.
set_launch_params
(
v
,
v
,
compute_global_for
(
ctx
,
get_min_elements
(
options
.
inputs
)
/
vec
.
size
,
256
));
compute_global_for
(
ctx
,
get_min_elements
(
options
.
inputs
)
/
vec
.
size
,
256
));
auto
src
=
interpolate_string
(
concat_kernel
,
auto
src
=
interpolate_string
(
concat_kernel
,
{{
"kernel"
,
options
.
kernel_name
},
{{
"kernel"
,
options
.
kernel_name
},
{
"params"
,
enum_params
(
inputs
.
size
(),
"void * private_p"
)},
{
"params"
,
enum_params
(
inputs
.
size
(),
"void * private_p"
)},
...
...
src/targets/gpu/kernels/include/migraphx/kernels/concat.hpp
View file @
a0921a37
...
@@ -31,7 +31,7 @@
...
@@ -31,7 +31,7 @@
namespace
migraphx
{
namespace
migraphx
{
template
<
index_int
Axis
,
class
Output
,
class
Input
,
class
Start
>
template
<
index_int
Axis
,
class
Output
,
class
Input
,
class
Start
>
constexpr
auto
concat_slice
(
Output
out
,
Input
,
Start
)
constexpr
auto
concat_slice
(
Output
out
,
Input
,
Start
)
{
{
constexpr
auto
lens
=
get_shape_c
<
Input
>
{}.
lens
;
constexpr
auto
lens
=
get_shape_c
<
Input
>
{}.
lens
;
...
@@ -44,22 +44,20 @@ constexpr auto concat_slice(Output out, Input, Start)
...
@@ -44,22 +44,20 @@ constexpr auto concat_slice(Output out, Input, Start)
return
make_tensor_view
(
&
out
[
offset
],
s
);
return
make_tensor_view
(
&
out
[
offset
],
s
);
}
}
template
<
index_int
Axis
,
class
Input
>
template
<
index_int
Axis
,
class
Input
>
constexpr
auto
concat_ends
(
Input
)
constexpr
auto
concat_ends
(
Input
)
{
{
constexpr
auto
lens
=
get_shape_c
<
Input
>
{}.
lens
;
constexpr
auto
lens
=
get_shape_c
<
Input
>
{}.
lens
;
return
_c
<
lens
[
Axis
]
>
;
return
_c
<
lens
[
Axis
]
>
;
}
}
template
<
index_int
Axis
,
class
Output
,
class
...
Inputs
>
template
<
index_int
Axis
,
class
Output
,
class
...
Inputs
>
__device__
void
concat
(
Output
output
,
Inputs
...
inputs
)
__device__
void
concat
(
Output
output
,
Inputs
...
inputs
)
{
{
auto
idx
=
make_index
();
auto
idx
=
make_index
();
fold
([
&
](
auto
start
,
auto
input
)
{
fold
([
&
](
auto
start
,
auto
input
)
{
auto
y
=
concat_slice
<
Axis
>
(
output
,
input
,
start
);
auto
y
=
concat_slice
<
Axis
>
(
output
,
input
,
start
);
idx
.
global_stride
(
input
.
get_shape
().
elements
(),
[
&
](
auto
i
)
{
idx
.
global_stride
(
input
.
get_shape
().
elements
(),
[
&
](
auto
i
)
{
y
[
i
]
=
input
[
i
];
});
y
[
i
]
=
input
[
i
];
});
return
start
+
concat_ends
<
Axis
>
(
input
);
return
start
+
concat_ends
<
Axis
>
(
input
);
})(
_c
<
0
>
,
inputs
...);
})(
_c
<
0
>
,
inputs
...);
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment