Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
2f48b11a
"include/vscode:/vscode.git/clone" did not exist on "b0a4674ce7d9299ad747b741fec1815445960ad3"
Unverified
Commit
2f48b11a
authored
Nov 02, 2022
by
Paul Fultz II
Committed by
GitHub
Nov 02, 2022
Browse files
Concat pointwise fusions (#1388)
parent
04e33ec5
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
192 additions
and
27 deletions
+192
-27
src/targets/gpu/compile_ops.cpp
src/targets/gpu/compile_ops.cpp
+10
-3
src/targets/gpu/fuse_ops.cpp
src/targets/gpu/fuse_ops.cpp
+33
-3
src/targets/gpu/jit/concat.cpp
src/targets/gpu/jit/concat.cpp
+28
-11
src/targets/gpu/kernels/include/migraphx/kernels/concat.hpp
src/targets/gpu/kernels/include/migraphx/kernels/concat.hpp
+21
-9
src/targets/gpu/mlir.cpp
src/targets/gpu/mlir.cpp
+4
-1
test/verify/test_concat_broadcast_add.cpp
test/verify/test_concat_broadcast_add.cpp
+49
-0
test/verify/test_slice_concat_add.cpp
test/verify/test_slice_concat_add.cpp
+47
-0
No files found.
src/targets/gpu/compile_ops.cpp
View file @
2f48b11a
...
@@ -39,19 +39,26 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_COMPILE_PARALLEL);
...
@@ -39,19 +39,26 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_COMPILE_PARALLEL);
struct
precompile_op
struct
precompile_op
{
{
operation
op
=
op
::
identity
{};
operation
op
=
op
::
identity
{};
std
::
size_t
additional_args
=
1
;
bool
ignore_modules
=
false
;
template
<
class
Self
,
class
F
>
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
static
auto
reflect
(
Self
&
self
,
F
f
)
{
{
return
pack
(
f
(
self
.
op
,
"op"
));
return
pack
(
f
(
self
.
op
,
"op"
),
f
(
self
.
additional_args
,
"additional_args"
),
f
(
self
.
ignore_modules
,
"ignore_modules"
));
}
}
std
::
string
name
()
const
{
return
"gpu::precompile_op"
;
}
std
::
string
name
()
const
{
return
"gpu::precompile_op"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
,
const
std
::
vector
<
module_ref
>&
mods
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
,
const
std
::
vector
<
module_ref
>&
mods
)
const
{
{
inputs
.
pop_back
();
// Pop off additional args
inputs
.
resize
(
inputs
.
size
()
-
additional_args
);
if
(
ignore_modules
)
return
op
.
compute_shape
(
inputs
);
return
op
.
compute_shape
(
inputs
,
mods
);
return
op
.
compute_shape
(
inputs
,
mods
);
}
}
...
...
src/targets/gpu/fuse_ops.cpp
View file @
2f48b11a
...
@@ -772,11 +772,9 @@ struct find_layernorm_pointwise
...
@@ -772,11 +772,9 @@ struct find_layernorm_pointwise
{
{
auto
ins
=
r
.
result
;
auto
ins
=
r
.
result
;
auto
layernorm
=
r
.
instructions
[
"layernorm"
];
auto
layernorm
=
r
.
instructions
[
"layernorm"
];
auto
*
pm
=
ins
->
module_inputs
().
front
();
if
(
not
layernorm
->
module_inputs
().
empty
())
if
(
not
layernorm
->
module_inputs
().
empty
())
return
;
return
;
auto
*
pm
=
ins
->
module_inputs
().
front
();
auto
inputs
=
layernorm
->
inputs
();
auto
inputs
=
layernorm
->
inputs
();
inputs
.
pop_back
();
inputs
.
pop_back
();
inputs
.
insert
(
inputs
.
end
(),
ins
->
inputs
().
begin
()
+
1
,
ins
->
inputs
().
end
());
inputs
.
insert
(
inputs
.
end
(),
ins
->
inputs
().
begin
()
+
1
,
ins
->
inputs
().
end
());
...
@@ -785,6 +783,37 @@ struct find_layernorm_pointwise
...
@@ -785,6 +783,37 @@ struct find_layernorm_pointwise
}
}
};
};
struct
find_concat_pointwise
{
auto
matcher
()
const
{
return
precompile_name
(
"pointwise"
)(
match
::
arg
(
0
)(
precompile_name
(
"concat"
).
bind
(
"concat"
)));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
concat
=
r
.
instructions
[
"concat"
];
if
(
not
concat
->
module_inputs
().
empty
())
return
;
// TODO: Handle type conversions
if
(
ins
->
get_shape
().
type
()
!=
concat
->
get_shape
().
type
())
return
;
auto
*
pm
=
ins
->
module_inputs
().
front
();
auto
inputs
=
concat
->
inputs
();
inputs
.
pop_back
();
inputs
.
insert
(
inputs
.
end
(),
ins
->
inputs
().
begin
()
+
1
,
ins
->
inputs
().
end
());
auto
op
=
concat
->
get_operator
();
op
.
from_value
({{
"additional_args"
,
ins
->
inputs
().
size
()
-
1
},
{
"ignore_modules"
,
true
}});
m
.
replace_instruction
(
ins
,
op
,
inputs
,
{
pm
});
}
};
void
fuse_ops
::
apply
(
module
&
m
)
const
void
fuse_ops
::
apply
(
module
&
m
)
const
{
{
match
::
find_matches
(
m
,
find_contiguous_pointwise
{});
match
::
find_matches
(
m
,
find_contiguous_pointwise
{});
...
@@ -793,6 +822,7 @@ void fuse_ops::apply(module& m) const
...
@@ -793,6 +822,7 @@ void fuse_ops::apply(module& m) const
run_passes
(
m
,
{
dead_code_elimination
{}});
run_passes
(
m
,
{
dead_code_elimination
{}});
match
::
find_matches
(
m
,
match
::
find_matches
(
m
,
find_layernorm_pointwise
{},
find_layernorm_pointwise
{},
find_concat_pointwise
{},
find_gemm_pointwise
{},
find_gemm_pointwise
{},
find_contiguous_tranpose_gemm
{},
find_contiguous_tranpose_gemm
{},
find_commutative_broadcast
{});
find_commutative_broadcast
{});
...
...
src/targets/gpu/jit/concat.cpp
View file @
2f48b11a
...
@@ -38,16 +38,19 @@ using namespace migraphx::gpu::gen; // NOLINT
...
@@ -38,16 +38,19 @@ using namespace migraphx::gpu::gen; // NOLINT
static
const
char
*
const
concat_kernel
=
R"__migraphx__(
static
const
char
*
const
concat_kernel
=
R"__migraphx__(
#include <migraphx/kernels/concat.hpp>
#include <migraphx/kernels/concat.hpp>
#include <migraphx/kernels/vectorize.hpp>
#include <migraphx/kernels/vectorize.hpp>
#include <migraphx/kernels/ops.hpp>
#include <args.hpp>
#include <args.hpp>
namespace migraphx {
namespace migraphx {
${preamble}
extern "C" {
extern "C" {
__global__ void ${kernel}(${params})
__global__ void ${kernel}(${params})
{
{
transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto y, auto... xs) {
transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto y,
${concat_params},
auto... xs) {
concat<${axis}>(y, xs...);
concat<${axis}>(
${concat_args})(${post},
y, xs...);
});
});
}
}
...
@@ -68,28 +71,42 @@ struct concat_compiler : compiler<concat_compiler>
...
@@ -68,28 +71,42 @@ struct concat_compiler : compiler<concat_compiler>
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
{
{
// TODO: Use reduce_dims
auto
num_of_concat_inputs
=
v
.
get
(
"concat_inputs"
,
inputs
.
size
()
-
1
);
hip_compile_options
options
;
hip_compile_options
options
;
options
.
inputs
=
inputs
;
options
.
inputs
=
inputs
;
options
.
output
=
inputs
.
back
();
options
.
output
=
inputs
.
back
();
options
.
params
=
"-Wno-float-equal"
;
options
.
params
=
"-Wno-float-equal"
;
options
.
kernel_name
=
v
.
get
(
"kernel"
,
"concat_kernel"
);
auto
axis
=
find_fast_axis
(
options
.
inputs
);
auto
axis
=
find_fast_axis
(
options
.
inputs
);
auto
vec
=
vectorize
::
elements
(
ctx
,
axis
,
options
.
inputs
);
auto
vec
=
vectorize
::
elements
(
ctx
,
axis
,
options
.
inputs
);
options
.
kernel_name
=
v
.
get
(
"kernel"
,
"concat_kernel"
);
options
.
set_launch_params
(
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
get_concat_elements
(
options
.
inputs
)
/
vec
.
size
,
256
));
v
,
compute_global_for
(
ctx
,
get_concat_elements
(
options
.
inputs
)
/
vec
.
size
,
256
));
auto
src
=
interpolate_string
(
concat_kernel
,
auto
src
=
interpolate_string
(
{{
"kernel"
,
options
.
kernel_name
},
concat_kernel
,
{
"params"
,
enum_params
(
inputs
.
size
(),
"void * private_p"
)},
{{
"kernel"
,
options
.
kernel_name
},
{
"args"
,
enum_params
(
inputs
.
size
(),
"private_p"
)},
{
"params"
,
enum_params
(
inputs
.
size
(),
"void * private_p"
)},
{
"transformers"
,
make_transformer_args
(
vec
)},
{
"args"
,
enum_params
(
inputs
.
size
(),
"private_p"
)},
{
"axis"
,
v
.
at
(
"axis"
).
to
<
std
::
string
>
()}});
{
"concat_params"
,
enum_params
(
num_of_concat_inputs
,
"auto concat_x"
)},
{
"concat_args"
,
enum_params
(
num_of_concat_inputs
,
"concat_x"
)},
{
"post"
,
v
.
get
(
"post"
,
std
::
string
{
"op::id{}"
})},
{
"transformers"
,
make_transformer_args
(
vec
)},
{
"preamble"
,
v
.
get
(
"preamble"
,
std
::
string
{})},
{
"axis"
,
v
.
at
(
"axis"
).
to
<
std
::
string
>
()}});
return
compile_hip_code_object
(
src
,
options
);
return
compile_hip_code_object
(
src
,
options
);
}
}
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
const
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
const
{
{
return
replace
(
compile_op
(
ctx
,
to_shapes
(
ins
->
inputs
()),
op
.
to_value
()));
auto
v
=
op
.
to_value
();
if
(
not
ins
->
module_inputs
().
empty
())
{
auto
*
pm
=
ins
->
module_inputs
().
front
();
v
[
"concat_inputs"
]
=
ins
->
inputs
().
size
()
-
pm
->
get_parameter_names
().
size
();
v
[
"preamble"
]
=
generate_pointwise
(
*
pm
,
"post_concat"
);
v
[
"post"
]
=
"MIGRAPHX_LIFT(post_concat)"
;
v
[
"kernel"
]
=
"concat_"
+
generate_name_from_ops
(
*
pm
)
+
"_kernel"
;
}
return
replace
(
compile_op
(
ctx
,
to_shapes
(
ins
->
inputs
()),
v
));
}
}
};
};
...
...
src/targets/gpu/kernels/include/migraphx/kernels/concat.hpp
View file @
2f48b11a
...
@@ -41,7 +41,15 @@ constexpr auto concat_slice(Output out, Input, Start)
...
@@ -41,7 +41,15 @@ constexpr auto concat_slice(Output out, Input, Start)
return
Start
{}
*
output_shape
.
strides
[
Axis
];
return
Start
{}
*
output_shape
.
strides
[
Axis
];
});
});
constexpr
auto
s
=
make_shape
(
lens
,
strides
);
constexpr
auto
s
=
make_shape
(
lens
,
strides
);
return
make_tensor_view
(
&
out
[
offset
],
s
);
MIGRAPHX_ASSERT
(
offset
<
out
.
get_shape
().
element_space
());
MIGRAPHX_ASSERT
((
s
.
element_space
()
+
offset
)
<=
out
.
get_shape
().
element_space
());
return
make_tensor_view
(
out
.
data
()
+
offset
,
s
);
}
template
<
index_int
Axis
,
class
Input
,
class
Start
,
class
...
Ts
>
constexpr
auto
concat_slices
(
Input
input
,
Start
start
,
Ts
...
xs
)
{
return
[
=
](
auto
f
)
{
f
(
concat_slice
<
Axis
>
(
xs
,
input
,
start
)...);
};
}
}
template
<
index_int
Axis
,
class
Input
>
template
<
index_int
Axis
,
class
Input
>
...
@@ -51,15 +59,19 @@ constexpr auto concat_ends(Input)
...
@@ -51,15 +59,19 @@ constexpr auto concat_ends(Input)
return
_c
<
lens
[
Axis
]
>
;
return
_c
<
lens
[
Axis
]
>
;
}
}
template
<
index_int
Axis
,
class
Output
,
class
...
Inputs
>
template
<
index_int
Axis
,
class
...
Inputs
>
__device__
void
concat
(
Output
output
,
Inputs
...
inputs
)
__device__
auto
concat
(
Inputs
...
inputs
)
{
{
auto
idx
=
make_index
();
return
[
=
](
auto
f
,
auto
...
ts
)
{
fold
([
&
](
auto
start
,
auto
input
)
{
auto
idx
=
make_index
();
auto
y
=
concat_slice
<
Axis
>
(
output
,
input
,
start
);
fold
([
&
](
auto
start
,
auto
input
)
{
idx
.
global_stride
(
input
.
get_shape
().
elements
(),
[
&
](
auto
i
)
{
y
[
i
]
=
input
[
i
];
});
concat_slices
<
Axis
>
(
input
,
start
,
ts
...)([
&
](
auto
y
,
auto
...
xs
)
{
return
start
+
concat_ends
<
Axis
>
(
input
);
idx
.
global_stride
(
input
.
get_shape
().
elements
(),
})(
_c
<
0
>
,
inputs
...);
[
&
](
auto
i
)
{
y
[
i
]
=
f
(
input
[
i
],
xs
[
i
]...);
});
});
return
start
+
concat_ends
<
Axis
>
(
input
);
})(
_c
<
0
>
,
inputs
...);
};
}
}
}
// namespace migraphx
}
// namespace migraphx
...
...
src/targets/gpu/mlir.cpp
View file @
2f48b11a
...
@@ -101,7 +101,10 @@ struct mlir_handle
...
@@ -101,7 +101,10 @@ struct mlir_handle
mlir_handle
(
T
p
)
:
handle
(
ptr
{
p
})
{}
mlir_handle
(
T
p
)
:
handle
(
ptr
{
p
})
{}
T
get
()
const
{
return
handle
.
get
().
get
();
}
T
get
()
const
{
return
handle
.
get
().
get
();
// NOLINT(readability-redundant-smartptr-get)
}
T
release
()
{
return
handle
.
release
().
get
();
}
T
release
()
{
return
handle
.
release
().
get
();
}
...
...
test/verify/test_concat_broadcast_add.cpp
0 → 100644
View file @
2f48b11a
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
test_concat_broadcast_add
:
verify_program
<
test_concat_broadcast_add
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
s0
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
4
}};
migraphx
::
shape
s1
{
migraphx
::
shape
::
float_type
,
{
1
,
6
,
4
}};
migraphx
::
shape
s2
{
migraphx
::
shape
::
float_type
,
{
6
,
1
}};
auto
x
=
mm
->
add_parameter
(
"x"
,
s0
);
auto
y
=
mm
->
add_parameter
(
"y"
,
s0
);
auto
z
=
mm
->
add_parameter
(
"z"
,
s0
);
auto
concat
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"concat"
,
{{
"axis"
,
1
}}),
x
,
y
,
z
);
auto
b
=
mm
->
add_literal
(
migraphx
::
generate_literal
(
s2
,
15
));
auto
bb
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
s1
.
lens
()}}),
b
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
concat
,
bb
);
return
p
;
}
};
test/verify/test_slice_concat_add.cpp
0 → 100644
View file @
2f48b11a
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
test_slice_concat_add
:
verify_program
<
test_slice_concat_add
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
s0
{
migraphx
::
shape
::
float_type
,
{
1
,
24
,
2
,
2
}};
migraphx
::
shape
s1
{
migraphx
::
shape
::
float_type
,
{
1
,
8
,
2
,
2
}};
auto
x
=
mm
->
add_parameter
(
"x"
,
s0
);
auto
y
=
mm
->
add_parameter
(
"y"
,
s1
);
auto
z
=
mm
->
add_parameter
(
"z"
,
s0
);
auto
slice
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
8
}}}),
x
);
auto
concat
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"concat"
,
{{
"axis"
,
1
}}),
slice
,
y
,
y
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
concat
,
z
);
return
p
;
}
};
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment