Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
5ec8f913
Commit
5ec8f913
authored
Sep 13, 2022
by
Ted Themistokleous
Committed by
Ted Themistokleous
Sep 13, 2022
Browse files
Merge branch 'develop' into simplify_1_mul_div_ops
parents
32d69e8e
d78bcdfb
Changes
183
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
410 additions
and
103 deletions
+410
-103
src/program.cpp
src/program.cpp
+33
-8
src/py/migraphx_py.cpp
src/py/migraphx_py.cpp
+30
-11
src/quantization.cpp
src/quantization.cpp
+1
-1
src/rewrite_gelu.cpp
src/rewrite_gelu.cpp
+59
-0
src/rewrite_pooling.cpp
src/rewrite_pooling.cpp
+3
-3
src/rewrite_rnn.cpp
src/rewrite_rnn.cpp
+7
-7
src/shape.cpp
src/shape.cpp
+2
-2
src/simplify_algebra.cpp
src/simplify_algebra.cpp
+76
-28
src/simplify_reshapes.cpp
src/simplify_reshapes.cpp
+94
-5
src/targets/cpu/binary.cpp
src/targets/cpu/binary.cpp
+1
-1
src/targets/fpga/include/migraphx/fpga/target.hpp
src/targets/fpga/include/migraphx/fpga/target.hpp
+2
-1
src/targets/fpga/subgraph.cpp
src/targets/fpga/subgraph.cpp
+1
-1
src/targets/fpga/target.cpp
src/targets/fpga/target.cpp
+10
-4
src/targets/gpu/code_object_op.cpp
src/targets/gpu/code_object_op.cpp
+2
-1
src/targets/gpu/compile_gen.cpp
src/targets/gpu/compile_gen.cpp
+75
-22
src/targets/gpu/device/include/migraphx/gpu/device/array.hpp
src/targets/gpu/device/include/migraphx/gpu/device/array.hpp
+1
-1
src/targets/gpu/device/include/migraphx/gpu/device/visit.hpp
src/targets/gpu/device/include/migraphx/gpu/device/visit.hpp
+5
-3
src/targets/gpu/device/multinomial.cpp
src/targets/gpu/device/multinomial.cpp
+1
-1
src/targets/gpu/driver/compile_op.cpp
src/targets/gpu/driver/compile_op.cpp
+5
-2
src/targets/gpu/driver/include/migraphx/gpu/driver/perf.hpp
src/targets/gpu/driver/include/migraphx/gpu/driver/perf.hpp
+2
-1
No files found.
src/program.cpp
View file @
5ec8f913
...
@@ -37,6 +37,7 @@
...
@@ -37,6 +37,7 @@
#include <migraphx/output_iterator.hpp>
#include <migraphx/output_iterator.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/marker.hpp>
#include <migraphx/marker.hpp>
#include <migraphx/supported_segments.hpp>
#include <iostream>
#include <iostream>
#include <sstream>
#include <sstream>
#include <algorithm>
#include <algorithm>
...
@@ -77,11 +78,11 @@ program& program::operator=(program p)
...
@@ -77,11 +78,11 @@ program& program::operator=(program p)
void
program
::
assign
(
const
program
&
p
)
void
program
::
assign
(
const
program
&
p
)
{
{
if
(
!
impl
)
if
(
not
impl
)
{
{
impl
=
std
::
make_unique
<
program_impl
>
();
impl
=
std
::
make_unique
<
program_impl
>
();
}
}
else
if
(
!
impl
->
modules
.
empty
())
else
if
(
not
impl
->
modules
.
empty
())
{
{
impl
->
modules
.
clear
();
impl
->
modules
.
clear
();
}
}
...
@@ -167,13 +168,37 @@ target_assignments program::get_target_assignments(const std::vector<target>& ta
...
@@ -167,13 +168,37 @@ target_assignments program::get_target_assignments(const std::vector<target>& ta
target_assignments
p
;
target_assignments
p
;
const
auto
*
mod
=
get_main_module
();
const
auto
*
mod
=
get_main_module
();
for
(
auto
it
:
iterator_for
(
*
mod
))
std
::
vector
<
std
::
pair
<
target
,
supported_segments
>>
target_subgraphs
;
target_subgraphs
.
reserve
(
targets
.
size
());
std
::
transform
(
targets
.
begin
(),
targets
.
end
(),
std
::
back_inserter
(
target_subgraphs
),
[
&
](
const
auto
&
t
)
{
return
std
::
make_pair
(
t
,
t
.
find_supported
(
mod
,
m
));
});
for
(
const
auto
ins
:
iterator_for
(
*
mod
))
{
{
auto
t
=
std
::
max_element
(
if
(
contains
(
p
,
ins
))
targets
.
begin
(),
targets
.
end
(),
[
it
,
m
](
const
target
&
lhs
,
const
target
&
rhs
)
{
{
return
lhs
.
is_supported
(
it
,
m
)
<
rhs
.
is_supported
(
it
,
m
);
continue
;
});
}
p
.
add_assignment
(
it
,
t
->
name
());
for
(
const
auto
&
[
target
,
subgraph
]
:
target_subgraphs
)
{
// can't pass a structured binding into lambda in C++17 so create a variable for it
const
auto
&
t
=
target
;
for
(
const
auto
&
segment
:
subgraph
)
{
const
auto
&
instructions
=
segment
.
instructions
;
if
(
not
contains
(
instructions
,
ins
))
{
continue
;
}
std
::
transform
(
instructions
.
begin
(),
instructions
.
end
(),
std
::
inserter
(
p
,
p
.
end
()),
[
&
](
auto
instr
)
{
return
std
::
make_pair
(
instr
,
t
.
name
());
});
}
}
}
}
return
p
;
return
p
;
}
}
...
...
src/py/migraphx_py.cpp
View file @
5ec8f913
...
@@ -40,6 +40,7 @@
...
@@ -40,6 +40,7 @@
#include <migraphx/register_target.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/json.hpp>
#include <migraphx/json.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp>
#ifdef HAVE_GPU
#ifdef HAVE_GPU
#include <migraphx/gpu/hip.hpp>
#include <migraphx/gpu/hip.hpp>
...
@@ -82,7 +83,7 @@ void visit_py(T x, F f)
...
@@ -82,7 +83,7 @@ void visit_py(T x, F f)
{
{
f
(
x
.
template
cast
<
bool
>());
f
(
x
.
template
cast
<
bool
>());
}
}
else
if
(
py
::
isinstance
<
py
::
int_
>
(
x
))
else
if
(
py
::
isinstance
<
py
::
int_
>
(
x
)
or
py
::
hasattr
(
x
,
"__index__"
)
)
{
{
f
(
x
.
template
cast
<
int
>());
f
(
x
.
template
cast
<
int
>());
}
}
...
@@ -324,6 +325,7 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
...
@@ -324,6 +325,7 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
.
def
(
"get_parameter_names"
,
&
migraphx
::
program
::
get_parameter_names
)
.
def
(
"get_parameter_names"
,
&
migraphx
::
program
::
get_parameter_names
)
.
def
(
"get_parameter_shapes"
,
&
migraphx
::
program
::
get_parameter_shapes
)
.
def
(
"get_parameter_shapes"
,
&
migraphx
::
program
::
get_parameter_shapes
)
.
def
(
"get_output_shapes"
,
&
migraphx
::
program
::
get_output_shapes
)
.
def
(
"get_output_shapes"
,
&
migraphx
::
program
::
get_output_shapes
)
.
def
(
"is_compiled"
,
&
migraphx
::
program
::
is_compiled
)
.
def
(
.
def
(
"compile"
,
"compile"
,
[](
migraphx
::
program
&
p
,
const
migraphx
::
target
&
t
,
bool
offload_copy
,
bool
fast_math
)
{
[](
migraphx
::
program
&
p
,
const
migraphx
::
target
&
t
,
bool
offload_copy
,
bool
fast_math
)
{
...
@@ -358,18 +360,35 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
...
@@ -358,18 +360,35 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
.
def
(
"__ne__"
,
std
::
not_equal_to
<
migraphx
::
program
>
{})
.
def
(
"__ne__"
,
std
::
not_equal_to
<
migraphx
::
program
>
{})
.
def
(
"__repr__"
,
[](
const
migraphx
::
program
&
p
)
{
return
migraphx
::
to_string
(
p
);
});
.
def
(
"__repr__"
,
[](
const
migraphx
::
program
&
p
)
{
return
migraphx
::
to_string
(
p
);
});
py
::
class_
<
migraphx
::
operation
>
(
m
,
"op"
)
py
::
class_
<
migraphx
::
operation
>
op
(
m
,
"op"
);
.
def
(
py
::
init
([](
const
std
::
string
&
name
,
py
::
kwargs
kwargs
)
{
op
.
def
(
py
::
init
([](
const
std
::
string
&
name
,
py
::
kwargs
kwargs
)
{
migraphx
::
value
v
=
migraphx
::
value
::
object
{};
migraphx
::
value
v
=
migraphx
::
value
::
object
{};
if
(
kwargs
)
if
(
kwargs
)
{
{
v
=
migraphx
::
to_value
(
kwargs
);
v
=
migraphx
::
to_value
(
kwargs
);
}
}
return
migraphx
::
make_op
(
name
,
v
);
return
migraphx
::
make_op
(
name
,
v
);
}))
}))
.
def
(
"name"
,
&
migraphx
::
operation
::
name
);
.
def
(
"name"
,
&
migraphx
::
operation
::
name
);
py
::
enum_
<
migraphx
::
op
::
pooling_mode
>
(
op
,
"pooling_mode"
)
.
value
(
"average"
,
migraphx
::
op
::
pooling_mode
::
average
)
.
value
(
"max"
,
migraphx
::
op
::
pooling_mode
::
max
)
.
value
(
"lpnorm"
,
migraphx
::
op
::
pooling_mode
::
lpnorm
);
py
::
enum_
<
migraphx
::
op
::
rnn_direction
>
(
op
,
"rnn_direction"
)
.
value
(
"forward"
,
migraphx
::
op
::
rnn_direction
::
forward
)
.
value
(
"reverse"
,
migraphx
::
op
::
rnn_direction
::
reverse
)
.
value
(
"bidirectional"
,
migraphx
::
op
::
rnn_direction
::
bidirectional
);
m
.
def
(
"argument_from_pointer"
,
[](
const
migraphx
::
shape
shape
,
const
int64_t
address
)
{
return
migraphx
::
argument
(
shape
,
reinterpret_cast
<
void
*>
(
address
));
},
py
::
arg
(
"shape"
),
py
::
arg
(
"address"
));
m
.
def
(
m
.
def
(
"parse_tf"
,
"parse_tf"
,
[](
const
std
::
string
&
filename
,
[](
const
std
::
string
&
filename
,
...
...
src/quantization.cpp
View file @
5ec8f913
...
@@ -70,7 +70,7 @@ void quantize_int8(program& prog,
...
@@ -70,7 +70,7 @@ void quantize_int8(program& prog,
{
{
std
::
set
<
std
::
string
>
op_names
=
{
"convolution"
,
"dot"
};
std
::
set
<
std
::
string
>
op_names
=
{
"convolution"
,
"dot"
};
std
::
set
<
std
::
string
>
input_ins_names
(
ins_names
.
begin
(),
ins_names
.
end
());
std
::
set
<
std
::
string
>
input_ins_names
(
ins_names
.
begin
(),
ins_names
.
end
());
if
(
!
std
::
includes
(
if
(
not
std
::
includes
(
op_names
.
begin
(),
op_names
.
end
(),
input_ins_names
.
begin
(),
input_ins_names
.
end
()))
op_names
.
begin
(),
op_names
.
end
(),
input_ins_names
.
begin
(),
input_ins_names
.
end
()))
{
{
MIGRAPHX_THROW
(
"QUANTIZE_INT8: only support DOT and CONVOLUTION operation"
);
MIGRAPHX_THROW
(
"QUANTIZE_INT8: only support DOT and CONVOLUTION operation"
);
...
...
src/rewrite_gelu.cpp
0 → 100644
View file @
5ec8f913
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/rewrite_gelu.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/match/gelu_erf.hpp>
#include <migraphx/common.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
find_gelu_erf
{
auto
matcher
()
const
{
return
match
::
gelu_erf
();
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
x
=
r
.
instructions
[
"x"
];
if
(
x
->
get_shape
().
type
()
!=
migraphx
::
shape
::
half_type
)
return
;
auto
lit
=
m
.
add_literal
(
literal
{
shape
{
x
->
get_shape
().
type
()},
{
1.702
f
}});
auto
mul
=
insert_common_op
(
m
,
ins
,
make_op
(
"mul"
),
{
x
,
lit
});
auto
sig
=
m
.
insert_instruction
(
ins
,
make_op
(
"neg"
),
mul
);
sig
=
m
.
insert_instruction
(
ins
,
make_op
(
"exp"
),
sig
);
auto
one
=
m
.
add_literal
(
literal
{
shape
{
x
->
get_shape
().
type
()},
{
1.0
f
}});
sig
=
insert_common_op
(
m
,
ins
,
make_op
(
"add"
),
{
sig
,
one
});
sig
=
m
.
insert_instruction
(
ins
,
make_op
(
"div"
),
x
,
sig
);
m
.
replace_instruction
(
ins
,
sig
);
}
};
void
rewrite_gelu
::
apply
(
module
&
m
)
const
{
match
::
find_matches
(
m
,
find_gelu_erf
{});
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/rewrite_pooling.cpp
View file @
5ec8f913
...
@@ -47,12 +47,12 @@ void rewrite_pooling::apply(module& m) const
...
@@ -47,12 +47,12 @@ void rewrite_pooling::apply(module& m) const
if
(
not
s
.
standard
())
if
(
not
s
.
standard
())
continue
;
continue
;
auto
&&
op
=
any_cast
<
op
::
pooling
>
(
ins
->
get_operator
());
auto
&&
op
=
any_cast
<
op
::
pooling
>
(
ins
->
get_operator
());
if
(
!
std
::
all_of
(
op
.
padding
.
begin
(),
op
.
padding
.
end
(),
[](
auto
i
)
{
return
i
==
0
;
}))
if
(
not
std
::
all_of
(
op
.
padding
.
begin
(),
op
.
padding
.
end
(),
[](
auto
i
)
{
return
i
==
0
;
}))
continue
;
continue
;
if
(
!
std
::
all_of
(
op
.
stride
.
begin
(),
op
.
stride
.
end
(),
[](
auto
i
)
{
return
i
==
1
;
}))
if
(
not
std
::
all_of
(
op
.
stride
.
begin
(),
op
.
stride
.
end
(),
[](
auto
i
)
{
return
i
==
1
;
}))
continue
;
continue
;
auto
lens
=
s
.
lens
();
auto
lens
=
s
.
lens
();
if
(
!
std
::
equal
(
lens
.
begin
()
+
2
,
lens
.
end
(),
op
.
lengths
.
begin
(),
op
.
lengths
.
end
()))
if
(
not
std
::
equal
(
lens
.
begin
()
+
2
,
lens
.
end
(),
op
.
lengths
.
begin
(),
op
.
lengths
.
end
()))
continue
;
continue
;
std
::
int64_t
n
=
s
.
lens
()[
0
];
std
::
int64_t
n
=
s
.
lens
()[
0
];
std
::
int64_t
c
=
s
.
lens
()[
1
];
std
::
int64_t
c
=
s
.
lens
()[
1
];
...
...
src/rewrite_rnn.cpp
View file @
5ec8f913
...
@@ -214,7 +214,7 @@ void rewrite_rnn::apply_vanilla_rnn(module& m, instruction_ref ins) const
...
@@ -214,7 +214,7 @@ void rewrite_rnn::apply_vanilla_rnn(module& m, instruction_ref ins) const
ih
=
m
.
add_literal
(
migraphx
::
literal
{
ih_shape
,
data
});
ih
=
m
.
add_literal
(
migraphx
::
literal
{
ih_shape
,
data
});
}
}
if
(
!
is_forward
and
variable_seq_len
)
if
(
not
is_forward
and
variable_seq_len
)
{
{
args
[
0
]
=
args
[
0
]
=
m
.
insert_instruction
(
ins
,
make_op
(
"rnn_var_sl_shift_sequence"
),
args
[
0
],
seq_lens
);
m
.
insert_instruction
(
ins
,
make_op
(
"rnn_var_sl_shift_sequence"
),
args
[
0
],
seq_lens
);
...
@@ -520,7 +520,7 @@ void rewrite_rnn::apply_gru(module& m, instruction_ref ins) const
...
@@ -520,7 +520,7 @@ void rewrite_rnn::apply_gru(module& m, instruction_ref ins) const
ih
=
m
.
add_literal
(
migraphx
::
literal
{
ih_shape
,
data
});
ih
=
m
.
add_literal
(
migraphx
::
literal
{
ih_shape
,
data
});
}
}
if
(
!
is_forward
and
variable_seq_len
)
if
(
not
is_forward
and
variable_seq_len
)
{
{
args
[
0
]
=
args
[
0
]
=
m
.
insert_instruction
(
ins
,
make_op
(
"rnn_var_sl_shift_sequence"
),
args
[
0
],
seq_lens
);
m
.
insert_instruction
(
ins
,
make_op
(
"rnn_var_sl_shift_sequence"
),
args
[
0
],
seq_lens
);
...
@@ -977,7 +977,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
...
@@ -977,7 +977,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
pph
=
args
[
7
];
pph
=
args
[
7
];
}
}
if
(
!
is_forward
and
variable_seq_len
)
if
(
not
is_forward
and
variable_seq_len
)
{
{
args
[
0
]
=
args
[
0
]
=
m
.
insert_instruction
(
ins
,
make_op
(
"rnn_var_sl_shift_sequence"
),
args
[
0
],
seq_lens
);
m
.
insert_instruction
(
ins
,
make_op
(
"rnn_var_sl_shift_sequence"
),
args
[
0
],
seq_lens
);
...
@@ -1294,11 +1294,11 @@ bool rewrite_rnn::is_variable_seq_lens(const module& m, instruction_ref seq_lens
...
@@ -1294,11 +1294,11 @@ bool rewrite_rnn::is_variable_seq_lens(const module& m, instruction_ref seq_lens
std
::
vector
<
int64_t
>
vec_lens
;
std
::
vector
<
int64_t
>
vec_lens
;
arg_lens
.
visit
([
&
](
auto
l
)
{
vec_lens
.
assign
(
l
.
begin
(),
l
.
end
());
});
arg_lens
.
visit
([
&
](
auto
l
)
{
vec_lens
.
assign
(
l
.
begin
(),
l
.
end
());
});
int64_t
l
=
0
;
int64_t
l
=
0
;
if
(
!
vec_lens
.
empty
())
if
(
not
vec_lens
.
empty
())
{
{
l
=
vec_lens
[
0
];
l
=
vec_lens
[
0
];
}
}
if
(
!
std
::
all_of
(
vec_lens
.
begin
(),
vec_lens
.
end
(),
[
&
](
auto
v
)
{
return
v
==
l
;
}))
if
(
not
std
::
all_of
(
vec_lens
.
begin
(),
vec_lens
.
end
(),
[
&
](
auto
v
)
{
return
v
==
l
;
}))
{
{
is_var_lens
=
true
;
is_var_lens
=
true
;
}
}
...
@@ -1318,7 +1318,7 @@ rewrite_rnn::get_seq_len(const module& m, instruction_ref input, instruction_ref
...
@@ -1318,7 +1318,7 @@ rewrite_rnn::get_seq_len(const module& m, instruction_ref input, instruction_ref
bool
is_var_lens
=
is_variable_seq_lens
(
m
,
seq_lens
);
bool
is_var_lens
=
is_variable_seq_lens
(
m
,
seq_lens
);
auto
input_shape
=
input
->
get_shape
();
auto
input_shape
=
input
->
get_shape
();
auto
length
=
input_shape
.
lens
()[
0
];
auto
length
=
input_shape
.
lens
()[
0
];
if
(
!
is_var_lens
and
seq_lens
!=
m
.
end
())
if
(
not
is_var_lens
and
seq_lens
!=
m
.
end
())
{
{
auto
arg_len
=
seq_lens
->
eval
();
auto
arg_len
=
seq_lens
->
eval
();
std
::
vector
<
std
::
size_t
>
vec_lens
;
std
::
vector
<
std
::
size_t
>
vec_lens
;
...
@@ -1387,7 +1387,7 @@ void rewrite_rnn::replace_last_cell_output(module& m,
...
@@ -1387,7 +1387,7 @@ void rewrite_rnn::replace_last_cell_output(module& m,
if
(
variable_seq_len
)
if
(
variable_seq_len
)
{
{
if
(
!
ins_outputs
.
empty
())
if
(
not
ins_outputs
.
empty
())
{
{
cell_outputs
=
m
.
insert_instruction
(
cell_outputs
=
m
.
insert_instruction
(
std
::
next
(
ins
),
std
::
next
(
ins
),
...
...
src/shape.cpp
View file @
5ec8f913
...
@@ -477,7 +477,7 @@ bool operator==(const shape::dynamic_dimension& x, const shape::dynamic_dimensio
...
@@ -477,7 +477,7 @@ bool operator==(const shape::dynamic_dimension& x, const shape::dynamic_dimensio
bool
operator
!=
(
const
shape
::
dynamic_dimension
&
x
,
const
shape
::
dynamic_dimension
&
y
)
bool
operator
!=
(
const
shape
::
dynamic_dimension
&
x
,
const
shape
::
dynamic_dimension
&
y
)
{
{
return
!
(
x
==
y
);
return
not
(
x
==
y
);
}
}
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
shape
::
dynamic_dimension
&
x
)
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
shape
::
dynamic_dimension
&
x
)
{
{
...
@@ -497,7 +497,7 @@ bool operator==(const shape& x, const shape& y)
...
@@ -497,7 +497,7 @@ bool operator==(const shape& x, const shape& y)
x
.
strides
()
==
y
.
strides
()
and
x
.
sub_shapes
()
==
y
.
sub_shapes
());
x
.
strides
()
==
y
.
strides
()
and
x
.
sub_shapes
()
==
y
.
sub_shapes
());
}
}
bool
operator
!=
(
const
shape
&
x
,
const
shape
&
y
)
{
return
!
(
x
==
y
);
}
bool
operator
!=
(
const
shape
&
x
,
const
shape
&
y
)
{
return
not
(
x
==
y
);
}
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
shape
&
x
)
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
shape
&
x
)
{
{
...
...
src/simplify_algebra.cpp
View file @
5ec8f913
...
@@ -208,6 +208,42 @@ struct find_mul_add
...
@@ -208,6 +208,42 @@ struct find_mul_add
}
}
};
};
struct
find_dot_add
{
auto
matcher
()
const
{
return
match
::
name
(
"dot"
)(
match
::
either_arg
(
0
,
1
)(
match
::
name
(
"add"
)(
match
::
either_arg
(
0
,
1
)(
match
::
any
().
bind
(
"x"
),
match
::
any_of
(
match
::
is_constant
()).
bind
(
"b"
)),
match
::
none_of
(
match
::
args
(
match
::
is_constant
(),
match
::
is_constant
())),
match
::
used_once
()),
match
::
is_constant
().
bind
(
"a"
)));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
a_ins
=
r
.
instructions
[
"a"
];
auto
b_ins
=
r
.
instructions
[
"b"
];
auto
x_ins
=
r
.
instructions
[
"x"
];
assert
(
x_ins
!=
b_ins
);
const
bool
flipped
=
a_ins
==
ins
->
inputs
().
back
();
auto
insert_dot
=
[
&
](
auto
x
,
auto
y
)
{
if
(
flipped
)
return
m
.
insert_instruction
(
ins
,
make_op
(
"dot"
),
y
,
x
);
else
return
m
.
insert_instruction
(
ins
,
make_op
(
"dot"
),
x
,
y
);
};
auto
ax_ins
=
insert_dot
(
a_ins
,
x_ins
);
auto
ab_ins
=
insert_dot
(
a_ins
,
b_ins
);
m
.
replace_instruction
(
ins
,
make_op
(
"add"
),
ax_ins
,
ab_ins
);
}
};
struct
find_add_lit_broadcast
struct
find_add_lit_broadcast
{
{
auto
matcher
()
const
auto
matcher
()
const
...
@@ -267,28 +303,26 @@ struct find_double_add_lit_broadcast
...
@@ -267,28 +303,26 @@ struct find_double_add_lit_broadcast
struct
find_inner_broadcast
struct
find_inner_broadcast
{
{
auto
matcher
()
const
auto
matcher
()
const
{
return
pointwise
(
match
::
all_of
[
match
::
inputs
()](
match
::
broadcast
()));
}
{
return
pointwise
(
match
::
nargs
(
2
),
match
::
args
(
match
::
name
(
"broadcast"
).
bind
(
"x"
),
match
::
name
(
"broadcast"
).
bind
(
"y"
)));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
{
auto
ins
=
r
.
result
;
auto
ins
=
r
.
result
;
auto
x_ins
=
r
.
instructions
[
"x"
];
auto
broadcasts
=
ins
->
inputs
();
auto
y_ins
=
r
.
instructions
[
"y"
];
if
(
broadcasts
.
empty
())
return
;
auto
xbroadcast
=
any_cast
<
op
::
broadcast
>
(
x_ins
->
get_operator
());
std
::
vector
<
instruction_ref
>
inputs
;
auto
ybroadcast
=
any_cast
<
op
::
broadcast
>
(
y_ins
->
get_operator
());
std
::
transform
(
broadcasts
.
begin
(),
broadcasts
.
end
(),
if
(
xbroadcast
.
axis
!=
ybroadcast
.
axis
)
std
::
back_inserter
(
inputs
),
[](
auto
i
)
{
return
i
->
inputs
().
front
();
});
if
(
std
::
any_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
&
](
auto
i
)
{
return
i
->
get_shape
()
!=
inputs
.
front
()
->
get_shape
();
}))
return
;
return
;
auto
op
=
m
.
insert_instruction
(
auto
op
=
m
.
insert_instruction
(
ins
,
ins
->
get_operator
(),
inputs
);
ins
,
ins
->
get_operator
(),
x_ins
->
inputs
().
front
(),
y_ins
->
inputs
().
front
());
m
.
replace_instruction
(
ins
,
broadcasts
.
front
()
->
get_operator
(),
op
);
m
.
replace_instruction
(
ins
,
xbroadcast
,
op
);
}
}
};
};
...
@@ -416,8 +450,9 @@ struct find_splits
...
@@ -416,8 +450,9 @@ struct find_splits
{
{
auto
matcher
()
const
auto
matcher
()
const
{
{
return
match
::
any
(
match
::
any_of
[
match
::
outputs
()](
match
::
name
(
"slice"
)(
return
match
::
any
(
match
::
any_of
[
match
::
outputs
()](
match
::
pointwise
(),
reduction
()))));
match
::
any_of
[
match
::
outputs
()](
match
::
name
(
"slice"
)(
match
::
any_of
[
match
::
outputs
()](
match
::
pointwise
(
match
::
any_of
(
match
::
nargs
(
1
),
match
::
nargs
(
2
))),
reduction
()))));
}
}
static
bool
is_dependent
(
const
module
&
m
,
instruction_ref
ins1
,
instruction_ref
ins2
)
static
bool
is_dependent
(
const
module
&
m
,
instruction_ref
ins1
,
instruction_ref
ins2
)
...
@@ -580,10 +615,9 @@ struct find_splits
...
@@ -580,10 +615,9 @@ struct find_splits
auto
outputs
=
i
->
outputs
();
auto
outputs
=
i
->
outputs
();
for
(
auto
output
:
outputs
)
for
(
auto
output
:
outputs
)
{
{
if
(
not
contains
({
"reshape"
,
"squeeze"
,
"unsqueeze"
},
output
->
name
()
)
)
if
(
output
->
name
()
!=
"reshape"
)
continue
;
continue
;
auto
x
=
auto
x
=
m
.
insert_instruction
(
output
,
make_op
(
"contiguous"
),
i
);
m
.
insert_instruction
(
output
,
make_op
(
"contiguous"
),
output
->
inputs
());
m
.
replace_instruction
(
output
,
output
->
get_operator
(),
x
);
m
.
replace_instruction
(
output
,
output
->
get_operator
(),
x
);
}
}
...
@@ -753,7 +787,7 @@ MIGRAPHX_PRED_MATCHER(horiz_conv_dot, instruction_ref ins)
...
@@ -753,7 +787,7 @@ MIGRAPHX_PRED_MATCHER(horiz_conv_dot, instruction_ref ins)
};
};
auto
dots
=
std
::
count_if
(
ins
->
outputs
().
begin
(),
ins
->
outputs
().
end
(),
pred
(
"dot"
));
auto
dots
=
std
::
count_if
(
ins
->
outputs
().
begin
(),
ins
->
outputs
().
end
(),
pred
(
"dot"
));
auto
convs
=
std
::
count_if
(
ins
->
outputs
().
begin
(),
ins
->
outputs
().
end
(),
pred
(
"convolution"
));
auto
convs
=
std
::
count_if
(
ins
->
outputs
().
begin
(),
ins
->
outputs
().
end
(),
pred
(
"convolution"
));
return
!
(
dots
<
2
and
convs
<
2
);
return
not
(
dots
<
2
and
convs
<
2
);
}
}
struct
find_conv_dot_horiz_fusion
struct
find_conv_dot_horiz_fusion
...
@@ -773,7 +807,7 @@ struct find_conv_dot_horiz_fusion
...
@@ -773,7 +807,7 @@ struct find_conv_dot_horiz_fusion
auto
y
=
j
->
inputs
()[
1
]
->
get_shape
().
lens
();
auto
y
=
j
->
inputs
()[
1
]
->
get_shape
().
lens
();
if
(
x
.
size
()
!=
y
.
size
())
if
(
x
.
size
()
!=
y
.
size
())
return
false
;
return
false
;
// Check that non-ax
is
es match
// Check that non-axes match
int
axis
=
1
;
int
axis
=
1
;
if
(
i
->
name
()
==
"dot"
)
if
(
i
->
name
()
==
"dot"
)
{
{
...
@@ -809,13 +843,22 @@ struct find_conv_dot_horiz_fusion
...
@@ -809,13 +843,22 @@ struct find_conv_dot_horiz_fusion
for
(
auto
arg
:
args
)
for
(
auto
arg
:
args
)
m
.
move_instructions
(
arg
,
input
);
m
.
move_instructions
(
arg
,
input
);
// TODO: Check if ax
is
es match
// TODO: Check if axes match
auto
concat
=
auto
concat
=
m
.
insert_instruction
(
input
,
make_op
(
"concat"
,
{{
"axis"
,
concat_axis
}}),
args
);
m
.
insert_instruction
(
input
,
make_op
(
"concat"
,
{{
"axis"
,
concat_axis
}}),
args
);
auto
fused
=
m
.
insert_instruction
(
std
::
next
(
input
),
op
,
input
,
concat
);
auto
fused
=
m
.
insert_instruction
(
std
::
next
(
input
),
op
,
input
,
concat
);
int64_t
offset
=
0
;
int64_t
offset
=
0
;
for
(
auto
arg
:
range
(
start
,
last
))
for
(
auto
arg
:
range
(
start
,
last
))
{
{
auto
outputs
=
arg
->
outputs
();
for
(
auto
output
:
outputs
)
{
if
(
output
->
name
()
!=
"reshape"
)
continue
;
auto
x
=
m
.
insert_instruction
(
output
,
make_op
(
"contiguous"
),
arg
);
m
.
replace_instruction
(
output
,
output
->
get_operator
(),
x
);
}
int64_t
len
=
arg
->
get_shape
().
lens
()[
axis
];
int64_t
len
=
arg
->
get_shape
().
lens
()[
axis
];
m
.
replace_instruction
(
m
.
replace_instruction
(
arg
,
arg
,
...
@@ -993,7 +1036,7 @@ struct find_split_reshape
...
@@ -993,7 +1036,7 @@ struct find_split_reshape
// all outputs are reshape and of the same shape
// all outputs are reshape and of the same shape
auto
dims
=
any_cast
<
op
::
reshape
>
(
rsp
->
get_operator
()).
dims
;
auto
dims
=
any_cast
<
op
::
reshape
>
(
rsp
->
get_operator
()).
dims
;
if
(
!
same_ops
(
vec_rsp
))
if
(
not
same_ops
(
vec_rsp
))
{
{
return
;
return
;
}
}
...
@@ -1025,7 +1068,11 @@ struct find_split_reshape
...
@@ -1025,7 +1068,11 @@ struct find_split_reshape
std
::
vector
<
int64_t
>
rsp_out_lens
(
rsp_lens
.
begin
(),
rsp_lens
.
end
());
std
::
vector
<
int64_t
>
rsp_out_lens
(
rsp_lens
.
begin
(),
rsp_lens
.
end
());
rsp_out_lens
[
rsp_axis
]
=
std
::
accumulate
(
vec_dims
.
begin
(),
vec_dims
.
end
(),
std
::
int64_t
{
0
});
rsp_out_lens
[
rsp_axis
]
=
std
::
accumulate
(
vec_dims
.
begin
(),
vec_dims
.
end
(),
std
::
int64_t
{
0
});
// insert the reshape instruction
// insert the reshape instruction and add contiguous if needed
if
(
not
input
->
get_shape
().
standard
())
{
input
=
m
.
insert_instruction
(
std
::
next
(
input
),
make_op
(
"contiguous"
),
input
);
}
auto
rsp_ins
=
m
.
insert_instruction
(
auto
rsp_ins
=
m
.
insert_instruction
(
std
::
next
(
input
),
make_op
(
"reshape"
,
{{
"dims"
,
rsp_out_lens
}}),
input
);
std
::
next
(
input
),
make_op
(
"reshape"
,
{{
"dims"
,
rsp_out_lens
}}),
input
);
...
@@ -1072,7 +1119,7 @@ struct find_split_transpose
...
@@ -1072,7 +1119,7 @@ struct find_split_transpose
// all transpose are the same
// all transpose are the same
auto
perm
=
any_cast
<
op
::
transpose
>
(
trans
->
get_operator
()).
dims
;
auto
perm
=
any_cast
<
op
::
transpose
>
(
trans
->
get_operator
()).
dims
;
if
(
!
same_ops
(
vec_trans
))
if
(
not
same_ops
(
vec_trans
))
{
{
return
;
return
;
}
}
...
@@ -1118,6 +1165,7 @@ void simplify_algebra::apply(module& m) const
...
@@ -1118,6 +1165,7 @@ void simplify_algebra::apply(module& m) const
find_unit_ops
{},
find_unit_ops
{},
find_neg_unit_ops
{},
find_neg_unit_ops
{},
find_zero_ops
{},
find_zero_ops
{},
find_dot_add
{},
find_div_const
{},
find_div_const
{},
find_sub_const
{},
find_sub_const
{},
find_rsqrt
{},
find_rsqrt
{},
...
...
src/simplify_reshapes.cpp
View file @
5ec8f913
...
@@ -99,7 +99,7 @@ struct find_reshaper
...
@@ -99,7 +99,7 @@ struct find_reshaper
std
::
vector
<
instruction_ref
>
reshapes
{
ins
};
std
::
vector
<
instruction_ref
>
reshapes
{
ins
};
while
(
is_reshaper
(
reshapes
.
back
()))
while
(
is_reshaper
(
reshapes
.
back
()))
{
{
assert
(
!
reshapes
.
back
()
->
inputs
().
empty
());
assert
(
not
reshapes
.
back
()
->
inputs
().
empty
());
assert
(
m
.
has_instruction
(
reshapes
.
back
()
->
inputs
().
front
()));
assert
(
m
.
has_instruction
(
reshapes
.
back
()
->
inputs
().
front
()));
auto
input
=
reshapes
.
back
()
->
inputs
().
front
();
auto
input
=
reshapes
.
back
()
->
inputs
().
front
();
reshapes
.
push_back
(
input
);
reshapes
.
push_back
(
input
);
...
@@ -151,8 +151,11 @@ struct find_transpose
...
@@ -151,8 +151,11 @@ struct find_transpose
{
{
auto
matcher
()
const
auto
matcher
()
const
{
{
return
match
::
name
(
"transpose"
)(
match
::
none_of
(
auto
output_not_transpose
=
match
::
skip_output
(
match
::
name
(
"contiguous"
))(
match
::
name
(
"transpose"
))));
match
::
none_of
(
match
::
skip_output
(
match
::
name
(
"contiguous"
))(
match
::
name
(
"transpose"
)));
auto
input_has_transpose
=
match
::
args
(
match
::
skip
(
match
::
name
(
"contiguous"
))(
match
::
name
(
"transpose"
)));
return
match
::
name
(
"transpose"
)(
output_not_transpose
,
input_has_transpose
);
}
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
mr
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
mr
)
const
...
@@ -285,7 +288,7 @@ struct find_concat_transpose
...
@@ -285,7 +288,7 @@ struct find_concat_transpose
auto
permutation
=
find_permutation
(
s
);
auto
permutation
=
find_permutation
(
s
);
// permutation should be the same for all inputs
// permutation should be the same for all inputs
if
(
!
std
::
all_of
(
trans_inputs
.
begin
(),
trans_inputs
.
end
(),
[
&
](
auto
in
)
{
if
(
not
std
::
all_of
(
trans_inputs
.
begin
(),
trans_inputs
.
end
(),
[
&
](
auto
in
)
{
return
(
find_permutation
(
in
->
get_shape
())
==
permutation
);
return
(
find_permutation
(
in
->
get_shape
())
==
permutation
);
}))
}))
{
{
...
@@ -664,9 +667,94 @@ struct find_slice_transpose
...
@@ -664,9 +667,94 @@ struct find_slice_transpose
}
}
};
};
struct
find_transpose_slice
{
auto
matcher
()
const
{
return
match
::
name
(
"transpose"
)(
match
::
all_of
[
match
::
outputs
()](
match
::
name
(
"slice"
)));
}
static
std
::
vector
<
int64_t
>
slice_distance
(
const
op
::
slice
&
op
)
{
assert
(
op
.
starts
.
size
()
==
op
.
ends
.
size
());
std
::
vector
<
int64_t
>
result
(
op
.
starts
.
size
());
std
::
transform
(
op
.
ends
.
begin
(),
op
.
ends
.
end
(),
op
.
starts
.
begin
(),
result
.
begin
(),
std
::
minus
<>
{});
return
result
;
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
slices
=
ins
->
outputs
();
if
(
slices
.
empty
())
return
;
auto
slice
=
any_cast
<
op
::
slice
>
(
slices
.
front
()
->
get_operator
());
auto
sdistance
=
slice_distance
(
slice
);
// Check all distances and axes are the same
if
(
std
::
any_of
(
slices
.
begin
(),
slices
.
end
(),
[
&
](
auto
sins
)
{
auto
s
=
any_cast
<
op
::
slice
>
(
sins
->
get_operator
());
return
s
.
axes
!=
slice
.
axes
or
slice_distance
(
s
)
!=
sdistance
;
}))
return
;
// Check distances are divisible by lens of corresponding axes
auto
mod_by_distance
=
[
&
](
const
auto
&
v
,
auto
f
)
{
return
std
::
inner_product
(
v
.
begin
(),
v
.
end
(),
sdistance
.
begin
(),
0
,
std
::
plus
<>
{},
[
&
](
auto
x
,
auto
d
)
->
uint64_t
{
if
(
d
==
0
)
return
1
;
return
f
(
x
)
%
d
;
});
};
if
(
mod_by_distance
(
slice
.
axes
,
[
&
](
auto
x
)
{
return
ins
->
get_shape
().
lens
()[
x
];
})
!=
0
or
mod_by_distance
(
slice
.
starts
,
id
{})
!=
0
or
mod_by_distance
(
slice
.
ends
,
id
{})
!=
0
)
return
;
// TODO: Handle multiple axes
if
(
sdistance
.
size
()
!=
1
)
return
;
auto
axis
=
slice
.
axes
.
front
();
// Skip if axis would be packed
if
(
std
::
all_of
(
ins
->
get_shape
().
lens
().
begin
(),
ins
->
get_shape
().
lens
().
begin
()
+
axis
,
[](
auto
x
)
{
return
x
==
1
;
}))
return
;
// Compute axis before transpose to use for unsqueeze
auto
perm
=
ins
->
get_operator
().
to_value
()[
"permutation"
].
to_vector
<
int64_t
>
();
auto
preaxis
=
std
::
find
(
perm
.
begin
(),
perm
.
end
(),
axis
)
-
perm
.
begin
();
// Make unsqeeze
auto
unsqueeze
=
m
.
insert_instruction
(
ins
,
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
preaxis
}},
{
"steps"
,
sdistance
}}),
ins
->
inputs
());
// Make transpose
std
::
transform
(
perm
.
begin
(),
perm
.
end
(),
perm
.
begin
(),
[
&
](
auto
i
)
{
if
(
i
>
preaxis
)
return
i
+
1
;
return
i
;
});
perm
.
insert
(
perm
.
begin
(),
preaxis
+
1
);
auto
transpose
=
m
.
insert_instruction
(
ins
,
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
unsqueeze
);
// Slice and squeeze
for
(
auto
s
:
slices
)
{
auto
op
=
any_cast
<
op
::
slice
>
(
s
->
get_operator
());
op
.
axes
=
{
0
};
op
.
starts
=
{
op
.
starts
.
front
()
/
sdistance
.
front
()};
op
.
ends
=
{
op
.
ends
.
front
()
/
sdistance
.
front
()};
auto
slice_ins
=
m
.
insert_instruction
(
ins
,
op
,
transpose
);
auto
squeeze
=
m
.
insert_instruction
(
ins
,
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
slice_ins
);
m
.
replace_instruction
(
s
,
squeeze
);
}
}
};
void
simplify_reshapes
::
apply
(
module
&
m
)
const
void
simplify_reshapes
::
apply
(
module
&
m
)
const
{
{
for
(
int
i
=
0
;
i
<
2
;
i
++
)
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
{
match
::
find_matches
(
m
,
match
::
find_matches
(
m
,
find_where_op
{},
find_where_op
{},
...
@@ -679,6 +767,7 @@ void simplify_reshapes::apply(module& m) const
...
@@ -679,6 +767,7 @@ void simplify_reshapes::apply(module& m) const
find_nested_convert
{},
find_nested_convert
{},
find_nested_slice
{},
find_nested_slice
{},
find_nested_concat
{},
find_nested_concat
{},
find_transpose_slice
{},
find_slice_transpose
{},
find_slice_transpose
{},
find_transpose_contiguous_reshaper_unary
{});
find_transpose_contiguous_reshaper_unary
{});
dead_code_elimination
{}.
apply
(
m
);
dead_code_elimination
{}.
apply
(
m
);
...
...
src/targets/cpu/binary.cpp
View file @
5ec8f913
...
@@ -49,7 +49,7 @@ struct dnnl_binary : dnnl_op<dnnl_binary, dnnl::binary>
...
@@ -49,7 +49,7 @@ struct dnnl_binary : dnnl_op<dnnl_binary, dnnl::binary>
auto
s0
=
inputs
.
at
(
0
);
auto
s0
=
inputs
.
at
(
0
);
auto
s1
=
inputs
.
at
(
1
);
auto
s1
=
inputs
.
at
(
1
);
auto
r
=
s0
;
auto
r
=
s0
;
if
(
s0
!=
s1
or
!
s0
.
packed
())
if
(
s0
!=
s1
or
not
s0
.
packed
())
{
{
r
=
shape
{
s0
.
type
(),
s0
.
lens
()};
r
=
shape
{
s0
.
type
(),
s0
.
lens
()};
}
}
...
...
src/targets/fpga/include/migraphx/fpga/target.hpp
View file @
5ec8f913
...
@@ -30,6 +30,7 @@
...
@@ -30,6 +30,7 @@
#include <migraphx/compile_options.hpp>
#include <migraphx/compile_options.hpp>
#include <migraphx/fpga/context.hpp>
#include <migraphx/fpga/context.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/supported_segments.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -41,7 +42,7 @@ struct target
...
@@ -41,7 +42,7 @@ struct target
std
::
string
name
()
const
;
std
::
string
name
()
const
;
std
::
vector
<
pass
>
get_passes
(
migraphx
::
context
&
ctx
,
const
compile_options
&
)
const
;
std
::
vector
<
pass
>
get_passes
(
migraphx
::
context
&
ctx
,
const
compile_options
&
)
const
;
migraphx
::
context
get_context
()
const
{
return
context
{};
}
migraphx
::
context
get_context
()
const
{
return
context
{};
}
float
is
_supported
(
i
nst
ruction
_ref
ins
,
support_metric
m
);
supported_segments
find
_supported
(
co
nst
_module
_ref
mod
,
support_metric
m
)
const
;
argument
copy_to
(
const
argument
&
arg
)
const
{
return
arg
;
}
argument
copy_to
(
const
argument
&
arg
)
const
{
return
arg
;
}
argument
copy_from
(
const
argument
&
arg
)
const
{
return
arg
;
}
argument
copy_from
(
const
argument
&
arg
)
const
{
return
arg
;
}
...
...
src/targets/fpga/subgraph.cpp
View file @
5ec8f913
...
@@ -95,7 +95,7 @@ void subgraph::apply(module_pass_manager& mpm) const
...
@@ -95,7 +95,7 @@ void subgraph::apply(module_pass_manager& mpm) const
for
(
auto
it
:
iterator_for
(
mod
))
for
(
auto
it
:
iterator_for
(
mod
))
{
{
// assuming we want all the params/literals as inputs to the FPGA submodule
// assuming we want all the params/literals as inputs to the FPGA submodule
if
(
migraphx
::
starts_with
(
it
->
name
(),
"@param"
)
||
if
(
migraphx
::
starts_with
(
it
->
name
(),
"@param"
)
or
migraphx
::
starts_with
(
it
->
name
(),
"@literal"
))
migraphx
::
starts_with
(
it
->
name
(),
"@literal"
))
{
{
literal_inputs
.
push_back
(
it
);
literal_inputs
.
push_back
(
it
);
...
...
src/targets/fpga/target.cpp
View file @
5ec8f913
...
@@ -34,6 +34,7 @@
...
@@ -34,6 +34,7 @@
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/normalize_ops.hpp>
#include <migraphx/normalize_ops.hpp>
#include <migraphx/iterator_for.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -62,12 +63,17 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
...
@@ -62,12 +63,17 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
argument
target
::
allocate
(
const
shape
&
s
)
const
{
return
fill_argument
(
s
,
0
);
}
argument
target
::
allocate
(
const
shape
&
s
)
const
{
return
fill_argument
(
s
,
0
);
}
float
is
_supported
(
i
nst
ruction
_ref
ins
,
support_metric
m
)
supported_segments
target
::
find
_supported
(
co
nst
_module
_ref
mod
,
support_metric
m
)
const
{
{
// for now, not using the ins and metric to return a value
(
void
)
ins
;
(
void
)
m
;
(
void
)
m
;
return
1.0
;
supported_segment
instrs
;
for
(
const
auto
ins
:
iterator_for
(
*
mod
))
{
instrs
.
instructions
.
insert
(
ins
);
}
instrs
.
metric
=
1
;
// arbitrary value
return
{
instrs
};
}
}
MIGRAPHX_REGISTER_TARGET
(
target
);
MIGRAPHX_REGISTER_TARGET
(
target
);
...
...
src/targets/gpu/code_object_op.cpp
View file @
5ec8f913
...
@@ -51,7 +51,8 @@ code_object_op::compute(context& ctx, const shape&, const std::vector<argument>&
...
@@ -51,7 +51,8 @@ code_object_op::compute(context& ctx, const shape&, const std::vector<argument>&
std
::
vector
<
void
*>
kargs
(
args
.
size
());
std
::
vector
<
void
*>
kargs
(
args
.
size
());
std
::
transform
(
std
::
transform
(
args
.
begin
(),
args
.
end
(),
kargs
.
begin
(),
[](
const
argument
&
a
)
{
return
a
.
data
();
});
args
.
begin
(),
args
.
end
(),
kargs
.
begin
(),
[](
const
argument
&
a
)
{
return
a
.
data
();
});
k
.
launch
(
ctx
.
get_stream
().
get
(),
global
,
local
,
std
::
move
(
kargs
));
auto
[
start
,
stop
]
=
ctx
.
get_perf_events
();
k
.
launch
(
ctx
.
get_stream
().
get
(),
global
,
local
,
std
::
move
(
kargs
),
start
,
stop
);
return
args
[
get_output_arg
(
args
.
size
())];
return
args
[
get_output_arg
(
args
.
size
())];
}
}
void
code_object_op
::
finalize
(
context
&
,
const
shape
&
,
const
std
::
vector
<
shape
>&
)
void
code_object_op
::
finalize
(
context
&
,
const
shape
&
,
const
std
::
vector
<
shape
>&
)
...
...
src/targets/gpu/compile_gen.cpp
View file @
5ec8f913
...
@@ -25,6 +25,13 @@
...
@@ -25,6 +25,13 @@
#include <migraphx/shape.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/module.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -54,12 +61,19 @@ vectorize vectorize::elements(std::size_t axis, const std::vector<shape>& inputs
...
@@ -54,12 +61,19 @@ vectorize vectorize::elements(std::size_t axis, const std::vector<shape>& inputs
[
&
](
const
auto
&
input
)
->
std
::
size_t
{
[
&
](
const
auto
&
input
)
->
std
::
size_t
{
auto
stride
=
input
.
strides
()[
axis
];
auto
stride
=
input
.
strides
()[
axis
];
auto
len
=
input
.
lens
()[
axis
];
auto
len
=
input
.
lens
()[
axis
];
if
(
stride
!=
0
and
stride
!=
1
)
if
(
not
contains
({
0
,
1
},
stride
)
)
return
1
;
return
1
;
if
(
len
==
1
and
input
.
elements
()
>
sizes
.
front
())
if
(
len
==
1
and
input
.
elements
()
>
sizes
.
front
())
return
sizes
.
front
();
return
sizes
.
front
();
auto
it
=
std
::
find_if
(
auto
it
=
std
::
find_if
(
sizes
.
begin
(),
sizes
.
end
(),
[
&
](
auto
vsize
)
{
sizes
.
begin
(),
sizes
.
end
(),
[
&
](
auto
i
)
{
return
(
len
%
i
)
==
0
;
});
// The len is divisible by the size and all the strides are divisible by
// the size
return
(
len
%
vsize
)
==
0
and
std
::
all_of
(
input
.
strides
().
begin
(),
input
.
strides
().
end
(),
[
&
](
auto
i
)
{
return
contains
({
0
,
1
},
i
)
or
i
%
vsize
==
0
;
});
});
if
(
it
!=
sizes
.
end
())
if
(
it
!=
sizes
.
end
())
return
*
it
;
return
*
it
;
return
1
;
return
1
;
...
@@ -75,25 +89,25 @@ std::string vectorize::str() const
...
@@ -75,25 +89,25 @@ std::string vectorize::str() const
preload
preload
::
broadcasts
(
std
::
size_t
axis
,
const
std
::
vector
<
shape
>&
inputs
)
preload
preload
::
broadcasts
(
std
::
size_t
axis
,
const
std
::
vector
<
shape
>&
inputs
)
{
{
const
std
::
size_t
max_lds_bytes
=
4096
;
const
std
::
size_t
max_lds_bytes
=
4096
;
std
::
vector
<
bool
>
result
;
std
::
vector
<
bool
>
result
(
inputs
.
size
())
;
std
::
transform
(
inputs
.
begin
(),
std
::
vector
<
std
::
size_t
>
preloaded
;
inputs
.
end
(),
auto
idxs
=
range
(
inputs
.
size
());
std
::
back_inserter
(
re
sult
),
std
::
copy_if
(
idxs
.
begin
(),
idxs
.
end
(),
std
::
back_inserter
(
p
re
loaded
),
[
&
](
auto
i
)
{
[
&
](
const
shape
&
input
)
{
return
input
.
strides
()[
axis
]
==
0
;
});
return
input
s
[
i
]
.
strides
()[
axis
]
==
0
;
auto
bytes
=
std
::
inner_product
(
inputs
.
begin
(),
});
inputs
.
end
(),
std
::
sort
(
preloaded
.
begin
(),
preloaded
.
end
(),
by
(
std
::
less
<>
{},
[
&
](
auto
i
)
{
re
sult
.
begin
()
,
re
turn
inputs
[
i
].
bytes
()
;
std
::
size_t
{
0
},
}));
std
::
plus
<>
{},
[](
const
shape
&
s
,
bool
b
)
->
std
::
size_t
{
std
::
size_t
bytes
=
0
;
if
(
b
)
for
(
auto
i
:
preloaded
)
return
s
.
bytes
();
{
return
0
;
auto
input
=
inputs
[
i
]
;
}
);
bytes
+=
input
.
bytes
(
);
if
(
bytes
<
max_lds_bytes
)
if
(
bytes
>
max_lds_bytes
)
return
{
result
}
;
break
;
// TODO: Try to partially preload items
result
[
i
]
=
true
;
std
::
fill
(
result
.
begin
(),
result
.
end
(),
false
);
}
return
{
result
};
return
{
result
};
}
}
...
@@ -125,6 +139,45 @@ std::string make_transformer_args(std::vector<std::string> transformers)
...
@@ -125,6 +139,45 @@ std::string make_transformer_args(std::vector<std::string> transformers)
return
join_strings
(
std
::
move
(
transformers
),
", "
);
return
join_strings
(
std
::
move
(
transformers
),
", "
);
}
}
std
::
string
generate_pointwise
(
const
module
&
pm
,
const
std
::
string
&
name
)
{
module
m
=
pm
;
run_passes
(
m
,
{
eliminate_common_subexpression
{},
dead_code_elimination
{}});
cpp_generator
g
;
g
.
fmap
([](
const
std
::
string
&
fname
)
{
return
"migraphx::"
+
fname
;
});
g
.
add_point_op
(
"where"
,
"${function:where}(${0}, ${1}, ${2})"
);
g
.
add_point_op
(
"prelu"
,
"${function:where}(${0} < 0, ${0} * ${1}, ${0})"
);
g
.
add_point_op
(
"sign"
,
"${function:where}(${0} > 0, 1, ${function:where}(${0} < 0, -1, 0))"
);
g
.
add_point_op
(
"equal"
,
"migraphx::abs(${0} == ${1})"
);
g
.
add_point_op
(
"less"
,
"migraphx::abs(${0} < ${1})"
);
g
.
add_point_op
(
"greater"
,
"migraphx::abs(${0} > ${1})"
);
g
.
add_point_op
(
"not"
,
"migraphx::abs(not ${0})"
);
// Add explict conversions
g
.
fresult
(
[](
const
shape
&
s
)
{
return
"migraphx::convert<"
+
shape
::
cpp_type
(
s
.
type
())
+
">"
;
});
g
.
create_function
(
g
.
generate_module
(
m
).
set_attributes
({
"__device__"
}).
set_generic_types
(
m
).
set_name
(
name
));
return
g
.
str
();
}
static
std
::
vector
<
std
::
string
>
get_op_names
(
const
module
&
m
)
{
std
::
vector
<
std
::
string
>
result
;
for
(
auto
&
ins
:
m
)
{
if
(
starts_with
(
ins
.
name
(),
"@"
))
continue
;
result
.
push_back
(
ins
.
name
());
}
return
result
;
}
std
::
string
generate_name_from_ops
(
const
module
&
m
)
{
auto
op_names
=
get_op_names
(
m
);
return
join_strings
(
op_names
,
"_"
);
}
}
// namespace gen
}
// namespace gen
}
// namespace gpu
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/targets/gpu/device/include/migraphx/gpu/device/array.hpp
View file @
5ec8f913
...
@@ -131,7 +131,7 @@ struct hip_array
...
@@ -131,7 +131,7 @@ struct hip_array
friend
MIGRAPHX_DEVICE_CONSTEXPR
bool
operator
!=
(
const
hip_array
&
x
,
const
hip_array
&
y
)
friend
MIGRAPHX_DEVICE_CONSTEXPR
bool
operator
!=
(
const
hip_array
&
x
,
const
hip_array
&
y
)
{
{
return
!
(
x
==
y
);
return
not
(
x
==
y
);
}
}
// This uses the product order rather than lexical order
// This uses the product order rather than lexical order
friend
MIGRAPHX_DEVICE_CONSTEXPR
bool
operator
<
(
const
hip_array
&
x
,
const
hip_array
&
y
)
friend
MIGRAPHX_DEVICE_CONSTEXPR
bool
operator
<
(
const
hip_array
&
x
,
const
hip_array
&
y
)
...
...
src/targets/gpu/device/include/migraphx/gpu/device/visit.hpp
View file @
5ec8f913
...
@@ -117,12 +117,13 @@ template <class V, class F, class... Ts>
...
@@ -117,12 +117,13 @@ template <class V, class F, class... Ts>
void
hip_visit_all_impl
(
const
shape
&
s
,
F
f
,
V
&&
v
,
Ts
&&
...
xs
)
void
hip_visit_all_impl
(
const
shape
&
s
,
F
f
,
V
&&
v
,
Ts
&&
...
xs
)
{
{
std
::
initializer_list
<
migraphx
::
shape
::
type_t
>
types
=
{
get_shape
(
xs
).
type
()...};
std
::
initializer_list
<
migraphx
::
shape
::
type_t
>
types
=
{
get_shape
(
xs
).
type
()...};
if
(
!
std
::
all_of
(
if
(
not
std
::
all_of
(
types
.
begin
(),
types
.
end
(),
[
&
](
migraphx
::
shape
::
type_t
t
)
{
return
t
==
s
.
type
();
}))
types
.
begin
(),
types
.
end
(),
[
&
](
migraphx
::
shape
::
type_t
t
)
{
return
t
==
s
.
type
();
}))
MIGRAPHX_THROW
(
"Types must be the same"
);
MIGRAPHX_THROW
(
"Types must be the same"
);
std
::
initializer_list
<
index_int
>
ranks
=
{
std
::
initializer_list
<
index_int
>
ranks
=
{
static_cast
<
index_int
>
(
get_shape
(
xs
).
lens
().
size
())...};
static_cast
<
index_int
>
(
get_shape
(
xs
).
lens
().
size
())...};
if
(
!
std
::
all_of
(
ranks
.
begin
(),
ranks
.
end
(),
[
&
](
index_int
r
)
{
return
r
==
s
.
lens
().
size
();
}))
if
(
not
std
::
all_of
(
ranks
.
begin
(),
ranks
.
end
(),
[
&
](
index_int
r
)
{
return
r
==
s
.
lens
().
size
();
}))
MIGRAPHX_THROW
(
"Ranks must be the same"
);
MIGRAPHX_THROW
(
"Ranks must be the same"
);
visit_tensor_size
(
s
.
lens
().
size
(),
[
&
](
auto
ndim
)
{
visit_tensor_size
(
s
.
lens
().
size
(),
[
&
](
auto
ndim
)
{
s
.
visit_type
(
hip_visitor
([
&
](
auto
as
)
{
v
(
f
(
xs
,
ndim
,
as
)...);
}));
s
.
visit_type
(
hip_visitor
([
&
](
auto
as
)
{
v
(
f
(
xs
,
ndim
,
as
)...);
}));
...
@@ -134,7 +135,8 @@ void hip_visit_views_impl(const shape& s, F f, V&& v, Ts&&... xs)
...
@@ -134,7 +135,8 @@ void hip_visit_views_impl(const shape& s, F f, V&& v, Ts&&... xs)
{
{
std
::
initializer_list
<
index_int
>
ranks
=
{
std
::
initializer_list
<
index_int
>
ranks
=
{
static_cast
<
index_int
>
(
get_shape
(
xs
).
lens
().
size
())...};
static_cast
<
index_int
>
(
get_shape
(
xs
).
lens
().
size
())...};
if
(
!
std
::
all_of
(
ranks
.
begin
(),
ranks
.
end
(),
[
&
](
index_int
r
)
{
return
r
==
s
.
lens
().
size
();
}))
if
(
not
std
::
all_of
(
ranks
.
begin
(),
ranks
.
end
(),
[
&
](
index_int
r
)
{
return
r
==
s
.
lens
().
size
();
}))
MIGRAPHX_THROW
(
"Ranks must be the same"
);
MIGRAPHX_THROW
(
"Ranks must be the same"
);
visit_tensor_size
(
s
.
lens
().
size
(),
[
&
](
auto
ndim
)
{
v
(
f
(
xs
,
ndim
)...);
});
visit_tensor_size
(
s
.
lens
().
size
(),
[
&
](
auto
ndim
)
{
v
(
f
(
xs
,
ndim
)...);
});
}
}
...
...
src/targets/gpu/device/multinomial.cpp
View file @
5ec8f913
...
@@ -47,7 +47,7 @@ constexpr Iterator upper_bound(Iterator first, Iterator last, const T& value)
...
@@ -47,7 +47,7 @@ constexpr Iterator upper_bound(Iterator first, Iterator last, const T& value)
it
=
first
;
it
=
first
;
step
=
count
/
2
;
step
=
count
/
2
;
std
::
advance
(
it
,
step
);
std
::
advance
(
it
,
step
);
if
(
!
(
value
<
*
it
))
if
(
not
(
value
<
*
it
))
{
{
first
=
++
it
;
first
=
++
it
;
count
-=
step
+
1
;
count
-=
step
+
1
;
...
...
src/targets/gpu/driver/compile_op.cpp
View file @
5ec8f913
...
@@ -38,8 +38,11 @@ struct compile_op : action<compile_op>
...
@@ -38,8 +38,11 @@ struct compile_op : action<compile_op>
context
ctx
;
context
ctx
;
auto
inputs
=
p
.
parse_shapes
(
v
.
at
(
"inputs"
));
auto
inputs
=
p
.
parse_shapes
(
v
.
at
(
"inputs"
));
auto
op
=
gpu
::
compile_op
(
v
.
at
(
"name"
).
to
<
std
::
string
>
(),
ctx
,
inputs
,
v
);
auto
op
=
gpu
::
compile_op
(
v
.
at
(
"name"
).
to
<
std
::
string
>
(),
ctx
,
inputs
,
v
);
double
t
=
time_op
(
ctx
,
op
,
inputs
,
p
.
get
(
v
,
"iterations"
,
100
));
auto
[
host_time
,
device_time
]
=
time_op
(
ctx
,
op
,
inputs
,
p
.
get
(
v
,
"iterations"
,
100
));
std
::
cout
<<
op
<<
": "
<<
t
<<
"ms"
<<
std
::
endl
;
std
::
cout
<<
op
<<
": "
<<
host_time
<<
"ms"
;
if
(
device_time
>
0
)
std
::
cout
<<
", "
<<
device_time
<<
"ms"
;
std
::
cout
<<
std
::
endl
;
}
}
};
};
...
...
src/targets/gpu/driver/include/migraphx/gpu/driver/perf.hpp
View file @
5ec8f913
...
@@ -33,7 +33,8 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -33,7 +33,8 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace
gpu
{
namespace
gpu
{
namespace
driver
{
namespace
driver
{
double
time_op
(
context
&
ctx
,
operation
op
,
const
std
::
vector
<
shape
>&
inputs
,
int
n
=
100
);
std
::
pair
<
double
,
double
>
time_op
(
context
&
ictx
,
operation
op
,
const
std
::
vector
<
shape
>&
inputs
,
int
n
=
100
);
}
// namespace driver
}
// namespace driver
}
// namespace gpu
}
// namespace gpu
...
...
Prev
1
2
3
4
5
6
7
8
9
10
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment