Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
5c4e15f2
Commit
5c4e15f2
authored
Nov 20, 2023
by
Paul
Browse files
Unify the concat versions
parent
602924d4
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
73 additions
and
146 deletions
+73
-146
src/targets/gpu/jit/concat.cpp
src/targets/gpu/jit/concat.cpp
+72
-130
src/targets/gpu/kernels/include/migraphx/kernels/concat.hpp
src/targets/gpu/kernels/include/migraphx/kernels/concat.hpp
+1
-16
No files found.
src/targets/gpu/jit/concat.cpp
View file @
5c4e15f2
...
@@ -63,85 +63,7 @@ MIGRAPHX_GLOBAL void ${kernel}(${params})
...
@@ -63,85 +63,7 @@ MIGRAPHX_GLOBAL void ${kernel}(${params})
struct
concat_compiler
:
compiler
<
concat_compiler
>
struct
concat_compiler
:
compiler
<
concat_compiler
>
{
{
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"concat"
};
}
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"fused_concat"
,
"concat"
};
}
static
std
::
size_t
get_concat_elements
(
const
std
::
vector
<
shape
>&
inputs
)
{
return
inputs
.
back
().
elements
()
/
(
inputs
.
size
()
-
1
);
}
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
{
auto
num_of_concat_inputs
=
v
.
get
(
"concat_inputs"
,
inputs
.
size
()
-
1
);
hip_compile_options
options
;
options
.
inputs
=
inputs
;
options
.
output
=
inputs
.
back
();
options
.
params
=
"-Wno-float-equal"
;
options
.
kernel_name
=
v
.
get
(
"kernel"
,
"concat_kernel"
);
auto
axis
=
find_fast_axis
(
options
.
inputs
);
vectorize
vec
{};
if
(
axis
!=
v
.
at
(
"axis"
).
to
<
std
::
size_t
>
())
vec
=
vectorize
::
elements
(
ctx
,
axis
,
options
.
inputs
);
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
get_concat_elements
(
options
.
inputs
)
/
vec
.
size
,
256
));
auto
src
=
interpolate_string
(
concat_kernel
,
{{
"kernel"
,
options
.
kernel_name
},
{
"params"
,
enum_params
(
inputs
.
size
(),
"void * private_p"
)},
{
"args"
,
enum_params
(
inputs
.
size
(),
"private_p"
)},
{
"concat_params"
,
enum_params
(
num_of_concat_inputs
,
"auto concat_x"
)},
{
"concat_args"
,
enum_params
(
num_of_concat_inputs
,
"concat_x"
)},
{
"post"
,
v
.
get
(
"post"
,
std
::
string
{
"op::id{}"
})},
{
"transformers"
,
make_transformer_args
(
vec
)},
{
"preamble"
,
v
.
get
(
"preamble"
,
std
::
string
{})},
{
"axis"
,
v
.
at
(
"axis"
).
to
<
std
::
string
>
()}});
return
compile_hip_code_object
(
src
,
options
);
}
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
const
{
auto
v
=
op
.
to_value
();
if
(
not
ins
->
module_inputs
().
empty
())
{
auto
*
pm
=
ins
->
module_inputs
().
front
();
v
[
"concat_inputs"
]
=
ins
->
inputs
().
size
()
-
pm
->
get_parameter_names
().
size
();
v
[
"preamble"
]
=
generate_pointwise
(
*
pm
,
"post_concat"
);
v
[
"post"
]
=
"MIGRAPHX_LIFT(post_concat)"
;
v
[
"kernel"
]
=
"concat_"
+
generate_name_from_ops
(
*
pm
)
+
"_kernel"
;
}
return
compile_op
(
ctx
,
to_shapes
(
ins
->
inputs
()),
v
);
}
};
// NOLINTNEXTLINE
static
const
char
*
const
fused_concat_kernel
=
R"__migraphx__(
#include <migraphx/kernels/concat.hpp>
#include <migraphx/kernels/vectorize.hpp>
#include <migraphx/kernels/ops.hpp>
#include <args.hpp>
namespace migraphx {
${preamble}
extern "C" {
MIGRAPHX_GLOBAL void ${kernel}(${params})
{
transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto y, ${concat_params}, auto... xs) {
concat2<${axis}>(${concat_args})(${post}, y, xs...);
});
}
}
} // namespace migraphx
)__migraphx__"
;
struct
fused_concat_compiler
:
compiler
<
fused_concat_compiler
>
{
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"fused_concat"
};
}
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
{
{
...
@@ -160,20 +82,21 @@ struct fused_concat_compiler : compiler<fused_concat_compiler>
...
@@ -160,20 +82,21 @@ struct fused_concat_compiler : compiler<fused_concat_compiler>
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
nelements_per_op
/
vec
.
size
,
256
));
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
nelements_per_op
/
vec
.
size
,
256
));
std
::
vector
<
std
::
string
>
concat_params
;
std
::
vector
<
std
::
string
>
concat_params
;
std
::
vector
<
std
::
string
>
concat_args
;
std
::
vector
<
std
::
string
>
concat_args
;
for
(
const
auto
&
name
:
op_names
)
for
(
auto
i
:
range
(
op_names
.
size
())
)
{
{
const
auto
&
name
=
op_names
[
i
];
auto
n
=
args
.
at
(
name
).
to
<
std
::
size_t
>
();
auto
n
=
args
.
at
(
name
).
to
<
std
::
size_t
>
();
auto
prefix
=
name
+
"_concat_x"
;
auto
prefix
=
to_c_id
(
name
+
std
::
to_string
(
i
)
+
"_concat_x"
)
;
transform
(
range
(
n
),
std
::
back_inserter
(
concat_params
),
[
&
](
auto
i
)
{
transform
(
range
(
n
),
std
::
back_inserter
(
concat_params
),
[
&
](
auto
j
)
{
return
"auto "
+
prefix
+
std
::
to_string
(
i
);
return
"auto "
+
prefix
+
std
::
to_string
(
j
);
});
});
std
::
vector
<
std
::
string
>
pack_args
=
{
"MIGRAPHX_LIFT("
+
name
+
")"
};
std
::
vector
<
std
::
string
>
pack_args
=
{
"MIGRAPHX_LIFT("
+
name
+
")"
};
transform
(
range
(
n
),
std
::
back_inserter
(
pack_args
),
[
&
](
auto
i
)
{
transform
(
range
(
n
),
std
::
back_inserter
(
pack_args
),
[
&
](
auto
j
)
{
return
prefix
+
std
::
to_string
(
i
);
return
prefix
+
std
::
to_string
(
j
);
});
});
concat_args
.
push_back
(
"pack("
+
join_strings
(
pack_args
,
", "
)
+
")"
);
concat_args
.
push_back
(
"pack("
+
join_strings
(
pack_args
,
", "
)
+
")"
);
}
}
auto
src
=
interpolate_string
(
fused_
concat_kernel
,
auto
src
=
interpolate_string
(
concat_kernel
,
{{
"kernel"
,
options
.
kernel_name
},
{{
"kernel"
,
options
.
kernel_name
},
{
"params"
,
enum_params
(
inputs
.
size
(),
"void * private_p"
)},
{
"params"
,
enum_params
(
inputs
.
size
(),
"void * private_p"
)},
{
"args"
,
enum_params
(
inputs
.
size
(),
"private_p"
)},
{
"args"
,
enum_params
(
inputs
.
size
(),
"private_p"
)},
...
@@ -189,50 +112,69 @@ struct fused_concat_compiler : compiler<fused_concat_compiler>
...
@@ -189,50 +112,69 @@ struct fused_concat_compiler : compiler<fused_concat_compiler>
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
const
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
const
{
{
auto
v
=
op
.
to_value
();
auto
v
=
op
.
to_value
();
std
::
unordered_map
<
std
::
string
,
std
::
string
>
mod_names_lookup
;
if
(
op
.
name
()
==
"fused_concat"
)
transform
(
range
(
ins
->
module_inputs
().
size
()),
{
std
::
inserter
(
mod_names_lookup
,
mod_names_lookup
.
end
()),
std
::
unordered_map
<
std
::
string
,
std
::
string
>
mod_names_lookup
;
[
&
](
auto
i
)
{
transform
(
range
(
ins
->
module_inputs
().
size
()),
return
std
::
make_pair
(
ins
->
module_inputs
()[
i
]
->
name
(),
std
::
inserter
(
mod_names_lookup
,
mod_names_lookup
.
end
()),
"pointwise"
+
std
::
to_string
(
i
));
[
&
](
auto
i
)
{
});
return
std
::
make_pair
(
ins
->
module_inputs
()[
i
]
->
name
(),
v
[
"preamble"
]
=
transform_accumulate
(
"pointwise"
+
std
::
to_string
(
i
));
ins
->
module_inputs
().
begin
(),
});
ins
->
module_inputs
().
end
(),
v
[
"preamble"
]
=
transform_accumulate
(
std
::
string
{},
ins
->
module_inputs
().
begin
(),
std
::
plus
<>
{},
ins
->
module_inputs
().
end
(),
[
&
](
module_ref
mod
)
{
std
::
string
{},
return
generate_pointwise
(
*
mod
,
mod_names_lookup
.
at
(
mod
->
name
()))
+
"
\n
"
;
std
::
plus
<>
{},
});
[
&
](
module_ref
mod
)
{
std
::
vector
<
std
::
string
>
mod_names
;
return
generate_pointwise
(
*
mod
,
mod_names_lookup
.
at
(
mod
->
name
()))
+
"
\n
"
;
std
::
transform
(
ins
->
module_inputs
().
begin
(),
});
ins
->
module_inputs
().
end
()
-
1
,
std
::
vector
<
std
::
string
>
mod_names
;
std
::
back_inserter
(
mod_names
),
std
::
transform
(
ins
->
module_inputs
().
begin
(),
[
&
](
module_ref
mod
)
{
return
mod_names_lookup
.
at
(
mod
->
name
());
});
ins
->
module_inputs
().
end
()
-
1
,
v
[
"ops"
]
=
mod_names
;
std
::
back_inserter
(
mod_names
),
module_ref
last_mod
=
ins
->
module_inputs
().
back
();
[
&
](
module_ref
mod
)
{
return
mod_names_lookup
.
at
(
mod
->
name
());
});
v
[
"post"
]
=
"MIGRAPHX_LIFT("
+
mod_names_lookup
.
at
(
last_mod
->
name
())
+
")"
;
v
[
"ops"
]
=
mod_names
;
std
::
unordered_map
<
std
::
string
,
std
::
size_t
>
mod_args
;
module_ref
last_mod
=
ins
->
module_inputs
().
back
();
std
::
transform
(
ins
->
module_inputs
().
begin
(),
v
[
"post"
]
=
"MIGRAPHX_LIFT("
+
mod_names_lookup
.
at
(
last_mod
->
name
())
+
")"
;
ins
->
module_inputs
().
end
()
-
1
,
std
::
unordered_map
<
std
::
string
,
std
::
size_t
>
mod_args
;
std
::
inserter
(
mod_args
,
mod_args
.
end
()),
std
::
transform
(
ins
->
module_inputs
().
begin
(),
[
&
](
module_ref
mod
)
{
ins
->
module_inputs
().
end
()
-
1
,
const
auto
&
name
=
mod_names_lookup
.
at
(
mod
->
name
());
std
::
inserter
(
mod_args
,
mod_args
.
end
()),
return
std
::
make_pair
(
name
,
mod
->
get_parameter_names
().
size
());
[
&
](
module_ref
mod
)
{
});
const
auto
&
name
=
mod_names_lookup
.
at
(
mod
->
name
());
v
[
"args"
]
=
mod_args
;
return
std
::
make_pair
(
name
,
mod
->
get_parameter_names
().
size
());
auto
prefix_name
=
transform_accumulate
(
ins
->
module_inputs
().
begin
(),
});
ins
->
module_inputs
().
end
()
-
1
,
v
[
"args"
]
=
mod_args
;
std
::
string
{},
auto
prefix_name
=
transform_accumulate
(
ins
->
module_inputs
().
begin
(),
std
::
plus
<>
{},
ins
->
module_inputs
().
end
()
-
1
,
[
&
](
module_ref
mod
)
->
std
::
string
{
std
::
string
{},
auto
name
=
generate_name_from_ops
(
*
mod
);
std
::
plus
<>
{},
if
(
name
.
empty
())
[
&
](
module_ref
mod
)
->
std
::
string
{
return
""
;
auto
name
=
generate_name_from_ops
(
*
mod
);
return
name
+
"_"
;
if
(
name
.
empty
())
});
return
""
;
v
[
"kernel"
]
=
prefix_name
+
"concat_"
+
return
name
+
"_"
;
generate_name_from_ops
(
*
(
ins
->
module_inputs
().
back
()))
+
"_kernel"
;
});
v
[
"kernel"
]
=
prefix_name
+
"concat_"
+
generate_name_from_ops
(
*
(
ins
->
module_inputs
().
back
()))
+
"_kernel"
;
}
else
if
(
op
.
name
()
==
"concat"
)
{
auto
concat_inputs
=
ins
->
inputs
().
size
()
-
1
;
if
(
not
ins
->
module_inputs
().
empty
())
{
auto
*
pm
=
ins
->
module_inputs
().
front
();
concat_inputs
=
ins
->
inputs
().
size
()
-
pm
->
get_parameter_names
().
size
();
v
[
"preamble"
]
=
generate_pointwise
(
*
pm
,
"post_concat"
);
v
[
"post"
]
=
"MIGRAPHX_LIFT(post_concat)"
;
v
[
"kernel"
]
=
"concat_"
+
generate_name_from_ops
(
*
pm
)
+
"_kernel"
;
}
std
::
vector
<
std
::
string
>
mod_names
(
concat_inputs
,
"op::id{}"
);
v
[
"ops"
]
=
mod_names
;
std
::
unordered_map
<
std
::
string
,
std
::
size_t
>
mod_args
=
{{
"op::id{}"
,
1
}};
v
[
"args"
]
=
mod_args
;
}
return
compile_op
(
ctx
,
to_shapes
(
ins
->
inputs
()),
v
);
return
compile_op
(
ctx
,
to_shapes
(
ins
->
inputs
()),
v
);
}
}
};
};
...
...
src/targets/gpu/kernels/include/migraphx/kernels/concat.hpp
View file @
5c4e15f2
...
@@ -59,23 +59,8 @@ constexpr auto concat_ends(Input)
...
@@ -59,23 +59,8 @@ constexpr auto concat_ends(Input)
return
_c
<
lens
[
Axis
]
>
;
return
_c
<
lens
[
Axis
]
>
;
}
}
template
<
index_int
Axis
,
class
...
Inputs
>
__device__
auto
concat
(
Inputs
...
inputs
)
{
return
[
=
](
auto
f
,
auto
...
ts
)
{
auto
idx
=
make_index
();
fold
([
&
](
auto
start
,
auto
input
)
{
concat_slices
<
Axis
>
(
input
,
start
,
ts
...)([
&
](
auto
y
,
auto
...
xs
)
{
idx
.
global_stride
(
input
.
get_shape
().
elements
(),
[
&
](
auto
i
)
{
y
[
i
]
=
f
(
input
[
i
],
xs
[
i
]...);
});
});
return
start
+
concat_ends
<
Axis
>
(
input
);
})(
_c
<
0
>
,
inputs
...);
};
}
template
<
index_int
Axis
,
class
...
InputPacks
>
template
<
index_int
Axis
,
class
...
InputPacks
>
__device__
auto
concat
2
(
InputPacks
...
input_packs
)
__device__
auto
concat
(
InputPacks
...
input_packs
)
{
{
return
[
=
](
auto
f
,
auto
...
ts
)
{
return
[
=
](
auto
f
,
auto
...
ts
)
{
auto
idx
=
make_index
();
auto
idx
=
make_index
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment