Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
5c4e15f2
Commit
5c4e15f2
authored
Nov 20, 2023
by
Paul
Browse files
Unify the concat versions
parent
602924d4
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
73 additions
and
146 deletions
+73
-146
src/targets/gpu/jit/concat.cpp
src/targets/gpu/jit/concat.cpp
+72
-130
src/targets/gpu/kernels/include/migraphx/kernels/concat.hpp
src/targets/gpu/kernels/include/migraphx/kernels/concat.hpp
+1
-16
No files found.
src/targets/gpu/jit/concat.cpp
View file @
5c4e15f2
...
@@ -63,85 +63,7 @@ MIGRAPHX_GLOBAL void ${kernel}(${params})
...
@@ -63,85 +63,7 @@ MIGRAPHX_GLOBAL void ${kernel}(${params})
struct
concat_compiler
:
compiler
<
concat_compiler
>
struct
concat_compiler
:
compiler
<
concat_compiler
>
{
{
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"concat"
};
}
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"fused_concat"
,
"concat"
};
}
static
std
::
size_t
get_concat_elements
(
const
std
::
vector
<
shape
>&
inputs
)
{
return
inputs
.
back
().
elements
()
/
(
inputs
.
size
()
-
1
);
}
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
{
auto
num_of_concat_inputs
=
v
.
get
(
"concat_inputs"
,
inputs
.
size
()
-
1
);
hip_compile_options
options
;
options
.
inputs
=
inputs
;
options
.
output
=
inputs
.
back
();
options
.
params
=
"-Wno-float-equal"
;
options
.
kernel_name
=
v
.
get
(
"kernel"
,
"concat_kernel"
);
auto
axis
=
find_fast_axis
(
options
.
inputs
);
vectorize
vec
{};
if
(
axis
!=
v
.
at
(
"axis"
).
to
<
std
::
size_t
>
())
vec
=
vectorize
::
elements
(
ctx
,
axis
,
options
.
inputs
);
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
get_concat_elements
(
options
.
inputs
)
/
vec
.
size
,
256
));
auto
src
=
interpolate_string
(
concat_kernel
,
{{
"kernel"
,
options
.
kernel_name
},
{
"params"
,
enum_params
(
inputs
.
size
(),
"void * private_p"
)},
{
"args"
,
enum_params
(
inputs
.
size
(),
"private_p"
)},
{
"concat_params"
,
enum_params
(
num_of_concat_inputs
,
"auto concat_x"
)},
{
"concat_args"
,
enum_params
(
num_of_concat_inputs
,
"concat_x"
)},
{
"post"
,
v
.
get
(
"post"
,
std
::
string
{
"op::id{}"
})},
{
"transformers"
,
make_transformer_args
(
vec
)},
{
"preamble"
,
v
.
get
(
"preamble"
,
std
::
string
{})},
{
"axis"
,
v
.
at
(
"axis"
).
to
<
std
::
string
>
()}});
return
compile_hip_code_object
(
src
,
options
);
}
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
const
{
auto
v
=
op
.
to_value
();
if
(
not
ins
->
module_inputs
().
empty
())
{
auto
*
pm
=
ins
->
module_inputs
().
front
();
v
[
"concat_inputs"
]
=
ins
->
inputs
().
size
()
-
pm
->
get_parameter_names
().
size
();
v
[
"preamble"
]
=
generate_pointwise
(
*
pm
,
"post_concat"
);
v
[
"post"
]
=
"MIGRAPHX_LIFT(post_concat)"
;
v
[
"kernel"
]
=
"concat_"
+
generate_name_from_ops
(
*
pm
)
+
"_kernel"
;
}
return
compile_op
(
ctx
,
to_shapes
(
ins
->
inputs
()),
v
);
}
};
// NOLINTNEXTLINE
static
const
char
*
const
fused_concat_kernel
=
R"__migraphx__(
#include <migraphx/kernels/concat.hpp>
#include <migraphx/kernels/vectorize.hpp>
#include <migraphx/kernels/ops.hpp>
#include <args.hpp>
namespace migraphx {
${preamble}
extern "C" {
MIGRAPHX_GLOBAL void ${kernel}(${params})
{
transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto y, ${concat_params}, auto... xs) {
concat2<${axis}>(${concat_args})(${post}, y, xs...);
});
}
}
} // namespace migraphx
)__migraphx__"
;
struct
fused_concat_compiler
:
compiler
<
fused_concat_compiler
>
{
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"fused_concat"
};
}
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
{
{
...
@@ -160,20 +82,21 @@ struct fused_concat_compiler : compiler<fused_concat_compiler>
...
@@ -160,20 +82,21 @@ struct fused_concat_compiler : compiler<fused_concat_compiler>
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
nelements_per_op
/
vec
.
size
,
256
));
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
nelements_per_op
/
vec
.
size
,
256
));
std
::
vector
<
std
::
string
>
concat_params
;
std
::
vector
<
std
::
string
>
concat_params
;
std
::
vector
<
std
::
string
>
concat_args
;
std
::
vector
<
std
::
string
>
concat_args
;
for
(
const
auto
&
name
:
op_names
)
for
(
auto
i
:
range
(
op_names
.
size
())
)
{
{
const
auto
&
name
=
op_names
[
i
];
auto
n
=
args
.
at
(
name
).
to
<
std
::
size_t
>
();
auto
n
=
args
.
at
(
name
).
to
<
std
::
size_t
>
();
auto
prefix
=
name
+
"_concat_x"
;
auto
prefix
=
to_c_id
(
name
+
std
::
to_string
(
i
)
+
"_concat_x"
)
;
transform
(
range
(
n
),
std
::
back_inserter
(
concat_params
),
[
&
](
auto
i
)
{
transform
(
range
(
n
),
std
::
back_inserter
(
concat_params
),
[
&
](
auto
j
)
{
return
"auto "
+
prefix
+
std
::
to_string
(
i
);
return
"auto "
+
prefix
+
std
::
to_string
(
j
);
});
});
std
::
vector
<
std
::
string
>
pack_args
=
{
"MIGRAPHX_LIFT("
+
name
+
")"
};
std
::
vector
<
std
::
string
>
pack_args
=
{
"MIGRAPHX_LIFT("
+
name
+
")"
};
transform
(
range
(
n
),
std
::
back_inserter
(
pack_args
),
[
&
](
auto
i
)
{
transform
(
range
(
n
),
std
::
back_inserter
(
pack_args
),
[
&
](
auto
j
)
{
return
prefix
+
std
::
to_string
(
i
);
return
prefix
+
std
::
to_string
(
j
);
});
});
concat_args
.
push_back
(
"pack("
+
join_strings
(
pack_args
,
", "
)
+
")"
);
concat_args
.
push_back
(
"pack("
+
join_strings
(
pack_args
,
", "
)
+
")"
);
}
}
auto
src
=
interpolate_string
(
fused_
concat_kernel
,
auto
src
=
interpolate_string
(
concat_kernel
,
{{
"kernel"
,
options
.
kernel_name
},
{{
"kernel"
,
options
.
kernel_name
},
{
"params"
,
enum_params
(
inputs
.
size
(),
"void * private_p"
)},
{
"params"
,
enum_params
(
inputs
.
size
(),
"void * private_p"
)},
{
"args"
,
enum_params
(
inputs
.
size
(),
"private_p"
)},
{
"args"
,
enum_params
(
inputs
.
size
(),
"private_p"
)},
...
@@ -189,6 +112,8 @@ struct fused_concat_compiler : compiler<fused_concat_compiler>
...
@@ -189,6 +112,8 @@ struct fused_concat_compiler : compiler<fused_concat_compiler>
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
const
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
const
{
{
auto
v
=
op
.
to_value
();
auto
v
=
op
.
to_value
();
if
(
op
.
name
()
==
"fused_concat"
)
{
std
::
unordered_map
<
std
::
string
,
std
::
string
>
mod_names_lookup
;
std
::
unordered_map
<
std
::
string
,
std
::
string
>
mod_names_lookup
;
transform
(
range
(
ins
->
module_inputs
().
size
()),
transform
(
range
(
ins
->
module_inputs
().
size
()),
std
::
inserter
(
mod_names_lookup
,
mod_names_lookup
.
end
()),
std
::
inserter
(
mod_names_lookup
,
mod_names_lookup
.
end
()),
...
@@ -233,6 +158,23 @@ struct fused_concat_compiler : compiler<fused_concat_compiler>
...
@@ -233,6 +158,23 @@ struct fused_concat_compiler : compiler<fused_concat_compiler>
});
});
v
[
"kernel"
]
=
prefix_name
+
"concat_"
+
v
[
"kernel"
]
=
prefix_name
+
"concat_"
+
generate_name_from_ops
(
*
(
ins
->
module_inputs
().
back
()))
+
"_kernel"
;
generate_name_from_ops
(
*
(
ins
->
module_inputs
().
back
()))
+
"_kernel"
;
}
else
if
(
op
.
name
()
==
"concat"
)
{
auto
concat_inputs
=
ins
->
inputs
().
size
()
-
1
;
if
(
not
ins
->
module_inputs
().
empty
())
{
auto
*
pm
=
ins
->
module_inputs
().
front
();
concat_inputs
=
ins
->
inputs
().
size
()
-
pm
->
get_parameter_names
().
size
();
v
[
"preamble"
]
=
generate_pointwise
(
*
pm
,
"post_concat"
);
v
[
"post"
]
=
"MIGRAPHX_LIFT(post_concat)"
;
v
[
"kernel"
]
=
"concat_"
+
generate_name_from_ops
(
*
pm
)
+
"_kernel"
;
}
std
::
vector
<
std
::
string
>
mod_names
(
concat_inputs
,
"op::id{}"
);
v
[
"ops"
]
=
mod_names
;
std
::
unordered_map
<
std
::
string
,
std
::
size_t
>
mod_args
=
{{
"op::id{}"
,
1
}};
v
[
"args"
]
=
mod_args
;
}
return
compile_op
(
ctx
,
to_shapes
(
ins
->
inputs
()),
v
);
return
compile_op
(
ctx
,
to_shapes
(
ins
->
inputs
()),
v
);
}
}
};
};
...
...
src/targets/gpu/kernels/include/migraphx/kernels/concat.hpp
View file @
5c4e15f2
...
@@ -59,23 +59,8 @@ constexpr auto concat_ends(Input)
...
@@ -59,23 +59,8 @@ constexpr auto concat_ends(Input)
return
_c
<
lens
[
Axis
]
>
;
return
_c
<
lens
[
Axis
]
>
;
}
}
template
<
index_int
Axis
,
class
...
Inputs
>
__device__
auto
concat
(
Inputs
...
inputs
)
{
return
[
=
](
auto
f
,
auto
...
ts
)
{
auto
idx
=
make_index
();
fold
([
&
](
auto
start
,
auto
input
)
{
concat_slices
<
Axis
>
(
input
,
start
,
ts
...)([
&
](
auto
y
,
auto
...
xs
)
{
idx
.
global_stride
(
input
.
get_shape
().
elements
(),
[
&
](
auto
i
)
{
y
[
i
]
=
f
(
input
[
i
],
xs
[
i
]...);
});
});
return
start
+
concat_ends
<
Axis
>
(
input
);
})(
_c
<
0
>
,
inputs
...);
};
}
template
<
index_int
Axis
,
class
...
InputPacks
>
template
<
index_int
Axis
,
class
...
InputPacks
>
__device__
auto
concat
2
(
InputPacks
...
input_packs
)
__device__
auto
concat
(
InputPacks
...
input_packs
)
{
{
return
[
=
](
auto
f
,
auto
...
ts
)
{
return
[
=
](
auto
f
,
auto
...
ts
)
{
auto
idx
=
make_index
();
auto
idx
=
make_index
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment