Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
fc60486e
Unverified
Commit
fc60486e
authored
Sep 26, 2023
by
Chris Austen
Committed by
GitHub
Sep 26, 2023
Browse files
Merge branch 'develop' into enable_navi_32_ci
parents
6041f74c
434a06cf
Changes
174
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
546 additions
and
138 deletions
+546
-138
src/onnx/parse_roialign.cpp
src/onnx/parse_roialign.cpp
+7
-4
src/optimize_module.cpp
src/optimize_module.cpp
+7
-3
src/program.cpp
src/program.cpp
+1
-1
src/propagate_constant.cpp
src/propagate_constant.cpp
+4
-4
src/simplify_algebra.cpp
src/simplify_algebra.cpp
+92
-32
src/simplify_dyn_ops.cpp
src/simplify_dyn_ops.cpp
+141
-0
src/simplify_reshapes.cpp
src/simplify_reshapes.cpp
+26
-1
src/split_single_dyn_dim.cpp
src/split_single_dyn_dim.cpp
+0
-32
src/targets/gpu/CMakeLists.txt
src/targets/gpu/CMakeLists.txt
+2
-0
src/targets/gpu/compile_hip.cpp
src/targets/gpu/compile_hip.cpp
+53
-2
src/targets/gpu/compile_hip_code_object.cpp
src/targets/gpu/compile_hip_code_object.cpp
+31
-20
src/targets/gpu/device/include/migraphx/gpu/device/launch.hpp
...targets/gpu/device/include/migraphx/gpu/device/launch.hpp
+11
-2
src/targets/gpu/device/targets.cpp
src/targets/gpu/device/targets.cpp
+66
-0
src/targets/gpu/device/targets.hpp.in
src/targets/gpu/device/targets.hpp.in
+48
-0
src/targets/gpu/fuse_mlir.cpp
src/targets/gpu/fuse_mlir.cpp
+42
-26
src/targets/gpu/include/migraphx/gpu/compile_hip.hpp
src/targets/gpu/include/migraphx/gpu/compile_hip.hpp
+2
-0
src/targets/gpu/include/migraphx/gpu/mlir.hpp
src/targets/gpu/include/migraphx/gpu/mlir.hpp
+2
-1
src/targets/gpu/jit/mlir.cpp
src/targets/gpu/jit/mlir.cpp
+1
-3
src/targets/gpu/jit/roialign.cpp
src/targets/gpu/jit/roialign.cpp
+1
-1
src/targets/gpu/mlir.cpp
src/targets/gpu/mlir.cpp
+9
-6
No files found.
src/onnx/parse_roialign.cpp
View file @
fc60486e
...
...
@@ -37,15 +37,18 @@ struct parse_roialign : op_parser<parse_roialign>
std
::
vector
<
op_desc
>
operators
()
const
{
return
{{
"RoiAlign"
}};
}
instruction_ref
parse
(
const
op_desc
&
/*opd*/
,
const
onnx_parser
&
/*
parser
*/
,
const
onnx_parser
&
parser
,
onnx_parser
::
node_info
info
,
const
std
::
vector
<
instruction_ref
>&
args
)
const
{
std
::
string
coord_trans_mode
=
"half_pixel"
;
if
(
contains
(
info
.
attributes
,
"coordinate_transformation_mode"
))
std
::
string
coord_trans_mode
=
parser
.
opset_version
>=
16
?
"half_pixel"
:
"output_half_pixel"
;
if
(
const
auto
*
a
=
"coordinate_transformation_mode"
;
contains
(
info
.
attributes
,
a
))
{
coord_trans_mode
=
info
.
attributes
.
at
(
"coordinate_transformation_mode"
).
s
();
coord_trans_mode
=
info
.
attributes
.
at
(
a
).
s
();
}
if
(
not
contains
({
"half_pixel"
,
"output_half_pixel"
},
coord_trans_mode
))
{
MIGRAPHX_THROW
(
"coordinate_transformation_mode
\"
"
+
coord_trans_mode
+
...
...
src/optimize_module.cpp
View file @
fc60486e
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
...
...
@@ -36,8 +36,12 @@ void optimize_module::apply(module_pass_manager& mpm) const
{
for
(
int
i
=
0
;
i
<
2
;
i
++
)
{
mpm
.
run_pass
(
simplify_reshapes
{});
mpm
.
run_pass
(
simplify_algebra
{});
// loop to further optimize after initial transformations
for
(
int
j
=
0
;
j
<
2
;
j
++
)
{
mpm
.
run_pass
(
simplify_reshapes
{});
mpm
.
run_pass
(
simplify_algebra
{});
}
mpm
.
run_pass
(
eliminate_common_subexpression
{});
mpm
.
run_pass
(
dead_code_elimination
{});
mpm
.
run_pass
(
propagate_constant
{});
...
...
src/program.cpp
View file @
fc60486e
...
...
@@ -624,7 +624,7 @@ std::string get_migraphx_version()
program file version is for the data structure or format of the MXR file. Version should be bumped
if any changes occur to the format of the MXR file.
*/
const
int
program_file_version
=
6
;
const
int
program_file_version
=
7
;
value
program
::
to_value
()
const
{
...
...
src/propagate_constant.cpp
View file @
fc60486e
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
...
...
@@ -35,10 +35,10 @@ inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_TRACE_PROPAGATE_CONSTANT
)
bool
skip_prop
o
gate
(
instruction_ref
ins
)
bool
skip_prop
a
gate
(
instruction_ref
ins
)
{
if
(
ins
->
name
()
==
"contiguous"
)
return
skip_prop
o
gate
(
ins
->
inputs
().
front
());
return
skip_prop
a
gate
(
ins
->
inputs
().
front
());
auto
&&
s
=
ins
->
get_shape
();
if
(
s
.
broadcasted
()
and
not
s
.
scalar
())
return
true
;
...
...
@@ -47,7 +47,7 @@ bool skip_propogate(instruction_ref ins)
return
false
;
}
bool
is_const_ins
(
instruction_ref
ins
)
{
return
ins
->
can_eval
()
and
not
skip_prop
o
gate
(
ins
);
}
bool
is_const_ins
(
instruction_ref
ins
)
{
return
ins
->
can_eval
()
and
not
skip_prop
a
gate
(
ins
);
}
void
propagate_constant
::
apply
(
module
&
m
)
const
{
...
...
src/simplify_algebra.cpp
View file @
fc60486e
...
...
@@ -1325,48 +1325,59 @@ struct find_split_reshape
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
slc
=
r
.
instructions
[
"slice"
];
auto
rsp
=
r
.
instructions
[
"reshape"
];
auto
slc
=
r
.
instructions
[
"slice"
];
auto
rsp
=
r
.
instructions
[
"reshape"
];
auto
input
=
slc
->
inputs
().
front
();
// Only apply simplification when slices are on a single axis
auto
axes
=
any_cast
<
op
::
slice
>
(
slc
->
get_operator
()).
axes
;
if
(
axes
.
size
()
>
1
)
{
return
;
}
auto
input
=
slc
->
inputs
().
front
();
auto
split_outputs
=
get_splits
(
input
);
if
(
split_outputs
.
empty
())
{
return
;
}
// Only want to apply this optimization if each split output is followed by
// a contiguous op and a reshape
if
(
std
::
any_of
(
split_outputs
.
begin
(),
split_outputs
.
end
(),
[](
auto
i
)
{
if
(
i
->
outputs
().
size
()
==
1
)
{
auto
cont
=
i
->
outputs
().
front
();
return
cont
->
outputs
().
size
()
!=
1
;
}
return
false
;
}))
// Find all the reshapes (similar to rsp) that can be simplified
std
::
vector
<
instruction_ref
>
conts
;
std
::
vector
<
instruction_ref
>
vec_rsp
;
// Iterate through slice and contiguous outputs to allow simplifications when
// slice is followed by multiple reshapes
for
(
auto
&
i
:
split_outputs
)
{
return
;
std
::
copy_if
(
i
->
outputs
().
begin
(),
i
->
outputs
().
end
(),
std
::
back_inserter
(
conts
),
[](
auto
j
)
{
return
j
->
name
()
==
"contiguous"
;
});
}
std
::
vector
<
instruction_ref
>
vec_rsp
(
split_outputs
.
size
());
std
::
transform
(
split_outputs
.
begin
(),
split_outputs
.
end
(),
vec_rsp
.
begin
(),
[](
auto
i
)
{
auto
cont
=
i
->
outputs
().
front
();
return
cont
->
outputs
().
front
();
});
for
(
auto
&
i
:
conts
)
{
std
::
copy_if
(
i
->
outputs
().
begin
(),
i
->
outputs
().
end
(),
std
::
back_inserter
(
vec_rsp
),
[
&
](
auto
j
)
{
return
j
->
get_operator
()
==
rsp
->
get_operator
();
});
}
// all outputs are reshape and of the same shape
auto
dims
=
any_cast
<
op
::
reshape
>
(
rsp
->
get_operator
()).
dims
;
if
(
not
same_ops
(
vec_rsp
))
// No simplification needed if there is only one slice -> cont -> reshape
if
(
vec_rsp
.
size
()
<=
1
)
{
return
;
}
// ensure reshape happens after the axis dimension
auto
axis
=
any_cast
<
op
::
slice
>
(
slc
->
get_operator
()).
axes
[
0
];
auto
axis
=
axes
[
0
];
auto
slc_lens
=
slc
->
get_shape
().
lens
();
auto
slc_dim_size
=
std
::
accumulate
(
slc_lens
.
begin
()
+
axis
,
slc_lens
.
end
(),
1
,
std
::
multiplies
<
std
::
size_t
>
());
auto
input_lens
=
input
->
get_shape
().
lens
();
auto
input_size
=
input
->
get_shape
().
elements
();
auto
slc_axis_len
=
input_lens
[
axis
];
// search the reshape output (standard shape) to decide which axis are
// in its output corresponding to the slc_dim_size
...
...
@@ -1393,16 +1404,67 @@ struct find_split_reshape
{
rsp_axis
=
std
::
distance
(
rsp_strides
.
begin
(),
ait
);
}
// calculate reshape output shape
std
::
vector
<
int64_t
>
vec_dims
(
vec_rsp
.
size
());
std
::
transform
(
vec_rsp
.
begin
(),
vec_rsp
.
end
(),
vec_dims
.
begin
(),
[
&
](
auto
is
)
{
return
is
->
get_shape
().
lens
()[
rsp_axis
];
});
// Calculate reshape output shape
// Need to find a reshape such that data represented by instructions in vec_rsp can be
// written as slices of this new reshape. This is done by holding all the dims constant in
// rsp_lens to compute the required dim for rsp_axis (axis that will be sliced)
// ex 1: Input Shape: {2, 12, 4}, Slice Axis: 1, Slices are: (0:4), (4:8), (8:12),
// Reshape Outputs: {2, 2, 2, 4}, {2, 2, 2, 4}, {2, 2, 2, 4}
// rsp_axis = 1, rsp_out_lens (initial) = {2, 1, 2, 4}, rsp_fixed_size = 2*1*2*4 = 16
// rsp_axis_len = 2*12*4 / 16 = 6
// rsp_out_lens (final) = {2, 6, 2, 4}
// ex 2: Input Shape: {2, 12, 4}, Slice Axis: 1, Slices are: (0:4), (4:8), (8:12),
// Reshape Outputs: {2, 16}, {2, 16}, {2, 16}
// rsp_axis = 1, rsp_out_lens (initial) = {2, 1}, rsp_fixed_size = 2*1 = 2
// rsp_axis_len = 2*12*4 / 2 = 48
// rsp_out_lens (final) = {2, 48}
std
::
vector
<
int64_t
>
rsp_out_lens
(
rsp_lens
.
begin
(),
rsp_lens
.
end
());
rsp_out_lens
[
rsp_axis
]
=
1
;
auto
rsp_fixed_size
=
std
::
accumulate
(
rsp_out_lens
.
begin
(),
rsp_out_lens
.
end
(),
1
,
std
::
multiplies
<
std
::
size_t
>
());
rsp_out_lens
[
rsp_axis
]
=
std
::
accumulate
(
vec_dims
.
begin
(),
vec_dims
.
end
(),
std
::
int64_t
{
0
});
// cannot create a valid reshape for simplification
if
(
input_size
%
rsp_fixed_size
!=
0
)
{
return
;
}
auto
rsp_axis_len
=
input_size
/
rsp_fixed_size
;
rsp_out_lens
[
rsp_axis
]
=
rsp_axis_len
;
// Calculate new slice start and end indices. Indices are scaled using the new reshape axis
// and the original slice axis. See examples:
// ex 1: Input Shape: {2, 12, 4}, Slice Axis: 1, Slices are: (0:4), (4:8), (8:12),
// Reshape Outputs: {2, 2, 2, 4}, {2, 2, 2, 4}, {2, 2, 2, 4}
// slc_axis_len = 12, rsp_axis_len = 6
// New Starts: {0*6/12, 4*6/12, 8*6/12} = {0, 2, 4}
// New Ends: {4*6/12, 8*6/12, 12*6/12} = {2, 4, 6}
// ex 2: Input Shape: {2, 12, 4}, Slice Axis: 1, Slices are: (0:4), (4:8), (8:12),
// Reshape Outputs: {2, 16}, {2, 16}, {2, 16}
// slc_axis_len = 12, rsp_axis_len = 48
// New Starts: {0*48/12, 4*48/12, 8*48/12} = { 0, 16, 32}
// New Ends: {4*48/12, 8*48/12, 12*48/12} = {16, 32, 48}
std
::
vector
<
int64_t
>
new_starts
(
vec_rsp
.
size
());
std
::
transform
(
vec_rsp
.
begin
(),
vec_rsp
.
end
(),
new_starts
.
begin
(),
[
&
](
auto
is
)
{
auto
cont
=
is
->
inputs
().
front
();
auto
og_slc
=
cont
->
inputs
().
front
();
return
any_cast
<
op
::
slice
>
(
og_slc
->
get_operator
()).
starts
[
0
]
*
rsp_axis_len
/
slc_axis_len
;
});
std
::
vector
<
int64_t
>
new_ends
(
vec_rsp
.
size
());
std
::
transform
(
vec_rsp
.
begin
(),
vec_rsp
.
end
(),
new_ends
.
begin
(),
[
&
](
auto
is
)
{
auto
cont
=
is
->
inputs
().
front
();
auto
og_slc
=
cont
->
inputs
().
front
();
return
any_cast
<
op
::
slice
>
(
og_slc
->
get_operator
()).
ends
[
0
]
*
rsp_axis_len
/
slc_axis_len
;
});
// insert the reshape instruction and add contiguous if needed
if
(
not
input
->
get_shape
().
standard
())
...
...
@@ -1413,16 +1475,14 @@ struct find_split_reshape
std
::
next
(
input
),
make_op
(
"reshape"
,
{{
"dims"
,
rsp_out_lens
}}),
input
);
// replace the original reshape with slice
int64_t
start
=
0
;
for
(
std
::
size_t
i
=
0
;
i
<
vec_rsp
.
size
();
++
i
)
{
m
.
replace_instruction
(
vec_rsp
[
i
],
make_op
(
"slice"
,
{{
"axes"
,
{
rsp_axis
}},
{
"starts"
,
{
start
}},
{
"ends"
,
{
start
+
vec_dim
s
[
i
]}}}),
{{
"axes"
,
{
rsp_axis
}},
{
"starts"
,
{
new_
start
s
[
i
]
}},
{
"ends"
,
{
new_end
s
[
i
]}}}),
rsp_ins
);
start
+=
vec_dims
[
i
];
}
}
};
...
...
src/simplify_dyn_ops.cpp
0 → 100644
View file @
fc60486e
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/simplify_dyn_ops.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/make_op.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
/**
* Convert 2 input static shape broadcast/multibroadcast into 1 input version.
* Some compiler passes (ex. simplify_algebra) only support the 1 input versions
* of the broadcasting operators.
*/
struct
find_static_2in_broadcasts
{
auto
matcher
()
const
{
return
match
::
broadcast
(
match
::
nargs
(
2
),
match
::
arg
(
0
)(
match
::
static_shape
()),
match
::
arg
(
1
)(
match
::
static_shape
()));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
mr
)
const
{
auto
ins
=
mr
.
result
;
auto
out_lens
=
ins
->
get_shape
().
lens
();
auto
broadcast_op
=
ins
->
get_operator
();
if
(
broadcast_op
.
name
()
==
"broadcast"
)
{
broadcast_op
.
from_value
({{
"out_lens"
,
out_lens
}});
}
else
{
broadcast_op
.
from_value
({{
"out_lens"
,
out_lens
},
{
"out_dyn_dims"
,
{}}});
}
m
.
replace_instruction
(
ins
,
broadcast_op
,
ins
->
inputs
().
at
(
0
));
}
};
/**
* Simplify slice with variable `starts` and `ends` to the constant version if
* the `input_starts` and `input_ends` inputs are constant.
*/
struct
find_const_3in_slice
{
auto
matcher
()
const
{
return
match
::
name
(
"slice"
)(
match
::
nargs
(
3
),
match
::
arg
(
1
)(
match
::
is_constant
()),
match
::
arg
(
2
)(
match
::
is_constant
()));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
mr
)
const
{
auto
ins
=
mr
.
result
;
auto
inputs
=
ins
->
inputs
();
argument
starts_arg
=
inputs
.
at
(
1
)
->
eval
();
argument
ends_arg
=
inputs
.
at
(
2
)
->
eval
();
if
(
not
starts_arg
.
empty
()
and
not
ends_arg
.
empty
())
{
std
::
vector
<
int64_t
>
starts_vec
;
std
::
vector
<
int64_t
>
ends_vec
;
starts_arg
.
visit
([
&
](
auto
output
)
{
starts_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
ends_arg
.
visit
([
&
](
auto
output
)
{
ends_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
auto
slice_val
=
ins
->
get_operator
().
to_value
();
auto
axes_vec
=
slice_val
.
at
(
"axes"
).
to_vector
<
int64_t
>
();
m
.
replace_instruction
(
ins
,
make_op
(
"slice"
,
{{
"starts"
,
starts_vec
},
{
"ends"
,
ends_vec
},
{
"axes"
,
axes_vec
}}),
inputs
.
at
(
0
));
}
}
};
/**
* Simplify slice with variable `starts`, `ends`, and `input_axes` to the constant version if
* the `input_starts`, `input_ends`, and `input_axes` inputs are constant.
*/
struct
find_const_4in_slice
{
auto
matcher
()
const
{
return
match
::
name
(
"slice"
)(
match
::
nargs
(
4
),
match
::
arg
(
1
)(
match
::
is_constant
()),
match
::
arg
(
2
)(
match
::
is_constant
()),
match
::
arg
(
3
)(
match
::
is_constant
()));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
mr
)
const
{
auto
ins
=
mr
.
result
;
auto
inputs
=
ins
->
inputs
();
argument
starts_arg
=
inputs
.
at
(
1
)
->
eval
();
argument
ends_arg
=
inputs
.
at
(
2
)
->
eval
();
argument
axes_arg
=
inputs
.
at
(
3
)
->
eval
();
if
(
not
starts_arg
.
empty
()
and
not
ends_arg
.
empty
()
and
not
axes_arg
.
empty
())
{
std
::
vector
<
int64_t
>
starts_vec
;
std
::
vector
<
int64_t
>
ends_vec
;
std
::
vector
<
int64_t
>
axes_vec
;
starts_arg
.
visit
([
&
](
auto
output
)
{
starts_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
ends_arg
.
visit
([
&
](
auto
output
)
{
ends_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
axes_arg
.
visit
([
&
](
auto
output
)
{
axes_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
m
.
replace_instruction
(
ins
,
make_op
(
"slice"
,
{{
"starts"
,
starts_vec
},
{
"ends"
,
ends_vec
},
{
"axes"
,
axes_vec
}}),
inputs
.
at
(
0
));
}
}
};
void
simplify_dyn_ops
::
apply
(
module
&
m
)
const
{
match
::
find_matches
(
m
,
find_static_2in_broadcasts
{},
find_const_3in_slice
{},
find_const_4in_slice
{});
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/simplify_reshapes.cpp
View file @
fc60486e
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
...
...
@@ -627,6 +627,30 @@ struct find_transpose_contiguous_reshaper_unary
}
};
struct
find_broadcast_transpose
{
auto
matcher
()
const
{
return
match
::
name
(
"transpose"
)(
match
::
arg
(
0
)(
match
::
name
(
"multibroadcast"
).
bind
(
"bcast_ins"
)));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
ins_lens
=
ins
->
get_shape
().
lens
();
auto
bcast_ins
=
r
.
instructions
[
"bcast_ins"
];
auto
input
=
bcast_ins
->
inputs
().
front
();
// for now, focusing on scalar transformation
if
(
not
input
->
get_shape
().
scalar
())
return
;
auto
new_mbcast
=
m
.
insert_instruction
(
bcast_ins
,
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
ins_lens
}}),
input
);
m
.
replace_instruction
(
ins
,
new_mbcast
);
}
};
struct
find_slice_transpose
{
auto
matcher
()
const
...
...
@@ -799,6 +823,7 @@ void simplify_reshapes::apply(module& m) const
find_nested_slice
{},
find_nested_concat
{},
find_transpose_slice
{},
find_broadcast_transpose
{},
find_slice_transpose
{},
find_transpose_contiguous_reshaper_unary
{});
dead_code_elimination
{}.
apply
(
m
);
...
...
src/split_single_dyn_dim.cpp
View file @
fc60486e
...
...
@@ -68,37 +68,6 @@ has_one_dyn_dim(const std::unordered_map<std::string, shape>& param_shapes)
dds_it
->
max
};
}
namespace
{
struct
find_static_2in_broadcasts
{
// Convert 2 input static shape broadcast/multibroadcast into 1 input version.
// Some compiler passes (ex. simplify_algebra) only support the 1 input versions
// of the broadcasting operators.
auto
matcher
()
const
{
return
match
::
broadcast
(
match
::
nargs
(
2
),
match
::
arg
(
0
)(
match
::
static_shape
()),
match
::
arg
(
1
)(
match
::
static_shape
()));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
mr
)
const
{
auto
ins
=
mr
.
result
;
auto
out_lens
=
ins
->
get_shape
().
lens
();
auto
broadcast_op
=
ins
->
get_operator
();
if
(
broadcast_op
.
name
()
==
"broadcast"
)
{
broadcast_op
.
from_value
({{
"out_lens"
,
out_lens
}});
}
else
{
broadcast_op
.
from_value
({{
"out_lens"
,
out_lens
},
{
"out_dyn_dims"
,
{}}});
}
m
.
replace_instruction
(
ins
,
broadcast_op
,
ins
->
inputs
().
at
(
0
));
}
};
}
// namespace
/**
* Makes all the shapes in the dynamic_dimension range. Probably won't work for `if`
* and `loop` instructions, depending on how the submodules for those
...
...
@@ -135,7 +104,6 @@ void split_single_dyn_dim::apply(module_pass_manager& mpm) const
dd_check
->
dyn_param_str
,
migraphx
::
shape
{
dyn_param_shape
.
type
(),
static_lens
});
auto
outputs
=
submod
->
add_instructions
(
mm
,
map_ins
);
submod
->
add_return
({
outputs
});
match
::
find_matches
(
*
submod
,
find_static_2in_broadcasts
{});
submodules
.
push_back
(
submod
);
}
// redirect to select_module operator and return
...
...
src/targets/gpu/CMakeLists.txt
View file @
fc60486e
...
...
@@ -50,6 +50,7 @@ file(GLOB KERNEL_FILES CONFIGURE_DEPENDS
message
(
STATUS
"KERNEL_FILES:
${
KERNEL_FILES
}
"
)
add_embed_library
(
migraphx_kernels
${
KERNEL_FILES
}
RELATIVE
${
CMAKE_CURRENT_SOURCE_DIR
}
/kernels/include/
)
configure_file
(
device/targets.hpp.in include/migraphx/gpu/device/targets.hpp
)
file
(
GLOB DEVICE_GPU_SRCS CONFIGURE_DEPENDS
${
CMAKE_CURRENT_SOURCE_DIR
}
/device/*.cpp
)
add_library
(
migraphx_device
${
DEVICE_GPU_SRCS
}
)
...
...
@@ -69,6 +70,7 @@ rocm_clang_tidy_check(migraphx_device)
target_link_libraries
(
migraphx_device PUBLIC migraphx
)
target_link_libraries
(
migraphx_device PRIVATE compile_for_gpu
)
target_include_directories
(
migraphx_device PUBLIC $<BUILD_INTERFACE:
${
CMAKE_CURRENT_SOURCE_DIR
}
/include>
)
target_include_directories
(
migraphx_device PRIVATE $<BUILD_INTERFACE:
${
CMAKE_CURRENT_BINAR_DIR
}
/include>
)
target_include_directories
(
migraphx_device PRIVATE $<BUILD_INTERFACE:
${
CMAKE_CURRENT_SOURCE_DIR
}
/device/include>
)
target_compile_options
(
migraphx_device PRIVATE -Wno-ignored-attributes
)
migraphx_generate_export_header
(
migraphx_device DIRECTORY migraphx/gpu/device
)
...
...
src/targets/gpu/compile_hip.cpp
View file @
fc60486e
...
...
@@ -115,6 +115,12 @@ struct hiprtc_program
std
::
string
cpp_src
=
""
;
std
::
string
cpp_name
=
""
;
hiprtc_program
(
const
std
::
string
&
src
,
const
std
::
string
&
name
=
"main.cpp"
)
:
cpp_src
(
src
),
cpp_name
(
name
)
{
create_program
();
}
hiprtc_program
(
std
::
vector
<
hiprtc_src_file
>
srcs
)
{
for
(
auto
&&
src
:
srcs
)
...
...
@@ -130,6 +136,14 @@ struct hiprtc_program
include_names
.
push_back
(
std
::
move
(
src
.
path
));
}
}
create_program
();
}
void
create_program
()
{
assert
(
not
cpp_src
.
empty
());
assert
(
not
cpp_name
.
empty
());
assert
(
headers
.
size
()
==
include_names
.
size
());
prog
=
hiprtc_program_create
(
cpp_src
.
c_str
(),
cpp_name
.
c_str
(),
headers
.
size
(),
...
...
@@ -137,7 +151,7 @@ struct hiprtc_program
include_names
.
data
());
}
void
compile
(
const
std
::
vector
<
std
::
string
>&
options
)
const
void
compile
(
const
std
::
vector
<
std
::
string
>&
options
,
bool
quiet
=
false
)
const
{
if
(
enabled
(
MIGRAPHX_TRACE_HIPRTC
{}))
std
::
cout
<<
"hiprtc "
<<
join_strings
(
options
,
" "
)
<<
" "
<<
cpp_name
<<
std
::
endl
;
...
...
@@ -148,7 +162,7 @@ struct hiprtc_program
[](
const
std
::
string
&
s
)
{
return
s
.
c_str
();
});
auto
result
=
hiprtcCompileProgram
(
prog
.
get
(),
c_options
.
size
(),
c_options
.
data
());
auto
prog_log
=
log
();
if
(
not
prog_log
.
empty
())
if
(
not
prog_log
.
empty
()
and
not
quiet
)
{
std
::
cerr
<<
prog_log
<<
std
::
endl
;
}
...
...
@@ -210,6 +224,20 @@ std::vector<std::vector<char>> compile_hip_src_with_hiprtc(std::vector<hiprtc_sr
return
{
prog
.
get_code_obj
()};
}
bool
hip_has_flags
(
const
std
::
vector
<
std
::
string
>&
flags
)
{
hiprtc_program
prog
{
" "
};
try
{
prog
.
compile
(
flags
,
true
);
return
true
;
}
catch
(...)
{
return
false
;
}
}
std
::
vector
<
std
::
vector
<
char
>>
compile_hip_src
(
const
std
::
vector
<
src_file
>&
srcs
,
std
::
string
params
,
const
std
::
string
&
arch
)
{
...
...
@@ -323,6 +351,29 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
return
{
compiler
.
compile
(
srcs
)};
}
bool
hip_has_flags
(
const
std
::
vector
<
std
::
string
>&
flags
)
{
src_compiler
compiler
;
compiler
.
compiler
=
MIGRAPHX_STRINGIZE
(
MIGRAPHX_HIP_COMPILER
);
compiler
.
flags
=
join_strings
(
flags
,
" "
)
+
" -x hip -c --offload-arch=gfx900 --cuda-device-only"
;
std
::
string
src
;
src_file
input
;
input
.
path
=
"main.cpp"
;
input
.
content
=
std
::
make_pair
(
src
.
data
(),
src
.
data
()
+
src
.
size
());
try
{
compiler
.
compile
({
input
});
return
true
;
}
catch
(...)
{
return
false
;
}
}
#endif // MIGRAPHX_USE_HIPRTC
std
::
string
enum_params
(
std
::
size_t
count
,
std
::
string
param
)
...
...
src/targets/gpu/compile_hip_code_object.cpp
View file @
fc60486e
...
...
@@ -91,28 +91,39 @@ __content__
return
replace_string
(
args_hpp
,
"__content__"
,
inner
);
}
static
std
::
vector
<
std
::
string
>
get_compiler_warnings
()
{
std
::
vector
<
std
::
string
>
warnings
=
{
"-Weverything"
,
"-Wno-c++98-compat"
,
"-Wno-c++98-compat-pedantic"
,
"-Wno-conversion"
,
"-Wno-double-promotion"
,
"-Wno-exit-time-destructors"
,
"-Wno-extra-semi"
,
"-Wno-extra-semi-stmt"
,
"-Wno-float-conversion"
,
"-Wno-gnu-anonymous-struct"
,
"-Wno-gnu-zero-variadic-macro-arguments"
,
"-Wno-missing-prototypes"
,
"-Wno-nested-anon-types"
,
"-Wno-padded"
,
"-Wno-shorten-64-to-32"
,
"-Wno-sign-conversion"
,
"-Wno-sign-compare"
,
"-Wno-unused-command-line-argument"
,
"-Wno-weak-vtables"
,
"-Wno-c99-extensions"
,
};
if
(
hip_has_flags
({
"-Werror"
,
"-Wunsafe-buffer-usage"
}))
warnings
.
push_back
(
"-Wno-unsafe-buffer-usage"
);
return
warnings
;
}
const
std
::
vector
<
std
::
string
>&
compiler_warnings
()
{
static
std
::
vector
<
std
::
string
>
warnings
=
{
"-Weverything"
,
"-Wno-c++98-compat"
,
"-Wno-c++98-compat-pedantic"
,
"-Wno-conversion"
,
"-Wno-double-promotion"
,
"-Wno-exit-time-destructors"
,
"-Wno-extra-semi"
,
"-Wno-extra-semi-stmt"
,
"-Wno-float-conversion"
,
"-Wno-gnu-anonymous-struct"
,
"-Wno-gnu-zero-variadic-macro-arguments"
,
"-Wno-missing-prototypes"
,
"-Wno-nested-anon-types"
,
"-Wno-padded"
,
"-Wno-shorten-64-to-32"
,
"-Wno-sign-conversion"
,
"-Wno-sign-compare"
,
"-Wno-unused-command-line-argument"
,
"-Wno-weak-vtables"
,
"-Wno-c99-extensions"
};
static
std
::
vector
<
std
::
string
>
warnings
=
get_compiler_warnings
();
return
warnings
;
}
...
...
src/targets/gpu/device/include/migraphx/gpu/device/launch.hpp
View file @
fc60486e
...
...
@@ -26,7 +26,9 @@
#include <hip/hip_runtime.h>
#include <migraphx/config.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/gpu/device/types.hpp>
#include <migraphx/gpu/device/targets.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -84,8 +86,15 @@ inline auto launch(hipStream_t stream, index_int global, index_int local)
hipError_t
kernel_launch_status
=
hipGetLastError
();
if
(
kernel_launch_status
!=
hipSuccess
)
{
MIGRAPHX_THROW
(
"MIGraphX device kernel failed to launch with error: "
+
std
::
string
(
hipGetErrorString
(
kernel_launch_status
)));
std
::
string
message
=
hipGetErrorString
(
kernel_launch_status
);
if
(
not
contains
(
get_targets
(),
get_device_name
()))
{
message
+=
". Trying to run a kernel for "
+
get_device_name
()
+
" but MIGraphX was built for targets "
+
get_targets_as_string
()
+
". Please rebuild MIGraphX with -DGPU_TARGETS='"
+
get_device_name
()
+
"'."
;
}
MIGRAPHX_THROW
(
"MIGraphX device kernel failed to launch with error: "
+
message
);
}
};
}
...
...
src/targets/gpu/device/targets.cpp
0 → 100644
View file @
fc60486e
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/device/targets.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/errors.hpp>
#include <hip/hip_runtime_api.h>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
device
{
static
std
::
vector
<
std
::
string
>
parse_targets
()
{
return
split_string
(
MIGRAPHX_GPU_TARGETS
,
';'
);
}
const
std
::
vector
<
std
::
string
>&
get_targets
()
{
static
auto
result
=
parse_targets
();
return
result
;
}
std
::
string
get_targets_as_string
()
{
return
join_strings
(
get_targets
(),
", "
);
}
static
int
get_device_id
()
{
int
device
;
auto
status
=
hipGetDevice
(
&
device
);
if
(
status
!=
hipSuccess
)
MIGRAPHX_THROW
(
"No device"
);
return
device
;
}
std
::
string
get_device_name
()
{
hipDeviceProp_t
props
{};
auto
status
=
hipGetDeviceProperties
(
&
props
,
get_device_id
());
if
(
status
!=
hipSuccess
)
MIGRAPHX_THROW
(
"Failed to get device properties"
);
return
props
.
gcnArchName
;
}
}
// namespace device
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/device/targets.hpp.in
0 → 100644
View file @
fc60486e
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_DEVICE_TARGETS_CPP
#define MIGRAPHX_GUARD_DEVICE_TARGETS_CPP
#include <migraphx/config.hpp>
#include <string>
#include <vector>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
#define MIGRAPHX_GPU_TARGETS "@GPU_TARGETS@" // NOLINT
const std::vector<std::string>& get_targets();
std::string get_targets_as_string();
std::string get_device_name();
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_DEVICE_TARGETS_CPP
src/targets/gpu/fuse_mlir.cpp
View file @
fc60486e
...
...
@@ -103,7 +103,10 @@ struct mlir_op
}
if
(
ins
->
name
()
==
"@return"
)
{
return
ins_shapes
[
ins
->
inputs
().
at
(
0
)].
with_type
(
type
);
auto
s
=
ins_shapes
[
ins
->
inputs
().
at
(
0
)].
with_type
(
type
);
if
(
not
s
.
standard
())
MIGRAPHX_THROW
(
"MLIR doesnt support non-standard output"
);
return
s
;
}
std
::
vector
<
shape
>
input_shapes
;
input_shapes
.
resize
(
ins
->
inputs
().
size
());
...
...
@@ -235,8 +238,7 @@ struct find_mlir_fused_ops
"log"
,
"recip"
,
"rsqrt"
,
// There are bugs in MLIR right now for models using sigmoid so disable it for now
// "sigmoid",
"sigmoid"
,
"softmax"
,
"tanh"
,
};
...
...
@@ -300,10 +302,8 @@ struct find_mlir_fused_ops
}
};
struct
find_mlir_standalone_
convolution_
op
struct
find_mlir_standalone_op
{
auto
matcher
()
const
{
return
match
::
name
(
"convolution"
);
}
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
auto
conv_based_op
=
r
.
result
;
...
...
@@ -325,6 +325,16 @@ struct find_mlir_standalone_convolution_op
}
};
struct
find_mlir_standalone_convolution_op
:
find_mlir_standalone_op
{
auto
matcher
()
const
{
return
match
::
name
(
"convolution"
);
}
};
struct
find_mlir_standalone_dot_op
:
find_mlir_standalone_op
{
auto
matcher
()
const
{
return
match
::
name
(
"dot"
);
}
};
/**
* @brief Declares a new MIGraphX environment variable which forces to generate
* only specific MLIR operations.
...
...
@@ -332,7 +342,7 @@ struct find_mlir_standalone_convolution_op
* The variable, if defined, forces MIGraphX to use only specific operations
* with MLIR regardless of the underlying GPU architecture. The variable accepts
* a list of operations separated by comma. The variable recognizes the following
* operations: "fused", "convolution". If the variable is not defined MIGraphX
* operations: "fused", "convolution"
, "dot"
. If the variable is not defined MIGraphX
* will decide by itself which operations to delegate to MLIR. The variable is
* intended to be primarily used by rocMLIR developers.
*/
...
...
@@ -347,31 +357,33 @@ bool is_requested(std::string_view option)
return
contains
(
options
,
option
);
}
bool
is_
fusion_
enabled
()
bool
is_enabled
(
std
::
string_view
op_name
,
context
*
ctx
)
{
if
(
is_self_decide
())
{
return
true
;
}
return
is_requested
(
"fused"
);
}
bool
is_standalone_convs_enabled
(
context
*
ctx
)
{
if
(
is_self_decide
())
{
if
(
ctx
==
nullptr
)
if
(
op_name
==
"fused"
)
{
return
false
;
return
true
;
}
else
if
(
op_name
==
"convolution"
)
{
if
(
ctx
==
nullptr
)
{
return
false
;
}
else
{
const
auto
&
device
=
ctx
->
get_current_device
();
const
std
::
string
navi_family
{
"gfx110"
};
return
starts_with
(
device
.
get_gfx_name
(),
navi_family
);
}
}
else
{
const
auto
&
device
=
ctx
->
get_current_device
();
const
std
::
string
navi_family
{
"gfx110"
};
return
starts_with
(
device
.
get_gfx_name
(),
navi_family
);
return
false
;
}
}
return
is_requested
(
"convolution"
);
return
is_requested
(
op_name
);
}
}
// namespace
...
...
@@ -380,21 +392,25 @@ bool is_standalone_convs_enabled(context* ctx)
void
fuse_mlir
::
apply
(
module_pass_manager
&
mpm
)
const
{
#ifdef MIGRAPHX_MLIR
if
(
is_
fusion_
enabled
())
if
(
is_enabled
(
"fused"
,
this
->
ctx
))
{
match
::
find_matches
(
mpm
,
find_mlir_fused_ops
{});
}
if
(
is_
standalone_convs_enabled
(
this
->
ctx
))
if
(
is_
enabled
(
"convolution"
,
this
->
ctx
))
{
match
::
find_matches
(
mpm
,
find_mlir_standalone_convolution_op
{});
}
if
(
is_enabled
(
"dot"
,
this
->
ctx
))
{
match
::
find_matches
(
mpm
,
find_mlir_standalone_dot_op
{});
}
#else
(
void
)
mpm
;
#endif
}
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/include/migraphx/gpu/compile_hip.hpp
View file @
fc60486e
...
...
@@ -58,6 +58,8 @@ struct hiprtc_src_file
}
};
MIGRAPHX_GPU_EXPORT
bool
hip_has_flags
(
const
std
::
vector
<
std
::
string
>&
flags
);
MIGRAPHX_GPU_EXPORT
std
::
vector
<
std
::
vector
<
char
>>
compile_hip_src_with_hiprtc
(
std
::
vector
<
hiprtc_src_file
>
srcs
,
std
::
string
params
,
const
std
::
string
&
arch
);
...
...
src/targets/gpu/include/migraphx/gpu/mlir.hpp
View file @
fc60486e
...
...
@@ -49,7 +49,8 @@ MIGRAPHX_GPU_EXPORT instruction_ref insert_mlir(module& m,
MIGRAPHX_GPU_EXPORT
tuning_config
get_tuning_config_mlir
(
const
context
&
migraphx_ctx
,
module
m
,
const
std
::
vector
<
shape
>&
inputs
);
const
std
::
vector
<
shape
>&
inputs
,
bool
exhaustive
);
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/targets/gpu/jit/mlir.cpp
View file @
fc60486e
...
...
@@ -57,11 +57,9 @@ struct mlir_compiler : compiler<mlir_compiler>
const
operation
&
,
bool
exhaustive
)
const
{
if
(
not
exhaustive
)
return
nullopt
;
auto
shapes
=
to_shapes
(
ins
->
inputs
());
auto
*
smod
=
ins
->
module_inputs
().
front
();
return
get_tuning_config_mlir
(
ctx
,
*
smod
,
shapes
);
return
get_tuning_config_mlir
(
ctx
,
*
smod
,
shapes
,
exhaustive
);
}
};
...
...
src/targets/gpu/jit/roialign.cpp
View file @
fc60486e
...
...
@@ -81,7 +81,7 @@ struct roialign_compiler : compiler<roialign_compiler>
// coord_trans_mode
auto
ctm
=
v
.
at
(
"coordinate_transformation_mode"
).
to
<
std
::
string
>
();
float
rois_offset
=
(
ctm
==
"
output_
half_pixel"
)
?
-
0.5
f
:
0.0
f
;
float
rois_offset
=
(
ctm
==
"half_pixel"
)
?
-
0.5
f
:
0.0
f
;
options
.
params
+=
" -DROIS_OFFSET="
+
std
::
to_string
(
rois_offset
);
// spatial_scale
...
...
src/targets/gpu/mlir.cpp
View file @
fc60486e
...
...
@@ -682,11 +682,12 @@ struct mlir_program
MIGRAPHX_THROW
(
"Failed setting tuning key: "
+
*
str
);
}
tuning_config
get_tuning_config
()
MIGRAPHX_TIDY_CONST
tuning_config
get_tuning_config
(
bool
exhaustive
)
MIGRAPHX_TIDY_CONST
{
tuning_config
tc
;
run_high_level_pipeline
();
auto
tuning_mode
=
RocmlirTuningParamSetKindFull
;
auto
tuning_mode
=
exhaustive
?
RocmlirTuningParamSetKindFull
:
RocmlirTuningParamSetKindQuick
;
if
(
enabled
(
MIGRAPHX_MLIR_TUNE_EXHAUSTIVE
{}))
tuning_mode
=
RocmlirTuningParamSetKindExhaustive
;
mlir_tuning_space
params
{
mlirRockTuningSpaceCreate
(
mmodule
.
get
(),
tuning_mode
)};
...
...
@@ -914,15 +915,17 @@ instruction_ref insert_mlir(module& m,
return
m
.
insert_instruction
(
ins
,
co
,
refs
);
}
tuning_config
get_tuning_config_mlir
(
const
context
&
migraphx_ctx
,
module
m
,
const
std
::
vector
<
shape
>&
inputs
)
tuning_config
get_tuning_config_mlir
(
const
context
&
migraphx_ctx
,
module
m
,
const
std
::
vector
<
shape
>&
inputs
,
bool
exhaustive
)
{
adjust_param_shapes
(
m
,
inputs
);
mlir_program
mp
;
mp
.
set_gpu_properties
(
migraphx_ctx
);
mp
.
parse
(
m
);
return
mp
.
get_tuning_config
();
return
mp
.
get_tuning_config
(
exhaustive
);
}
#else
...
...
@@ -951,7 +954,7 @@ insert_mlir(module& m, instruction_ref, code_object_op co, const std::vector<ins
return
m
.
end
();
}
tuning_config
get_tuning_config_mlir
(
const
context
&
,
module
,
const
std
::
vector
<
shape
>&
)
tuning_config
get_tuning_config_mlir
(
const
context
&
,
module
,
const
std
::
vector
<
shape
>&
,
bool
)
{
return
{};
}
...
...
Prev
1
2
3
4
5
6
…
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment