Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
a045fb19
Commit
a045fb19
authored
Sep 27, 2023
by
Alan Turner
Browse files
Merge branch 'develop' into ck-flash-attn
parents
135eb63e
434a06cf
Changes
217
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
610 additions
and
191 deletions
+610
-191
src/load_save.cpp
src/load_save.cpp
+21
-0
src/msgpack.cpp
src/msgpack.cpp
+56
-8
src/normalize_attributes.cpp
src/normalize_attributes.cpp
+21
-14
src/onnx/parse_constant.cpp
src/onnx/parse_constant.cpp
+26
-3
src/onnx/parse_pooling.cpp
src/onnx/parse_pooling.cpp
+33
-22
src/onnx/parse_resize.cpp
src/onnx/parse_resize.cpp
+19
-25
src/onnx/parse_roialign.cpp
src/onnx/parse_roialign.cpp
+7
-4
src/optimize_module.cpp
src/optimize_module.cpp
+7
-3
src/pad_calc.cpp
src/pad_calc.cpp
+40
-1
src/program.cpp
src/program.cpp
+1
-1
src/propagate_constant.cpp
src/propagate_constant.cpp
+4
-4
src/shape.cpp
src/shape.cpp
+23
-16
src/simplify_algebra.cpp
src/simplify_algebra.cpp
+96
-33
src/simplify_dyn_ops.cpp
src/simplify_dyn_ops.cpp
+141
-0
src/simplify_reshapes.cpp
src/simplify_reshapes.cpp
+27
-2
src/split_single_dyn_dim.cpp
src/split_single_dyn_dim.cpp
+0
-32
src/targets/cpu/lowering.cpp
src/targets/cpu/lowering.cpp
+1
-1
src/targets/gpu/CMakeLists.txt
src/targets/gpu/CMakeLists.txt
+3
-0
src/targets/gpu/compile_hip.cpp
src/targets/gpu/compile_hip.cpp
+53
-2
src/targets/gpu/compile_hip_code_object.cpp
src/targets/gpu/compile_hip_code_object.cpp
+31
-20
No files found.
src/load_save.cpp
View file @
a045fb19
...
@@ -21,6 +21,7 @@
...
@@ -21,6 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
* THE SOFTWARE.
*/
*/
#include <migraphx/instruction.hpp>
#include <migraphx/load_save.hpp>
#include <migraphx/load_save.hpp>
#include <migraphx/file_buffer.hpp>
#include <migraphx/file_buffer.hpp>
#include <migraphx/json.hpp>
#include <migraphx/json.hpp>
...
@@ -60,9 +61,29 @@ void save(const program& p, const std::string& filename, const file_options& opt
...
@@ -60,9 +61,29 @@ void save(const program& p, const std::string& filename, const file_options& opt
{
{
write_buffer
(
filename
,
save_buffer
(
p
,
options
));
write_buffer
(
filename
,
save_buffer
(
p
,
options
));
}
}
// MIOpen doesn't support serializing fusion plans with Find-2.0 APIs
void
print_miopen_warning
(
const
program
&
p
)
{
auto
mods
=
p
.
get_modules
();
if
(
std
::
any_of
(
mods
.
begin
(),
mods
.
end
(),
[](
const
auto
*
m
)
{
return
std
::
any_of
(
m
->
begin
(),
m
->
end
(),
[](
const
instruction
&
i
)
{
return
i
.
name
()
==
"gpu::miopen_fusion"
;
});
}))
{
std
::
cout
<<
"[WARNING]: Program has miopen_fusion instructions for which tuned solutions "
"are not stored inside serialized MIGraphX program. Consider serializing with "
"MIGRAPHX_DISABLE_MIOPEN_FUSION=1 flag set."
<<
std
::
endl
;
;
}
}
std
::
vector
<
char
>
save_buffer
(
const
program
&
p
,
const
file_options
&
options
)
std
::
vector
<
char
>
save_buffer
(
const
program
&
p
,
const
file_options
&
options
)
{
{
value
v
=
p
.
to_value
();
value
v
=
p
.
to_value
();
print_miopen_warning
(
p
);
std
::
vector
<
char
>
buffer
;
std
::
vector
<
char
>
buffer
;
if
(
options
.
format
==
"msgpack"
)
if
(
options
.
format
==
"msgpack"
)
{
{
...
...
src/msgpack.cpp
View file @
a045fb19
...
@@ -25,6 +25,33 @@
...
@@ -25,6 +25,33 @@
#include <migraphx/serialize.hpp>
#include <migraphx/serialize.hpp>
#include <msgpack.hpp>
#include <msgpack.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
// Leave an extra byte for error checking
constexpr
std
::
size_t
msgpack_size_limit
=
std
::
numeric_limits
<
uint32_t
>::
max
()
-
1
;
template
<
class
Range
>
std
::
size_t
msgpack_chunk_size
(
const
Range
&
r
)
{
return
1
+
(
r
.
size
()
-
1
)
/
msgpack_size_limit
;
}
template
<
class
Iterator
,
class
F
>
void
msgpack_chunk_for_each
(
Iterator
start
,
Iterator
last
,
F
f
)
{
while
(
std
::
distance
(
start
,
last
)
>
msgpack_size_limit
)
{
auto
next
=
std
::
next
(
start
,
msgpack_size_limit
);
f
(
start
,
next
);
start
=
next
;
}
f
(
start
,
last
);
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
namespace
msgpack
{
namespace
msgpack
{
MSGPACK_API_VERSION_NAMESPACE
(
MSGPACK_DEFAULT_API_NS
)
MSGPACK_API_VERSION_NAMESPACE
(
MSGPACK_DEFAULT_API_NS
)
{
{
...
@@ -63,16 +90,31 @@ MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS)
...
@@ -63,16 +90,31 @@ MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS)
break
;
break
;
}
}
case
msgpack
::
type
::
BIN
:
{
case
msgpack
::
type
::
BIN
:
{
// For backwards compatibility
v
=
migraphx
::
value
::
binary
{
o
.
via
.
bin
.
ptr
,
o
.
via
.
bin
.
size
};
v
=
migraphx
::
value
::
binary
{
o
.
via
.
bin
.
ptr
,
o
.
via
.
bin
.
size
};
break
;
break
;
}
}
case
msgpack
::
type
::
ARRAY
:
{
case
msgpack
::
type
::
ARRAY
:
{
migraphx
::
value
r
=
migraphx
::
value
::
array
{};
if
(
o
.
via
.
array
.
size
!=
0
and
o
.
via
.
array
.
ptr
->
type
==
msgpack
::
type
::
BIN
)
std
::
for_each
(
{
o
.
via
.
array
.
ptr
,
auto
bin
=
migraphx
::
value
::
binary
{};
o
.
via
.
array
.
ptr
+
o
.
via
.
array
.
size
,
std
::
for_each
(
[
&
](
const
msgpack
::
object
&
so
)
{
r
.
push_back
(
so
.
as
<
migraphx
::
value
>
());
});
o
.
via
.
array
.
ptr
,
v
=
r
;
o
.
via
.
array
.
ptr
+
o
.
via
.
array
.
size
,
[
&
](
const
msgpack
::
object
&
so
)
{
bin
.
insert
(
bin
.
end
(),
so
.
via
.
bin
.
ptr
,
so
.
via
.
bin
.
ptr
+
so
.
via
.
bin
.
size
);
});
v
=
bin
;
}
else
{
migraphx
::
value
r
=
migraphx
::
value
::
array
{};
std
::
for_each
(
o
.
via
.
array
.
ptr
,
o
.
via
.
array
.
ptr
+
o
.
via
.
array
.
size
,
[
&
](
const
msgpack
::
object
&
so
)
{
r
.
push_back
(
so
.
as
<
migraphx
::
value
>
());
});
v
=
r
;
}
break
;
break
;
}
}
case
msgpack
::
type
::
MAP
:
{
case
msgpack
::
type
::
MAP
:
{
...
@@ -102,8 +144,12 @@ MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS)
...
@@ -102,8 +144,12 @@ MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS)
{
{
const
auto
*
data
=
reinterpret_cast
<
const
char
*>
(
x
.
data
());
const
auto
*
data
=
reinterpret_cast
<
const
char
*>
(
x
.
data
());
auto
size
=
x
.
size
();
auto
size
=
x
.
size
();
o
.
pack_bin
(
size
);
o
.
pack_array
(
migraphx
::
msgpack_chunk_size
(
x
));
o
.
pack_bin_body
(
data
,
size
);
migraphx
::
msgpack_chunk_for_each
(
data
,
data
+
size
,
[
&
](
const
char
*
start
,
const
char
*
last
)
{
o
.
pack_bin
(
last
-
start
);
o
.
pack_bin_body
(
start
,
last
-
start
);
});
return
o
;
return
o
;
}
}
};
};
...
@@ -129,6 +175,8 @@ MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS)
...
@@ -129,6 +175,8 @@ MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS)
o
.
pack_array
(
0
);
o
.
pack_array
(
0
);
return
;
return
;
}
}
if
(
v
.
size
()
>
migraphx
::
msgpack_size_limit
)
MIGRAPHX_THROW
(
"Size is too large for msgpack"
);
if
(
not
v
.
front
().
get_key
().
empty
())
if
(
not
v
.
front
().
get_key
().
empty
())
{
{
o
.
pack_map
(
v
.
size
());
o
.
pack_map
(
v
.
size
());
...
...
src/normalize_attributes.cpp
View file @
a045fb19
...
@@ -26,7 +26,7 @@
...
@@ -26,7 +26,7 @@
#include <migraphx/normalize_attributes.hpp>
#include <migraphx/normalize_attributes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/op/common.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -192,20 +192,27 @@ bool normalize_attributes(operation& op, const shape& input_shape)
...
@@ -192,20 +192,27 @@ bool normalize_attributes(operation& op, const shape& input_shape)
auto
val
=
op
.
to_value
();
auto
val
=
op
.
to_value
();
if
(
attrs
.
contains
(
"normalize_padding"
))
if
(
attrs
.
contains
(
"normalize_padding"
))
{
{
auto
padding
=
val
.
at
(
attrs
.
at
(
"normalize_padding"
).
to
<
std
::
string
>
());
bool
use_auto_padding
=
auto
padding_size
=
padding
.
size
();
(
val
.
contains
(
"padding_mode"
)
and
auto
padding_start
=
2
;
(
val
.
at
(
"padding_mode"
).
to
<
int
>
()
!=
migraphx
::
op
::
padding_mode_t
::
default_
));
if
(
not
use_auto_padding
)
if
(
padding_size
==
2
*
(
input_shape
.
ndim
()
-
padding_start
))
tuned
=
true
;
else
if
(
padding_size
!=
(
input_shape
.
ndim
()
-
padding_start
))
MIGRAPHX_THROW
(
"inconsistent padding size"
);
else
{
{
auto
result
=
tune_pad_attribute
(
padding
);
auto
padding
=
val
.
at
(
attrs
.
at
(
"normalize_padding"
).
to
<
std
::
string
>
());
val
[
"padding"
]
=
result
;
auto
padding_size
=
padding
.
size
();
op
.
from_value
(
val
);
auto
padding_start
=
2
;
tuned
=
true
;
if
(
padding_size
==
2
*
(
input_shape
.
ndim
()
-
padding_start
))
tuned
=
true
;
else
if
(
padding_size
!=
(
input_shape
.
ndim
()
-
padding_start
))
{
MIGRAPHX_THROW
(
"normalize_attributes: inconsistent padding vector size "
);
}
else
{
auto
result
=
tune_pad_attribute
(
padding
);
val
[
"padding"
]
=
result
;
op
.
from_value
(
val
);
tuned
=
true
;
}
}
}
}
}
if
(
not
attrs
.
contains
(
"normalize_axes"
))
if
(
not
attrs
.
contains
(
"normalize_axes"
))
...
...
src/onnx/parse_constant.cpp
View file @
a045fb19
...
@@ -25,6 +25,7 @@
...
@@ -25,6 +25,7 @@
#include <migraphx/ranges.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/stringutils.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -39,16 +40,38 @@ struct parse_constant : op_parser<parse_constant>
...
@@ -39,16 +40,38 @@ struct parse_constant : op_parser<parse_constant>
onnx_parser
::
node_info
info
,
onnx_parser
::
node_info
info
,
const
std
::
vector
<
instruction_ref
>&
/*args*/
)
const
const
std
::
vector
<
instruction_ref
>&
/*args*/
)
const
{
{
literal
v
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"value"
));
static
const
std
::
vector
<
std
::
string
>
attributes
=
{
"value"
,
"value_float"
,
"value_floats"
,
"value_int"
,
"value_ints"
};
std
::
vector
<
std
::
string
>
present_attributes
;
std
::
copy_if
(
attributes
.
begin
(),
attributes
.
end
(),
std
::
back_inserter
(
present_attributes
),
[
&
](
const
std
::
string
&
a
)
{
return
contains
(
info
.
attributes
,
a
);
});
if
(
present_attributes
.
empty
())
{
MIGRAPHX_THROW
(
"Constant node does not contain any supported attribute"
);
}
if
(
present_attributes
.
size
()
>
1
)
{
MIGRAPHX_THROW
(
"Constant contains multiple attributes: "
+
join_strings
(
std
::
move
(
present_attributes
),
", "
));
}
// cppcheck-suppress accessMoved
auto
&&
attr
=
info
.
attributes
[
present_attributes
[
0
]];
literal
v
=
parser
.
parse_value
(
attr
);
// return empty literal
// return empty literal
if
(
v
.
get_shape
().
elements
()
==
0
)
if
(
v
.
get_shape
().
elements
()
==
0
)
{
{
return
info
.
add_literal
(
literal
{
v
.
get_shape
().
type
()});
return
info
.
add_literal
(
literal
{
v
.
get_shape
().
type
()});
}
}
auto
dim_size
=
info
.
attributes
.
at
(
"value"
).
t
().
dims_size
();
// if dim_size is 0, it is a scalar
// if dim_size is 0, it is a scalar
if
(
dim_size
==
0
)
if
(
attr
.
has_t
()
and
attr
.
t
().
dim
s
_size
()
==
0
)
{
{
migraphx
::
shape
scalar_shape
{
v
.
get_shape
().
type
()};
migraphx
::
shape
scalar_shape
{
v
.
get_shape
().
type
()};
return
info
.
add_literal
(
migraphx
::
literal
{
scalar_shape
,
v
.
data
()});
return
info
.
add_literal
(
migraphx
::
literal
{
scalar_shape
,
v
.
data
()});
...
...
src/onnx/parse_pooling.cpp
View file @
a045fb19
/*
/*
* The MIT License (MIT)
* The MIT License (MIT)
*
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* of this software and associated documentation files (the "Software"), to deal
...
@@ -151,26 +151,6 @@ struct parse_pooling : op_parser<parse_pooling>
...
@@ -151,26 +151,6 @@ struct parse_pooling : op_parser<parse_pooling>
kdims
,
paddings
.
size
()
/
2
,
"PARSE_POOLING: inconsistent explicit paddings"
);
kdims
,
paddings
.
size
()
/
2
,
"PARSE_POOLING: inconsistent explicit paddings"
);
}
}
if
(
contains
(
info
.
attributes
,
"auto_pad"
))
{
if
(
in_shape
.
dynamic
())
{
MIGRAPHX_THROW
(
"PARSE_POOLING: Auto padding pooling with dynamic input shape not supported"
);
}
else
{
values
[
"padding"
].
clear
();
// return paddings could be empty, then setting to 0 for no padding
cal_auto_padding_size
(
info
,
values
,
values
[
"lengths"
].
to_vector
<
std
::
size_t
>
(),
{
1
,
1
},
in_shape
.
lens
(),
paddings
);
}
}
if
(
paddings
.
size
()
!=
2
*
kdims
)
if
(
paddings
.
size
()
!=
2
*
kdims
)
{
{
paddings
.
resize
(
kdims
*
2
);
paddings
.
resize
(
kdims
*
2
);
...
@@ -192,6 +172,36 @@ struct parse_pooling : op_parser<parse_pooling>
...
@@ -192,6 +172,36 @@ struct parse_pooling : op_parser<parse_pooling>
// used to calculate the supposed output shape
// used to calculate the supposed output shape
std
::
vector
<
int64_t
>
orig_padding
=
paddings
;
std
::
vector
<
int64_t
>
orig_padding
=
paddings
;
// TODO: add parsing for dilations
if
(
contains
(
info
.
attributes
,
"auto_pad"
)
and
to_upper
(
info
.
attributes
[
"auto_pad"
].
s
())
!=
"NOTSET"
)
{
auto
auto_pad
=
to_upper
(
info
.
attributes
[
"auto_pad"
].
s
());
// don't use the given padding sizes, if any
// values["padding"].clear();
if
(
in_shape
.
dynamic
())
{
// set padding_mode to trigger auto padding at runtime
bool
is_same_upper
=
(
auto_pad
.
find
(
"SAME_UPPER"
)
!=
std
::
string
::
npos
);
values
[
"padding_mode"
]
=
is_same_upper
?
to_value
(
op
::
padding_mode_t
::
same_upper
)
:
to_value
(
op
::
padding_mode_t
::
same_lower
);
}
else
{
// Calculate auto padding
// dilations (argument 4) not supported; default to all 1's
cal_auto_padding_size
(
info
,
values
,
values
[
"lengths"
].
to_vector
<
std
::
size_t
>
(),
std
::
vector
<
size_t
>
(
in_shape
.
ndim
()
-
2
,
1
),
in_shape
.
lens
(),
paddings
);
values
[
"padding"
]
=
paddings
;
// default padding_mode indicates that padding sizes are not calculated dynamically
values
[
"padding_mode"
]
=
migraphx
::
op
::
padding_mode_t
::
default_
;
}
}
std
::
vector
<
int64_t
>
slice_start
;
std
::
vector
<
int64_t
>
slice_start
;
std
::
vector
<
int64_t
>
slice_end
;
std
::
vector
<
int64_t
>
slice_end
;
tune_padding_size
(
values
,
paddings
,
count_include_pad
,
slice_start
);
tune_padding_size
(
values
,
paddings
,
count_include_pad
,
slice_start
);
...
@@ -208,8 +218,9 @@ struct parse_pooling : op_parser<parse_pooling>
...
@@ -208,8 +218,9 @@ struct parse_pooling : op_parser<parse_pooling>
orig_padding
.
insert
(
orig_padding
.
begin
(),
2
,
0
);
orig_padding
.
insert
(
orig_padding
.
begin
(),
2
,
0
);
op
::
pad
pad
{
orig_padding
,
0.0
f
};
op
::
pad
pad
{
orig_padding
,
0.0
f
};
shape
padded_shape
=
pad
.
compute_shape
({
l0
->
get_shape
()});
shape
padded_shape
=
pad
.
compute_shape
({
l0
->
get_shape
()});
auto
out_lens
=
make_op
(
"pooling"
,
values
).
compute_shape
({
padded_shape
}).
lens
();
// make an op just to get its output shape
auto
out_lens
=
make_op
(
"pooling"
,
values
).
compute_shape
({
padded_shape
}).
lens
();
// compute slice_end information
// compute slice_end information
slice_end
.
resize
(
slice_start
.
size
());
slice_end
.
resize
(
slice_start
.
size
());
std
::
transform
(
out_lens
.
begin
()
+
2
,
std
::
transform
(
out_lens
.
begin
()
+
2
,
...
...
src/onnx/parse_resize.cpp
View file @
a045fb19
/*
/*
* The MIT License (MIT)
* The MIT License (MIT)
*
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* of this software and associated documentation files (the "Software"), to deal
...
@@ -97,22 +97,19 @@ const auto& get_original_idx_op(const std::string& mode)
...
@@ -97,22 +97,19 @@ const auto& get_original_idx_op(const std::string& mode)
static
std
::
vector
<
int
>
static
std
::
vector
<
int
>
calc_neighbor_points
(
const
std
::
vector
<
std
::
vector
<
std
::
vector
<
std
::
size_t
>>>&
vvv_ind
,
calc_neighbor_points
(
const
std
::
vector
<
std
::
vector
<
std
::
vector
<
std
::
size_t
>>>&
vvv_ind
,
int
i_dim
,
int
i_dim
,
const
std
::
vector
<
std
::
vector
<
std
::
size_t
>>
&
vec_dims
,
std
::
vector
<
std
::
vector
<
std
::
size_t
>>
vec_dims
,
const
shape
&
in_s
)
const
shape
&
in_s
)
{
{
if
(
i_dim
==
vvv_ind
.
size
())
if
(
i_dim
==
vvv_ind
.
size
())
{
{
std
::
vector
<
int
>
vec_ind
;
std
::
vector
<
int
>
vec_ind
(
vec_dims
.
size
());
vec_ind
.
resize
(
vec_dims
.
size
());
std
::
transform
(
vec_dims
.
begin
(),
vec_dims
.
end
(),
vec_ind
.
begin
(),
[
&
](
auto
idx
)
{
std
::
transform
(
vec_dims
.
begin
(),
vec_dims
.
end
(),
vec_ind
.
begin
(),
[
&
](
auto
idx
)
{
return
static_cast
<
int
>
(
in_s
.
index
(
idx
));
return
static_cast
<
int
>
(
in_s
.
index
(
idx
));
});
});
return
vec_ind
;
return
vec_ind
;
}
}
const
auto
&
vv_ind
=
vvv_ind
[
i_dim
];
const
auto
&
vv_lo
=
vvv_ind
[
i_dim
][
0
];
const
auto
&
vv_lo
=
vv_ind
.
at
(
0
);
std
::
vector
<
std
::
vector
<
std
::
size_t
>>
vec_dims1
;
std
::
vector
<
std
::
vector
<
std
::
size_t
>>
vec_dims1
;
for
(
std
::
size_t
start
=
0
;
start
<
vec_dims
.
size
();
start
+=
vv_lo
.
size
())
for
(
std
::
size_t
start
=
0
;
start
<
vec_dims
.
size
();
start
+=
vv_lo
.
size
())
{
{
...
@@ -126,8 +123,8 @@ calc_neighbor_points(const std::vector<std::vector<std::vector<std::size_t>>>& v
...
@@ -126,8 +123,8 @@ calc_neighbor_points(const std::vector<std::vector<std::vector<std::size_t>>>& v
});
});
}
}
const
auto
&
vv_hi
=
vv_ind
.
at
(
1
)
;
const
auto
&
vv_hi
=
vv
v
_ind
[
i_dim
][
1
]
;
for
(
std
::
size_t
start
=
0
;
start
<
vec_dims
.
size
();
start
+=
vv_
lo
.
size
())
for
(
std
::
size_t
start
=
0
;
start
<
vec_dims
.
size
();
start
+=
vv_
hi
.
size
())
{
{
std
::
transform
(
vv_hi
.
begin
(),
std
::
transform
(
vv_hi
.
begin
(),
vv_hi
.
end
(),
vv_hi
.
end
(),
...
@@ -138,8 +135,8 @@ calc_neighbor_points(const std::vector<std::vector<std::vector<std::size_t>>>& v
...
@@ -138,8 +135,8 @@ calc_neighbor_points(const std::vector<std::vector<std::vector<std::size_t>>>& v
return
dim
;
return
dim
;
});
});
}
}
vec_dims
.
clear
();
return
calc_neighbor_points
(
vvv_ind
,
i_dim
+
1
,
vec_dims1
,
in_s
);
return
calc_neighbor_points
(
vvv_ind
,
i_dim
+
1
,
std
::
move
(
vec_dims1
)
,
in_s
);
}
}
static
std
::
string
get_coord_trans_mode
(
const
onnx_parser
::
attribute_map
&
attr
)
static
std
::
string
get_coord_trans_mode
(
const
onnx_parser
::
attribute_map
&
attr
)
...
@@ -240,7 +237,7 @@ struct parse_resize : op_parser<parse_resize>
...
@@ -240,7 +237,7 @@ struct parse_resize : op_parser<parse_resize>
auto
arg_out_s
=
arg
->
eval
();
auto
arg_out_s
=
arg
->
eval
();
check_arg_empty
(
arg_out_s
,
check_arg_empty
(
arg_out_s
,
"PARSE_"
+
opd
.
op_name
+
": dynamic output size is not supported!"
);
"PARSE_"
+
opd
.
op_name
+
": dynamic output size is not supported!"
);
arg_out_s
.
visit
([
&
](
auto
ol
)
{
out_lens
.
assign
(
ol
.
begin
(),
ol
.
end
());
});
arg_out_s
.
visit
([
&
](
const
auto
&
ol
)
{
out_lens
.
assign
(
ol
.
begin
(),
ol
.
end
());
});
if
(
out_lens
.
size
()
!=
in_lens
.
size
())
if
(
out_lens
.
size
()
!=
in_lens
.
size
())
{
{
...
@@ -267,7 +264,7 @@ struct parse_resize : op_parser<parse_resize>
...
@@ -267,7 +264,7 @@ struct parse_resize : op_parser<parse_resize>
"PARSE_"
+
opd
.
op_name
+
"PARSE_"
+
opd
.
op_name
+
": dynamic input scale is not supported!"
);
": dynamic input scale is not supported!"
);
arg_scale
.
visit
([
&
](
auto
v
)
{
vec_scale
.
assign
(
v
.
begin
(),
v
.
end
());
});
arg_scale
.
visit
([
&
](
const
auto
&
v
)
{
vec_scale
.
assign
(
v
.
begin
(),
v
.
end
());
});
if
(
in_lens
.
size
()
!=
vec_scale
.
size
())
if
(
in_lens
.
size
()
!=
vec_scale
.
size
())
{
{
MIGRAPHX_THROW
(
"PARSE_"
+
opd
.
op_name
+
MIGRAPHX_THROW
(
"PARSE_"
+
opd
.
op_name
+
...
@@ -300,15 +297,15 @@ struct parse_resize : op_parser<parse_resize>
...
@@ -300,15 +297,15 @@ struct parse_resize : op_parser<parse_resize>
// map out_idx to in_idx
// map out_idx to in_idx
auto
nearest_op
=
get_nearest_op
(
nearest_mode
);
auto
nearest_op
=
get_nearest_op
(
nearest_mode
);
shape_for_each
(
out_s
,
[
&
](
auto
idx
)
{
shape_for_each
(
out_s
,
[
&
](
const
auto
&
out_idx_v
,
size_t
out_
idx
)
{
auto
in_idx
=
idx
;
std
::
vector
<
size_t
>
in_idx
(
out_idx_v
.
size
())
;
for
(
auto
ii
=
0
;
ii
<
in_lens
.
size
();
++
ii
)
for
(
auto
ii
=
0
;
ii
<
in_lens
.
size
();
++
ii
)
{
{
auto
idx_val
=
idx_op
(
in_lens
[
ii
],
out_lens
[
ii
],
idx
[
ii
],
vec_scale
[
ii
]);
auto
idx_val
=
idx_op
(
in_lens
[
ii
],
out_lens
[
ii
],
out_
idx
_v
[
ii
],
vec_scale
[
ii
]);
in_idx
[
ii
]
=
nearest_op
(
in_lens
[
ii
],
idx_val
);
in_idx
[
ii
]
=
nearest_op
(
in_lens
[
ii
],
idx_val
);
}
}
ind
[
out_
s
.
index
(
idx
)
]
=
static_cast
<
int64_t
>
(
in_s
.
index
(
in_idx
));
ind
[
out_idx
]
=
static_cast
<
int64_t
>
(
in_s
.
index
(
in_idx
));
});
});
shape
ind_s
{
shape
::
int32_type
,
out_lens
};
shape
ind_s
{
shape
::
int32_type
,
out_lens
};
...
@@ -323,24 +320,21 @@ struct parse_resize : op_parser<parse_resize>
...
@@ -323,24 +320,21 @@ struct parse_resize : op_parser<parse_resize>
// get the number of dimensions
// get the number of dimensions
std
::
size_t
n_dim
=
out_lens
.
size
();
std
::
size_t
n_dim
=
out_lens
.
size
();
std
::
vector
<
std
::
vector
<
std
::
size_t
>>
vv_ind
(
2
,
std
::
vector
<
std
::
size_t
>
(
out_elements
));
auto
vvv_ind
=
std
::
vector
(
n_dim
,
std
::
vector
(
2
,
std
::
vector
<
size_t
>
(
out_elements
)));
std
::
vector
<
std
::
vector
<
std
::
vector
<
std
::
size_t
>>>
vvv_ind
(
n_dim
,
vv_ind
);
std
::
vector
<
std
::
vector
<
float
>>
delta
(
n_dim
,
std
::
vector
<
float
>
(
out_elements
));
std
::
vector
<
std
::
vector
<
float
>>
delta
(
n_dim
,
std
::
vector
<
float
>
(
out_elements
));
shape_for_each
(
out_s
,
[
&
](
auto
idx
)
{
shape_for_each
(
out_s
,
[
&
](
const
auto
&
out_idx_v
,
size_t
out_idx
)
{
auto
in_idx
=
idx
;
auto
out_idx
=
out_s
.
index
(
idx
);
for
(
auto
ii
=
0
;
ii
<
in_lens
.
size
();
++
ii
)
for
(
auto
ii
=
0
;
ii
<
in_lens
.
size
();
++
ii
)
{
{
auto
idx_val
=
idx_op
(
in_lens
[
ii
],
out_lens
[
ii
],
idx
[
ii
],
vec_scale
[
ii
]);
auto
idx_val
=
idx_op
(
in_lens
[
ii
],
out_lens
[
ii
],
out_
idx
_v
[
ii
],
vec_scale
[
ii
]);
vvv_ind
[
ii
][
0
][
out_idx
]
=
nearest_floor
(
in_lens
[
ii
],
idx_val
);
vvv_ind
[
ii
][
0
][
out_idx
]
=
nearest_floor
(
in_lens
[
ii
],
idx_val
);
vvv_ind
[
ii
][
1
][
out_idx
]
=
nearest_ceil
(
in_lens
[
ii
],
idx_val
);
vvv_ind
[
ii
][
1
][
out_idx
]
=
nearest_ceil
(
in_lens
[
ii
],
idx_val
);
delta
[
ii
][
out_idx
]
=
idx_val
-
vvv_ind
[
ii
][
0
][
out_idx
];
delta
[
ii
][
out_idx
]
=
idx_val
-
vvv_ind
[
ii
][
0
][
out_idx
];
}
}
});
});
std
::
vector
<
std
::
vector
<
std
::
size_t
>>
vec_dims
(
out_eleme
nts
);
auto
ind
=
calc_neighbor_poi
nts
(
auto
ind
=
calc_neighbor_points
(
vvv_ind
,
0
,
vec_dims
,
in_s
);
vvv_ind
,
0
,
std
::
vector
<
std
::
vector
<
std
::
size_t
>>
(
out_elements
)
,
in_s
);
auto
ind_lens
=
out_lens
;
auto
ind_lens
=
out_lens
;
ind_lens
[
0
]
*=
(
std
::
size_t
{
1
}
<<
n_dim
);
ind_lens
[
0
]
*=
(
std
::
size_t
{
1
}
<<
n_dim
);
shape
ind_s
{
shape
::
int32_type
,
ind_lens
};
shape
ind_s
{
shape
::
int32_type
,
ind_lens
};
...
...
src/onnx/parse_roialign.cpp
View file @
a045fb19
...
@@ -37,15 +37,18 @@ struct parse_roialign : op_parser<parse_roialign>
...
@@ -37,15 +37,18 @@ struct parse_roialign : op_parser<parse_roialign>
std
::
vector
<
op_desc
>
operators
()
const
{
return
{{
"RoiAlign"
}};
}
std
::
vector
<
op_desc
>
operators
()
const
{
return
{{
"RoiAlign"
}};
}
instruction_ref
parse
(
const
op_desc
&
/*opd*/
,
instruction_ref
parse
(
const
op_desc
&
/*opd*/
,
const
onnx_parser
&
/*
parser
*/
,
const
onnx_parser
&
parser
,
onnx_parser
::
node_info
info
,
onnx_parser
::
node_info
info
,
const
std
::
vector
<
instruction_ref
>&
args
)
const
const
std
::
vector
<
instruction_ref
>&
args
)
const
{
{
std
::
string
coord_trans_mode
=
"half_pixel"
;
std
::
string
coord_trans_mode
=
if
(
contains
(
info
.
attributes
,
"coordinate_transformation_mode"
))
parser
.
opset_version
>=
16
?
"half_pixel"
:
"output_half_pixel"
;
if
(
const
auto
*
a
=
"coordinate_transformation_mode"
;
contains
(
info
.
attributes
,
a
))
{
{
coord_trans_mode
=
info
.
attributes
.
at
(
"coordinate_transformation_mode"
).
s
();
coord_trans_mode
=
info
.
attributes
.
at
(
a
).
s
();
}
}
if
(
not
contains
({
"half_pixel"
,
"output_half_pixel"
},
coord_trans_mode
))
if
(
not
contains
({
"half_pixel"
,
"output_half_pixel"
},
coord_trans_mode
))
{
{
MIGRAPHX_THROW
(
"coordinate_transformation_mode
\"
"
+
coord_trans_mode
+
MIGRAPHX_THROW
(
"coordinate_transformation_mode
\"
"
+
coord_trans_mode
+
...
...
src/optimize_module.cpp
View file @
a045fb19
/*
/*
* The MIT License (MIT)
* The MIT License (MIT)
*
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* of this software and associated documentation files (the "Software"), to deal
...
@@ -36,8 +36,12 @@ void optimize_module::apply(module_pass_manager& mpm) const
...
@@ -36,8 +36,12 @@ void optimize_module::apply(module_pass_manager& mpm) const
{
{
for
(
int
i
=
0
;
i
<
2
;
i
++
)
for
(
int
i
=
0
;
i
<
2
;
i
++
)
{
{
mpm
.
run_pass
(
simplify_reshapes
{});
// loop to further optimize after initial transformations
mpm
.
run_pass
(
simplify_algebra
{});
for
(
int
j
=
0
;
j
<
2
;
j
++
)
{
mpm
.
run_pass
(
simplify_reshapes
{});
mpm
.
run_pass
(
simplify_algebra
{});
}
mpm
.
run_pass
(
eliminate_common_subexpression
{});
mpm
.
run_pass
(
eliminate_common_subexpression
{});
mpm
.
run_pass
(
dead_code_elimination
{});
mpm
.
run_pass
(
dead_code_elimination
{});
mpm
.
run_pass
(
propagate_constant
{});
mpm
.
run_pass
(
propagate_constant
{});
...
...
src/pad_calc.cpp
View file @
a045fb19
/*
/*
* The MIT License (MIT)
* The MIT License (MIT)
*
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* of this software and associated documentation files (the "Software"), to deal
...
@@ -52,6 +52,11 @@ void calculate_padding(int64_t idx,
...
@@ -52,6 +52,11 @@ void calculate_padding(int64_t idx,
}
}
}
}
/**
* Given the input array dimensions; kernel (wei_lens); strides; and dilations,
* calculate the padding value in each dimension.
*
*/
std
::
vector
<
std
::
size_t
>
calc_dyn_auto_pad
(
const
std
::
vector
<
std
::
size_t
>&
input_lens
,
std
::
vector
<
std
::
size_t
>
calc_dyn_auto_pad
(
const
std
::
vector
<
std
::
size_t
>&
input_lens
,
const
std
::
vector
<
std
::
size_t
>&
wei_lens
,
const
std
::
vector
<
std
::
size_t
>&
wei_lens
,
const
std
::
vector
<
std
::
size_t
>&
strides
,
const
std
::
vector
<
std
::
size_t
>&
strides
,
...
@@ -60,6 +65,7 @@ std::vector<std::size_t> calc_dyn_auto_pad(const std::vector<std::size_t>& input
...
@@ -60,6 +65,7 @@ std::vector<std::size_t> calc_dyn_auto_pad(const std::vector<std::size_t>& input
{
{
std
::
vector
<
std
::
size_t
>
padding
;
std
::
vector
<
std
::
size_t
>
padding
;
assert
(
input_lens
.
size
()
>=
3
);
assert
(
input_lens
.
size
()
>=
3
);
assert
(
input_lens
.
size
()
==
wei_lens
.
size
());
std
::
size_t
num_spatial_dims
=
input_lens
.
size
()
-
2
;
std
::
size_t
num_spatial_dims
=
input_lens
.
size
()
-
2
;
padding
.
resize
(
2
*
num_spatial_dims
);
padding
.
resize
(
2
*
num_spatial_dims
);
for
(
std
::
size_t
i
=
0
;
i
<
num_spatial_dims
;
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
num_spatial_dims
;
i
++
)
...
@@ -88,6 +94,11 @@ std::vector<std::size_t> calc_dyn_auto_pad(const std::vector<std::size_t>& input
...
@@ -88,6 +94,11 @@ std::vector<std::size_t> calc_dyn_auto_pad(const std::vector<std::size_t>& input
return
padding
;
return
padding
;
}
}
/**
* Calculate the correct output shape for a convolution with
* a given input size and other parameters.
*
*/
shape
compute_padded_shape
(
const
shape
&
input
,
shape
compute_padded_shape
(
const
shape
&
input
,
const
shape
&
weights
,
const
shape
&
weights
,
const
std
::
vector
<
std
::
size_t
>&
padding
,
const
std
::
vector
<
std
::
size_t
>&
padding
,
...
@@ -111,5 +122,33 @@ shape compute_padded_shape(const shape& input,
...
@@ -111,5 +122,33 @@ shape compute_padded_shape(const shape& input,
return
input
.
with_lens
(
output_lens
);
return
input
.
with_lens
(
output_lens
);
}
}
/**
* Calculate the correct output shape for a pooling with
* a given input size and other parameters. This uses
* the same formula for pooling that compute_padded_shape() uses
* for convolutions, but takes slightly different inputs.
*
*/
shape
compute_padded_pool_shape
(
const
shape
&
input
,
const
shape
&
kernel
,
const
std
::
vector
<
std
::
size_t
>&
padding
,
const
std
::
vector
<
std
::
size_t
>&
stride
,
const
std
::
vector
<
std
::
size_t
>&
dilation
)
{
const
size_t
num_spatial_dims
=
input
.
lens
().
size
()
-
2
;
std
::
vector
<
size_t
>
output_lens
{
input
.
lens
()[
0
],
input
.
lens
()[
1
]};
// calculate the output shape of the pooling: ((W - K + 2P) / S) + 1
for
(
size_t
i
=
0
;
i
<
num_spatial_dims
;
++
i
)
{
auto
padding_factor
=
padding
[
i
]
+
padding
[
i
+
num_spatial_dims
];
output_lens
.
push_back
(
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
1
,
(
input
.
lens
()[
i
+
2
]
-
(
1
+
dilation
[
i
]
*
(
kernel
.
lens
()[
i
]
-
1
))
+
padding_factor
)
/
stride
[
i
]
+
1
)));
}
return
input
.
with_lens
(
output_lens
);
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
src/program.cpp
View file @
a045fb19
...
@@ -624,7 +624,7 @@ std::string get_migraphx_version()
...
@@ -624,7 +624,7 @@ std::string get_migraphx_version()
program file version is for the data structure or format of the MXR file. Version should be bumped
program file version is for the data structure or format of the MXR file. Version should be bumped
if any changes occur to the format of the MXR file.
if any changes occur to the format of the MXR file.
*/
*/
const
int
program_file_version
=
6
;
const
int
program_file_version
=
7
;
value
program
::
to_value
()
const
value
program
::
to_value
()
const
{
{
...
...
src/propagate_constant.cpp
View file @
a045fb19
/*
/*
* The MIT License (MIT)
* The MIT License (MIT)
*
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* of this software and associated documentation files (the "Software"), to deal
...
@@ -35,10 +35,10 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -35,10 +35,10 @@ inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_TRACE_PROPAGATE_CONSTANT
)
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_TRACE_PROPAGATE_CONSTANT
)
bool
skip_prop
o
gate
(
instruction_ref
ins
)
bool
skip_prop
a
gate
(
instruction_ref
ins
)
{
{
if
(
ins
->
name
()
==
"contiguous"
)
if
(
ins
->
name
()
==
"contiguous"
)
return
skip_prop
o
gate
(
ins
->
inputs
().
front
());
return
skip_prop
a
gate
(
ins
->
inputs
().
front
());
auto
&&
s
=
ins
->
get_shape
();
auto
&&
s
=
ins
->
get_shape
();
if
(
s
.
broadcasted
()
and
not
s
.
scalar
())
if
(
s
.
broadcasted
()
and
not
s
.
scalar
())
return
true
;
return
true
;
...
@@ -47,7 +47,7 @@ bool skip_propogate(instruction_ref ins)
...
@@ -47,7 +47,7 @@ bool skip_propogate(instruction_ref ins)
return
false
;
return
false
;
}
}
bool
is_const_ins
(
instruction_ref
ins
)
{
return
ins
->
can_eval
()
and
not
skip_prop
o
gate
(
ins
);
}
bool
is_const_ins
(
instruction_ref
ins
)
{
return
ins
->
can_eval
()
and
not
skip_prop
a
gate
(
ins
);
}
void
propagate_constant
::
apply
(
module
&
m
)
const
void
propagate_constant
::
apply
(
module
&
m
)
const
{
{
...
...
src/shape.cpp
View file @
a045fb19
...
@@ -50,13 +50,14 @@ struct shape_impl
...
@@ -50,13 +50,14 @@ struct shape_impl
{
{
assert
(
t
!=
shape
::
tuple_type
);
assert
(
t
!=
shape
::
tuple_type
);
}
}
shape_impl
(
shape
::
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
)
shape_impl
(
shape
::
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
)
:
m_type
(
t
),
m_lens
(
std
::
move
(
l
)),
m_standard
(
true
)
:
m_type
(
t
),
m_lens
(
std
::
move
(
l
)),
m_standard
(
true
)
{
{
assert
(
t
!=
shape
::
tuple_type
);
assert
(
t
!=
shape
::
tuple_type
);
this
->
calculate_strides
();
this
->
calculate_strides
();
assert
(
m_lens
.
size
()
==
m_strides
.
size
());
}
}
shape_impl
(
shape
::
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
,
std
::
vector
<
std
::
size_t
>
s
)
shape_impl
(
shape
::
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
,
std
::
vector
<
std
::
size_t
>
s
)
:
m_type
(
t
),
m_lens
(
std
::
move
(
l
)),
m_strides
(
std
::
move
(
s
))
:
m_type
(
t
),
m_lens
(
std
::
move
(
l
)),
m_strides
(
std
::
move
(
s
))
{
{
...
@@ -151,6 +152,22 @@ struct shape_impl
...
@@ -151,6 +152,22 @@ struct shape_impl
m_lens
.
begin
(),
m_lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
m_lens
.
begin
(),
m_lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
}
}
std
::
size_t
get_index
(
size_t
i
)
const
{
std
::
size_t
result
=
0
;
std
::
size_t
s
=
1
;
for
(
auto
k
:
migraphx
::
reverse
(
migraphx
::
range
(
m_lens
.
size
())))
{
std
::
size_t
stride
=
m_strides
[
k
];
std
::
size_t
len
=
m_lens
[
k
];
std
::
size_t
idx
=
(
i
%
(
s
*
len
))
/
s
;
result
+=
stride
*
idx
;
s
*=
len
;
}
return
result
;
}
std
::
vector
<
std
::
size_t
>
min_lens
()
const
std
::
vector
<
std
::
size_t
>
min_lens
()
const
{
{
std
::
vector
<
std
::
size_t
>
ret
(
m_dyn_dims
.
size
());
std
::
vector
<
std
::
size_t
>
ret
(
m_dyn_dims
.
size
());
...
@@ -213,6 +230,7 @@ std::string shape::name(shape::type_t t)
...
@@ -213,6 +230,7 @@ std::string shape::name(shape::type_t t)
}
}
MIGRAPHX_THROW
(
"Invalid type"
);
MIGRAPHX_THROW
(
"Invalid type"
);
}
}
std
::
string
shape
::
cpp_type
(
shape
::
type_t
t
)
std
::
string
shape
::
cpp_type
(
shape
::
type_t
t
)
{
{
switch
(
t
)
switch
(
t
)
...
@@ -229,10 +247,12 @@ std::string shape::cpp_type(shape::type_t t)
...
@@ -229,10 +247,12 @@ std::string shape::cpp_type(shape::type_t t)
shape
::
shape
()
:
impl
(
shape_impl
::
default_shape
())
{}
shape
::
shape
()
:
impl
(
shape_impl
::
default_shape
())
{}
shape
::
shape
(
type_t
t
)
:
impl
(
std
::
make_shared
<
shape_impl
>
(
t
))
{}
shape
::
shape
(
type_t
t
)
:
impl
(
std
::
make_shared
<
shape_impl
>
(
t
))
{}
shape
::
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
)
shape
::
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
)
:
impl
(
std
::
make_shared
<
shape_impl
>
(
t
,
std
::
move
(
l
)))
:
impl
(
std
::
make_shared
<
shape_impl
>
(
t
,
std
::
move
(
l
)))
{
{
}
}
shape
::
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
,
std
::
vector
<
std
::
size_t
>
s
)
shape
::
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
,
std
::
vector
<
std
::
size_t
>
s
)
:
impl
(
std
::
make_shared
<
shape_impl
>
(
t
,
std
::
move
(
l
),
std
::
move
(
s
)))
:
impl
(
std
::
make_shared
<
shape_impl
>
(
t
,
std
::
move
(
l
),
std
::
move
(
s
)))
{
{
...
@@ -358,21 +378,8 @@ std::size_t shape::index(std::size_t i) const
...
@@ -358,21 +378,8 @@ std::size_t shape::index(std::size_t i) const
assert
(
this
->
lens
().
size
()
==
this
->
strides
().
size
());
assert
(
this
->
lens
().
size
()
==
this
->
strides
().
size
());
if
(
this
->
standard
())
if
(
this
->
standard
())
return
i
;
return
i
;
else
{
return
impl
->
get_index
(
i
);
std
::
size_t
s
=
1
;
std
::
size_t
result
=
0
;
for
(
std
::
size_t
j
=
0
;
j
<
this
->
lens
().
size
();
j
++
)
{
const
std
::
size_t
k
=
this
->
lens
().
size
()
-
j
-
1
;
const
std
::
size_t
stride
=
this
->
strides
()[
k
];
const
std
::
size_t
len
=
this
->
lens
()[
k
];
const
std
::
size_t
idx
=
(
i
%
(
s
*
len
))
/
s
;
result
+=
stride
*
idx
;
s
*=
len
;
}
return
result
;
}
}
}
std
::
vector
<
std
::
size_t
>
shape
::
multi
(
std
::
size_t
idx
)
const
std
::
vector
<
std
::
size_t
>
shape
::
multi
(
std
::
size_t
idx
)
const
...
...
src/simplify_algebra.cpp
View file @
a045fb19
...
@@ -1325,48 +1325,59 @@ struct find_split_reshape
...
@@ -1325,48 +1325,59 @@ struct find_split_reshape
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
{
auto
slc
=
r
.
instructions
[
"slice"
];
auto
slc
=
r
.
instructions
[
"slice"
];
auto
rsp
=
r
.
instructions
[
"reshape"
];
auto
rsp
=
r
.
instructions
[
"reshape"
];
auto
input
=
slc
->
inputs
().
front
();
// Only apply simplification when slices are on a single axis
auto
axes
=
any_cast
<
op
::
slice
>
(
slc
->
get_operator
()).
axes
;
if
(
axes
.
size
()
>
1
)
{
return
;
}
auto
input
=
slc
->
inputs
().
front
();
auto
split_outputs
=
get_splits
(
input
);
auto
split_outputs
=
get_splits
(
input
);
if
(
split_outputs
.
empty
())
if
(
split_outputs
.
empty
())
{
{
return
;
return
;
}
}
// Only want to apply this optimization if each split output is followed by
// Find all the reshapes (similar to rsp) that can be simplified
// a contiguous op and a reshape
std
::
vector
<
instruction_ref
>
conts
;
if
(
std
::
any_of
(
split_outputs
.
begin
(),
split_outputs
.
end
(),
[](
auto
i
)
{
std
::
vector
<
instruction_ref
>
vec_rsp
;
if
(
i
->
outputs
().
size
()
==
1
)
{
// Iterate through slice and contiguous outputs to allow simplifications when
auto
cont
=
i
->
outputs
().
front
();
// slice is followed by multiple reshapes
return
cont
->
outputs
().
size
()
!=
1
;
for
(
auto
&
i
:
split_outputs
)
}
return
false
;
}))
{
{
return
;
std
::
copy_if
(
i
->
outputs
().
begin
(),
i
->
outputs
().
end
(),
std
::
back_inserter
(
conts
),
[](
auto
j
)
{
return
j
->
name
()
==
"contiguous"
;
});
}
}
std
::
vector
<
instruction_ref
>
vec_rsp
(
split_outputs
.
size
());
for
(
auto
&
i
:
conts
)
std
::
transform
(
split_outputs
.
begin
(),
split_outputs
.
end
(),
vec_rsp
.
begin
(),
[](
auto
i
)
{
{
auto
cont
=
i
->
outputs
().
front
();
std
::
copy_if
(
i
->
outputs
().
begin
(),
return
cont
->
outputs
().
front
();
i
->
outputs
().
end
(),
});
std
::
back_inserter
(
vec_rsp
),
[
&
](
auto
j
)
{
return
j
->
get_operator
()
==
rsp
->
get_operator
();
});
}
// all outputs are reshape and of the same shape
// No simplification needed if there is only one slice -> cont -> reshape
auto
dims
=
any_cast
<
op
::
reshape
>
(
rsp
->
get_operator
()).
dims
;
if
(
vec_rsp
.
size
()
<=
1
)
if
(
not
same_ops
(
vec_rsp
))
{
{
return
;
return
;
}
}
// ensure reshape happens after the axis dimension
// ensure reshape happens after the axis dimension
auto
axis
=
any_cast
<
op
::
slice
>
(
slc
->
get_operator
()).
axes
[
0
];
auto
axis
=
axes
[
0
];
auto
slc_lens
=
slc
->
get_shape
().
lens
();
auto
slc_lens
=
slc
->
get_shape
().
lens
();
auto
slc_dim_size
=
std
::
accumulate
(
auto
slc_dim_size
=
std
::
accumulate
(
slc_lens
.
begin
()
+
axis
,
slc_lens
.
end
(),
1
,
std
::
multiplies
<
std
::
size_t
>
());
slc_lens
.
begin
()
+
axis
,
slc_lens
.
end
(),
1
,
std
::
multiplies
<
std
::
size_t
>
());
auto
input_lens
=
input
->
get_shape
().
lens
();
auto
input_size
=
input
->
get_shape
().
elements
();
auto
slc_axis_len
=
input_lens
[
axis
];
// search the reshape output (standard shape) to decide which axis are
// search the reshape output (standard shape) to decide which axis are
// in its output corresponding to the slc_dim_size
// in its output corresponding to the slc_dim_size
...
@@ -1393,16 +1404,67 @@ struct find_split_reshape
...
@@ -1393,16 +1404,67 @@ struct find_split_reshape
{
{
rsp_axis
=
std
::
distance
(
rsp_strides
.
begin
(),
ait
);
rsp_axis
=
std
::
distance
(
rsp_strides
.
begin
(),
ait
);
}
}
// calculate reshape output shape
std
::
vector
<
int64_t
>
vec_dims
(
vec_rsp
.
size
());
std
::
transform
(
vec_rsp
.
begin
(),
vec_rsp
.
end
(),
vec_dims
.
begin
(),
[
&
](
auto
is
)
{
// Calculate reshape output shape
return
is
->
get_shape
().
lens
()[
rsp_axis
];
// Need to find a reshape such that data represented by instructions in vec_rsp can be
});
// written as slices of this new reshape. This is done by holding all the dims constant in
// rsp_lens to compute the required dim for rsp_axis (axis that will be sliced)
// ex 1: Input Shape: {2, 12, 4}, Slice Axis: 1, Slices are: (0:4), (4:8), (8:12),
// Reshape Outputs: {2, 2, 2, 4}, {2, 2, 2, 4}, {2, 2, 2, 4}
// rsp_axis = 1, rsp_out_lens (initial) = {2, 1, 2, 4}, rsp_fixed_size = 2*1*2*4 = 16
// rsp_axis_len = 2*12*4 / 16 = 6
// rsp_out_lens (final) = {2, 6, 2, 4}
// ex 2: Input Shape: {2, 12, 4}, Slice Axis: 1, Slices are: (0:4), (4:8), (8:12),
// Reshape Outputs: {2, 16}, {2, 16}, {2, 16}
// rsp_axis = 1, rsp_out_lens (initial) = {2, 1}, rsp_fixed_size = 2*1 = 2
// rsp_axis_len = 2*12*4 / 2 = 48
// rsp_out_lens (final) = {2, 48}
std
::
vector
<
int64_t
>
rsp_out_lens
(
rsp_lens
.
begin
(),
rsp_lens
.
end
());
std
::
vector
<
int64_t
>
rsp_out_lens
(
rsp_lens
.
begin
(),
rsp_lens
.
end
());
rsp_out_lens
[
rsp_axis
]
=
1
;
auto
rsp_fixed_size
=
std
::
accumulate
(
rsp_out_lens
.
begin
(),
rsp_out_lens
.
end
(),
1
,
std
::
multiplies
<
std
::
size_t
>
());
rsp_out_lens
[
rsp_axis
]
=
std
::
accumulate
(
vec_dims
.
begin
(),
vec_dims
.
end
(),
std
::
int64_t
{
0
});
// cannot create a valid reshape for simplification
if
(
input_size
%
rsp_fixed_size
!=
0
)
{
return
;
}
auto
rsp_axis_len
=
input_size
/
rsp_fixed_size
;
rsp_out_lens
[
rsp_axis
]
=
rsp_axis_len
;
// Calculate new slice start and end indices. Indices are scaled using the new reshape axis
// and the original slice axis. See examples:
// ex 1: Input Shape: {2, 12, 4}, Slice Axis: 1, Slices are: (0:4), (4:8), (8:12),
// Reshape Outputs: {2, 2, 2, 4}, {2, 2, 2, 4}, {2, 2, 2, 4}
// slc_axis_len = 12, rsp_axis_len = 6
// New Starts: {0*6/12, 4*6/12, 8*6/12} = {0, 2, 4}
// New Ends: {4*6/12, 8*6/12, 12*6/12} = {2, 4, 6}
// ex 2: Input Shape: {2, 12, 4}, Slice Axis: 1, Slices are: (0:4), (4:8), (8:12),
// Reshape Outputs: {2, 16}, {2, 16}, {2, 16}
// slc_axis_len = 12, rsp_axis_len = 48
// New Starts: {0*48/12, 4*48/12, 8*48/12} = { 0, 16, 32}
// New Ends: {4*48/12, 8*48/12, 12*48/12} = {16, 32, 48}
std
::
vector
<
int64_t
>
new_starts
(
vec_rsp
.
size
());
std
::
transform
(
vec_rsp
.
begin
(),
vec_rsp
.
end
(),
new_starts
.
begin
(),
[
&
](
auto
is
)
{
auto
cont
=
is
->
inputs
().
front
();
auto
og_slc
=
cont
->
inputs
().
front
();
return
any_cast
<
op
::
slice
>
(
og_slc
->
get_operator
()).
starts
[
0
]
*
rsp_axis_len
/
slc_axis_len
;
});
std
::
vector
<
int64_t
>
new_ends
(
vec_rsp
.
size
());
std
::
transform
(
vec_rsp
.
begin
(),
vec_rsp
.
end
(),
new_ends
.
begin
(),
[
&
](
auto
is
)
{
auto
cont
=
is
->
inputs
().
front
();
auto
og_slc
=
cont
->
inputs
().
front
();
return
any_cast
<
op
::
slice
>
(
og_slc
->
get_operator
()).
ends
[
0
]
*
rsp_axis_len
/
slc_axis_len
;
});
// insert the reshape instruction and add contiguous if needed
// insert the reshape instruction and add contiguous if needed
if
(
not
input
->
get_shape
().
standard
())
if
(
not
input
->
get_shape
().
standard
())
...
@@ -1413,16 +1475,14 @@ struct find_split_reshape
...
@@ -1413,16 +1475,14 @@ struct find_split_reshape
std
::
next
(
input
),
make_op
(
"reshape"
,
{{
"dims"
,
rsp_out_lens
}}),
input
);
std
::
next
(
input
),
make_op
(
"reshape"
,
{{
"dims"
,
rsp_out_lens
}}),
input
);
// replace the original reshape with slice
// replace the original reshape with slice
int64_t
start
=
0
;
for
(
std
::
size_t
i
=
0
;
i
<
vec_rsp
.
size
();
++
i
)
for
(
std
::
size_t
i
=
0
;
i
<
vec_rsp
.
size
();
++
i
)
{
{
m
.
replace_instruction
(
m
.
replace_instruction
(
vec_rsp
[
i
],
vec_rsp
[
i
],
make_op
(
make_op
(
"slice"
,
"slice"
,
{{
"axes"
,
{
rsp_axis
}},
{
"starts"
,
{
start
}},
{
"ends"
,
{
start
+
vec_dim
s
[
i
]}}}),
{{
"axes"
,
{
rsp_axis
}},
{
"starts"
,
{
new_
start
s
[
i
]
}},
{
"ends"
,
{
new_end
s
[
i
]}}}),
rsp_ins
);
rsp_ins
);
start
+=
vec_dims
[
i
];
}
}
}
}
};
};
...
@@ -1446,10 +1506,13 @@ struct find_split_transpose
...
@@ -1446,10 +1506,13 @@ struct find_split_transpose
{
{
return
;
return
;
}
}
if
(
std
::
any_of
(
split_outputs
.
begin
(),
split_outputs
.
end
(),
[](
auto
i
)
{
return
i
->
outputs
().
size
()
!=
1
;
}))
return
;
std
::
vector
<
instruction_ref
>
vec_trans
(
split_outputs
.
size
());
std
::
vector
<
instruction_ref
>
vec_trans
(
split_outputs
.
size
());
std
::
transform
(
split_outputs
.
begin
(),
split_outputs
.
end
(),
vec_trans
.
begin
(),
[](
auto
i
)
{
std
::
transform
(
split_outputs
.
begin
(),
split_outputs
.
end
(),
vec_trans
.
begin
(),
[](
auto
i
)
{
assert
(
i
->
outputs
().
size
()
==
1
);
return
i
->
outputs
().
front
();
return
i
->
outputs
().
front
();
});
});
...
...
src/simplify_dyn_ops.cpp
0 → 100644
View file @
a045fb19
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/simplify_dyn_ops.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/make_op.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
/**
* Convert 2 input static shape broadcast/multibroadcast into 1 input version.
* Some compiler passes (ex. simplify_algebra) only support the 1 input versions
* of the broadcasting operators.
*/
struct
find_static_2in_broadcasts
{
auto
matcher
()
const
{
return
match
::
broadcast
(
match
::
nargs
(
2
),
match
::
arg
(
0
)(
match
::
static_shape
()),
match
::
arg
(
1
)(
match
::
static_shape
()));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
mr
)
const
{
auto
ins
=
mr
.
result
;
auto
out_lens
=
ins
->
get_shape
().
lens
();
auto
broadcast_op
=
ins
->
get_operator
();
if
(
broadcast_op
.
name
()
==
"broadcast"
)
{
broadcast_op
.
from_value
({{
"out_lens"
,
out_lens
}});
}
else
{
broadcast_op
.
from_value
({{
"out_lens"
,
out_lens
},
{
"out_dyn_dims"
,
{}}});
}
m
.
replace_instruction
(
ins
,
broadcast_op
,
ins
->
inputs
().
at
(
0
));
}
};
/**
* Simplify slice with variable `starts` and `ends` to the constant version if
* the `input_starts` and `input_ends` inputs are constant.
*/
struct
find_const_3in_slice
{
auto
matcher
()
const
{
return
match
::
name
(
"slice"
)(
match
::
nargs
(
3
),
match
::
arg
(
1
)(
match
::
is_constant
()),
match
::
arg
(
2
)(
match
::
is_constant
()));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
mr
)
const
{
auto
ins
=
mr
.
result
;
auto
inputs
=
ins
->
inputs
();
argument
starts_arg
=
inputs
.
at
(
1
)
->
eval
();
argument
ends_arg
=
inputs
.
at
(
2
)
->
eval
();
if
(
not
starts_arg
.
empty
()
and
not
ends_arg
.
empty
())
{
std
::
vector
<
int64_t
>
starts_vec
;
std
::
vector
<
int64_t
>
ends_vec
;
starts_arg
.
visit
([
&
](
auto
output
)
{
starts_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
ends_arg
.
visit
([
&
](
auto
output
)
{
ends_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
auto
slice_val
=
ins
->
get_operator
().
to_value
();
auto
axes_vec
=
slice_val
.
at
(
"axes"
).
to_vector
<
int64_t
>
();
m
.
replace_instruction
(
ins
,
make_op
(
"slice"
,
{{
"starts"
,
starts_vec
},
{
"ends"
,
ends_vec
},
{
"axes"
,
axes_vec
}}),
inputs
.
at
(
0
));
}
}
};
/**
* Simplify slice with variable `starts`, `ends`, and `input_axes` to the constant version if
* the `input_starts`, `input_ends`, and `input_axes` inputs are constant.
*/
struct
find_const_4in_slice
{
auto
matcher
()
const
{
return
match
::
name
(
"slice"
)(
match
::
nargs
(
4
),
match
::
arg
(
1
)(
match
::
is_constant
()),
match
::
arg
(
2
)(
match
::
is_constant
()),
match
::
arg
(
3
)(
match
::
is_constant
()));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
mr
)
const
{
auto
ins
=
mr
.
result
;
auto
inputs
=
ins
->
inputs
();
argument
starts_arg
=
inputs
.
at
(
1
)
->
eval
();
argument
ends_arg
=
inputs
.
at
(
2
)
->
eval
();
argument
axes_arg
=
inputs
.
at
(
3
)
->
eval
();
if
(
not
starts_arg
.
empty
()
and
not
ends_arg
.
empty
()
and
not
axes_arg
.
empty
())
{
std
::
vector
<
int64_t
>
starts_vec
;
std
::
vector
<
int64_t
>
ends_vec
;
std
::
vector
<
int64_t
>
axes_vec
;
starts_arg
.
visit
([
&
](
auto
output
)
{
starts_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
ends_arg
.
visit
([
&
](
auto
output
)
{
ends_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
axes_arg
.
visit
([
&
](
auto
output
)
{
axes_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
m
.
replace_instruction
(
ins
,
make_op
(
"slice"
,
{{
"starts"
,
starts_vec
},
{
"ends"
,
ends_vec
},
{
"axes"
,
axes_vec
}}),
inputs
.
at
(
0
));
}
}
};
void
simplify_dyn_ops
::
apply
(
module
&
m
)
const
{
match
::
find_matches
(
m
,
find_static_2in_broadcasts
{},
find_const_3in_slice
{},
find_const_4in_slice
{});
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/simplify_reshapes.cpp
View file @
a045fb19
/*
/*
* The MIT License (MIT)
* The MIT License (MIT)
*
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* of this software and associated documentation files (the "Software"), to deal
...
@@ -627,6 +627,30 @@ struct find_transpose_contiguous_reshaper_unary
...
@@ -627,6 +627,30 @@ struct find_transpose_contiguous_reshaper_unary
}
}
};
};
struct
find_broadcast_transpose
{
auto
matcher
()
const
{
return
match
::
name
(
"transpose"
)(
match
::
arg
(
0
)(
match
::
name
(
"multibroadcast"
).
bind
(
"bcast_ins"
)));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
ins_lens
=
ins
->
get_shape
().
lens
();
auto
bcast_ins
=
r
.
instructions
[
"bcast_ins"
];
auto
input
=
bcast_ins
->
inputs
().
front
();
// for now, focusing on scalar transformation
if
(
not
input
->
get_shape
().
scalar
())
return
;
auto
new_mbcast
=
m
.
insert_instruction
(
bcast_ins
,
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
ins_lens
}}),
input
);
m
.
replace_instruction
(
ins
,
new_mbcast
);
}
};
struct
find_slice_transpose
struct
find_slice_transpose
{
{
auto
matcher
()
const
auto
matcher
()
const
...
@@ -784,7 +808,7 @@ struct find_transpose_slice
...
@@ -784,7 +808,7 @@ struct find_transpose_slice
void
simplify_reshapes
::
apply
(
module
&
m
)
const
void
simplify_reshapes
::
apply
(
module
&
m
)
const
{
{
for
(
int
i
=
0
;
i
<
4
;
i
++
)
for
(
int
i
=
0
;
i
<
depth
;
i
++
)
{
{
match
::
find_matches
(
m
,
match
::
find_matches
(
m
,
find_where_op
{},
find_where_op
{},
...
@@ -799,6 +823,7 @@ void simplify_reshapes::apply(module& m) const
...
@@ -799,6 +823,7 @@ void simplify_reshapes::apply(module& m) const
find_nested_slice
{},
find_nested_slice
{},
find_nested_concat
{},
find_nested_concat
{},
find_transpose_slice
{},
find_transpose_slice
{},
find_broadcast_transpose
{},
find_slice_transpose
{},
find_slice_transpose
{},
find_transpose_contiguous_reshaper_unary
{});
find_transpose_contiguous_reshaper_unary
{});
dead_code_elimination
{}.
apply
(
m
);
dead_code_elimination
{}.
apply
(
m
);
...
...
src/split_single_dyn_dim.cpp
View file @
a045fb19
...
@@ -68,37 +68,6 @@ has_one_dyn_dim(const std::unordered_map<std::string, shape>& param_shapes)
...
@@ -68,37 +68,6 @@ has_one_dyn_dim(const std::unordered_map<std::string, shape>& param_shapes)
dds_it
->
max
};
dds_it
->
max
};
}
}
namespace
{
struct
find_static_2in_broadcasts
{
// Convert 2 input static shape broadcast/multibroadcast into 1 input version.
// Some compiler passes (ex. simplify_algebra) only support the 1 input versions
// of the broadcasting operators.
auto
matcher
()
const
{
return
match
::
broadcast
(
match
::
nargs
(
2
),
match
::
arg
(
0
)(
match
::
static_shape
()),
match
::
arg
(
1
)(
match
::
static_shape
()));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
mr
)
const
{
auto
ins
=
mr
.
result
;
auto
out_lens
=
ins
->
get_shape
().
lens
();
auto
broadcast_op
=
ins
->
get_operator
();
if
(
broadcast_op
.
name
()
==
"broadcast"
)
{
broadcast_op
.
from_value
({{
"out_lens"
,
out_lens
}});
}
else
{
broadcast_op
.
from_value
({{
"out_lens"
,
out_lens
},
{
"out_dyn_dims"
,
{}}});
}
m
.
replace_instruction
(
ins
,
broadcast_op
,
ins
->
inputs
().
at
(
0
));
}
};
}
// namespace
/**
/**
* Makes all the shapes in the dynamic_dimension range. Probably won't work for `if`
* Makes all the shapes in the dynamic_dimension range. Probably won't work for `if`
* and `loop` instructions, depending on how the submodules for those
* and `loop` instructions, depending on how the submodules for those
...
@@ -135,7 +104,6 @@ void split_single_dyn_dim::apply(module_pass_manager& mpm) const
...
@@ -135,7 +104,6 @@ void split_single_dyn_dim::apply(module_pass_manager& mpm) const
dd_check
->
dyn_param_str
,
migraphx
::
shape
{
dyn_param_shape
.
type
(),
static_lens
});
dd_check
->
dyn_param_str
,
migraphx
::
shape
{
dyn_param_shape
.
type
(),
static_lens
});
auto
outputs
=
submod
->
add_instructions
(
mm
,
map_ins
);
auto
outputs
=
submod
->
add_instructions
(
mm
,
map_ins
);
submod
->
add_return
({
outputs
});
submod
->
add_return
({
outputs
});
match
::
find_matches
(
*
submod
,
find_static_2in_broadcasts
{});
submodules
.
push_back
(
submod
);
submodules
.
push_back
(
submod
);
}
}
// redirect to select_module operator and return
// redirect to select_module operator and return
...
...
src/targets/cpu/lowering.cpp
View file @
a045fb19
/*
/*
* The MIT License (MIT)
* The MIT License (MIT)
*
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* of this software and associated documentation files (the "Software"), to deal
...
...
src/targets/gpu/CMakeLists.txt
View file @
a045fb19
...
@@ -50,6 +50,7 @@ file(GLOB KERNEL_FILES CONFIGURE_DEPENDS
...
@@ -50,6 +50,7 @@ file(GLOB KERNEL_FILES CONFIGURE_DEPENDS
message
(
STATUS
"KERNEL_FILES:
${
KERNEL_FILES
}
"
)
message
(
STATUS
"KERNEL_FILES:
${
KERNEL_FILES
}
"
)
add_embed_library
(
migraphx_kernels
${
KERNEL_FILES
}
RELATIVE
${
CMAKE_CURRENT_SOURCE_DIR
}
/kernels/include/
)
add_embed_library
(
migraphx_kernels
${
KERNEL_FILES
}
RELATIVE
${
CMAKE_CURRENT_SOURCE_DIR
}
/kernels/include/
)
configure_file
(
device/targets.hpp.in include/migraphx/gpu/device/targets.hpp
)
file
(
GLOB DEVICE_GPU_SRCS CONFIGURE_DEPENDS
${
CMAKE_CURRENT_SOURCE_DIR
}
/device/*.cpp
)
file
(
GLOB DEVICE_GPU_SRCS CONFIGURE_DEPENDS
${
CMAKE_CURRENT_SOURCE_DIR
}
/device/*.cpp
)
add_library
(
migraphx_device
${
DEVICE_GPU_SRCS
}
)
add_library
(
migraphx_device
${
DEVICE_GPU_SRCS
}
)
...
@@ -69,6 +70,7 @@ rocm_clang_tidy_check(migraphx_device)
...
@@ -69,6 +70,7 @@ rocm_clang_tidy_check(migraphx_device)
target_link_libraries
(
migraphx_device PUBLIC migraphx
)
target_link_libraries
(
migraphx_device PUBLIC migraphx
)
target_link_libraries
(
migraphx_device PRIVATE compile_for_gpu
)
target_link_libraries
(
migraphx_device PRIVATE compile_for_gpu
)
target_include_directories
(
migraphx_device PUBLIC $<BUILD_INTERFACE:
${
CMAKE_CURRENT_SOURCE_DIR
}
/include>
)
target_include_directories
(
migraphx_device PUBLIC $<BUILD_INTERFACE:
${
CMAKE_CURRENT_SOURCE_DIR
}
/include>
)
target_include_directories
(
migraphx_device PRIVATE $<BUILD_INTERFACE:
${
CMAKE_CURRENT_BINAR_DIR
}
/include>
)
target_include_directories
(
migraphx_device PRIVATE $<BUILD_INTERFACE:
${
CMAKE_CURRENT_SOURCE_DIR
}
/device/include>
)
target_include_directories
(
migraphx_device PRIVATE $<BUILD_INTERFACE:
${
CMAKE_CURRENT_SOURCE_DIR
}
/device/include>
)
target_compile_options
(
migraphx_device PRIVATE -Wno-ignored-attributes
)
target_compile_options
(
migraphx_device PRIVATE -Wno-ignored-attributes
)
migraphx_generate_export_header
(
migraphx_device DIRECTORY migraphx/gpu/device
)
migraphx_generate_export_header
(
migraphx_device DIRECTORY migraphx/gpu/device
)
...
@@ -123,6 +125,7 @@ add_library(migraphx_gpu
...
@@ -123,6 +125,7 @@ add_library(migraphx_gpu
lrn.cpp
lrn.cpp
mlir.cpp
mlir.cpp
multinomial.cpp
multinomial.cpp
no_device.cpp
nonzero.cpp
nonzero.cpp
pack_args.cpp
pack_args.cpp
pack_int8_args.cpp
pack_int8_args.cpp
...
...
src/targets/gpu/compile_hip.cpp
View file @
a045fb19
...
@@ -115,6 +115,12 @@ struct hiprtc_program
...
@@ -115,6 +115,12 @@ struct hiprtc_program
std
::
string
cpp_src
=
""
;
std
::
string
cpp_src
=
""
;
std
::
string
cpp_name
=
""
;
std
::
string
cpp_name
=
""
;
hiprtc_program
(
const
std
::
string
&
src
,
const
std
::
string
&
name
=
"main.cpp"
)
:
cpp_src
(
src
),
cpp_name
(
name
)
{
create_program
();
}
hiprtc_program
(
std
::
vector
<
hiprtc_src_file
>
srcs
)
hiprtc_program
(
std
::
vector
<
hiprtc_src_file
>
srcs
)
{
{
for
(
auto
&&
src
:
srcs
)
for
(
auto
&&
src
:
srcs
)
...
@@ -130,6 +136,14 @@ struct hiprtc_program
...
@@ -130,6 +136,14 @@ struct hiprtc_program
include_names
.
push_back
(
std
::
move
(
src
.
path
));
include_names
.
push_back
(
std
::
move
(
src
.
path
));
}
}
}
}
create_program
();
}
void
create_program
()
{
assert
(
not
cpp_src
.
empty
());
assert
(
not
cpp_name
.
empty
());
assert
(
headers
.
size
()
==
include_names
.
size
());
prog
=
hiprtc_program_create
(
cpp_src
.
c_str
(),
prog
=
hiprtc_program_create
(
cpp_src
.
c_str
(),
cpp_name
.
c_str
(),
cpp_name
.
c_str
(),
headers
.
size
(),
headers
.
size
(),
...
@@ -137,7 +151,7 @@ struct hiprtc_program
...
@@ -137,7 +151,7 @@ struct hiprtc_program
include_names
.
data
());
include_names
.
data
());
}
}
void
compile
(
const
std
::
vector
<
std
::
string
>&
options
)
const
void
compile
(
const
std
::
vector
<
std
::
string
>&
options
,
bool
quiet
=
false
)
const
{
{
if
(
enabled
(
MIGRAPHX_TRACE_HIPRTC
{}))
if
(
enabled
(
MIGRAPHX_TRACE_HIPRTC
{}))
std
::
cout
<<
"hiprtc "
<<
join_strings
(
options
,
" "
)
<<
" "
<<
cpp_name
<<
std
::
endl
;
std
::
cout
<<
"hiprtc "
<<
join_strings
(
options
,
" "
)
<<
" "
<<
cpp_name
<<
std
::
endl
;
...
@@ -148,7 +162,7 @@ struct hiprtc_program
...
@@ -148,7 +162,7 @@ struct hiprtc_program
[](
const
std
::
string
&
s
)
{
return
s
.
c_str
();
});
[](
const
std
::
string
&
s
)
{
return
s
.
c_str
();
});
auto
result
=
hiprtcCompileProgram
(
prog
.
get
(),
c_options
.
size
(),
c_options
.
data
());
auto
result
=
hiprtcCompileProgram
(
prog
.
get
(),
c_options
.
size
(),
c_options
.
data
());
auto
prog_log
=
log
();
auto
prog_log
=
log
();
if
(
not
prog_log
.
empty
())
if
(
not
prog_log
.
empty
()
and
not
quiet
)
{
{
std
::
cerr
<<
prog_log
<<
std
::
endl
;
std
::
cerr
<<
prog_log
<<
std
::
endl
;
}
}
...
@@ -210,6 +224,20 @@ std::vector<std::vector<char>> compile_hip_src_with_hiprtc(std::vector<hiprtc_sr
...
@@ -210,6 +224,20 @@ std::vector<std::vector<char>> compile_hip_src_with_hiprtc(std::vector<hiprtc_sr
return
{
prog
.
get_code_obj
()};
return
{
prog
.
get_code_obj
()};
}
}
bool
hip_has_flags
(
const
std
::
vector
<
std
::
string
>&
flags
)
{
hiprtc_program
prog
{
" "
};
try
{
prog
.
compile
(
flags
,
true
);
return
true
;
}
catch
(...)
{
return
false
;
}
}
std
::
vector
<
std
::
vector
<
char
>>
std
::
vector
<
std
::
vector
<
char
>>
compile_hip_src
(
const
std
::
vector
<
src_file
>&
srcs
,
std
::
string
params
,
const
std
::
string
&
arch
)
compile_hip_src
(
const
std
::
vector
<
src_file
>&
srcs
,
std
::
string
params
,
const
std
::
string
&
arch
)
{
{
...
@@ -323,6 +351,29 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
...
@@ -323,6 +351,29 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
return
{
compiler
.
compile
(
srcs
)};
return
{
compiler
.
compile
(
srcs
)};
}
}
bool
hip_has_flags
(
const
std
::
vector
<
std
::
string
>&
flags
)
{
src_compiler
compiler
;
compiler
.
compiler
=
MIGRAPHX_STRINGIZE
(
MIGRAPHX_HIP_COMPILER
);
compiler
.
flags
=
join_strings
(
flags
,
" "
)
+
" -x hip -c --offload-arch=gfx900 --cuda-device-only"
;
std
::
string
src
;
src_file
input
;
input
.
path
=
"main.cpp"
;
input
.
content
=
std
::
make_pair
(
src
.
data
(),
src
.
data
()
+
src
.
size
());
try
{
compiler
.
compile
({
input
});
return
true
;
}
catch
(...)
{
return
false
;
}
}
#endif // MIGRAPHX_USE_HIPRTC
#endif // MIGRAPHX_USE_HIPRTC
std
::
string
enum_params
(
std
::
size_t
count
,
std
::
string
param
)
std
::
string
enum_params
(
std
::
size_t
count
,
std
::
string
param
)
...
...
src/targets/gpu/compile_hip_code_object.cpp
View file @
a045fb19
...
@@ -91,28 +91,39 @@ __content__
...
@@ -91,28 +91,39 @@ __content__
return
replace_string
(
args_hpp
,
"__content__"
,
inner
);
return
replace_string
(
args_hpp
,
"__content__"
,
inner
);
}
}
static
std
::
vector
<
std
::
string
>
get_compiler_warnings
()
{
std
::
vector
<
std
::
string
>
warnings
=
{
"-Weverything"
,
"-Wno-c++98-compat"
,
"-Wno-c++98-compat-pedantic"
,
"-Wno-conversion"
,
"-Wno-double-promotion"
,
"-Wno-exit-time-destructors"
,
"-Wno-extra-semi"
,
"-Wno-extra-semi-stmt"
,
"-Wno-float-conversion"
,
"-Wno-gnu-anonymous-struct"
,
"-Wno-gnu-zero-variadic-macro-arguments"
,
"-Wno-missing-prototypes"
,
"-Wno-nested-anon-types"
,
"-Wno-padded"
,
"-Wno-shorten-64-to-32"
,
"-Wno-sign-conversion"
,
"-Wno-sign-compare"
,
"-Wno-unused-command-line-argument"
,
"-Wno-weak-vtables"
,
"-Wno-c99-extensions"
,
};
if
(
hip_has_flags
({
"-Werror"
,
"-Wunsafe-buffer-usage"
}))
warnings
.
push_back
(
"-Wno-unsafe-buffer-usage"
);
return
warnings
;
}
const
std
::
vector
<
std
::
string
>&
compiler_warnings
()
const
std
::
vector
<
std
::
string
>&
compiler_warnings
()
{
{
static
std
::
vector
<
std
::
string
>
warnings
=
{
"-Weverything"
,
static
std
::
vector
<
std
::
string
>
warnings
=
get_compiler_warnings
();
"-Wno-c++98-compat"
,
"-Wno-c++98-compat-pedantic"
,
"-Wno-conversion"
,
"-Wno-double-promotion"
,
"-Wno-exit-time-destructors"
,
"-Wno-extra-semi"
,
"-Wno-extra-semi-stmt"
,
"-Wno-float-conversion"
,
"-Wno-gnu-anonymous-struct"
,
"-Wno-gnu-zero-variadic-macro-arguments"
,
"-Wno-missing-prototypes"
,
"-Wno-nested-anon-types"
,
"-Wno-padded"
,
"-Wno-shorten-64-to-32"
,
"-Wno-sign-conversion"
,
"-Wno-sign-compare"
,
"-Wno-unused-command-line-argument"
,
"-Wno-weak-vtables"
,
"-Wno-c99-extensions"
};
return
warnings
;
return
warnings
;
}
}
...
...
Prev
1
2
3
4
5
6
7
…
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment