Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
5c4e15f2
Commit
5c4e15f2
authored
Nov 20, 2023
by
Paul
Browse files
Unify the concat versions
parent
602924d4
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
73 additions
and
146 deletions
+73
-146
src/targets/gpu/jit/concat.cpp
src/targets/gpu/jit/concat.cpp
+72
-130
src/targets/gpu/kernels/include/migraphx/kernels/concat.hpp
src/targets/gpu/kernels/include/migraphx/kernels/concat.hpp
+1
-16
No files found.
src/targets/gpu/jit/concat.cpp
View file @
5c4e15f2
...
...
@@ -63,85 +63,7 @@ MIGRAPHX_GLOBAL void ${kernel}(${params})
struct
concat_compiler
:
compiler
<
concat_compiler
>
{
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"concat"
};
}
static
std
::
size_t
get_concat_elements
(
const
std
::
vector
<
shape
>&
inputs
)
{
return
inputs
.
back
().
elements
()
/
(
inputs
.
size
()
-
1
);
}
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
{
auto
num_of_concat_inputs
=
v
.
get
(
"concat_inputs"
,
inputs
.
size
()
-
1
);
hip_compile_options
options
;
options
.
inputs
=
inputs
;
options
.
output
=
inputs
.
back
();
options
.
params
=
"-Wno-float-equal"
;
options
.
kernel_name
=
v
.
get
(
"kernel"
,
"concat_kernel"
);
auto
axis
=
find_fast_axis
(
options
.
inputs
);
vectorize
vec
{};
if
(
axis
!=
v
.
at
(
"axis"
).
to
<
std
::
size_t
>
())
vec
=
vectorize
::
elements
(
ctx
,
axis
,
options
.
inputs
);
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
get_concat_elements
(
options
.
inputs
)
/
vec
.
size
,
256
));
auto
src
=
interpolate_string
(
concat_kernel
,
{{
"kernel"
,
options
.
kernel_name
},
{
"params"
,
enum_params
(
inputs
.
size
(),
"void * private_p"
)},
{
"args"
,
enum_params
(
inputs
.
size
(),
"private_p"
)},
{
"concat_params"
,
enum_params
(
num_of_concat_inputs
,
"auto concat_x"
)},
{
"concat_args"
,
enum_params
(
num_of_concat_inputs
,
"concat_x"
)},
{
"post"
,
v
.
get
(
"post"
,
std
::
string
{
"op::id{}"
})},
{
"transformers"
,
make_transformer_args
(
vec
)},
{
"preamble"
,
v
.
get
(
"preamble"
,
std
::
string
{})},
{
"axis"
,
v
.
at
(
"axis"
).
to
<
std
::
string
>
()}});
return
compile_hip_code_object
(
src
,
options
);
}
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
const
{
auto
v
=
op
.
to_value
();
if
(
not
ins
->
module_inputs
().
empty
())
{
auto
*
pm
=
ins
->
module_inputs
().
front
();
v
[
"concat_inputs"
]
=
ins
->
inputs
().
size
()
-
pm
->
get_parameter_names
().
size
();
v
[
"preamble"
]
=
generate_pointwise
(
*
pm
,
"post_concat"
);
v
[
"post"
]
=
"MIGRAPHX_LIFT(post_concat)"
;
v
[
"kernel"
]
=
"concat_"
+
generate_name_from_ops
(
*
pm
)
+
"_kernel"
;
}
return
compile_op
(
ctx
,
to_shapes
(
ins
->
inputs
()),
v
);
}
};
// NOLINTNEXTLINE
static
const
char
*
const
fused_concat_kernel
=
R"__migraphx__(
#include <migraphx/kernels/concat.hpp>
#include <migraphx/kernels/vectorize.hpp>
#include <migraphx/kernels/ops.hpp>
#include <args.hpp>
namespace migraphx {
${preamble}
extern "C" {
MIGRAPHX_GLOBAL void ${kernel}(${params})
{
transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto y, ${concat_params}, auto... xs) {
concat2<${axis}>(${concat_args})(${post}, y, xs...);
});
}
}
} // namespace migraphx
)__migraphx__"
;
struct
fused_concat_compiler
:
compiler
<
fused_concat_compiler
>
{
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"fused_concat"
};
}
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"fused_concat"
,
"concat"
};
}
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
{
...
...
@@ -160,20 +82,21 @@ struct fused_concat_compiler : compiler<fused_concat_compiler>
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
nelements_per_op
/
vec
.
size
,
256
));
std
::
vector
<
std
::
string
>
concat_params
;
std
::
vector
<
std
::
string
>
concat_args
;
for
(
const
auto
&
name
:
op_names
)
for
(
auto
i
:
range
(
op_names
.
size
())
)
{
const
auto
&
name
=
op_names
[
i
];
auto
n
=
args
.
at
(
name
).
to
<
std
::
size_t
>
();
auto
prefix
=
name
+
"_concat_x"
;
transform
(
range
(
n
),
std
::
back_inserter
(
concat_params
),
[
&
](
auto
i
)
{
return
"auto "
+
prefix
+
std
::
to_string
(
i
);
auto
prefix
=
to_c_id
(
name
+
std
::
to_string
(
i
)
+
"_concat_x"
)
;
transform
(
range
(
n
),
std
::
back_inserter
(
concat_params
),
[
&
](
auto
j
)
{
return
"auto "
+
prefix
+
std
::
to_string
(
j
);
});
std
::
vector
<
std
::
string
>
pack_args
=
{
"MIGRAPHX_LIFT("
+
name
+
")"
};
transform
(
range
(
n
),
std
::
back_inserter
(
pack_args
),
[
&
](
auto
i
)
{
return
prefix
+
std
::
to_string
(
i
);
transform
(
range
(
n
),
std
::
back_inserter
(
pack_args
),
[
&
](
auto
j
)
{
return
prefix
+
std
::
to_string
(
j
);
});
concat_args
.
push_back
(
"pack("
+
join_strings
(
pack_args
,
", "
)
+
")"
);
}
auto
src
=
interpolate_string
(
fused_
concat_kernel
,
auto
src
=
interpolate_string
(
concat_kernel
,
{{
"kernel"
,
options
.
kernel_name
},
{
"params"
,
enum_params
(
inputs
.
size
(),
"void * private_p"
)},
{
"args"
,
enum_params
(
inputs
.
size
(),
"private_p"
)},
...
...
@@ -189,50 +112,69 @@ struct fused_concat_compiler : compiler<fused_concat_compiler>
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
const
{
auto
v
=
op
.
to_value
();
std
::
unordered_map
<
std
::
string
,
std
::
string
>
mod_names_lookup
;
transform
(
range
(
ins
->
module_inputs
().
size
()),
std
::
inserter
(
mod_names_lookup
,
mod_names_lookup
.
end
()),
[
&
](
auto
i
)
{
return
std
::
make_pair
(
ins
->
module_inputs
()[
i
]
->
name
(),
"pointwise"
+
std
::
to_string
(
i
));
});
v
[
"preamble"
]
=
transform_accumulate
(
ins
->
module_inputs
().
begin
(),
ins
->
module_inputs
().
end
(),
std
::
string
{},
std
::
plus
<>
{},
[
&
](
module_ref
mod
)
{
return
generate_pointwise
(
*
mod
,
mod_names_lookup
.
at
(
mod
->
name
()))
+
"
\n
"
;
});
std
::
vector
<
std
::
string
>
mod_names
;
std
::
transform
(
ins
->
module_inputs
().
begin
(),
ins
->
module_inputs
().
end
()
-
1
,
std
::
back_inserter
(
mod_names
),
[
&
](
module_ref
mod
)
{
return
mod_names_lookup
.
at
(
mod
->
name
());
});
v
[
"ops"
]
=
mod_names
;
module_ref
last_mod
=
ins
->
module_inputs
().
back
();
v
[
"post"
]
=
"MIGRAPHX_LIFT("
+
mod_names_lookup
.
at
(
last_mod
->
name
())
+
")"
;
std
::
unordered_map
<
std
::
string
,
std
::
size_t
>
mod_args
;
std
::
transform
(
ins
->
module_inputs
().
begin
(),
ins
->
module_inputs
().
end
()
-
1
,
std
::
inserter
(
mod_args
,
mod_args
.
end
()),
[
&
](
module_ref
mod
)
{
const
auto
&
name
=
mod_names_lookup
.
at
(
mod
->
name
());
return
std
::
make_pair
(
name
,
mod
->
get_parameter_names
().
size
());
});
v
[
"args"
]
=
mod_args
;
auto
prefix_name
=
transform_accumulate
(
ins
->
module_inputs
().
begin
(),
ins
->
module_inputs
().
end
()
-
1
,
std
::
string
{},
std
::
plus
<>
{},
[
&
](
module_ref
mod
)
->
std
::
string
{
auto
name
=
generate_name_from_ops
(
*
mod
);
if
(
name
.
empty
())
return
""
;
return
name
+
"_"
;
});
v
[
"kernel"
]
=
prefix_name
+
"concat_"
+
generate_name_from_ops
(
*
(
ins
->
module_inputs
().
back
()))
+
"_kernel"
;
if
(
op
.
name
()
==
"fused_concat"
)
{
std
::
unordered_map
<
std
::
string
,
std
::
string
>
mod_names_lookup
;
transform
(
range
(
ins
->
module_inputs
().
size
()),
std
::
inserter
(
mod_names_lookup
,
mod_names_lookup
.
end
()),
[
&
](
auto
i
)
{
return
std
::
make_pair
(
ins
->
module_inputs
()[
i
]
->
name
(),
"pointwise"
+
std
::
to_string
(
i
));
});
v
[
"preamble"
]
=
transform_accumulate
(
ins
->
module_inputs
().
begin
(),
ins
->
module_inputs
().
end
(),
std
::
string
{},
std
::
plus
<>
{},
[
&
](
module_ref
mod
)
{
return
generate_pointwise
(
*
mod
,
mod_names_lookup
.
at
(
mod
->
name
()))
+
"
\n
"
;
});
std
::
vector
<
std
::
string
>
mod_names
;
std
::
transform
(
ins
->
module_inputs
().
begin
(),
ins
->
module_inputs
().
end
()
-
1
,
std
::
back_inserter
(
mod_names
),
[
&
](
module_ref
mod
)
{
return
mod_names_lookup
.
at
(
mod
->
name
());
});
v
[
"ops"
]
=
mod_names
;
module_ref
last_mod
=
ins
->
module_inputs
().
back
();
v
[
"post"
]
=
"MIGRAPHX_LIFT("
+
mod_names_lookup
.
at
(
last_mod
->
name
())
+
")"
;
std
::
unordered_map
<
std
::
string
,
std
::
size_t
>
mod_args
;
std
::
transform
(
ins
->
module_inputs
().
begin
(),
ins
->
module_inputs
().
end
()
-
1
,
std
::
inserter
(
mod_args
,
mod_args
.
end
()),
[
&
](
module_ref
mod
)
{
const
auto
&
name
=
mod_names_lookup
.
at
(
mod
->
name
());
return
std
::
make_pair
(
name
,
mod
->
get_parameter_names
().
size
());
});
v
[
"args"
]
=
mod_args
;
auto
prefix_name
=
transform_accumulate
(
ins
->
module_inputs
().
begin
(),
ins
->
module_inputs
().
end
()
-
1
,
std
::
string
{},
std
::
plus
<>
{},
[
&
](
module_ref
mod
)
->
std
::
string
{
auto
name
=
generate_name_from_ops
(
*
mod
);
if
(
name
.
empty
())
return
""
;
return
name
+
"_"
;
});
v
[
"kernel"
]
=
prefix_name
+
"concat_"
+
generate_name_from_ops
(
*
(
ins
->
module_inputs
().
back
()))
+
"_kernel"
;
}
else
if
(
op
.
name
()
==
"concat"
)
{
auto
concat_inputs
=
ins
->
inputs
().
size
()
-
1
;
if
(
not
ins
->
module_inputs
().
empty
())
{
auto
*
pm
=
ins
->
module_inputs
().
front
();
concat_inputs
=
ins
->
inputs
().
size
()
-
pm
->
get_parameter_names
().
size
();
v
[
"preamble"
]
=
generate_pointwise
(
*
pm
,
"post_concat"
);
v
[
"post"
]
=
"MIGRAPHX_LIFT(post_concat)"
;
v
[
"kernel"
]
=
"concat_"
+
generate_name_from_ops
(
*
pm
)
+
"_kernel"
;
}
std
::
vector
<
std
::
string
>
mod_names
(
concat_inputs
,
"op::id{}"
);
v
[
"ops"
]
=
mod_names
;
std
::
unordered_map
<
std
::
string
,
std
::
size_t
>
mod_args
=
{{
"op::id{}"
,
1
}};
v
[
"args"
]
=
mod_args
;
}
return
compile_op
(
ctx
,
to_shapes
(
ins
->
inputs
()),
v
);
}
};
...
...
src/targets/gpu/kernels/include/migraphx/kernels/concat.hpp
View file @
5c4e15f2
...
...
@@ -59,23 +59,8 @@ constexpr auto concat_ends(Input)
return
_c
<
lens
[
Axis
]
>
;
}
template
<
index_int
Axis
,
class
...
Inputs
>
__device__
auto
concat
(
Inputs
...
inputs
)
{
return
[
=
](
auto
f
,
auto
...
ts
)
{
auto
idx
=
make_index
();
fold
([
&
](
auto
start
,
auto
input
)
{
concat_slices
<
Axis
>
(
input
,
start
,
ts
...)([
&
](
auto
y
,
auto
...
xs
)
{
idx
.
global_stride
(
input
.
get_shape
().
elements
(),
[
&
](
auto
i
)
{
y
[
i
]
=
f
(
input
[
i
],
xs
[
i
]...);
});
});
return
start
+
concat_ends
<
Axis
>
(
input
);
})(
_c
<
0
>
,
inputs
...);
};
}
template
<
index_int
Axis
,
class
...
InputPacks
>
__device__
auto
concat
2
(
InputPacks
...
input_packs
)
__device__
auto
concat
(
InputPacks
...
input_packs
)
{
return
[
=
](
auto
f
,
auto
...
ts
)
{
auto
idx
=
make_index
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment