Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
2fc6b715
Commit
2fc6b715
authored
Apr 14, 2023
by
Paul
Browse files
Merge
parents
5967d68d
118e05c7
Changes
177
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
228 additions
and
629 deletions
+228
-629
src/include/migraphx/register_op.hpp
src/include/migraphx/register_op.hpp
+22
-1
src/include/migraphx/register_target.hpp
src/include/migraphx/register_target.hpp
+15
-1
src/include/migraphx/serialize.hpp
src/include/migraphx/serialize.hpp
+2
-1
src/include/migraphx/shape.hpp
src/include/migraphx/shape.hpp
+15
-9
src/include/migraphx/split_single_dyn_dim.hpp
src/include/migraphx/split_single_dyn_dim.hpp
+22
-14
src/module.cpp
src/module.cpp
+9
-0
src/msgpack.cpp
src/msgpack.cpp
+16
-0
src/onnx/include/migraphx/onnx/onnx_parser.hpp
src/onnx/include/migraphx/onnx/onnx_parser.hpp
+1
-1
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+2
-2
src/onnx/onnx_parser.cpp
src/onnx/onnx_parser.cpp
+1
-1
src/onnx/parse_quantizelinear.cpp
src/onnx/parse_quantizelinear.cpp
+6
-8
src/onnx/parse_reshape.cpp
src/onnx/parse_reshape.cpp
+2
-2
src/opt/memory_coloring_impl.cpp
src/opt/memory_coloring_impl.cpp
+0
-376
src/opt/memory_coloring_impl.hpp
src/opt/memory_coloring_impl.hpp
+0
-193
src/pass_manager.cpp
src/pass_manager.cpp
+9
-0
src/process.cpp
src/process.cpp
+34
-11
src/program.cpp
src/program.cpp
+2
-1
src/promote_literals.cpp
src/promote_literals.cpp
+54
-0
src/propagate_constant.cpp
src/propagate_constant.cpp
+16
-7
src/py/migraphx_py.cpp
src/py/migraphx_py.cpp
+0
-1
No files found.
src/include/migraphx/register_op.hpp
View file @
2fc6b715
...
@@ -33,15 +33,36 @@
...
@@ -33,15 +33,36 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
// unregister all ops for specified target, useful when unloading dynamically plugged-in target lib
void
unregister_op
(
const
std
::
string
&
op_name
);
namespace
detail
{
struct
op_handler
{
operation
op
;
std
::
string
name
;
op_handler
(
const
operation
&
op_r
)
:
op
(
op_r
),
name
(
op
.
name
()){};
~
op_handler
()
{
unregister_op
(
name
);
}
};
}
// namespace detail
void
register_op_init
();
void
register_op
(
const
operation
&
op
);
void
register_op
(
const
operation
&
op
);
operation
load_op
(
const
std
::
string
&
name
);
operation
load_op
(
const
std
::
string
&
name
);
bool
has_op
(
const
std
::
string
&
name
);
bool
has_op
(
const
std
::
string
&
name
);
std
::
vector
<
std
::
string
>
get_operators
();
std
::
vector
<
std
::
string
>
get_operators
();
template
<
class
T
>
template
<
class
T
>
void
register_op
()
void
register_op
()
{
{
register_op
(
T
{});
register_op_init
();
// instantiate static op_map;
static
auto
op_h
=
detail
::
op_handler
(
T
{});
register_op
(
op_h
.
op
);
}
}
struct
register_op_action
struct
register_op_action
...
...
src/include/migraphx/register_target.hpp
View file @
2fc6b715
...
@@ -33,14 +33,28 @@
...
@@ -33,14 +33,28 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
void
register_target_init
();
void
register_target
(
const
target
&
t
);
void
register_target
(
const
target
&
t
);
void
unregister_target
(
const
std
::
string
&
name
);
target
make_target
(
const
std
::
string
&
name
);
target
make_target
(
const
std
::
string
&
name
);
std
::
vector
<
std
::
string
>
get_targets
();
std
::
vector
<
std
::
string
>
get_targets
();
namespace
detail
{
struct
target_handler
{
target
t
;
std
::
string
target_name
;
target_handler
(
const
target
&
t_r
)
:
t
(
t_r
),
target_name
(
t
.
name
())
{}
~
target_handler
()
{
unregister_target
(
target_name
);
}
};
}
// namespace detail
template
<
class
T
>
template
<
class
T
>
void
register_target
()
void
register_target
()
{
{
register_target
(
T
{});
register_target_init
();
static
auto
t_h
=
detail
::
target_handler
(
T
{});
register_target
(
t_h
.
t
);
}
}
struct
register_target_action
struct
register_target_action
...
...
src/include/migraphx/serialize.hpp
View file @
2fc6b715
...
@@ -188,7 +188,8 @@ auto from_value_impl(rank<3>, const value& v, T& x)
...
@@ -188,7 +188,8 @@ auto from_value_impl(rank<3>, const value& v, T& x)
}
}
template
<
class
T
>
template
<
class
T
>
auto
from_value_impl
(
rank
<
4
>
,
const
value
&
v
,
T
&
x
)
->
decltype
(
x
.
insert
(
*
x
.
begin
()),
void
())
auto
from_value_impl
(
rank
<
4
>
,
const
value
&
v
,
T
&
x
)
->
decltype
(
x
.
insert
(
*
x
.
begin
()),
std
::
declval
<
typename
T
::
mapped_type
>
(),
void
())
{
{
x
.
clear
();
x
.
clear
();
for
(
auto
&&
e
:
v
)
for
(
auto
&&
e
:
v
)
...
...
src/include/migraphx/shape.hpp
View file @
2fc6b715
...
@@ -29,10 +29,12 @@
...
@@ -29,10 +29,12 @@
#include <ostream>
#include <ostream>
#include <numeric>
#include <numeric>
#include <memory>
#include <memory>
#include <set>
#include <migraphx/functional.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/errors.hpp>
#include <migraphx/errors.hpp>
#include <migraphx/half.hpp>
#include <migraphx/half.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
namespace
migraphx
{
...
@@ -87,12 +89,12 @@ struct shape
...
@@ -87,12 +89,12 @@ struct shape
{
{
std
::
size_t
min
=
0
;
std
::
size_t
min
=
0
;
std
::
size_t
max
=
0
;
std
::
size_t
max
=
0
;
std
::
size_t
opt
=
0
;
std
::
set
<
std
::
size_t
>
opt
imals
{}
;
template
<
class
Self
,
class
F
>
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
static
auto
reflect
(
Self
&
self
,
F
f
)
{
{
return
pack
(
f
(
self
.
min
,
"min"
),
f
(
self
.
max
,
"max"
),
f
(
self
.
opt
,
"opt"
));
return
pack
(
f
(
self
.
min
,
"min"
),
f
(
self
.
max
,
"max"
),
f
(
self
.
opt
imals
,
"opt
imals
"
));
}
}
bool
is_fixed
()
const
;
bool
is_fixed
()
const
;
...
@@ -132,11 +134,12 @@ struct shape
...
@@ -132,11 +134,12 @@ struct shape
shape
(
type_t
t
,
std
::
vector
<
dynamic_dimension
>
dims
);
shape
(
type_t
t
,
std
::
vector
<
dynamic_dimension
>
dims
);
// Construct a dynamic shape from three sets of lengths (of the same rank)
// Construct a dynamic shape from vectors of mins, maxes, and optimals.
// optimals_list is a vector of optimals that corresponds to each min and max.
shape
(
type_t
t
,
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
mins
,
std
::
vector
<
std
::
size_t
>
mins
,
std
::
vector
<
std
::
size_t
>
maxes
,
std
::
vector
<
std
::
size_t
>
maxes
,
std
::
vector
<
std
::
size_t
>
opt
s
);
std
::
vector
<
std
::
set
<
std
::
size_t
>
>
opt
imals_list
);
template
<
class
Range
>
template
<
class
Range
>
shape
(
type_t
t
,
const
Range
&
l
)
:
shape
(
t
,
std
::
vector
<
std
::
size_t
>
(
l
.
begin
(),
l
.
end
()))
shape
(
type_t
t
,
const
Range
&
l
)
:
shape
(
t
,
std
::
vector
<
std
::
size_t
>
(
l
.
begin
(),
l
.
end
()))
...
@@ -186,21 +189,21 @@ struct shape
...
@@ -186,21 +189,21 @@ struct shape
/*!
/*!
* Minimum lengths for dynamic shape.
* Minimum lengths for dynamic shape.
* lens() for
fixed
shape.
* lens() for
static
shape.
*/
*/
std
::
vector
<
std
::
size_t
>
min_lens
()
const
;
std
::
vector
<
std
::
size_t
>
min_lens
()
const
;
/*!
/*!
* Maximum lengths for dynamic shape.
* Maximum lengths for dynamic shape.
* lens() for
fixed
shape.
* lens() for
static
shape.
*/
*/
std
::
vector
<
std
::
size_t
>
max_lens
()
const
;
std
::
vector
<
std
::
size_t
>
max_lens
()
const
;
/*!
/*!
* Optimum lengths for dynamic shape.
* Optimum lengths for dynamic shape.
*
lens() for fixed
shape.
*
Empty for static
shape.
*/
*/
std
::
vector
<
std
::
size_t
>
opt_lens
()
const
;
std
::
vector
<
std
::
set
<
std
::
size_t
>
>
opt_lens
()
const
;
/// Map multiple indices to space index
/// Map multiple indices to space index
std
::
size_t
index
(
std
::
initializer_list
<
std
::
size_t
>
l
)
const
;
std
::
size_t
index
(
std
::
initializer_list
<
std
::
size_t
>
l
)
const
;
...
@@ -253,9 +256,12 @@ struct shape
...
@@ -253,9 +256,12 @@ struct shape
shape
with_type
(
type_t
t
)
const
;
shape
with_type
(
type_t
t
)
const
;
// convert the shape to an equivalent dynamic shape
// convert the shape to an equivalent dynamic shape
with empty optimals
shape
to_dynamic
()
const
;
shape
to_dynamic
()
const
;
// convert the shape to a static one setting any non-fixed dynamic_dimensions to x
shape
to_static
(
std
::
size_t
x
)
const
;
friend
bool
operator
==
(
const
shape
&
x
,
const
shape
&
y
);
friend
bool
operator
==
(
const
shape
&
x
,
const
shape
&
y
);
friend
bool
operator
!=
(
const
shape
&
x
,
const
shape
&
y
);
friend
bool
operator
!=
(
const
shape
&
x
,
const
shape
&
y
);
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
shape
&
x
);
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
shape
&
x
);
...
...
test/context_test.c
pp
→
src/include/migraphx/split_single_dyn_dim.h
pp
View file @
2fc6b715
/*
/*
* The MIT License (MIT)
* The MIT License (MIT)
*
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* of this software and associated documentation files (the "Software"), to deal
...
@@ -21,20 +21,28 @@
...
@@ -21,20 +21,28 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
* THE SOFTWARE.
*/
*/
#include <migraphx/serialize.hpp>
#ifndef MIGRAPHX_GUARD_RTGLIB_SPLIT_SINGLE_DYN_DIM_HPP
#include <migraphx/context.hpp>
#define MIGRAPHX_GUARD_RTGLIB_SPLIT_SINGLE_DYN_DIM_HPP
#include <migraphx/ref/context.hpp>
#include <migraphx/functional.hpp>
#include <test.hpp>
TEST_CASE
(
context
)
#include <string>
#include <migraphx/pass_manager.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
/**
* Split dynamic dimension over submodules if exactly one dimension in the parameter list is
* dynamic.
*/
struct
split_single_dyn_dim
{
{
migraphx
::
context
ctx
=
migraphx
::
ref
::
context
{};
std
::
string
name
()
const
{
return
"split_single_dyn_dim"
;
}
migraphx
::
value
v
=
ctx
.
to_value
()
;
void
apply
(
module_pass_manager
&
)
const
;
EXPECT
(
v
.
empty
())
;
}
;
migraphx
::
context
cpu_ctx
=
migraphx
::
ref
::
context
{};
}
// namespace MIGRAPHX_INLINE_NS
cpu_ctx
.
from_value
(
v
);
}
// namespace migraphx
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
#endif
src/module.cpp
View file @
2fc6b715
...
@@ -166,6 +166,7 @@ void module::assign(const module& m)
...
@@ -166,6 +166,7 @@ void module::assign(const module& m)
auto
s
=
ins
->
get_shape
();
auto
s
=
ins
->
get_shape
();
copy_ins
=
impl
->
insert
(
impl
->
instructions
.
end
(),
copy_ins
=
impl
->
insert
(
impl
->
instructions
.
end
(),
{
builtin
::
param
{
name
,
order
},
std
::
move
(
s
),
{}});
{
builtin
::
param
{
name
,
order
},
std
::
move
(
s
),
{}});
impl
->
nparams
++
;
}
}
else
if
(
ins
->
name
()
==
"@outline"
)
else
if
(
ins
->
name
()
==
"@outline"
)
{
{
...
@@ -594,6 +595,14 @@ std::vector<shape> module::get_output_shapes() const
...
@@ -594,6 +595,14 @@ std::vector<shape> module::get_output_shapes() const
}
}
}
}
std
::
vector
<
instruction_ref
>
module
::
get_returns
()
const
{
auto
last
=
std
::
prev
(
this
->
end
());
if
(
last
->
name
()
==
"@return"
)
return
last
->
inputs
();
return
{
last
};
}
instruction_ref
module
::
validate
()
const
instruction_ref
module
::
validate
()
const
{
{
return
std
::
find_if
(
return
std
::
find_if
(
...
...
src/msgpack.cpp
View file @
2fc6b715
...
@@ -172,6 +172,22 @@ struct vector_stream
...
@@ -172,6 +172,22 @@ struct vector_stream
}
}
};
};
struct
writer_stream
{
std
::
function
<
void
(
const
char
*
,
std
::
size_t
)
>
writer
;
writer_stream
&
write
(
const
char
*
b
,
std
::
size_t
n
)
{
writer
(
b
,
n
);
return
*
this
;
}
};
void
to_msgpack
(
const
value
&
v
,
std
::
function
<
void
(
const
char
*
,
std
::
size_t
)
>
writer
)
{
writer_stream
ws
{
std
::
move
(
writer
)};
msgpack
::
pack
(
ws
,
v
);
}
std
::
vector
<
char
>
to_msgpack
(
const
value
&
v
)
std
::
vector
<
char
>
to_msgpack
(
const
value
&
v
)
{
{
vector_stream
vs
;
vector_stream
vs
;
...
...
src/onnx/include/migraphx/onnx/onnx_parser.hpp
View file @
2fc6b715
...
@@ -94,7 +94,7 @@ struct onnx_parser
...
@@ -94,7 +94,7 @@ struct onnx_parser
node_map
nodes
;
node_map
nodes
;
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
instructions
;
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
instructions
;
program
prog
=
program
();
program
prog
=
program
();
shape
::
dynamic_dimension
default_dyn_dim_value
=
{
1
,
1
,
0
};
shape
::
dynamic_dimension
default_dyn_dim_value
=
{
1
,
1
};
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
size_t
>>
map_input_dims
;
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
size_t
>>
map_input_dims
;
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
shape
::
dynamic_dimension
>>
map_dyn_input_dims
;
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
shape
::
dynamic_dimension
>>
map_dyn_input_dims
;
bool
use_dyn_output
=
false
;
bool
use_dyn_output
=
false
;
...
...
src/onnx/onnx.cpp
View file @
2fc6b715
...
@@ -46,14 +46,14 @@ program parse_onnx_from(const onnx_options& options, Ts&&... xs)
...
@@ -46,14 +46,14 @@ program parse_onnx_from(const onnx_options& options, Ts&&... xs)
auto
dim_val
=
options
.
default_dim_value
;
auto
dim_val
=
options
.
default_dim_value
;
if
(
dim_val
!=
0
)
if
(
dim_val
!=
0
)
{
{
if
(
options
.
default_dyn_dim_value
!=
shape
::
dynamic_dimension
{
1
,
1
,
0
})
if
(
options
.
default_dyn_dim_value
!=
shape
::
dynamic_dimension
{
1
,
1
})
{
{
MIGRAPHX_THROW
(
"PARSE_ONNX_FROM: both default_dim_value and default_dyn_dim_value"
MIGRAPHX_THROW
(
"PARSE_ONNX_FROM: both default_dim_value and default_dyn_dim_value"
"set to non-default value"
);
"set to non-default value"
);
}
}
else
else
{
{
parser
.
default_dyn_dim_value
=
{
dim_val
,
dim_val
,
0
};
parser
.
default_dyn_dim_value
=
{
dim_val
,
dim_val
};
}
}
}
}
else
else
...
...
src/onnx/onnx_parser.cpp
View file @
2fc6b715
...
@@ -491,7 +491,7 @@ shape onnx_parser::parse_type(const onnx::TypeProto& t,
...
@@ -491,7 +491,7 @@ shape onnx_parser::parse_type(const onnx::TypeProto& t,
return
default_dyn_dim_value
;
return
default_dyn_dim_value
;
}
}
std
::
size_t
tmp
=
d
.
dim_value
();
std
::
size_t
tmp
=
d
.
dim_value
();
return
{
tmp
,
tmp
,
0
};
return
{
tmp
,
tmp
};
}
}
else
else
{
{
...
...
src/onnx/parse_quantizelinear.cpp
View file @
2fc6b715
...
@@ -26,6 +26,7 @@
...
@@ -26,6 +26,7 @@
#include <migraphx/ranges.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/tune_axis.hpp>
#include <migraphx/tune_axis.hpp>
#include <migraphx/common.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -47,18 +48,15 @@ struct parse_quantizelinear : op_parser<parse_quantizelinear>
...
@@ -47,18 +48,15 @@ struct parse_quantizelinear : op_parser<parse_quantizelinear>
auto
input_lens
=
args
[
0
]
->
get_shape
().
lens
();
auto
input_lens
=
args
[
0
]
->
get_shape
().
lens
();
auto
n_dim
=
input_lens
.
size
();
auto
n_dim
=
input_lens
.
size
();
instruction_ref
y_scale
;
instruction_ref
y_scale
=
args
[
1
]
;
if
(
args
[
1
]
->
get_shape
().
elements
()
!=
1
)
if
(
args
[
1
]
->
get_shape
().
elements
()
!=
1
)
{
{
auto
tuned_axis
=
tune_axis
(
n_dim
,
axis
,
opd
.
op_name
);
auto
tuned_axis
=
tune_axis
(
n_dim
,
axis
,
opd
.
op_name
);
y_scale
=
info
.
add_instruction
(
y_scale
=
info
.
add_instruction
(
make_op
(
"broadcast"
,
{{
"axis"
,
tuned_axis
},
{
"out_lens"
,
input_lens
}}),
args
[
1
]);
make_op
(
"broadcast"
,
{{
"axis"
,
tuned_axis
},
{
"out_lens"
,
input_lens
}}),
args
[
1
]);
}
}
else
{
auto
common_args
=
add_common_args
(
*
info
.
mod
,
{
args
[
0
],
y_scale
});
y_scale
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
input_lens
}}),
args
[
1
]);
}
if
(
args
.
size
()
==
3
)
if
(
args
.
size
()
==
3
)
{
{
...
@@ -76,10 +74,10 @@ struct parse_quantizelinear : op_parser<parse_quantizelinear>
...
@@ -76,10 +74,10 @@ struct parse_quantizelinear : op_parser<parse_quantizelinear>
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
input_lens
}}),
y_zero_point
);
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
input_lens
}}),
y_zero_point
);
}
}
return
info
.
add_instruction
(
make_op
(
"quantizelinear"
),
args
[
0
],
y_scale
,
y_zero_point
);
common_args
.
push_back
(
y_zero_point
);
}
}
return
info
.
add_instruction
(
make_op
(
"quantizelinear"
),
args
[
0
],
y_scale
);
return
info
.
add_instruction
(
make_op
(
"quantizelinear"
),
common_args
);
}
}
};
};
...
...
src/onnx/parse_reshape.cpp
View file @
2fc6b715
...
@@ -53,8 +53,8 @@ struct parse_reshape : op_parser<parse_reshape>
...
@@ -53,8 +53,8 @@ struct parse_reshape : op_parser<parse_reshape>
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
dims
));
});
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
dims
));
});
}
}
return
info
.
add_instruction
(
make_op
(
"
reshape"
,
{{
"dim
s"
,
dims
}}),
auto
cont
=
info
.
add_instruction
(
make_op
(
"
contiguou
s"
)
,
args
[
0
]);
info
.
make_contiguous
(
args
[
0
])
);
return
info
.
add_instruction
(
make_op
(
"reshape"
,
{{
"dims"
,
dims
}}),
cont
);
}
}
};
};
...
...
src/opt/memory_coloring_impl.cpp
deleted
100644 → 0
View file @
5967d68d
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include "memory_coloring_impl.hpp"
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
void
memory_coloring_impl
::
run
()
{
// calc implicit depdendencies
mod_implicit_deps
=
p_mod
->
calc_implicit_deps
();
MIGRAPHX_DEBUG
(
dump
(
"---Before memory coloring---"
));
MIGRAPHX_DEBUG
(
dump_module
());
build
();
if
(
num_of_lives
!=
0
)
{
MIGRAPHX_DEBUG
(
dump_intervals
());
// Coloring
while
(
not
alloc_queue
.
empty
())
{
interval_ptr
interval
=
alloc_queue
.
top
();
allocate
(
interval
);
alloc_queue
.
pop
();
}
// rewrite happens after all modules are processed
rewrite
();
if
(
enable_verify
)
verify
();
}
}
bool
memory_coloring_impl
::
allocate
(
interval_ptr
interval
)
{
shape
s
=
interval
->
result
;
std
::
size_t
size
=
s
.
bytes
();
if
(
size
==
0
)
return
false
;
std
::
size_t
element_size
=
(
s
.
elements
()
==
0
?
4
:
(
size
/
s
.
elements
()));
live_range
&
segment
=
interval
->
segment
;
int
vn
=
segment
.
vn
;
std
::
priority_queue
<
live_range
*
,
std
::
vector
<
live_range
*>
,
ordering
>
conflict_queue
;
std
::
unordered_map
<
long
long
,
live_range
*>
offset2_live
;
offset2_live
.
clear
();
if
(
conflict_table
.
find
(
vn
)
!=
conflict_table
.
end
())
{
const
std
::
set
<
int
>&
vn_set
=
conflict_table
[
vn
];
for
(
const
auto
&
iter
:
vn_set
)
{
live_range
*
range
=
live_ranges
[
iter
];
long
long
offset
=
range
->
offset
;
if
(
offset
!=
invalid_offset
)
{
conflict_queue
.
push
(
range
);
if
(
offset2_live
.
find
(
offset
)
==
offset2_live
.
end
())
{
offset2_live
[
offset
]
=
range
;
}
else
{
live_range
*
prev
=
offset2_live
[
offset
];
assert
(
prev
->
offset
==
offset
);
if
(
prev
->
size
<
range
->
size
)
offset2_live
[
offset
]
=
range
;
}
}
}
}
std
::
size_t
offset
=
0
;
while
(
not
conflict_queue
.
empty
())
{
live_range
*
range
=
conflict_queue
.
top
();
std
::
size_t
iter_offset
=
range
->
offset
;
if
(
offset
>
iter_offset
)
{
offset
=
std
::
max
(
offset
,
iter_offset
+
range
->
size
);
}
else
if
(
offset2_live
[
iter_offset
]
==
range
)
{
if
((
iter_offset
>
offset
)
&&
(
iter_offset
-
offset
)
>=
size
)
{
break
;
}
offset
=
iter_offset
+
range
->
size
;
}
// alignment
if
((
offset
%
element_size
)
!=
0
)
offset
+=
(
element_size
-
(
offset
%
element_size
));
conflict_queue
.
pop
();
}
// when int8 type is used, the offset could be any number
// if not 4-byte aligned, miopen int8 convolution can crash
offset
=
(
offset
+
3
)
/
4
*
4
;
segment
.
offset
=
offset
;
MIGRAPHX_DEBUG
(
segment
.
dump
());
required_bytes
=
std
::
max
(
required_bytes
,
offset
+
segment
.
size
);
return
true
;
}
void
memory_coloring_impl
::
build
()
{
std
::
size_t
num_of_instrs
=
p_mod
->
size
();
if
(
num_of_instrs
==
0
)
return
;
auto
cur_points
=
num_of_instrs
*
2
;
instruction_ref
iter
=
p_mod
->
end
();
instruction_ref
begin
=
p_mod
->
begin
();
std
::
vector
<
instruction_ref
>
dead_instrs
;
std
::
set
<
int
>
live_set
;
// Build live intervals.
live_intervals
.
resize
(
num_of_instrs
);
do
{
iter
=
std
::
prev
(
iter
);
const
instruction
*
p_iter
=
&
(
*
iter
);
interval_ptr
def_interval
=
nullptr
;
bool
is_dead
=
false
;
if
(
instr2_live
.
find
(
p_iter
)
!=
instr2_live
.
end
())
{
def_interval
=
instr2_live
[
p_iter
];
bool
is_lit
=
is_literal
(
iter
);
if
(
is_allocate
(
iter
)
or
is_lit
)
{
live_range
&
range
=
def_interval
->
segment
;
def_interval
->
result
=
iter
->
get_shape
();
def_interval
->
is_literal
=
is_lit
;
range
.
begin
=
cur_points
;
def_interval
->
def_point
=
cur_points
;
range
.
size
=
(
iter
->
get_shape
()).
bytes
();
if
(
not
is_lit
or
unify_literals
)
alloc_queue
.
push
(
def_interval
);
live_set
.
erase
(
range
.
vn
);
}
}
else
if
(
not
is_param
(
iter
)
&&
not
is_outline
(
iter
)
&&
not
is_check_context
(
iter
))
{
is_dead
=
true
;
}
auto
inputs
=
iter
->
inputs
();
if
(
contains
(
mod_implicit_deps
,
iter
))
{
const
auto
&
impl_deps
=
mod_implicit_deps
.
at
(
iter
);
inputs
.
insert
(
inputs
.
end
(),
impl_deps
.
begin
(),
impl_deps
.
end
());
}
for
(
auto
&&
arg
:
inputs
)
{
if
(
not
p_mod
->
has_instruction
(
arg
))
continue
;
if
(
is_param
(
arg
)
or
is_outline
(
arg
))
{
if
(
is_output_param
(
arg
))
is_dead
=
false
;
if
(
def_interval
!=
nullptr
)
{
def_interval
->
is_live_on_entry
=
true
;
}
continue
;
}
const
instruction
*
p_arg
=
&
(
*
instruction
::
get_output_alias
(
arg
));
if
(
instr2_live
.
find
(
p_arg
)
==
instr2_live
.
end
())
{
// First time see a use, create a live interval.
int
id
=
num_of_lives
++
;
interval_ptr
interval
=
&
(
live_intervals
[
id
]);
interval
->
id
=
id
;
interval
->
segment
.
end
=
cur_points
;
interval
->
segment
.
vn
=
++
max_value_number
;
interval
->
add_use
(
cur_points
);
instr2_live
[
p_arg
]
=
interval
;
add_conflicts
(
live_set
,
max_value_number
);
live_set
.
insert
(
max_value_number
);
live_ranges
[
max_value_number
]
=
&
(
interval
->
segment
);
earliest_end_point
=
cur_points
;
if
(
latest_end_point
==
-
1
)
latest_end_point
=
cur_points
;
}
else
{
interval_ptr
interval
=
instr2_live
[
p_arg
];
interval
->
add_use
(
cur_points
);
assert
(
live_set
.
find
(
interval
->
id
)
!=
live_set
.
end
());
}
}
if
(
is_dead
)
dead_instrs
.
push_back
(
iter
);
cur_points
-=
2
;
}
while
(
iter
!=
begin
);
}
void
memory_coloring_impl
::
rewrite
()
{
std
::
vector
<
std
::
size_t
>
dims
;
dims
.
push_back
((
required_bytes
+
sizeof
(
float
)
-
1
)
/
sizeof
(
float
));
shape
s
=
{
shape
::
float_type
,
dims
};
instruction_ref
scratch_param
=
p_mod
->
add_parameter
(
"scratch"
,
s
);
for
(
auto
ins
:
iterator_for
(
*
p_mod
))
{
const
instruction
*
p_iter
=
&
(
*
ins
);
if
(
instr2_live
.
find
(
p_iter
)
!=
instr2_live
.
end
())
{
interval_ptr
interval
=
instr2_live
[
p_iter
];
if
(
interval
->
get_begin
()
==
invalid_offset
)
continue
;
if
(
not
unify_literals
&&
interval
->
is_literal
)
continue
;
std
::
size_t
offset
=
0
;
if
(
interval
->
get_offset
()
!=
invalid_offset
)
{
offset
=
interval
->
get_offset
();
}
else
{
assert
(
interval
->
result
.
bytes
()
==
0
);
}
if
(
is_allocate
(
ins
))
{
p_mod
->
replace_instruction
(
ins
,
make_op
(
"load"
,
{{
"shape"
,
to_value
(
ins
->
get_shape
())},
{
"offset"
,
offset
}}),
scratch_param
);
}
}
}
MIGRAPHX_DEBUG
(
dump
(
"---After rewrite---"
));
MIGRAPHX_DEBUG
(
dump_module
());
}
void
memory_coloring_impl
::
verify
()
{
if
(
num_of_lives
>
0
)
{
for
(
int
i
=
0
;
i
<
num_of_lives
;
++
i
)
{
const
live_interval
&
interval
=
live_intervals
[
i
];
const
live_range
&
segment
=
interval
.
segment
;
if
(
segment
.
begin
==
invalid_offset
)
{
// if(not interval.is_live_on_entry)
// MIGRAPHX_THROW("interval is not live on entry");
continue
;
}
if
(
segment
.
offset
==
invalid_offset
)
{
continue
;
}
int
vn
=
segment
.
vn
;
if
(
conflict_table
.
find
(
vn
)
!=
conflict_table
.
end
())
{
const
std
::
set
<
int
>&
vn_set
=
conflict_table
[
vn
];
for
(
const
auto
&
iter
:
vn_set
)
{
live_range
*
range
=
live_ranges
[
iter
];
if
(
range
->
offset
==
invalid_offset
)
continue
;
if
(
not
is_disjoin
(
*
range
,
segment
))
MIGRAPHX_THROW
(
"range and segment is not disjoined"
);
}
}
}
}
}
#ifdef MIGRAPHX_DEBUG_OPT
void
memory_coloring_impl
::
dump
(
const
std
::
string
&
str
)
{
std
::
cout
<<
str
<<
std
::
endl
;
}
void
memory_coloring_impl
::
dump_module
()
{
std
::
cout
<<
*
p_mod
<<
std
::
endl
;
}
void
memory_coloring_impl
::
dump_intervals
()
{
if
(
num_of_lives
>
0
)
{
std
::
cout
<<
"---live intervals ---"
<<
std
::
endl
;
for
(
int
i
=
0
;
i
<
num_of_lives
;
++
i
)
{
live_interval
&
interval
=
live_intervals
[
i
];
interval
.
dump
();
}
std
::
cout
<<
"---conflict table---"
<<
std
::
endl
;
for
(
int
i
=
0
;
i
<=
max_value_number
;
++
i
)
{
std
::
cout
<<
" segment:"
<<
i
;
std
::
cout
<<
" =>"
;
const
std
::
set
<
int
>&
table
=
conflict_table
[
i
];
for
(
const
auto
&
iter
:
table
)
{
std
::
cout
<<
(
iter
)
<<
","
;
}
}
std
::
cout
<<
std
::
endl
;
}
}
// map liveness tracking point to instruction enum.
static
int
get_ins_enum
(
int
x
)
{
if
(
x
>
0
)
{
return
(
x
/
2
)
-
1
;
}
else
return
invalid_offset
;
}
void
live_range
::
dump
()
{
std
::
cout
<<
" segment:"
<<
vn
;
std
::
cout
<<
" ["
<<
get_ins_enum
(
begin
)
<<
", "
<<
get_ins_enum
(
end
)
<<
"]"
;
if
(
offset
!=
invalid_offset
)
{
std
::
cout
<<
" mem:"
;
std
::
cout
<<
" ["
<<
offset
<<
","
<<
offset
+
size
-
1
<<
"]"
;
}
std
::
cout
<<
std
::
endl
;
}
void
live_interval
::
dump
()
{
std
::
cout
<<
"id:"
<<
id
;
segment
.
dump
();
std
::
cout
<<
" uses:"
;
for
(
const
auto
&
iter
:
use_points
)
{
std
::
cout
<<
" "
<<
get_ins_enum
(
iter
)
<<
","
;
}
std
::
cout
<<
" def:"
;
std
::
cout
<<
" "
<<
get_ins_enum
(
def_point
);
if
(
is_literal
)
std
::
cout
<<
" literal"
;
std
::
cout
<<
" "
<<
result
;
std
::
cout
<<
std
::
endl
;
}
#endif
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/opt/memory_coloring_impl.hpp
deleted
100644 → 0
View file @
5967d68d
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_MEMORY_COLORING_IMPL_HPP
#define MIGRAPHX_GUARD_RTGLIB_MEMORY_COLORING_IMPL_HPP
#include <migraphx/program.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/pass_config.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/config.hpp>
#include <set>
#include <list>
#include <vector>
#include <queue>
#ifdef MIGRAPHX_DEBUG_OPT
#define MIGRAPHX_DEBUG(s) s
#else
#define MIGRAPHX_DEBUG(s)
#endif // MIGRAPHX_DEBUG_OPT
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
static
const
std
::
size_t
invalid_offset
=
std
::
numeric_limits
<
std
::
size_t
>::
max
();
struct
live_range
{
std
::
size_t
begin
;
// begin point in the instruction stream.
std
::
size_t
end
;
// end point in the instruction stream.
std
::
size_t
offset
;
// offset to base pointer of allocated memory trunk.
std
::
size_t
vn
;
// value number that identifies this live_range.
std
::
size_t
size
;
// size of required memory in bytes
#ifdef MIGRAPHX_DEBUG_OPT
void
dump
();
#endif
};
struct
live_interval
{
live_interval
()
:
segment
({
invalid_offset
,
invalid_offset
,
invalid_offset
,
invalid_offset
,
0
})
{
}
void
add_use
(
std
::
size_t
use
)
{
use_points
.
push_front
(
use
);
}
std
::
size_t
get_begin
()
const
{
return
segment
.
begin
;
}
std
::
size_t
get_end
()
const
{
return
segment
.
end
;
}
long
long
get_offset
()
const
{
return
segment
.
offset
;
}
#ifdef MIGRAPHX_DEBUG_OPT
void
dump
();
#endif
live_range
segment
;
std
::
size_t
id
=
invalid_offset
;
std
::
list
<
std
::
size_t
>
use_points
{};
std
::
size_t
def_point
=
invalid_offset
;
shape
result
{};
bool
is_literal
=
false
;
bool
is_live_on_entry
=
false
;
};
using
interval_ptr
=
live_interval
*
;
struct
memory_coloring_impl
{
memory_coloring_impl
(
module
*
p
,
std
::
string
alloc_op
,
bool
p_verify
)
:
p_mod
(
p
),
allocation_op
(
std
::
move
(
alloc_op
)),
enable_verify
(
p_verify
)
{
}
bool
allocate
(
interval_ptr
);
void
add_conflicts
(
const
std
::
set
<
int
>&
live_set
,
int
val
)
{
for
(
const
auto
&
iter
:
live_set
)
{
conflict_table
[
iter
].
insert
(
val
);
conflict_table
[
val
].
insert
(
iter
);
}
}
void
build
();
void
run
();
void
rewrite
();
private:
static
bool
is_param
(
const
instruction_ref
ins
)
{
return
ins
->
name
()
==
"@param"
;
}
static
bool
is_output_param
(
const
instruction_ref
ins
)
{
if
(
not
is_param
(
ins
))
return
false
;
auto
param_name
=
any_cast
<
builtin
::
param
>
(
ins
->
get_operator
()).
parameter
;
return
contains
(
param_name
,
"#output_"
);
}
bool
is_allocate
(
const
instruction_ref
ins
)
const
{
return
ins
->
name
()
==
allocation_op
;
}
static
bool
is_outline
(
const
instruction_ref
ins
)
{
return
ins
->
name
()
==
"@outline"
;
}
static
bool
is_literal
(
const
instruction_ref
ins
)
{
return
ins
->
name
()
==
"@literal"
;
}
static
bool
is_check_context
(
const
instruction_ref
ins
)
{
return
ins
->
name
()
==
"check_context"
;
}
static
bool
is_disjoin
(
const
live_range
&
range1
,
const
live_range
&
range2
)
{
if
((
range1
.
size
==
0
)
or
(
range2
.
size
==
0
))
return
false
;
auto
end1
=
range1
.
offset
+
range1
.
size
-
1
;
auto
end2
=
range2
.
offset
+
range2
.
size
-
1
;
return
((
end1
<
range2
.
offset
)
or
(
end2
<
range1
.
offset
));
}
void
verify
();
#ifdef MIGRAPHX_DEBUG_OPT
void
dump
(
const
std
::
string
&
);
void
dump_module
();
void
dump_intervals
();
#endif
struct
ordering
{
bool
operator
()(
const
interval_ptr
&
i1
,
const
interval_ptr
&
i2
)
const
{
auto
len1
=
i1
->
get_end
()
-
i1
->
get_begin
();
auto
len2
=
i2
->
get_end
()
-
i2
->
get_begin
();
if
(
len1
!=
len2
)
{
return
(
len1
<
len2
);
}
else
if
(
i1
->
result
.
bytes
()
!=
i2
->
result
.
bytes
())
{
return
(
i1
->
result
.
bytes
()
<
i2
->
result
.
bytes
());
}
else
{
return
i1
->
id
>
i2
->
id
;
}
}
bool
operator
()(
const
live_range
*
i1
,
const
live_range
*
i2
)
const
{
return
(
i1
->
offset
>
i2
->
offset
);
}
};
module
*
p_mod
;
std
::
unordered_map
<
const
instruction
*
,
interval_ptr
>
instr2_live
;
// universe of live intervals.
std
::
vector
<
live_interval
>
live_intervals
=
{};
// Map live range value number to live range.
std
::
unordered_map
<
int
,
live_range
*>
live_ranges
=
{};
// Map live range value number to a set of conflicting live ranges' value numbers.
std
::
unordered_map
<
int
,
std
::
set
<
int
>>
conflict_table
=
{};
// Priority queue for coloring.
std
::
priority_queue
<
interval_ptr
,
std
::
vector
<
interval_ptr
>
,
ordering
>
alloc_queue
{};
int
num_of_lives
=
0
;
int
max_value_number
=
-
1
;
std
::
size_t
required_bytes
=
0
;
// The earliest program point where an live interval ends.
int
earliest_end_point
=
-
1
;
// The latest program point where an live interval ends.
int
latest_end_point
=
-
1
;
// Whether to unify literals into coloring.
bool
unify_literals
=
false
;
std
::
string
allocation_op
{};
bool
enable_verify
;
ins_dep_map
mod_implicit_deps
;
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/pass_manager.cpp
View file @
2fc6b715
...
@@ -86,12 +86,21 @@ struct module_pm : module_pass_manager
...
@@ -86,12 +86,21 @@ struct module_pm : module_pass_manager
assert
(
mod
);
assert
(
mod
);
return
*
mod
;
return
*
mod
;
}
}
virtual
module
*
create_module
(
const
std
::
string
&
name
)
override
virtual
module
*
create_module
(
const
std
::
string
&
name
)
override
{
{
assert
(
prog
);
assert
(
prog
);
return
prog
->
create_module
(
name
);
return
prog
->
create_module
(
name
);
}
}
virtual
module
*
get_common_parent
()
override
{
return
common_parent
;
}
virtual
module
*
get_common_parent
()
override
{
return
common_parent
;
}
virtual
module
*
get_root_module
()
override
{
assert
(
prog
);
return
prog
->
get_main_module
();
}
virtual
void
run_pass
(
const
pass
&
p
)
override
virtual
void
run_pass
(
const
pass
&
p
)
override
{
{
assert
(
mod
);
assert
(
mod
);
...
...
src/process.cpp
View file @
2fc6b715
...
@@ -38,27 +38,42 @@ std::function<void(const char*)> redirect_to(std::ostream& os)
...
@@ -38,27 +38,42 @@ std::function<void(const char*)> redirect_to(std::ostream& os)
return
[
&
](
const
char
*
x
)
{
os
<<
x
;
};
return
[
&
](
const
char
*
x
)
{
os
<<
x
;
};
}
}
int
exec
(
const
std
::
string
&
cmd
,
const
std
::
function
<
void
(
const
char
*
)
>&
std_out
)
template
<
class
F
>
int
exec
(
const
std
::
string
&
cmd
,
const
char
*
type
,
F
f
)
{
{
int
ec
=
0
;
int
ec
=
0
;
if
(
enabled
(
MIGRAPHX_TRACE_CMD_EXECUTE
{}))
if
(
enabled
(
MIGRAPHX_TRACE_CMD_EXECUTE
{}))
std
::
cout
<<
cmd
<<
std
::
endl
;
std
::
cout
<<
cmd
<<
std
::
endl
;
auto
closer
=
[
&
](
FILE
*
stream
)
{
auto
closer
=
[
&
](
FILE
*
stream
)
{
auto
status
=
pclose
(
stream
);
auto
status
=
pclose
(
stream
);
ec
=
WIFEXITED
(
status
)
?
0
:
WEXITSTATUS
(
status
);
// NOLINT
ec
=
WIFEXITED
(
status
)
?
WEXITSTATUS
(
status
)
:
0
;
// NOLINT
};
};
{
{
// TODO: Use execve instead of popen
// TODO: Use execve instead of popen
std
::
unique_ptr
<
FILE
,
decltype
(
closer
)
>
pipe
(
popen
(
cmd
.
c_str
(),
"r"
),
closer
);
// NOLINT
std
::
unique_ptr
<
FILE
,
decltype
(
closer
)
>
pipe
(
popen
(
cmd
.
c_str
(),
type
),
closer
);
// NOLINT
if
(
not
pipe
)
if
(
not
pipe
)
MIGRAPHX_THROW
(
"popen() failed: "
+
cmd
);
MIGRAPHX_THROW
(
"popen() failed: "
+
cmd
);
std
::
array
<
char
,
128
>
buffer
;
f
(
pipe
.
get
());
while
(
fgets
(
buffer
.
data
(),
buffer
.
size
(),
pipe
.
get
())
!=
nullptr
)
std_out
(
buffer
.
data
());
}
}
return
ec
;
return
ec
;
}
}
int
exec
(
const
std
::
string
&
cmd
,
const
std
::
function
<
void
(
const
char
*
)
>&
std_out
)
{
return
exec
(
cmd
,
"r"
,
[
&
](
FILE
*
f
)
{
std
::
array
<
char
,
128
>
buffer
;
while
(
fgets
(
buffer
.
data
(),
buffer
.
size
(),
f
)
!=
nullptr
)
std_out
(
buffer
.
data
());
});
}
int
exec
(
const
std
::
string
&
cmd
,
std
::
function
<
void
(
process
::
writer
)
>
std_in
)
{
return
exec
(
cmd
,
"w"
,
[
&
](
FILE
*
f
)
{
std_in
([
&
](
const
char
*
buffer
,
std
::
size_t
n
)
{
std
::
fwrite
(
buffer
,
1
,
n
,
f
);
});
});
}
struct
process_impl
struct
process_impl
{
{
std
::
string
command
{};
std
::
string
command
{};
...
@@ -72,6 +87,15 @@ struct process_impl
...
@@ -72,6 +87,15 @@ struct process_impl
result
+=
command
;
result
+=
command
;
return
result
;
return
result
;
}
}
template
<
class
...
Ts
>
void
check_exec
(
Ts
&&
...
xs
)
const
{
int
ec
=
migraphx
::
exec
(
std
::
forward
<
Ts
>
(
xs
)...);
if
(
ec
!=
0
)
MIGRAPHX_THROW
(
"Command "
+
get_command
()
+
" exited with status "
+
std
::
to_string
(
ec
));
}
};
};
process
::
process
(
const
std
::
string
&
cmd
)
:
impl
(
std
::
make_unique
<
process_impl
>
())
process
::
process
(
const
std
::
string
&
cmd
)
:
impl
(
std
::
make_unique
<
process_impl
>
())
...
@@ -95,12 +119,11 @@ process& process::cwd(const fs::path& p)
...
@@ -95,12 +119,11 @@ process& process::cwd(const fs::path& p)
return
*
this
;
return
*
this
;
}
}
void
process
::
exec
()
void
process
::
exec
()
{
impl
->
check_exec
(
impl
->
get_command
(),
redirect_to
(
std
::
cout
));
}
void
process
::
write
(
std
::
function
<
void
(
process
::
writer
)
>
pipe_in
)
{
{
auto
ec
=
migraphx
::
exec
(
impl
->
get_command
(),
redirect_to
(
std
::
cout
));
impl
->
check_exec
(
impl
->
get_command
(),
std
::
move
(
pipe_in
));
if
(
ec
!=
0
)
MIGRAPHX_THROW
(
"Command "
+
impl
->
get_command
()
+
" exited with status "
+
std
::
to_string
(
ec
));
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/program.cpp
View file @
2fc6b715
...
@@ -331,7 +331,8 @@ std::vector<argument> generic_eval(const module* mod,
...
@@ -331,7 +331,8 @@ std::vector<argument> generic_eval(const module* mod,
MIGRAPHX_THROW
(
"Parameter not found: "
+
param_name
);
MIGRAPHX_THROW
(
"Parameter not found: "
+
param_name
);
auto
param
=
params
[
param_name
];
auto
param
=
params
[
param_name
];
// TODO: may want to check correct number of dimensions and/or was within bounds
// TODO: may want to check correct number of dimensions and/or was within bounds
if
(
not
ins
->
get_shape
().
dynamic
()
and
param
.
get_shape
()
!=
ins
->
get_shape
())
if
(
not
ins
->
get_shape
().
any_of_dynamic
()
and
param
.
get_shape
()
!=
ins
->
get_shape
())
{
{
MIGRAPHX_THROW
(
"Incorrect shape {"
+
to_string
(
param
.
get_shape
())
+
MIGRAPHX_THROW
(
"Incorrect shape {"
+
to_string
(
param
.
get_shape
())
+
"} for parameter: "
+
param_name
+
"} for parameter: "
+
param_name
+
...
...
src/promote_literals.cpp
0 → 100644
View file @
2fc6b715
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/promote_literals.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/module.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
void
promote_literals
::
apply
(
module_pass_manager
&
mpm
)
const
{
module
&
m
=
mpm
.
get_module
();
module_ref
root_module
=
mpm
.
get_root_module
();
if
(
m
.
name
()
==
"main"
)
return
;
for
(
auto
ins
:
iterator_for
(
m
))
{
if
(
ins
->
name
()
==
"@literal"
)
{
auto
new_lit
=
root_module
->
add_literal
(
ins
->
get_literal
());
for
(
auto
out_ins
:
ins
->
outputs
())
{
out_ins
->
replace_argument
(
out_ins
,
ins
,
new_lit
);
}
}
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/propagate_constant.cpp
View file @
2fc6b715
...
@@ -44,7 +44,7 @@ bool skip_propogate(instruction_ref ins)
...
@@ -44,7 +44,7 @@ bool skip_propogate(instruction_ref ins)
return
false
;
return
false
;
}
}
bool
is_const
(
instruction_ref
ins
)
{
return
ins
->
can_eval
()
and
not
skip_propogate
(
ins
);
}
bool
is_const
_ins
(
instruction_ref
ins
)
{
return
ins
->
can_eval
()
and
not
skip_propogate
(
ins
);
}
void
propagate_constant
::
apply
(
module
&
m
)
const
void
propagate_constant
::
apply
(
module
&
m
)
const
{
{
...
@@ -54,14 +54,23 @@ void propagate_constant::apply(module& m) const
...
@@ -54,14 +54,23 @@ void propagate_constant::apply(module& m) const
// Find instructions that can be evaluated to a literal
// Find instructions that can be evaluated to a literal
for
(
auto
i
:
iterator_for
(
m
))
for
(
auto
i
:
iterator_for
(
m
))
{
{
if
(
is_const
(
i
)
and
i
!=
last
)
const
bool
is_const
=
is_const_ins
(
i
);
if
(
is_const
and
i
!=
last
)
continue
;
continue
;
std
::
copy_if
(
if
(
i
==
last
and
is_const
)
i
->
inputs
().
begin
(),
{
const_instrs
.
insert
(
i
);
}
else
{
std
::
copy_if
(
i
->
inputs
().
begin
(),
i
->
inputs
().
end
(),
i
->
inputs
().
end
(),
std
::
inserter
(
const_instrs
,
const_instrs
.
begin
()),
std
::
inserter
(
const_instrs
,
const_instrs
.
begin
()),
[
&
](
const
instruction_ref
ins
)
{
return
is_const
(
ins
)
and
ins
->
name
()
!=
"@literal"
;
});
[
&
](
const
instruction_ref
ins
)
{
return
is_const_ins
(
ins
)
and
ins
->
name
()
!=
"@literal"
;
});
}
}
}
// Compute literals in parallel
// Compute literals in parallel
...
...
src/py/migraphx_py.cpp
View file @
2fc6b715
...
@@ -35,7 +35,6 @@
...
@@ -35,7 +35,6 @@
#include <migraphx/stringutils.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/tf.hpp>
#include <migraphx/tf.hpp>
#include <migraphx/onnx.hpp>
#include <migraphx/onnx.hpp>
#include <migraphx/type_name.hpp>
#include <migraphx/load_save.hpp>
#include <migraphx/load_save.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/json.hpp>
#include <migraphx/json.hpp>
...
...
Prev
1
2
3
4
5
6
7
8
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment