Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
baac1dab
Commit
baac1dab
authored
May 24, 2023
by
Alan Turner
Browse files
Merge remote-tracking branch 'origin/develop' into ck-host-lib
parents
830dff7a
77042e30
Changes
299
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
808 additions
and
714 deletions
+808
-714
src/onnx/parse_slice.cpp
src/onnx/parse_slice.cpp
+7
-2
src/onnx/parse_where.cpp
src/onnx/parse_where.cpp
+34
-18
src/opt/memory_coloring_impl.cpp
src/opt/memory_coloring_impl.cpp
+0
-376
src/opt/memory_coloring_impl.hpp
src/opt/memory_coloring_impl.hpp
+0
-193
src/pass_manager.cpp
src/pass_manager.cpp
+10
-0
src/process.cpp
src/process.cpp
+34
-11
src/program.cpp
src/program.cpp
+4
-5
src/promote_literals.cpp
src/promote_literals.cpp
+54
-0
src/propagate_constant.cpp
src/propagate_constant.cpp
+32
-7
src/py/migraphx_py.cpp
src/py/migraphx_py.cpp
+71
-17
src/register_op.cpp
src/register_op.cpp
+10
-0
src/register_target.cpp
src/register_target.cpp
+23
-1
src/replace_allocate.cpp
src/replace_allocate.cpp
+5
-8
src/schedule.cpp
src/schedule.cpp
+3
-3
src/shape.cpp
src/shape.cpp
+104
-51
src/simplify_algebra.cpp
src/simplify_algebra.cpp
+250
-17
src/simplify_reshapes.cpp
src/simplify_reshapes.cpp
+1
-1
src/split_single_dyn_dim.cpp
src/split_single_dyn_dim.cpp
+165
-0
src/targets/cpu/include/migraphx/cpu/parallel.hpp
src/targets/cpu/include/migraphx/cpu/parallel.hpp
+1
-1
src/targets/cpu/include/migraphx/cpu/target.hpp
src/targets/cpu/include/migraphx/cpu/target.hpp
+0
-3
No files found.
src/onnx/parse_slice.cpp
View file @
baac1dab
...
...
@@ -46,7 +46,7 @@ struct parse_slice : op_parser<parse_slice>
std
::
vector
<
int64_t
>
steps
;
// slice can have up to 5 inputs, we first check the 5th one
// to decide whether MIGRAPHX can handle this slice
// to decide whether MIGRAPHX can handle this slice
.
if
(
args
.
size
()
==
5
)
{
migraphx
::
argument
step_arg
=
args
.
back
()
->
eval
();
...
...
@@ -90,9 +90,10 @@ struct parse_slice : op_parser<parse_slice>
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
op
.
starts
));
});
}
// If axes arg is not given, the default is all of them.
if
(
op
.
axes
.
empty
())
{
std
::
vector
<
int64_t
>
axes
(
args
[
0
]
->
get_shape
().
lens
().
size
());
std
::
vector
<
int64_t
>
axes
(
args
[
0
]
->
get_shape
().
ndim
());
std
::
iota
(
axes
.
begin
(),
axes
.
end
(),
int64_t
{
0
});
op
.
axes
=
axes
;
}
...
...
@@ -103,6 +104,7 @@ struct parse_slice : op_parser<parse_slice>
assert
(
op
.
axes
.
size
()
==
op
.
starts
.
size
());
assert
(
op
.
axes
.
size
()
==
op
.
ends
.
size
());
// If any axes have negative step, prepare to add a "reverse" op
for
(
auto
i
:
range
(
steps
.
size
()))
{
if
(
steps
[
i
]
>=
0
)
...
...
@@ -117,7 +119,10 @@ struct parse_slice : op_parser<parse_slice>
auto
ins
=
info
.
add_instruction
(
op
,
args
[
0
]);
if
(
not
raxes
.
empty
())
{
ins
=
info
.
add_instruction
(
make_op
(
"reverse"
,
{{
"axes"
,
raxes
}}),
ins
);
}
// If any steps are other than default 1, add a "steps" op
if
(
std
::
any_of
(
steps
.
begin
(),
steps
.
end
(),
[](
auto
s
)
{
return
std
::
abs
(
s
)
!=
1
;
}))
{
std
::
vector
<
int64_t
>
nsteps
;
...
...
src/onnx/parse_where.cpp
View file @
baac1dab
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
...
...
@@ -40,28 +40,44 @@ struct parse_where : op_parser<parse_where>
const
onnx_parser
::
node_info
&
info
,
std
::
vector
<
instruction_ref
>
args
)
const
{
auto
lens
=
compute_broadcasted_lens
(
args
[
0
]
->
get_shape
().
lens
(),
args
[
1
]
->
get_shape
().
lens
());
lens
=
compute_broadcasted_lens
(
lens
,
args
[
2
]
->
get_shape
().
lens
());
if
(
args
[
0
]
->
get_shape
().
lens
()
!=
lens
)
// TODO: broadcasting for dynamic shapes is only implemented
// for binary ops at time of writing, not ternary ops.
// When it becomes available, add multibroadcasting steps in the dynamic shape case.
// For now for dynamic shapes, just insert the Where op. All shapes must be the
// same for it to succeed.
if
(
std
::
all_of
(
args
.
begin
(),
args
.
end
(),
[](
auto
v
)
{
return
v
->
get_shape
().
dynamic
();
}))
{
args
[
0
]
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
lens
}}),
args
[
0
]);
return
info
.
add_instruction
(
make_op
(
"where"
),
args
[
0
],
args
[
1
],
args
[
2
]);
}
if
(
args
[
1
]
->
get_shape
().
lens
()
!=
lens
)
else
if
(
std
::
none_of
(
args
.
begin
(),
args
.
end
(),
[](
auto
v
)
{
return
v
->
get_shape
().
dynamic
();
})
)
{
args
[
1
]
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
lens
}}),
args
[
1
]);
}
// If shapes are static and any are broadcasted, insert multibroadcast ops
auto
lens
=
compute_broadcasted_lens
(
args
[
0
]
->
get_shape
().
lens
(),
args
[
1
]
->
get_shape
().
lens
());
lens
=
compute_broadcasted_lens
(
lens
,
args
[
2
]
->
get_shape
().
lens
());
if
(
args
[
0
]
->
get_shape
().
lens
()
!=
lens
)
{
args
[
0
]
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
lens
}}),
args
[
0
]);
}
if
(
args
[
2
]
->
get_shape
().
lens
()
!=
lens
)
{
args
[
2
]
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
lens
}}),
args
[
2
]);
}
if
(
args
[
1
]
->
get_shape
().
lens
()
!=
lens
)
{
args
[
1
]
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
lens
}}),
args
[
1
]);
}
if
(
args
[
2
]
->
get_shape
().
lens
()
!=
lens
)
{
args
[
2
]
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
lens
}}),
args
[
2
]);
}
return
info
.
add_instruction
(
make_op
(
"where"
),
args
[
0
],
args
[
1
],
args
[
2
]);
return
info
.
add_instruction
(
make_op
(
"where"
),
args
[
0
],
args
[
1
],
args
[
2
]);
}
else
MIGRAPHX_THROW
(
"PARSE_WHERE: doesn't support mixed static and dynamic shape inputs"
);
}
};
...
...
src/opt/memory_coloring_impl.cpp
deleted
100644 → 0
View file @
830dff7a
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include "memory_coloring_impl.hpp"
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
void
memory_coloring_impl
::
run
()
{
// calc implicit depdendencies
mod_implicit_deps
=
p_mod
->
calc_implicit_deps
();
MIGRAPHX_DEBUG
(
dump
(
"---Before memory coloring---"
));
MIGRAPHX_DEBUG
(
dump_module
());
build
();
if
(
num_of_lives
!=
0
)
{
MIGRAPHX_DEBUG
(
dump_intervals
());
// Coloring
while
(
not
alloc_queue
.
empty
())
{
interval_ptr
interval
=
alloc_queue
.
top
();
allocate
(
interval
);
alloc_queue
.
pop
();
}
// rewrite happens after all modules are processed
rewrite
();
if
(
enable_verify
)
verify
();
}
}
bool
memory_coloring_impl
::
allocate
(
interval_ptr
interval
)
{
shape
s
=
interval
->
result
;
std
::
size_t
size
=
s
.
bytes
();
if
(
size
==
0
)
return
false
;
std
::
size_t
element_size
=
(
s
.
elements
()
==
0
?
4
:
(
size
/
s
.
elements
()));
live_range
&
segment
=
interval
->
segment
;
int
vn
=
segment
.
vn
;
std
::
priority_queue
<
live_range
*
,
std
::
vector
<
live_range
*>
,
ordering
>
conflict_queue
;
std
::
unordered_map
<
long
long
,
live_range
*>
offset2_live
;
offset2_live
.
clear
();
if
(
conflict_table
.
find
(
vn
)
!=
conflict_table
.
end
())
{
const
std
::
set
<
int
>&
vn_set
=
conflict_table
[
vn
];
for
(
const
auto
&
iter
:
vn_set
)
{
live_range
*
range
=
live_ranges
[
iter
];
long
long
offset
=
range
->
offset
;
if
(
offset
!=
invalid_offset
)
{
conflict_queue
.
push
(
range
);
if
(
offset2_live
.
find
(
offset
)
==
offset2_live
.
end
())
{
offset2_live
[
offset
]
=
range
;
}
else
{
live_range
*
prev
=
offset2_live
[
offset
];
assert
(
prev
->
offset
==
offset
);
if
(
prev
->
size
<
range
->
size
)
offset2_live
[
offset
]
=
range
;
}
}
}
}
std
::
size_t
offset
=
0
;
while
(
not
conflict_queue
.
empty
())
{
live_range
*
range
=
conflict_queue
.
top
();
std
::
size_t
iter_offset
=
range
->
offset
;
if
(
offset
>
iter_offset
)
{
offset
=
std
::
max
(
offset
,
iter_offset
+
range
->
size
);
}
else
if
(
offset2_live
[
iter_offset
]
==
range
)
{
if
((
iter_offset
>
offset
)
&&
(
iter_offset
-
offset
)
>=
size
)
{
break
;
}
offset
=
iter_offset
+
range
->
size
;
}
// alignment
if
((
offset
%
element_size
)
!=
0
)
offset
+=
(
element_size
-
(
offset
%
element_size
));
conflict_queue
.
pop
();
}
// when int8 type is used, the offset could be any number
// if not 4-byte aligned, miopen int8 convolution can crash
offset
=
(
offset
+
3
)
/
4
*
4
;
segment
.
offset
=
offset
;
MIGRAPHX_DEBUG
(
segment
.
dump
());
required_bytes
=
std
::
max
(
required_bytes
,
offset
+
segment
.
size
);
return
true
;
}
void
memory_coloring_impl
::
build
()
{
std
::
size_t
num_of_instrs
=
p_mod
->
size
();
if
(
num_of_instrs
==
0
)
return
;
auto
cur_points
=
num_of_instrs
*
2
;
instruction_ref
iter
=
p_mod
->
end
();
instruction_ref
begin
=
p_mod
->
begin
();
std
::
vector
<
instruction_ref
>
dead_instrs
;
std
::
set
<
int
>
live_set
;
// Build live intervals.
live_intervals
.
resize
(
num_of_instrs
);
do
{
iter
=
std
::
prev
(
iter
);
const
instruction
*
p_iter
=
&
(
*
iter
);
interval_ptr
def_interval
=
nullptr
;
bool
is_dead
=
false
;
if
(
instr2_live
.
find
(
p_iter
)
!=
instr2_live
.
end
())
{
def_interval
=
instr2_live
[
p_iter
];
bool
is_lit
=
is_literal
(
iter
);
if
(
is_allocate
(
iter
)
or
is_lit
)
{
live_range
&
range
=
def_interval
->
segment
;
def_interval
->
result
=
iter
->
get_shape
();
def_interval
->
is_literal
=
is_lit
;
range
.
begin
=
cur_points
;
def_interval
->
def_point
=
cur_points
;
range
.
size
=
(
iter
->
get_shape
()).
bytes
();
if
(
not
is_lit
or
unify_literals
)
alloc_queue
.
push
(
def_interval
);
live_set
.
erase
(
range
.
vn
);
}
}
else
if
(
not
is_param
(
iter
)
&&
not
is_outline
(
iter
)
&&
not
is_check_context
(
iter
))
{
is_dead
=
true
;
}
auto
inputs
=
iter
->
inputs
();
if
(
contains
(
mod_implicit_deps
,
iter
))
{
const
auto
&
impl_deps
=
mod_implicit_deps
.
at
(
iter
);
inputs
.
insert
(
inputs
.
end
(),
impl_deps
.
begin
(),
impl_deps
.
end
());
}
for
(
auto
&&
arg
:
inputs
)
{
if
(
not
p_mod
->
has_instruction
(
arg
))
continue
;
if
(
is_param
(
arg
)
or
is_outline
(
arg
))
{
if
(
is_output_param
(
arg
))
is_dead
=
false
;
if
(
def_interval
!=
nullptr
)
{
def_interval
->
is_live_on_entry
=
true
;
}
continue
;
}
const
instruction
*
p_arg
=
&
(
*
instruction
::
get_output_alias
(
arg
));
if
(
instr2_live
.
find
(
p_arg
)
==
instr2_live
.
end
())
{
// First time see a use, create a live interval.
int
id
=
num_of_lives
++
;
interval_ptr
interval
=
&
(
live_intervals
[
id
]);
interval
->
id
=
id
;
interval
->
segment
.
end
=
cur_points
;
interval
->
segment
.
vn
=
++
max_value_number
;
interval
->
add_use
(
cur_points
);
instr2_live
[
p_arg
]
=
interval
;
add_conflicts
(
live_set
,
max_value_number
);
live_set
.
insert
(
max_value_number
);
live_ranges
[
max_value_number
]
=
&
(
interval
->
segment
);
earliest_end_point
=
cur_points
;
if
(
latest_end_point
==
-
1
)
latest_end_point
=
cur_points
;
}
else
{
interval_ptr
interval
=
instr2_live
[
p_arg
];
interval
->
add_use
(
cur_points
);
assert
(
live_set
.
find
(
interval
->
id
)
!=
live_set
.
end
());
}
}
if
(
is_dead
)
dead_instrs
.
push_back
(
iter
);
cur_points
-=
2
;
}
while
(
iter
!=
begin
);
}
void
memory_coloring_impl
::
rewrite
()
{
std
::
vector
<
std
::
size_t
>
dims
;
dims
.
push_back
((
required_bytes
+
sizeof
(
float
)
-
1
)
/
sizeof
(
float
));
shape
s
=
{
shape
::
float_type
,
dims
};
instruction_ref
scratch_param
=
p_mod
->
add_parameter
(
"scratch"
,
s
);
for
(
auto
ins
:
iterator_for
(
*
p_mod
))
{
const
instruction
*
p_iter
=
&
(
*
ins
);
if
(
instr2_live
.
find
(
p_iter
)
!=
instr2_live
.
end
())
{
interval_ptr
interval
=
instr2_live
[
p_iter
];
if
(
interval
->
get_begin
()
==
invalid_offset
)
continue
;
if
(
not
unify_literals
&&
interval
->
is_literal
)
continue
;
std
::
size_t
offset
=
0
;
if
(
interval
->
get_offset
()
!=
invalid_offset
)
{
offset
=
interval
->
get_offset
();
}
else
{
assert
(
interval
->
result
.
bytes
()
==
0
);
}
if
(
is_allocate
(
ins
))
{
p_mod
->
replace_instruction
(
ins
,
make_op
(
"load"
,
{{
"shape"
,
to_value
(
ins
->
get_shape
())},
{
"offset"
,
offset
}}),
scratch_param
);
}
}
}
MIGRAPHX_DEBUG
(
dump
(
"---After rewrite---"
));
MIGRAPHX_DEBUG
(
dump_module
());
}
void
memory_coloring_impl
::
verify
()
{
if
(
num_of_lives
>
0
)
{
for
(
int
i
=
0
;
i
<
num_of_lives
;
++
i
)
{
const
live_interval
&
interval
=
live_intervals
[
i
];
const
live_range
&
segment
=
interval
.
segment
;
if
(
segment
.
begin
==
invalid_offset
)
{
// if(not interval.is_live_on_entry)
// MIGRAPHX_THROW("interval is not live on entry");
continue
;
}
if
(
segment
.
offset
==
invalid_offset
)
{
continue
;
}
int
vn
=
segment
.
vn
;
if
(
conflict_table
.
find
(
vn
)
!=
conflict_table
.
end
())
{
const
std
::
set
<
int
>&
vn_set
=
conflict_table
[
vn
];
for
(
const
auto
&
iter
:
vn_set
)
{
live_range
*
range
=
live_ranges
[
iter
];
if
(
range
->
offset
==
invalid_offset
)
continue
;
if
(
not
is_disjoin
(
*
range
,
segment
))
MIGRAPHX_THROW
(
"range and segment is not disjoined"
);
}
}
}
}
}
#ifdef MIGRAPHX_DEBUG_OPT
void
memory_coloring_impl
::
dump
(
const
std
::
string
&
str
)
{
std
::
cout
<<
str
<<
std
::
endl
;
}
void
memory_coloring_impl
::
dump_module
()
{
std
::
cout
<<
*
p_mod
<<
std
::
endl
;
}
void
memory_coloring_impl
::
dump_intervals
()
{
if
(
num_of_lives
>
0
)
{
std
::
cout
<<
"---live intervals ---"
<<
std
::
endl
;
for
(
int
i
=
0
;
i
<
num_of_lives
;
++
i
)
{
live_interval
&
interval
=
live_intervals
[
i
];
interval
.
dump
();
}
std
::
cout
<<
"---conflict table---"
<<
std
::
endl
;
for
(
int
i
=
0
;
i
<=
max_value_number
;
++
i
)
{
std
::
cout
<<
" segment:"
<<
i
;
std
::
cout
<<
" =>"
;
const
std
::
set
<
int
>&
table
=
conflict_table
[
i
];
for
(
const
auto
&
iter
:
table
)
{
std
::
cout
<<
(
iter
)
<<
","
;
}
}
std
::
cout
<<
std
::
endl
;
}
}
// map liveness tracking point to instruction enum.
static
int
get_ins_enum
(
int
x
)
{
if
(
x
>
0
)
{
return
(
x
/
2
)
-
1
;
}
else
return
invalid_offset
;
}
void
live_range
::
dump
()
{
std
::
cout
<<
" segment:"
<<
vn
;
std
::
cout
<<
" ["
<<
get_ins_enum
(
begin
)
<<
", "
<<
get_ins_enum
(
end
)
<<
"]"
;
if
(
offset
!=
invalid_offset
)
{
std
::
cout
<<
" mem:"
;
std
::
cout
<<
" ["
<<
offset
<<
","
<<
offset
+
size
-
1
<<
"]"
;
}
std
::
cout
<<
std
::
endl
;
}
void
live_interval
::
dump
()
{
std
::
cout
<<
"id:"
<<
id
;
segment
.
dump
();
std
::
cout
<<
" uses:"
;
for
(
const
auto
&
iter
:
use_points
)
{
std
::
cout
<<
" "
<<
get_ins_enum
(
iter
)
<<
","
;
}
std
::
cout
<<
" def:"
;
std
::
cout
<<
" "
<<
get_ins_enum
(
def_point
);
if
(
is_literal
)
std
::
cout
<<
" literal"
;
std
::
cout
<<
" "
<<
result
;
std
::
cout
<<
std
::
endl
;
}
#endif
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/opt/memory_coloring_impl.hpp
deleted
100644 → 0
View file @
830dff7a
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_MEMORY_COLORING_IMPL_HPP
#define MIGRAPHX_GUARD_RTGLIB_MEMORY_COLORING_IMPL_HPP
#include <migraphx/program.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/pass_config.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/config.hpp>
#include <set>
#include <list>
#include <vector>
#include <queue>
#ifdef MIGRAPHX_DEBUG_OPT
#define MIGRAPHX_DEBUG(s) s
#else
#define MIGRAPHX_DEBUG(s)
#endif // MIGRAPHX_DEBUG_OPT
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
static
const
std
::
size_t
invalid_offset
=
std
::
numeric_limits
<
std
::
size_t
>::
max
();
struct
live_range
{
std
::
size_t
begin
;
// begin point in the instruction stream.
std
::
size_t
end
;
// end point in the instruction stream.
std
::
size_t
offset
;
// offset to base pointer of allocated memory trunk.
std
::
size_t
vn
;
// value number that identifies this live_range.
std
::
size_t
size
;
// size of required memory in bytes
#ifdef MIGRAPHX_DEBUG_OPT
void
dump
();
#endif
};
struct
live_interval
{
live_interval
()
:
segment
({
invalid_offset
,
invalid_offset
,
invalid_offset
,
invalid_offset
,
0
})
{
}
void
add_use
(
std
::
size_t
use
)
{
use_points
.
push_front
(
use
);
}
std
::
size_t
get_begin
()
const
{
return
segment
.
begin
;
}
std
::
size_t
get_end
()
const
{
return
segment
.
end
;
}
long
long
get_offset
()
const
{
return
segment
.
offset
;
}
#ifdef MIGRAPHX_DEBUG_OPT
void
dump
();
#endif
live_range
segment
;
std
::
size_t
id
=
invalid_offset
;
std
::
list
<
std
::
size_t
>
use_points
{};
std
::
size_t
def_point
=
invalid_offset
;
shape
result
{};
bool
is_literal
=
false
;
bool
is_live_on_entry
=
false
;
};
using
interval_ptr
=
live_interval
*
;
struct
memory_coloring_impl
{
memory_coloring_impl
(
module
*
p
,
std
::
string
alloc_op
,
bool
p_verify
)
:
p_mod
(
p
),
allocation_op
(
std
::
move
(
alloc_op
)),
enable_verify
(
p_verify
)
{
}
bool
allocate
(
interval_ptr
);
void
add_conflicts
(
const
std
::
set
<
int
>&
live_set
,
int
val
)
{
for
(
const
auto
&
iter
:
live_set
)
{
conflict_table
[
iter
].
insert
(
val
);
conflict_table
[
val
].
insert
(
iter
);
}
}
void
build
();
void
run
();
void
rewrite
();
private:
static
bool
is_param
(
const
instruction_ref
ins
)
{
return
ins
->
name
()
==
"@param"
;
}
static
bool
is_output_param
(
const
instruction_ref
ins
)
{
if
(
not
is_param
(
ins
))
return
false
;
auto
param_name
=
any_cast
<
builtin
::
param
>
(
ins
->
get_operator
()).
parameter
;
return
contains
(
param_name
,
"#output_"
);
}
bool
is_allocate
(
const
instruction_ref
ins
)
const
{
return
ins
->
name
()
==
allocation_op
;
}
static
bool
is_outline
(
const
instruction_ref
ins
)
{
return
ins
->
name
()
==
"@outline"
;
}
static
bool
is_literal
(
const
instruction_ref
ins
)
{
return
ins
->
name
()
==
"@literal"
;
}
static
bool
is_check_context
(
const
instruction_ref
ins
)
{
return
ins
->
name
()
==
"check_context"
;
}
static
bool
is_disjoin
(
const
live_range
&
range1
,
const
live_range
&
range2
)
{
if
((
range1
.
size
==
0
)
or
(
range2
.
size
==
0
))
return
false
;
auto
end1
=
range1
.
offset
+
range1
.
size
-
1
;
auto
end2
=
range2
.
offset
+
range2
.
size
-
1
;
return
((
end1
<
range2
.
offset
)
or
(
end2
<
range1
.
offset
));
}
void
verify
();
#ifdef MIGRAPHX_DEBUG_OPT
void
dump
(
const
std
::
string
&
);
void
dump_module
();
void
dump_intervals
();
#endif
struct
ordering
{
bool
operator
()(
const
interval_ptr
&
i1
,
const
interval_ptr
&
i2
)
const
{
auto
len1
=
i1
->
get_end
()
-
i1
->
get_begin
();
auto
len2
=
i2
->
get_end
()
-
i2
->
get_begin
();
if
(
len1
!=
len2
)
{
return
(
len1
<
len2
);
}
else
if
(
i1
->
result
.
bytes
()
!=
i2
->
result
.
bytes
())
{
return
(
i1
->
result
.
bytes
()
<
i2
->
result
.
bytes
());
}
else
{
return
i1
->
id
>
i2
->
id
;
}
}
bool
operator
()(
const
live_range
*
i1
,
const
live_range
*
i2
)
const
{
return
(
i1
->
offset
>
i2
->
offset
);
}
};
module
*
p_mod
;
std
::
unordered_map
<
const
instruction
*
,
interval_ptr
>
instr2_live
;
// universe of live intervals.
std
::
vector
<
live_interval
>
live_intervals
=
{};
// Map live range value number to live range.
std
::
unordered_map
<
int
,
live_range
*>
live_ranges
=
{};
// Map live range value number to a set of conflicting live ranges' value numbers.
std
::
unordered_map
<
int
,
std
::
set
<
int
>>
conflict_table
=
{};
// Priority queue for coloring.
std
::
priority_queue
<
interval_ptr
,
std
::
vector
<
interval_ptr
>
,
ordering
>
alloc_queue
{};
int
num_of_lives
=
0
;
int
max_value_number
=
-
1
;
std
::
size_t
required_bytes
=
0
;
// The earliest program point where an live interval ends.
int
earliest_end_point
=
-
1
;
// The latest program point where an live interval ends.
int
latest_end_point
=
-
1
;
// Whether to unify literals into coloring.
bool
unify_literals
=
false
;
std
::
string
allocation_op
{};
bool
enable_verify
;
ins_dep_map
mod_implicit_deps
;
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/pass_manager.cpp
View file @
baac1dab
...
...
@@ -86,14 +86,24 @@ struct module_pm : module_pass_manager
assert
(
mod
);
return
*
mod
;
}
virtual
module
*
create_module
(
const
std
::
string
&
name
)
override
{
assert
(
prog
);
return
prog
->
create_module
(
name
);
}
virtual
module
*
get_common_parent
()
override
{
return
common_parent
;
}
virtual
module
*
get_root_module
()
override
{
assert
(
prog
);
return
prog
->
get_main_module
();
}
virtual
void
run_pass
(
const
pass
&
p
)
override
{
trace
(
"Pass: "
,
p
.
name
());
assert
(
mod
);
assert
(
mod
->
validate
()
==
mod
->
end
());
if
(
enabled
(
MIGRAPHX_TIME_PASSES
{}))
...
...
src/process.cpp
View file @
baac1dab
...
...
@@ -38,27 +38,42 @@ std::function<void(const char*)> redirect_to(std::ostream& os)
return
[
&
](
const
char
*
x
)
{
os
<<
x
;
};
}
int
exec
(
const
std
::
string
&
cmd
,
const
std
::
function
<
void
(
const
char
*
)
>&
std_out
)
template
<
class
F
>
int
exec
(
const
std
::
string
&
cmd
,
const
char
*
type
,
F
f
)
{
int
ec
=
0
;
if
(
enabled
(
MIGRAPHX_TRACE_CMD_EXECUTE
{}))
std
::
cout
<<
cmd
<<
std
::
endl
;
auto
closer
=
[
&
](
FILE
*
stream
)
{
auto
status
=
pclose
(
stream
);
ec
=
WIFEXITED
(
status
)
?
0
:
WEXITSTATUS
(
status
);
// NOLINT
ec
=
WIFEXITED
(
status
)
?
WEXITSTATUS
(
status
)
:
0
;
// NOLINT
};
{
// TODO: Use execve instead of popen
std
::
unique_ptr
<
FILE
,
decltype
(
closer
)
>
pipe
(
popen
(
cmd
.
c_str
(),
"r"
),
closer
);
// NOLINT
std
::
unique_ptr
<
FILE
,
decltype
(
closer
)
>
pipe
(
popen
(
cmd
.
c_str
(),
type
),
closer
);
// NOLINT
if
(
not
pipe
)
MIGRAPHX_THROW
(
"popen() failed: "
+
cmd
);
std
::
array
<
char
,
128
>
buffer
;
while
(
fgets
(
buffer
.
data
(),
buffer
.
size
(),
pipe
.
get
())
!=
nullptr
)
std_out
(
buffer
.
data
());
f
(
pipe
.
get
());
}
return
ec
;
}
int
exec
(
const
std
::
string
&
cmd
,
const
std
::
function
<
void
(
const
char
*
)
>&
std_out
)
{
return
exec
(
cmd
,
"r"
,
[
&
](
FILE
*
f
)
{
std
::
array
<
char
,
128
>
buffer
;
while
(
fgets
(
buffer
.
data
(),
buffer
.
size
(),
f
)
!=
nullptr
)
std_out
(
buffer
.
data
());
});
}
int
exec
(
const
std
::
string
&
cmd
,
std
::
function
<
void
(
process
::
writer
)
>
std_in
)
{
return
exec
(
cmd
,
"w"
,
[
&
](
FILE
*
f
)
{
std_in
([
&
](
const
char
*
buffer
,
std
::
size_t
n
)
{
std
::
fwrite
(
buffer
,
1
,
n
,
f
);
});
});
}
struct
process_impl
{
std
::
string
command
{};
...
...
@@ -72,6 +87,15 @@ struct process_impl
result
+=
command
;
return
result
;
}
template
<
class
...
Ts
>
void
check_exec
(
Ts
&&
...
xs
)
const
{
int
ec
=
migraphx
::
exec
(
std
::
forward
<
Ts
>
(
xs
)...);
if
(
ec
!=
0
)
MIGRAPHX_THROW
(
"Command "
+
get_command
()
+
" exited with status "
+
std
::
to_string
(
ec
));
}
};
process
::
process
(
const
std
::
string
&
cmd
)
:
impl
(
std
::
make_unique
<
process_impl
>
())
...
...
@@ -95,12 +119,11 @@ process& process::cwd(const fs::path& p)
return
*
this
;
}
void
process
::
exec
()
void
process
::
exec
()
{
impl
->
check_exec
(
impl
->
get_command
(),
redirect_to
(
std
::
cout
));
}
void
process
::
write
(
std
::
function
<
void
(
process
::
writer
)
>
pipe_in
)
{
auto
ec
=
migraphx
::
exec
(
impl
->
get_command
(),
redirect_to
(
std
::
cout
));
if
(
ec
!=
0
)
MIGRAPHX_THROW
(
"Command "
+
impl
->
get_command
()
+
" exited with status "
+
std
::
to_string
(
ec
));
impl
->
check_exec
(
impl
->
get_command
(),
std
::
move
(
pipe_in
));
}
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/program.cpp
View file @
baac1dab
...
...
@@ -210,17 +210,15 @@ void program::compile(const target& t, compile_options options)
assert
(
not
this
->
is_compiled
());
this
->
impl
->
target_name
=
t
.
name
();
this
->
impl
->
ctx
=
t
.
get_context
();
if
(
enabled
(
MIGRAPHX_TRACE_COMPILE
{}))
options
.
trace
=
tracer
{
std
::
cout
};
options
.
trace
(
*
this
);
options
.
trace
();
auto
&&
passes
=
t
.
get_passes
(
this
->
impl
->
ctx
,
options
);
run_passes
(
*
this
,
passes
,
options
.
trace
);
auto
mods
=
this
->
get_modules
();
// Validate and finalize
for
(
const
auto
&
mod
:
reverse
(
mods
))
{
...
...
@@ -333,7 +331,8 @@ std::vector<argument> generic_eval(const module* mod,
MIGRAPHX_THROW
(
"Parameter not found: "
+
param_name
);
auto
param
=
params
[
param_name
];
// TODO: may want to check correct number of dimensions and/or was within bounds
if
(
not
ins
->
get_shape
().
dynamic
()
and
param
.
get_shape
()
!=
ins
->
get_shape
())
if
(
not
ins
->
get_shape
().
any_of_dynamic
()
and
param
.
get_shape
()
!=
ins
->
get_shape
())
{
MIGRAPHX_THROW
(
"Incorrect shape {"
+
to_string
(
param
.
get_shape
())
+
"} for parameter: "
+
param_name
+
...
...
@@ -381,7 +380,7 @@ std::vector<argument> generic_eval(const module* mod,
}));
}
assert
(
results
.
find
(
ins
)
!=
results
.
end
());
if
(
not
ins
->
get_shape
().
dynamic
())
if
(
not
ins
->
get_shape
().
any_of_
dynamic
())
{
assert
(
results
.
at
(
ins
).
get_shape
()
==
ins
->
get_shape
());
}
...
...
src/promote_literals.cpp
0 → 100644
View file @
baac1dab
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/promote_literals.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/module.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
void
promote_literals
::
apply
(
module_pass_manager
&
mpm
)
const
{
module
&
m
=
mpm
.
get_module
();
module_ref
root_module
=
mpm
.
get_root_module
();
if
(
m
.
name
()
==
"main"
)
return
;
for
(
auto
ins
:
iterator_for
(
m
))
{
if
(
ins
->
name
()
==
"@literal"
)
{
auto
new_lit
=
root_module
->
add_literal
(
ins
->
get_literal
());
for
(
auto
out_ins
:
ins
->
outputs
())
{
out_ins
->
replace_argument
(
out_ins
,
ins
,
new_lit
);
}
}
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/propagate_constant.cpp
View file @
baac1dab
...
...
@@ -27,11 +27,14 @@
#include <migraphx/literal.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/env.hpp>
#include <unordered_set>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_TRACE_PROPAGATE_CONSTANT
)
bool
skip_propogate
(
instruction_ref
ins
)
{
if
(
ins
->
name
()
==
"contiguous"
)
...
...
@@ -44,7 +47,7 @@ bool skip_propogate(instruction_ref ins)
return
false
;
}
bool
is_const
(
instruction_ref
ins
)
{
return
ins
->
can_eval
()
and
not
skip_propogate
(
ins
);
}
bool
is_const
_ins
(
instruction_ref
ins
)
{
return
ins
->
can_eval
()
and
not
skip_propogate
(
ins
);
}
void
propagate_constant
::
apply
(
module
&
m
)
const
{
...
...
@@ -54,14 +57,23 @@ void propagate_constant::apply(module& m) const
// Find instructions that can be evaluated to a literal
for
(
auto
i
:
iterator_for
(
m
))
{
if
(
is_const
(
i
)
and
i
!=
last
)
const
bool
is_const
=
is_const_ins
(
i
);
if
(
is_const
and
i
!=
last
)
continue
;
std
::
copy_if
(
i
->
inputs
().
begin
(),
i
->
inputs
().
end
(),
std
::
inserter
(
const_instrs
,
const_instrs
.
begin
()),
[
&
](
const
instruction_ref
ins
)
{
return
is_const
(
ins
)
and
ins
->
name
()
!=
"@literal"
;
});
if
(
i
==
last
and
is_const
)
{
const_instrs
.
insert
(
i
);
}
else
{
std
::
copy_if
(
i
->
inputs
().
begin
(),
i
->
inputs
().
end
(),
std
::
inserter
(
const_instrs
,
const_instrs
.
begin
()),
[
&
](
const
instruction_ref
ins
)
{
return
is_const_ins
(
ins
)
and
ins
->
name
()
!=
"@literal"
;
});
}
}
// Compute literals in parallel
...
...
@@ -76,6 +88,19 @@ void propagate_constant::apply(module& m) const
{
if
(
not
literals
[
i
].
empty
())
{
if
(
enabled
(
MIGRAPHX_TRACE_PROPAGATE_CONSTANT
{}))
{
std
::
cout
<<
"Constant replace: "
<<
std
::
endl
;
std
::
vector
<
instruction_ref
>
inss
;
fix
([
&
](
auto
self
,
auto
ins
)
{
if
(
contains
(
inss
,
ins
))
return
;
for
(
auto
input
:
ins
->
inputs
())
self
(
input
);
inss
.
push_back
(
ins
);
})(
const_instrs_vec
[
i
]);
m
.
debug_print
(
inss
);
}
assert
(
literals
[
i
].
get_shape
()
==
const_instrs_vec
[
i
]
->
get_shape
());
auto
l
=
m
.
add_literal
(
literals
[
i
].
get_shape
(),
literals
[
i
].
data
());
m
.
replace_instruction
(
const_instrs_vec
[
i
],
l
);
...
...
src/py/migraphx_py.cpp
View file @
baac1dab
...
...
@@ -35,7 +35,6 @@
#include <migraphx/stringutils.hpp>
#include <migraphx/tf.hpp>
#include <migraphx/onnx.hpp>
#include <migraphx/type_name.hpp>
#include <migraphx/load_save.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/json.hpp>
...
...
@@ -63,6 +62,7 @@ namespace py = pybind11;
PYBIND11_MODULE(__VA_ARGS__) \
MIGRAPHX_POP_WARNING
#define MIGRAPHX_PYTHON_GENERATE_SHAPE_ENUM(x, t) .value(#x, migraphx::shape::type_t::x)
namespace
migraphx
{
migraphx
::
value
to_value
(
py
::
kwargs
kwargs
);
...
...
@@ -95,6 +95,10 @@ void visit_py(T x, F f)
{
f
(
x
.
template
cast
<
std
::
string
>());
}
else
if
(
py
::
isinstance
<
migraphx
::
shape
::
dynamic_dimension
>
(
x
))
{
f
(
migraphx
::
to_value
(
x
.
template
cast
<
migraphx
::
shape
::
dynamic_dimension
>()));
}
else
{
MIGRAPHX_THROW
(
"VISIT_PY: Unsupported data type!"
);
...
...
@@ -165,7 +169,10 @@ template <class T>
py
::
buffer_info
to_buffer_info
(
T
&
x
)
{
migraphx
::
shape
s
=
x
.
get_shape
();
auto
strides
=
s
.
strides
();
assert
(
s
.
type
()
!=
migraphx
::
shape
::
tuple_type
);
if
(
s
.
dynamic
())
MIGRAPHX_THROW
(
"MIGRAPHX PYTHON: dynamic shape argument passed to to_buffer_info"
);
auto
strides
=
s
.
strides
();
std
::
transform
(
strides
.
begin
(),
strides
.
end
(),
strides
.
begin
(),
[
&
](
auto
i
)
{
return
i
*
s
.
type_size
();
});
py
::
buffer_info
b
;
...
...
@@ -177,7 +184,7 @@ py::buffer_info to_buffer_info(T& x)
b
=
py
::
buffer_info
(
x
.
data
(),
as
.
size
(),
py
::
format_descriptor
<
bool
>::
format
(),
s
.
lens
().
size
(),
s
.
ndim
(),
s
.
lens
(),
strides
);
}
...
...
@@ -186,7 +193,7 @@ py::buffer_info to_buffer_info(T& x)
b
=
py
::
buffer_info
(
x
.
data
(),
as
.
size
(),
py
::
format_descriptor
<
decltype
(
as
())
>::
format
(),
s
.
lens
().
size
(),
s
.
ndim
(),
s
.
lens
(),
strides
);
}
...
...
@@ -236,10 +243,18 @@ migraphx::shape to_shape(const py::buffer_info& info)
MIGRAPHX_PYBIND11_MODULE
(
migraphx
,
m
)
{
py
::
class_
<
migraphx
::
shape
>
(
m
,
"shape"
)
py
::
class_
<
migraphx
::
shape
>
shape_cls
(
m
,
"shape"
);
shape_cls
.
def
(
py
::
init
([](
py
::
kwargs
kwargs
)
{
auto
v
=
migraphx
::
to_value
(
kwargs
);
auto
t
=
migraphx
::
shape
::
parse_type
(
v
.
get
(
"type"
,
"float"
));
auto
v
=
migraphx
::
to_value
(
kwargs
);
auto
t
=
migraphx
::
shape
::
parse_type
(
v
.
get
(
"type"
,
"float"
));
if
(
v
.
contains
(
"dyn_dims"
))
{
auto
dyn_dims
=
migraphx
::
from_value
<
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>>
(
v
.
at
(
"dyn_dims"
));
return
migraphx
::
shape
(
t
,
dyn_dims
);
}
auto
lens
=
v
.
get
<
std
::
size_t
>
(
"lens"
,
{
1
});
if
(
v
.
contains
(
"strides"
))
return
migraphx
::
shape
(
t
,
lens
,
v
.
at
(
"strides"
).
to_vector
<
std
::
size_t
>
());
...
...
@@ -249,19 +264,34 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
.
def
(
"type"
,
&
migraphx
::
shape
::
type
)
.
def
(
"lens"
,
&
migraphx
::
shape
::
lens
)
.
def
(
"strides"
,
&
migraphx
::
shape
::
strides
)
.
def
(
"ndim"
,
&
migraphx
::
shape
::
ndim
)
.
def
(
"elements"
,
&
migraphx
::
shape
::
elements
)
.
def
(
"bytes"
,
&
migraphx
::
shape
::
bytes
)
.
def
(
"type_string"
,
&
migraphx
::
shape
::
type_string
)
.
def
(
"type_size"
,
&
migraphx
::
shape
::
type_size
)
.
def
(
"dyn_dims"
,
&
migraphx
::
shape
::
dyn_dims
)
.
def
(
"packed"
,
&
migraphx
::
shape
::
packed
)
.
def
(
"transposed"
,
&
migraphx
::
shape
::
transposed
)
.
def
(
"broadcasted"
,
&
migraphx
::
shape
::
broadcasted
)
.
def
(
"standard"
,
&
migraphx
::
shape
::
standard
)
.
def
(
"scalar"
,
&
migraphx
::
shape
::
scalar
)
.
def
(
"dynamic"
,
&
migraphx
::
shape
::
dynamic
)
.
def
(
"__eq__"
,
std
::
equal_to
<
migraphx
::
shape
>
{})
.
def
(
"__ne__"
,
std
::
not_equal_to
<
migraphx
::
shape
>
{})
.
def
(
"__repr__"
,
[](
const
migraphx
::
shape
&
s
)
{
return
migraphx
::
to_string
(
s
);
});
py
::
enum_
<
migraphx
::
shape
::
type_t
>
(
shape_cls
,
"type_t"
)
MIGRAPHX_SHAPE_VISIT_TYPES
(
MIGRAPHX_PYTHON_GENERATE_SHAPE_ENUM
);
py
::
class_
<
migraphx
::
shape
::
dynamic_dimension
>
(
shape_cls
,
"dynamic_dimension"
)
.
def
(
py
::
init
<>
())
.
def
(
py
::
init
<
std
::
size_t
,
std
::
size_t
>
())
.
def
(
py
::
init
<
std
::
size_t
,
std
::
size_t
,
std
::
set
<
std
::
size_t
>>
())
.
def_readwrite
(
"min"
,
&
migraphx
::
shape
::
dynamic_dimension
::
min
)
.
def_readwrite
(
"max"
,
&
migraphx
::
shape
::
dynamic_dimension
::
max
)
.
def_readwrite
(
"optimals"
,
&
migraphx
::
shape
::
dynamic_dimension
::
optimals
)
.
def
(
"is_fixed"
,
&
migraphx
::
shape
::
dynamic_dimension
::
is_fixed
);
py
::
class_
<
migraphx
::
argument
>
(
m
,
"argument"
,
py
::
buffer_protocol
())
.
def_buffer
([](
migraphx
::
argument
&
x
)
->
py
::
buffer_info
{
return
to_buffer_info
(
x
);
})
.
def
(
py
::
init
([](
py
::
buffer
b
)
{
...
...
@@ -283,7 +313,9 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
py
::
class_
<
migraphx
::
target
>
(
m
,
"target"
);
py
::
class_
<
migraphx
::
instruction_ref
>
(
m
,
"instruction_ref"
);
py
::
class_
<
migraphx
::
instruction_ref
>
(
m
,
"instruction_ref"
)
.
def
(
"shape"
,
[](
migraphx
::
instruction_ref
i
)
{
return
i
->
get_shape
();
})
.
def
(
"op"
,
[](
migraphx
::
instruction_ref
i
)
{
return
i
->
get_operator
();
});
py
::
class_
<
migraphx
::
module
,
std
::
unique_ptr
<
migraphx
::
module
,
py
::
nodelete
>>
(
m
,
"module"
)
.
def
(
"print"
,
[](
const
migraphx
::
module
&
mm
)
{
std
::
cout
<<
mm
<<
std
::
endl
;
})
...
...
@@ -329,15 +361,21 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
.
def
(
"is_compiled"
,
&
migraphx
::
program
::
is_compiled
)
.
def
(
"compile"
,
[](
migraphx
::
program
&
p
,
const
migraphx
::
target
&
t
,
bool
offload_copy
,
bool
fast_math
)
{
[](
migraphx
::
program
&
p
,
const
migraphx
::
target
&
t
,
bool
offload_copy
,
bool
fast_math
,
bool
exhaustive_tune
)
{
migraphx
::
compile_options
options
;
options
.
offload_copy
=
offload_copy
;
options
.
fast_math
=
fast_math
;
options
.
offload_copy
=
offload_copy
;
options
.
fast_math
=
fast_math
;
options
.
exhaustive_tune
=
exhaustive_tune
;
p
.
compile
(
t
,
options
);
},
py
::
arg
(
"t"
),
py
::
arg
(
"offload_copy"
)
=
true
,
py
::
arg
(
"fast_math"
)
=
true
)
py
::
arg
(
"offload_copy"
)
=
true
,
py
::
arg
(
"fast_math"
)
=
true
,
py
::
arg
(
"exhaustive_tune"
)
=
false
)
.
def
(
"get_main_module"
,
[](
const
migraphx
::
program
&
p
)
{
return
p
.
get_main_module
();
})
.
def
(
"create_module"
,
...
...
@@ -428,13 +466,18 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
"parse_onnx"
,
[](
const
std
::
string
&
filename
,
unsigned
int
default_dim_value
,
migraphx
::
shape
::
dynamic_dimension
default_dyn_dim_value
,
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
size_t
>>
map_input_dims
,
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>>
map_dyn_input_dims
,
bool
skip_unknown_operators
,
bool
print_program_on_error
,
int64_t
max_loop_iterations
)
{
migraphx
::
onnx_options
options
;
options
.
default_dim_value
=
default_dim_value
;
options
.
default_dyn_dim_value
=
default_dyn_dim_value
;
options
.
map_input_dims
=
map_input_dims
;
options
.
map_dyn_input_dims
=
map_dyn_input_dims
;
options
.
skip_unknown_operators
=
skip_unknown_operators
;
options
.
print_program_on_error
=
print_program_on_error
;
options
.
max_loop_iterations
=
max_loop_iterations
;
...
...
@@ -442,8 +485,11 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
},
"Parse onnx file"
,
py
::
arg
(
"filename"
),
py
::
arg
(
"default_dim_value"
)
=
1
,
py
::
arg
(
"map_input_dims"
)
=
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
size_t
>>
(),
py
::
arg
(
"default_dim_value"
)
=
0
,
py
::
arg
(
"default_dyn_dim_value"
)
=
migraphx
::
shape
::
dynamic_dimension
{
1
,
1
},
py
::
arg
(
"map_input_dims"
)
=
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
size_t
>>
(),
py
::
arg
(
"map_dyn_input_dims"
)
=
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>>
(),
py
::
arg
(
"skip_unknown_operators"
)
=
false
,
py
::
arg
(
"print_program_on_error"
)
=
false
,
py
::
arg
(
"max_loop_iterations"
)
=
10
);
...
...
@@ -452,20 +498,28 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
"parse_onnx_buffer"
,
[](
const
std
::
string
&
onnx_buffer
,
unsigned
int
default_dim_value
,
migraphx
::
shape
::
dynamic_dimension
default_dyn_dim_value
,
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
size_t
>>
map_input_dims
,
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>>
map_dyn_input_dims
,
bool
skip_unknown_operators
,
bool
print_program_on_error
)
{
migraphx
::
onnx_options
options
;
options
.
default_dim_value
=
default_dim_value
;
options
.
default_dyn_dim_value
=
default_dyn_dim_value
;
options
.
map_input_dims
=
map_input_dims
;
options
.
map_dyn_input_dims
=
map_dyn_input_dims
;
options
.
skip_unknown_operators
=
skip_unknown_operators
;
options
.
print_program_on_error
=
print_program_on_error
;
return
migraphx
::
parse_onnx_buffer
(
onnx_buffer
,
options
);
},
"Parse onnx file"
,
py
::
arg
(
"filename"
),
py
::
arg
(
"default_dim_value"
)
=
1
,
py
::
arg
(
"map_input_dims"
)
=
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
size_t
>>
(),
py
::
arg
(
"default_dim_value"
)
=
0
,
py
::
arg
(
"default_dyn_dim_value"
)
=
migraphx
::
shape
::
dynamic_dimension
{
1
,
1
},
py
::
arg
(
"map_input_dims"
)
=
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
size_t
>>
(),
py
::
arg
(
"map_dyn_input_dims"
)
=
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>>
(),
py
::
arg
(
"skip_unknown_operators"
)
=
false
,
py
::
arg
(
"print_program_on_error"
)
=
false
);
...
...
src/register_op.cpp
View file @
baac1dab
...
...
@@ -33,7 +33,17 @@ std::unordered_map<std::string, operation>& op_map()
static
std
::
unordered_map
<
std
::
string
,
operation
>
m
;
// NOLINT
return
m
;
}
void
register_op_init
()
{
(
void
)
op_map
();
}
void
register_op
(
const
operation
&
op
)
{
op_map
()[
op
.
name
()]
=
op
;
}
void
unregister_op
(
const
std
::
string
&
op_name
)
{
assert
(
op_map
().
count
(
op_name
));
op_map
().
erase
(
op_name
);
}
operation
load_op
(
const
std
::
string
&
name
)
{
return
at
(
op_map
(),
name
,
"Operator not found: "
+
name
);
...
...
src/register_target.cpp
View file @
baac1dab
...
...
@@ -21,26 +21,48 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <string>
#include <unordered_map>
#include <migraphx/register_target.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/dynamic_loader.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
void
store_target_lib
(
const
dynamic_loader
&
lib
)
{
static
std
::
vector
<
dynamic_loader
>
target_loader
;
target_loader
.
emplace_back
(
lib
);
}
std
::
unordered_map
<
std
::
string
,
target
>&
target_map
()
{
static
std
::
unordered_map
<
std
::
string
,
target
>
m
;
// NOLINT
return
m
;
}
void
register_target_init
()
{
(
void
)
target_map
();
}
void
unregister_target
(
const
std
::
string
&
name
)
{
assert
(
target_map
().
count
(
name
));
target_map
().
erase
(
name
);
}
void
register_target
(
const
target
&
t
)
{
target_map
()[
t
.
name
()]
=
t
;
}
target
make_target
(
const
std
::
string
&
name
)
{
if
(
not
contains
(
target_map
(),
name
))
{
std
::
string
target_name
=
"libmigraphx_"
+
name
+
".so"
;
store_target_lib
(
dynamic_loader
(
target_name
));
}
const
auto
it
=
target_map
().
find
(
name
);
if
(
it
==
target_map
().
end
())
{
MIGRAPHX_THROW
(
"Requested target '"
+
name
+
"' is not
enabl
ed or not supported"
);
MIGRAPHX_THROW
(
"Requested target '"
+
name
+
"' is not
load
ed or not supported"
);
}
return
it
->
second
;
}
...
...
src/replace_allocate.cpp
View file @
baac1dab
...
...
@@ -104,19 +104,16 @@ void replace_allocate::apply(module& m) const
continue
;
auto
s
=
ins
->
get_shape
();
if
(
not
main_offload_copy
and
model
.
needs_out_params
()
and
contains
(
mod_output_names
,
ins
))
{
auto
out_param
=
m
.
add_parameter
(
mod_output_names
[
ins
],
s
);
m
.
replace_instruction
(
ins
,
out_param
);
continue
;
}
m
.
replace_instruction
(
ins
,
m
.
insert_instruction
(
ins
,
make_op
(
model
.
name
(),
migraphx
::
value
{{
"shape"
,
to_value
(
s
)}})));
else
{
m
.
replace_instruction
(
ins
,
make_op
(
model
.
name
(),
migraphx
::
value
{{
"shape"
,
to_value
(
s
)}}));
}
}
}
...
...
src/schedule.cpp
View file @
baac1dab
...
...
@@ -327,10 +327,10 @@ struct stream_info
return
[
=
](
auto
f
)
{
return
fix
<
bool
>
([
&
](
auto
self
,
auto
ins
)
{
return
all_of
(
select
(
ins
),
[
&
](
auto
i
)
{
if
(
iweights
.
at
(
i
)
==
0
)
return
self
(
i
);
else
if
(
has_stream
(
i
))
return
f
(
this
->
get_stream
(
i
));
else
return
self
(
i
);
});
})(
start
);
};
...
...
src/shape.cpp
View file @
baac1dab
...
...
@@ -74,13 +74,23 @@ struct shape_impl
shape_impl
(
shape
::
type_t
t
,
std
::
vector
<
std
::
size_t
>
mins
,
std
::
vector
<
std
::
size_t
>
maxes
,
std
::
vector
<
std
::
size_t
>
opt
s
)
std
::
vector
<
std
::
set
<
std
::
size_t
>
>
opt
imals_list
)
:
m_type
(
t
)
{
assert
(
mins
.
size
()
==
maxes
.
size
()
and
maxes
.
size
()
==
opts
.
size
());
for
(
size_t
i
=
0
;
i
<
mins
.
size
();
++
i
)
if
(
optimals_list
.
empty
())
{
m_dyn_dims
.
push_back
(
shape
::
dynamic_dimension
{
mins
[
i
],
maxes
[
i
],
opts
[
i
]});
for
(
size_t
i
=
0
;
i
<
mins
.
size
();
++
i
)
{
m_dyn_dims
.
push_back
(
shape
::
dynamic_dimension
{
mins
[
i
],
maxes
[
i
]});
}
}
else
{
assert
(
mins
.
size
()
==
maxes
.
size
()
and
maxes
.
size
()
==
optimals_list
.
size
());
for
(
size_t
i
=
0
;
i
<
mins
.
size
();
++
i
)
{
m_dyn_dims
.
push_back
(
shape
::
dynamic_dimension
{
mins
[
i
],
maxes
[
i
],
optimals_list
[
i
]});
}
}
}
...
...
@@ -147,7 +157,7 @@ struct shape_impl
std
::
transform
(
m_dyn_dims
.
cbegin
(),
m_dyn_dims
.
cend
(),
ret
.
begin
(),
[](
shape
::
dynamic_dimension
x
)
{
return
x
.
min
;
});
[](
const
shape
::
dynamic_dimension
&
x
)
{
return
x
.
min
;
});
return
ret
;
}
...
...
@@ -157,19 +167,20 @@ struct shape_impl
std
::
transform
(
m_dyn_dims
.
cbegin
(),
m_dyn_dims
.
cend
(),
ret
.
begin
(),
[](
shape
::
dynamic_dimension
x
)
{
return
x
.
max
;
});
[](
const
shape
::
dynamic_dimension
&
x
)
{
return
x
.
max
;
});
return
ret
;
}
std
::
vector
<
std
::
size_t
>
opt_lens
()
const
std
::
vector
<
std
::
set
<
std
::
size_t
>
>
opt_lens
()
const
{
std
::
vector
<
std
::
size_t
>
ret
(
m_dyn_dims
.
size
());
std
::
vector
<
std
::
set
<
std
::
size_t
>
>
ret
(
m_dyn_dims
.
size
());
std
::
transform
(
m_dyn_dims
.
cbegin
(),
m_dyn_dims
.
cend
(),
ret
.
begin
(),
[](
shape
::
dynamic_dimension
x
)
{
return
x
.
opt
;
});
[](
const
shape
::
dynamic_dimension
&
x
)
{
return
x
.
opt
imals
;
});
return
ret
;
}
// Does the shape skip over elements?
bool
skips
()
const
{
...
...
@@ -240,8 +251,9 @@ shape::shape(type_t t, std::vector<shape::dynamic_dimension> dims)
shape
::
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
mins
,
std
::
vector
<
std
::
size_t
>
maxes
,
std
::
vector
<
std
::
size_t
>
opts
)
:
impl
(
std
::
make_shared
<
shape_impl
>
(
t
,
std
::
move
(
mins
),
std
::
move
(
maxes
),
std
::
move
(
opts
)))
std
::
vector
<
std
::
set
<
std
::
size_t
>>
optimals_list
)
:
impl
(
std
::
make_shared
<
shape_impl
>
(
t
,
std
::
move
(
mins
),
std
::
move
(
maxes
),
std
::
move
(
optimals_list
)))
{
}
...
...
@@ -349,29 +361,26 @@ std::size_t shape::index(std::size_t i) const
}
}
std
::
vector
<
std
::
size_t
>
shape
::
multi
(
std
::
size_t
i
)
const
std
::
vector
<
std
::
size_t
>
shape
::
multi
(
std
::
size_t
i
dx
)
const
{
assert
(
this
->
standard
());
assert
(
idx
<
elements
());
std
::
vector
<
std
::
size_t
>
indices
(
lens
().
size
());
multi_copy
(
i
,
indices
.
data
(),
indices
.
data
()
+
lens
().
size
());
multi_copy
(
idx
,
indices
.
data
(),
indices
.
data
()
+
lens
().
size
());
return
indices
;
}
void
shape
::
multi_copy
(
std
::
size_t
i
,
std
::
size_t
*
start
,
const
std
::
size_t
*
end
)
const
void
shape
::
multi_copy
(
std
::
size_t
i
dx
,
std
::
size_t
*
start
,
const
std
::
size_t
*
end
)
const
{
assert
(
this
->
standard
())
;
size_t
tidx
=
idx
;
(
void
)
end
;
assert
(
idx
<
elements
());
assert
(
lens
().
size
()
<=
(
end
-
start
));
std
::
transform
(
strides
().
begin
(),
strides
().
end
(),
lens
().
begin
(),
start
,
[
&
](
std
::
size_t
stride
,
std
::
size_t
len
)
{
assert
(
len
>
0
and
stride
>
0
);
return
(
i
/
stride
)
%
len
;
});
for
(
size_t
ii
=
lens
().
size
()
-
1
;
ii
>
0
;
ii
--
)
{
*
(
start
+
ii
)
=
tidx
%
lens
()[
ii
];
tidx
=
tidx
/
lens
()[
ii
];
}
*
start
=
tidx
;
}
bool
shape
::
packed
()
const
...
...
@@ -469,12 +478,44 @@ shape shape::with_type(type_t t) const
shape
shape
::
to_dynamic
()
const
{
if
(
not
sub_shapes
().
empty
())
{
std
::
vector
<
shape
>
subs
;
std
::
transform
(
sub_shapes
().
cbegin
(),
sub_shapes
().
cend
(),
std
::
back_inserter
(
subs
),
[](
auto
s
)
{
return
s
.
to_dynamic
();
});
return
{
subs
};
}
if
(
this
->
dynamic
())
{
return
*
this
;
}
std
::
vector
<
std
::
size_t
>
zeroes
(
this
->
ndim
(),
0
);
return
{
type
(),
lens
(),
lens
(),
zeroes
};
return
{
type
(),
lens
(),
lens
(),
{}};
}
shape
shape
::
to_static
(
std
::
size_t
x
)
const
{
if
(
not
sub_shapes
().
empty
())
{
std
::
vector
<
shape
>
subs
;
std
::
transform
(
sub_shapes
().
cbegin
(),
sub_shapes
().
cend
(),
std
::
back_inserter
(
subs
),
[
&
](
auto
s
)
{
return
s
.
to_static
(
x
);
});
return
{
subs
};
}
if
(
not
this
->
dynamic
())
{
return
*
this
;
}
auto
static_lens
=
this
->
max_lens
();
std
::
transform
(
static_lens
.
begin
(),
static_lens
.
end
(),
this
->
dyn_dims
().
cbegin
(),
static_lens
.
begin
(),
[
&
](
auto
sl
,
auto
dd
)
{
return
dd
.
is_fixed
()
?
sl
:
x
;
});
return
{
type
(),
static_lens
};
}
std
::
size_t
shape
::
element_space
()
const
{
return
impl
->
element_space
();
}
...
...
@@ -483,6 +524,17 @@ std::string shape::type_string() const { return name(this->type()); }
bool
shape
::
dynamic
()
const
{
return
not
impl
->
m_dyn_dims
.
empty
();
}
bool
shape
::
any_of_dynamic
()
const
{
if
(
this
->
dynamic
())
{
return
true
;
}
return
std
::
any_of
(
this
->
sub_shapes
().
cbegin
(),
this
->
sub_shapes
().
cend
(),
[](
auto
s
)
{
return
s
.
any_of_dynamic
();
});
}
const
std
::
vector
<
shape
::
dynamic_dimension
>&
shape
::
dyn_dims
()
const
{
return
impl
->
m_dyn_dims
;
}
std
::
vector
<
std
::
size_t
>
shape
::
min_lens
()
const
...
...
@@ -495,23 +547,22 @@ std::vector<std::size_t> shape::max_lens() const
return
this
->
dynamic
()
?
impl
->
max_lens
()
:
this
->
lens
();
}
std
::
vector
<
std
::
size_t
>
shape
::
opt_lens
()
const
{
return
this
->
dynamic
()
?
impl
->
opt_lens
()
:
this
->
lens
();
}
std
::
vector
<
std
::
set
<
std
::
size_t
>>
shape
::
opt_lens
()
const
{
return
impl
->
opt_lens
();
}
bool
shape
::
dynamic_dimension
::
is_fixed
()
const
{
return
this
->
min
==
this
->
max
;
}
bool
shape
::
dynamic_dimension
::
has_optimal
()
const
{
return
opt
!=
0
;
}
bool
shape
::
dynamic_dimension
::
has_optimal
()
const
{
return
not
optimals
.
empty
()
;
}
shape
::
dynamic_dimension
&
shape
::
dynamic_dimension
::
operator
+=
(
const
std
::
size_t
&
x
)
{
this
->
min
+=
x
;
this
->
max
+=
x
;
if
(
this
->
opt
!=
0
)
{
this
->
opt
+=
x
;
};
std
::
set
<
std
::
size_t
>
new_optimals
;
std
::
transform
(
this
->
optimals
.
begin
(),
this
->
optimals
.
end
(),
std
::
inserter
(
new_optimals
,
new_optimals
.
begin
()),
[
&
x
](
const
auto
&
opt
)
{
return
(
opt
+
x
);
});
this
->
optimals
=
new_optimals
;
return
*
this
;
}
...
...
@@ -521,19 +572,23 @@ shape::dynamic_dimension& shape::dynamic_dimension::operator-=(const std::size_t
assert
(
this
->
max
>=
x
);
this
->
min
-=
x
;
this
->
max
-=
x
;
if
(
this
->
opt
!=
0
)
{
assert
(
this
->
opt
>=
x
);
this
->
opt
-=
x
;
}
std
::
set
<
std
::
size_t
>
new_optimals
;
std
::
transform
(
this
->
optimals
.
begin
(),
this
->
optimals
.
end
(),
std
::
inserter
(
new_optimals
,
new_optimals
.
begin
()),
[
&
x
](
const
auto
&
opt
)
{
assert
(
opt
>=
x
);
return
(
opt
-
x
);
});
this
->
optimals
=
new_optimals
;
return
*
this
;
}
bool
operator
==
(
const
shape
::
dynamic_dimension
&
x
,
const
shape
::
dynamic_dimension
&
y
)
{
// don't check opt if both are fixed
// don't check opt
imals
if both are fixed
return
(
x
.
min
==
y
.
min
and
x
.
max
==
y
.
max
and
((
x
.
is_fixed
()
and
y
.
is_fixed
())
or
(
x
.
opt
==
y
.
opt
)));
((
x
.
is_fixed
()
and
y
.
is_fixed
())
or
(
x
.
opt
imals
==
y
.
opt
imals
)));
}
bool
operator
!=
(
const
shape
::
dynamic_dimension
&
x
,
const
shape
::
dynamic_dimension
&
y
)
...
...
@@ -542,7 +597,7 @@ bool operator!=(const shape::dynamic_dimension& x, const shape::dynamic_dimensio
}
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
shape
::
dynamic_dimension
&
x
)
{
os
<<
"["
<<
x
.
min
<<
", "
<<
x
.
max
<<
", "
<<
x
.
opt
<<
"]"
;
os
<<
"[
"
<<
x
.
min
<<
", "
<<
x
.
max
<<
",
{
"
<<
migraphx
::
to_string_range
(
x
.
optimals
)
<<
"
}
]"
;
return
os
;
}
...
...
@@ -651,12 +706,10 @@ void migraphx_from_value(const value& v, shape& s)
{
auto
v_dd
=
v
.
at
(
"dynamic_dimensions"
);
std
::
vector
<
shape
::
dynamic_dimension
>
dyn_dims
(
v
.
at
(
"dynamic_dimensions"
).
size
());
std
::
transform
(
v_dd
.
begin
(),
v_dd
.
end
(),
dyn_dims
.
begin
(),
[](
migraphx
::
value
x
)
{
auto
x_min
=
x
.
at
(
"min"
).
template
to
<
size_t
>();
auto
x_max
=
x
.
at
(
"max"
).
template
to
<
size_t
>();
auto
x_opt
=
x
.
at
(
"opt"
).
template
to
<
size_t
>();
return
shape
::
dynamic_dimension
{
x_min
,
x_max
,
x_opt
};
});
std
::
transform
(
v_dd
.
begin
(),
v_dd
.
end
(),
dyn_dims
.
begin
(),
[](
const
migraphx
::
value
&
x
)
{
return
from_value
<
shape
::
dynamic_dimension
>
(
x
);
});
s
=
shape
{
shape
::
parse_type
(
t
),
dyn_dims
};
}
...
...
src/simplify_algebra.cpp
View file @
baac1dab
...
...
@@ -204,7 +204,137 @@ struct find_mul_slice_conv
}
};
// a * (x + b) => a * x + a * b
struct
find_mul_dot
{
auto
matcher
()
const
{
auto
is_dot_const_inputs
=
match
::
name
(
"dot"
)(
match
::
any_of
[
match
::
inputs
()](
match
::
is_constant
()));
return
match
::
name
(
"mul"
)(
match
::
either_arg
(
0
,
1
)(
is_dot_const_inputs
.
bind
(
"dot"
),
match
::
name
(
"broadcast"
,
"multibroadcast"
).
bind
(
"c"
)));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
dot_ins
=
r
.
instructions
[
"dot"
];
auto
a_ins
=
dot_ins
->
inputs
()[
0
];
auto
b_ins
=
dot_ins
->
inputs
()[
1
];
auto
c_ins
=
r
.
instructions
[
"c"
];
const
auto
&
c_strides
=
c_ins
->
get_shape
().
strides
();
// There should only be one stride that is not zero
if
(
std
::
count_if
(
c_strides
.
begin
(),
c_strides
.
end
(),
[](
auto
s
)
{
return
s
!=
0
;
})
>
1
)
return
;
auto
add_mul_const
=
[
&
](
instruction_ref
x_ins
)
{
if
(
not
x_ins
->
can_eval
())
return
m
.
end
();
auto
broadcast_v
=
c_ins
->
get_operator
().
to_value
();
broadcast_v
[
"out_lens"
]
=
x_ins
->
get_shape
().
lens
();
auto
cb_ins
=
m
.
insert_instruction
(
ins
,
make_op
(
c_ins
->
name
(),
broadcast_v
),
c_ins
->
inputs
());
return
m
.
insert_instruction
(
ins
,
make_op
(
"mul"
),
x_ins
,
cb_ins
);
};
if
(
c_strides
.
back
()
==
1
)
{
b_ins
=
add_mul_const
(
b_ins
);
}
else
if
(
c_strides
[
c_strides
.
size
()
-
2
]
==
1
)
{
a_ins
=
add_mul_const
(
a_ins
);
}
else
if
(
c_ins
->
get_shape
().
scalar
())
{
if
(
a_ins
->
can_eval
())
a_ins
=
add_mul_const
(
a_ins
);
else
b_ins
=
add_mul_const
(
b_ins
);
}
else
{
return
;
}
if
(
contains
({
a_ins
,
b_ins
},
m
.
end
()))
return
;
m
.
replace_instruction
(
ins
,
make_op
(
"dot"
),
a_ins
,
b_ins
);
}
};
struct
find_dot_mul
{
auto
matcher
()
const
{
auto
const_broadcast
=
match
::
name
(
"broadcast"
,
"multibroadcast"
)(
match
::
is_constant
());
auto
mul
=
match
::
name
(
"mul"
)(
match
::
used_once
(),
match
::
either_arg
(
0
,
1
)(
const_broadcast
.
bind
(
"d"
),
match
::
none_of
(
match
::
is_constant
()).
bind
(
"z"
)));
return
match
::
name
(
"dot"
)(
match
::
either_arg
(
0
,
1
)(
mul
,
match
::
is_constant
().
bind
(
"c"
)));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
a_ins
=
ins
->
inputs
()[
0
];
auto
b_ins
=
ins
->
inputs
()[
1
];
auto
d_ins
=
r
.
instructions
[
"d"
];
auto
c_ins
=
r
.
instructions
[
"c"
];
auto
z_ins
=
r
.
instructions
[
"z"
];
const
auto
&
d_strides
=
d_ins
->
get_shape
().
strides
();
// There should only be one stride that is not zero
if
(
std
::
count_if
(
d_strides
.
begin
(),
d_strides
.
end
(),
[](
auto
s
)
{
return
s
!=
0
;
})
>
1
)
return
;
if
(
not
d_ins
->
get_shape
().
scalar
())
{
if
(
d_strides
.
back
()
==
1
and
not
b_ins
->
can_eval
())
return
;
if
(
d_strides
[
d_strides
.
size
()
-
2
]
==
1
and
not
a_ins
->
can_eval
())
return
;
}
auto
broadcast_v
=
d_ins
->
get_operator
().
to_value
();
auto
c_lens
=
c_ins
->
get_shape
().
lens
();
std
::
vector
<
int64_t
>
permutation
(
c_lens
.
size
());
std
::
iota
(
permutation
.
begin
(),
permutation
.
end
(),
0
);
std
::
swap
(
permutation
.
back
(),
permutation
[
permutation
.
size
()
-
2
]);
c_lens
=
reorder_dims
(
c_lens
,
permutation
);
broadcast_v
[
"out_lens"
]
=
c_lens
;
auto
db_ins
=
m
.
insert_instruction
(
ins
,
make_op
(
d_ins
->
name
(),
broadcast_v
),
d_ins
->
inputs
());
auto
db_transpose_ins
=
m
.
insert_instruction
(
ins
,
make_op
(
"transpose"
,
{{
"permutation"
,
permutation
}}),
db_ins
);
auto
cd_ins
=
m
.
insert_instruction
(
ins
,
make_op
(
"mul"
),
c_ins
,
db_transpose_ins
);
if
(
c_ins
==
b_ins
)
{
a_ins
=
z_ins
;
b_ins
=
cd_ins
;
}
else
{
a_ins
=
cd_ins
;
b_ins
=
z_ins
;
}
m
.
replace_instruction
(
ins
,
make_op
(
"dot"
),
a_ins
,
b_ins
);
}
};
// ******************************
// a * (x + b) => a * x + a * b
// ******************************
// When a * (x + b) is followed by another add of constant, then the
// additional add can be const folded. Also, better fusions can be applied
// when the add comes after.
struct
find_mul_add
{
auto
matcher
()
const
...
...
@@ -356,30 +486,118 @@ struct find_inner_broadcast
{
auto
matcher
()
const
{
return
pointwise
(
match
::
all_of
[
match
::
inputs
()](
match
::
broadcast
()));
}
static
auto
non_scalar_op
(
const
std
::
string
&
name
)
{
return
[
=
](
instruction_ref
ins
)
{
if
(
ins
->
get_shape
().
scalar
())
return
false
;
return
ins
->
name
()
==
name
;
};
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
broadcasts
=
ins
->
inputs
();
if
(
broadcasts
.
empty
())
return
;
bool
mixed_broadcasts
=
any_of
(
broadcasts
,
non_scalar_op
(
"broadcast"
))
and
any_of
(
broadcasts
,
non_scalar_op
(
"multibroadcast"
));
// If the broadcast is not a single dimension, then dont perform inner_broadcast
if
(
mixed_broadcasts
and
any_of
(
broadcasts
,
[
&
](
instruction_ref
i
)
{
if
(
i
->
get_shape
().
scalar
())
return
false
;
if
(
i
->
name
()
==
"multibroadcast"
)
return
false
;
auto
input
=
i
->
inputs
().
at
(
0
);
const
auto
&
lens
=
input
->
get_shape
().
lens
();
return
std
::
count_if
(
lens
.
begin
(),
lens
.
end
(),
[
&
](
std
::
size_t
d
)
{
return
d
==
1
;
})
<
(
lens
.
size
()
-
1
);
}))
return
;
std
::
vector
<
instruction_ref
>
inputs
;
std
::
transform
(
broadcasts
.
begin
(),
broadcasts
.
end
(),
std
::
back_inserter
(
inputs
),
[](
auto
i
)
{
return
i
->
inputs
().
front
();
});
if
(
std
::
any_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
&
](
auto
i
)
{
return
i
->
get_shape
()
!=
inputs
.
front
()
->
get_shape
()
and
i
->
get_shape
().
elements
()
!=
1
;
}))
return
;
auto
b_it
=
std
::
find_if
(
broadcasts
.
begin
(),
broadcasts
.
end
(),
[
&
](
auto
i
)
{
return
not
i
->
get_shape
().
scalar
();
});
if
(
b_it
==
broadcasts
.
end
())
b_it
=
broadcasts
.
begin
();
[
&
](
instruction_ref
i
)
{
auto
input
=
i
->
inputs
().
front
();
if
(
mixed_broadcasts
and
not
i
->
get_shape
().
scalar
()
and
i
->
get_shape
().
lens
().
size
()
>
1
)
return
m
.
insert_instruction
(
i
,
make_op
(
"squeeze"
),
input
);
return
input
;
});
std
::
sort
(
broadcasts
.
begin
(),
broadcasts
.
end
(),
by
(
std
::
less
<>
{},
[](
instruction_ref
i
)
{
if
(
i
->
get_shape
().
scalar
())
return
2
;
else
if
(
i
->
name
()
==
"broadcast"
)
return
0
;
if
(
i
->
name
()
==
"multibroadcast"
)
return
1
;
return
3
;
}));
auto
op
=
insert_common_op
(
m
,
ins
,
ins
->
get_operator
(),
inputs
);
m
.
replace_instruction
(
ins
,
(
*
b_it
)
->
get_operator
(),
op
);
m
.
replace_instruction
(
ins
,
broadcasts
.
front
()
->
get_operator
(),
op
);
}
};
struct
find_dot_broadcast
{
auto
matcher
()
const
{
return
match
::
name
(
"dot"
)(
match
::
all_of
[
match
::
inputs
()](
match
::
broadcast
()));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
a
=
ins
->
inputs
()[
0
];
auto
b
=
ins
->
inputs
()[
1
];
if
(
a
->
get_operator
().
name
()
!=
b
->
get_operator
().
name
())
return
;
if
(
ins
->
get_shape
().
lens
().
size
()
<
3
)
return
;
auto
nbatch_axes
=
ins
->
get_shape
().
lens
().
size
()
-
2
;
const
auto
&
a_strides
=
a
->
get_shape
().
strides
();
const
auto
&
b_strides
=
b
->
get_shape
().
strides
();
// Find leading batch axes that are broadcasted
auto
p
=
std
::
mismatch
(
a_strides
.
begin
(),
a_strides
.
begin
()
+
nbatch_axes
,
b_strides
.
begin
(),
b_strides
.
begin
()
+
nbatch_axes
,
[](
auto
astride
,
auto
bstride
)
{
return
astride
==
0
and
bstride
==
0
;
});
auto
naxes
=
p
.
first
-
a_strides
.
begin
();
assert
(
naxes
<=
nbatch_axes
);
std
::
vector
<
std
::
size_t
>
axes
(
naxes
);
std
::
iota
(
axes
.
begin
(),
axes
.
end
(),
0
);
auto
insert_broadcast
=
[
&
](
instruction_ref
b_ins
)
->
instruction_ref
{
auto
input
=
b_ins
->
inputs
()[
0
];
std
::
vector
<
std
::
size_t
>
lens
(
b_ins
->
get_shape
().
lens
().
begin
()
+
naxes
,
b_ins
->
get_shape
().
lens
().
end
());
if
(
b_ins
->
name
()
==
"multibroadcast"
)
{
return
m
.
insert_instruction
(
ins
,
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
lens
}}),
input
);
}
else
if
(
b_ins
->
name
()
==
"broadcast"
)
{
auto
v
=
b_ins
->
get_operator
().
to_value
();
auto
axis
=
v
.
at
(
"axis"
).
to
<
std
::
size_t
>
()
-
naxes
;
return
m
.
insert_instruction
(
ins
,
make_op
(
"broadcast"
,
{{
"axis"
,
axis
},
{
"out_lens"
,
lens
}}),
input
);
}
assert
(
false
);
return
m
.
end
();
};
auto
a1
=
insert_broadcast
(
a
);
auto
b1
=
insert_broadcast
(
b
);
auto
dot
=
m
.
insert_instruction
(
ins
,
make_op
(
"dot"
),
a1
,
b1
);
auto
broadcast
=
m
.
insert_instruction
(
ins
,
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
ins
->
get_shape
().
lens
()}}),
dot
);
m
.
replace_instruction
(
ins
,
broadcast
);
}
};
...
...
@@ -388,7 +606,8 @@ struct find_concat_op
auto
matcher
()
const
{
return
match
::
name
(
"concat"
)(
match
::
any_of
[
match
::
inputs
()](
match
::
any_of
(
match
::
pointwise
(),
match
::
name
(
"broadcast"
)),
match
::
used_once
()));
match
::
any_of
(
match
::
pointwise
(),
match
::
name
(
"broadcast"
,
"multibroadcast"
)),
match
::
used_once
()));
}
template
<
class
Iterator
>
...
...
@@ -407,7 +626,8 @@ struct find_concat_op
static
bool
is_valid_op
(
const
operation
&
op
)
{
return
op
.
name
()
==
"broadcast"
or
op
.
attributes
().
contains
(
"pointwise"
);
return
contains
({
"broadcast"
,
"multibroadcast"
},
op
.
name
())
or
op
.
attributes
().
contains
(
"pointwise"
);
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
...
...
@@ -435,6 +655,16 @@ struct find_concat_op
op
=
b
;
iaxis
=
0
;
}
else
if
(
op
.
name
()
==
"multibroadcast"
)
{
shape
bshape
=
(
*
start
)
->
get_shape
();
auto
input
=
(
*
start
)
->
inputs
()[
0
];
if
(
iaxis
>=
bshape
.
strides
().
size
()
or
bshape
.
strides
()[
iaxis
]
==
0
)
return
{
start
,
last
};
op
.
from_value
({{
"out_lens"
,
get_output_lens
(
start
,
last
,
iaxis
)}});
auto
delta
=
bshape
.
lens
().
size
()
-
input
->
get_shape
().
lens
().
size
();
iaxis
-=
delta
;
}
std
::
vector
<
instruction_ref
>
concats
;
for
(
std
::
size_t
i
=
0
;
i
<
x
->
inputs
().
size
();
i
++
)
...
...
@@ -1009,7 +1239,7 @@ struct find_neg_unit_ops
auto
ins
=
r
.
result
;
auto
c_in
=
r
.
instructions
[
"x"
];
auto
neg
=
m
.
add
_instruction
(
make_op
(
"neg"
),
c_in
);
auto
neg
=
m
.
insert
_instruction
(
ins
,
make_op
(
"neg"
),
c_in
);
m
.
replace_instruction
(
ins
,
neg
);
}
};
...
...
@@ -1255,12 +1485,15 @@ void simplify_algebra::apply(module& m) const
{
match
::
find_matches
(
m
,
find_inner_broadcast
{},
find_dot_broadcast
{},
find_double_add_lit_broadcast
{},
find_add_lit_broadcast
{},
find_add_convs
{},
find_conv_dot_horiz_fusion
{},
find_mul_conv
{},
find_mul_slice_conv
{},
find_mul_dot
{},
find_dot_mul
{},
find_mul_add
{},
find_unit_ops
{},
find_neg_unit_ops
{},
...
...
src/simplify_reshapes.cpp
View file @
baac1dab
...
...
@@ -762,7 +762,7 @@ struct find_transpose_slice
return
;
// Compute axis before transpose to use for unsqueeze
auto
perm
=
ins
->
get_operator
().
to_value
()[
"permutation"
].
to_vector
<
int64_t
>
();
auto
preaxis
=
std
::
find
(
perm
.
begin
(),
perm
.
end
(),
axis
)
-
perm
.
begin
()
;
auto
preaxis
=
perm
[
axis
]
;
// Make unsqueeze
std
::
vector
<
int64_t
>
steps
(
sdistance
.
size
());
std
::
transform
(
...
...
src/split_single_dyn_dim.cpp
0 → 100644
View file @
baac1dab
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/split_single_dyn_dim.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/matcher.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
dynamic_dimensions_check
{
std
::
string
dyn_param_str
;
size_t
dyn_index
;
size_t
min_dim
;
size_t
max_dim
;
};
optional
<
dynamic_dimensions_check
>
has_one_dyn_dim
(
const
std
::
unordered_map
<
std
::
string
,
shape
>&
param_shapes
)
{
// True if parameters contain exactly one dynamic shape with exactly one non-fixed
// dynamic_dimension.
auto
is_dynamic
=
[](
const
auto
&
p
)
{
return
p
.
second
.
dynamic
();
};
auto
ps_it
=
std
::
find_if
(
param_shapes
.
begin
(),
param_shapes
.
end
(),
is_dynamic
);
if
(
ps_it
==
param_shapes
.
end
())
return
std
::
nullopt
;
// Check if there is a second dynamic parameter
if
(
std
::
any_of
(
std
::
next
(
ps_it
),
param_shapes
.
end
(),
is_dynamic
))
return
std
::
nullopt
;
const
auto
&
dds
=
ps_it
->
second
.
dyn_dims
();
auto
is_non_fixed
=
[](
const
auto
&
dd
)
{
return
not
dd
.
is_fixed
();
};
auto
dds_it
=
std
::
find_if
(
dds
.
begin
(),
dds
.
end
(),
is_non_fixed
);
if
(
dds_it
==
dds
.
end
())
return
std
::
nullopt
;
// Check if there is a second non-fixed dynamic_dimension
if
(
std
::
any_of
(
std
::
next
(
dds_it
),
dds
.
end
(),
is_non_fixed
))
return
std
::
nullopt
;
return
dynamic_dimensions_check
{
ps_it
->
first
,
static_cast
<
std
::
size_t
>
(
std
::
distance
(
dds
.
begin
(),
dds_it
)),
dds_it
->
min
,
dds_it
->
max
};
}
namespace
{
struct
find_static_2in_broadcasts
{
// Convert 2 input static shape broadcast/multibroadcast into 1 input version.
// Some compiler passes (ex. simplify_algebra) only support the 1 input versions
// of the broadcasting operators.
auto
matcher
()
const
{
return
match
::
broadcast
(
match
::
nargs
(
2
),
match
::
arg
(
0
)(
match
::
static_shape
()),
match
::
arg
(
1
)(
match
::
static_shape
()));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
mr
)
const
{
auto
ins
=
mr
.
result
;
auto
out_lens
=
ins
->
get_shape
().
lens
();
auto
broadcast_op
=
ins
->
get_operator
();
if
(
broadcast_op
.
name
()
==
"broadcast"
)
{
broadcast_op
.
from_value
({{
"out_lens"
,
out_lens
}});
}
else
{
broadcast_op
.
from_value
({{
"out_lens"
,
out_lens
},
{
"out_dyn_dims"
,
{}}});
}
m
.
replace_instruction
(
ins
,
broadcast_op
,
ins
->
inputs
().
at
(
0
));
}
};
}
// namespace
/**
* Makes all the shapes in the dynamic_dimension range. Probably won't work for `if`
* and `loop` instructions, depending on how the submodules for those
* work. Inserts select_module instruction to the top. Replaces return, bypassing other
* instructions. Skips if the dynamic parameter outputs to a select_module operator.
*/
void
split_single_dyn_dim
::
apply
(
module_pass_manager
&
mpm
)
const
{
module_ref
mm
=
&
mpm
.
get_module
();
auto
param_names
=
mm
->
get_parameter_names
();
auto
param_shapes
=
mm
->
get_parameter_shapes
();
optional
<
dynamic_dimensions_check
>
dd_check
=
has_one_dyn_dim
(
param_shapes
);
auto
any_sm_next
=
[
&
](
auto
ddc
)
{
auto
p_outputs
=
mm
->
get_parameter
(
ddc
->
dyn_param_str
)
->
outputs
();
return
std
::
any_of
(
p_outputs
.
cbegin
(),
p_outputs
.
cend
(),
[](
auto
ins
)
{
return
ins
->
name
()
==
"select_module"
;
});
};
if
(
dd_check
.
has_value
()
and
not
any_sm_next
(
dd_check
))
{
const
auto
&
dyn_param
=
mm
->
get_parameter
(
dd_check
->
dyn_param_str
);
auto
dyn_param_shape
=
mm
->
get_parameter_shape
(
dd_check
->
dyn_param_str
);
std
::
vector
<
module_ref
>
submodules
;
// create submodules for each dimension size
for
(
size_t
dim_size
:
migraphx
::
range
(
dd_check
->
min_dim
,
dd_check
->
max_dim
+
1
))
{
auto
*
submod
=
mpm
.
create_module
(
"dim_"
+
std
::
to_string
(
dim_size
));
// instruction map for new static shaped submodule parameters
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
map_ins
;
// create static shape using dim_size
auto
static_lens
=
dyn_param_shape
.
max_lens
();
static_lens
.
at
(
dd_check
->
dyn_index
)
=
dim_size
;
map_ins
[
dyn_param
]
=
submod
->
add_parameter
(
dd_check
->
dyn_param_str
,
migraphx
::
shape
{
dyn_param_shape
.
type
(),
static_lens
});
auto
outputs
=
submod
->
add_instructions
(
mm
,
map_ins
);
submod
->
add_return
({
outputs
});
match
::
find_matches
(
*
submod
,
find_static_2in_broadcasts
{});
submodules
.
push_back
(
submod
);
}
// redirect to select_module operator and return
std
::
vector
<
instruction_ref
>
sm_inputs
;
std
::
transform
(
param_names
.
cbegin
(),
param_names
.
cend
(),
std
::
back_inserter
(
sm_inputs
),
[
&
](
auto
pn
)
{
return
mm
->
get_parameter
(
pn
);
});
auto
output_shapes
=
mm
->
get_output_shapes
();
migraphx
::
shape
out_attr
=
migraphx
::
shape
{
output_shapes
};
auto
sm_ins
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"select_module"
,
{{
"output_dyn_shapes"
,
migraphx
::
to_value
(
out_attr
)}}),
sm_inputs
,
submodules
);
std
::
vector
<
instruction_ref
>
outputs
(
output_shapes
.
size
());
for
(
size_t
i
=
0
;
i
<
output_shapes
.
size
();
++
i
)
{
outputs
.
at
(
i
)
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"get_tuple_elem"
,
{{
"index"
,
i
}}),
sm_ins
);
}
mm
->
replace_return
(
outputs
);
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/cpu/include/migraphx/cpu/parallel.hpp
View file @
baac1dab
...
...
@@ -25,7 +25,7 @@
#define MIGRAPHX_GUARD_AMDMIGRAPHX_CPU_PARALLEL_HPP
// #define MIGRAPHX_DISABLE_OMP
#include <cmath>
#include <migraphx/config.hpp>
#ifdef MIGRAPHX_DISABLE_OMP
#include <migraphx/par_for.hpp>
...
...
src/targets/cpu/include/migraphx/cpu/target.hpp
View file @
baac1dab
...
...
@@ -40,14 +40,11 @@ struct target
std
::
string
name
()
const
;
std
::
vector
<
pass
>
get_passes
(
migraphx
::
context
&
gctx
,
const
compile_options
&
)
const
;
migraphx
::
context
get_context
()
const
{
return
context
{};
}
argument
copy_to
(
const
argument
&
arg
)
const
{
return
arg
;
}
argument
copy_from
(
const
argument
&
arg
)
const
{
return
arg
;
}
argument
allocate
(
const
shape
&
s
)
const
;
};
MIGRAPHX_REGISTER_TARGET
(
target
);
}
// namespace cpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
...
...
Prev
1
…
4
5
6
7
8
9
10
11
12
…
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment