Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
bc5d7f75
Commit
bc5d7f75
authored
Feb 15, 2019
by
Paul
Browse files
Merge from develop
parents
47c0854d
a5b0afa0
Changes
337
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1146 additions
and
233 deletions
+1146
-233
src/onnx/perf_onnx.cpp
src/onnx/perf_onnx.cpp
+11
-11
src/onnx/read_onnx.cpp
src/onnx/read_onnx.cpp
+2
-2
src/onnx/verify_onnx.cpp
src/onnx/verify_onnx.cpp
+30
-29
src/opt/common_header.hpp
src/opt/common_header.hpp
+19
-19
src/opt/memory_coloring.cpp
src/opt/memory_coloring.cpp
+6
-6
src/opt/memory_coloring_impl.cpp
src/opt/memory_coloring_impl.cpp
+17
-18
src/opt/memory_coloring_impl.hpp
src/opt/memory_coloring_impl.hpp
+11
-11
src/program.cpp
src/program.cpp
+59
-25
src/py/CMakeLists.txt
src/py/CMakeLists.txt
+20
-0
src/py/migraphx_py.cpp
src/py/migraphx_py.cpp
+153
-0
src/rewrite_rnn.cpp
src/rewrite_rnn.cpp
+668
-0
src/shape.cpp
src/shape.cpp
+10
-10
src/simplify_algebra.cpp
src/simplify_algebra.cpp
+9
-9
src/simplify_reshapes.cpp
src/simplify_reshapes.cpp
+80
-37
src/targets/cpu/CMakeLists.txt
src/targets/cpu/CMakeLists.txt
+7
-7
src/targets/cpu/gemm.cpp
src/targets/cpu/gemm.cpp
+7
-7
src/targets/cpu/include/migraph/cpu/context.hpp
src/targets/cpu/include/migraph/cpu/context.hpp
+0
-19
src/targets/cpu/include/migraph/cpu/target.hpp
src/targets/cpu/include/migraph/cpu/target.hpp
+0
-23
src/targets/cpu/include/migraphx/cpu/context.hpp
src/targets/cpu/include/migraphx/cpu/context.hpp
+19
-0
src/targets/cpu/include/migraphx/cpu/gemm.hpp
src/targets/cpu/include/migraphx/cpu/gemm.hpp
+18
-0
No files found.
src/onnx/perf_onnx.cpp
View file @
bc5d7f75
#include <migraph/onnx.hpp>
#include <migraph
x
/onnx.hpp>
#include <migraph/gpu/target.hpp>
#include <migraph/gpu/hip.hpp>
#include <migraph/generate.hpp>
#include <migraph/verify.hpp>
#include <migraph
x
/gpu/target.hpp>
#include <migraph
x
/gpu/hip.hpp>
#include <migraph
x
/generate.hpp>
#include <migraph
x
/verify.hpp>
migraph
::
program
::
parameter_map
create_param_map
(
const
migraph
::
program
&
p
,
bool
gpu
=
true
)
migraph
x
::
program
::
parameter_map
create_param_map
(
const
migraph
x
::
program
&
p
,
bool
gpu
=
true
)
{
migraph
::
program
::
parameter_map
m
;
migraph
x
::
program
::
parameter_map
m
;
for
(
auto
&&
x
:
p
.
get_parameter_shapes
())
{
if
(
gpu
)
m
[
x
.
first
]
=
migraph
::
gpu
::
to_gpu
(
migraph
::
generate_argument
(
x
.
second
));
m
[
x
.
first
]
=
migraph
x
::
gpu
::
to_gpu
(
migraph
x
::
generate_argument
(
x
.
second
));
else
m
[
x
.
first
]
=
migraph
::
generate_argument
(
x
.
second
);
m
[
x
.
first
]
=
migraph
x
::
generate_argument
(
x
.
second
);
}
return
m
;
}
...
...
@@ -25,9 +25,9 @@ int main(int argc, char const* argv[])
{
std
::
string
file
=
argv
[
1
];
std
::
size_t
n
=
argc
>
2
?
std
::
stoul
(
argv
[
2
])
:
50
;
auto
p
=
migraph
::
parse_onnx
(
file
);
auto
p
=
migraph
x
::
parse_onnx
(
file
);
std
::
cout
<<
"Compiling ... "
<<
std
::
endl
;
p
.
compile
(
migraph
::
gpu
::
target
{});
p
.
compile
(
migraph
x
::
gpu
::
target
{});
std
::
cout
<<
"Allocating params ... "
<<
std
::
endl
;
auto
m
=
create_param_map
(
p
);
std
::
cout
<<
"Running performance report ... "
<<
std
::
endl
;
...
...
src/onnx/read_onnx.cpp
View file @
bc5d7f75
#include <migraph/onnx.hpp>
#include <migraph
x
/onnx.hpp>
int
main
(
int
argc
,
char
const
*
argv
[])
{
if
(
argc
>
1
)
{
std
::
string
file
=
argv
[
1
];
auto
prog
=
migraph
::
parse_onnx
(
file
);
auto
prog
=
migraph
x
::
parse_onnx
(
file
);
std
::
cout
<<
prog
<<
std
::
endl
;
}
}
src/onnx/verify_onnx.cpp
View file @
bc5d7f75
#include <migraph/onnx.hpp>
#include <migraph
x
/onnx.hpp>
#include <migraph/cpu/target.hpp>
#include <migraph/gpu/target.hpp>
#include <migraph/gpu/hip.hpp>
#include <migraph/generate.hpp>
#include <migraph/verify_args.hpp>
#include <migraph/instruction.hpp>
#include <migraph
x
/cpu/target.hpp>
#include <migraph
x
/gpu/target.hpp>
#include <migraph
x
/gpu/hip.hpp>
#include <migraph
x
/generate.hpp>
#include <migraph
x
/verify_args.hpp>
#include <migraph
x
/instruction.hpp>
template
<
class
T
>
auto
get_hash
(
const
T
&
x
)
...
...
@@ -15,14 +15,14 @@ auto get_hash(const T& x)
}
template
<
class
F
>
migraph
::
argument
run_cpu
(
F
f
)
migraph
x
::
argument
run_cpu
(
F
f
)
{
auto
p
=
f
();
p
.
compile
(
migraph
::
cpu
::
target
{});
migraph
::
program
::
parameter_map
m
;
p
.
compile
(
migraph
x
::
cpu
::
target
{});
migraph
x
::
program
::
parameter_map
m
;
for
(
auto
&&
x
:
p
.
get_parameter_shapes
())
{
m
[
x
.
first
]
=
migraph
::
generate_argument
(
x
.
second
,
get_hash
(
x
.
first
));
m
[
x
.
first
]
=
migraph
x
::
generate_argument
(
x
.
second
,
get_hash
(
x
.
first
));
}
auto
out
=
p
.
eval
(
m
);
std
::
cout
<<
p
<<
std
::
endl
;
...
...
@@ -30,19 +30,20 @@ migraph::argument run_cpu(F f)
}
template
<
class
F
>
migraph
::
argument
run_gpu
(
F
f
)
migraph
x
::
argument
run_gpu
(
F
f
)
{
auto
p
=
f
();
p
.
compile
(
migraph
::
gpu
::
target
{});
p
.
compile
(
migraph
x
::
gpu
::
target
{});
migraph
::
program
::
parameter_map
m
;
migraph
x
::
program
::
parameter_map
m
;
for
(
auto
&&
x
:
p
.
get_parameter_shapes
())
{
m
[
x
.
first
]
=
migraph
::
gpu
::
to_gpu
(
migraph
::
generate_argument
(
x
.
second
,
get_hash
(
x
.
first
)));
m
[
x
.
first
]
=
migraphx
::
gpu
::
to_gpu
(
migraphx
::
generate_argument
(
x
.
second
,
get_hash
(
x
.
first
)));
}
auto
out
=
migraph
::
gpu
::
from_gpu
(
p
.
eval
(
m
));
auto
out
=
migraph
x
::
gpu
::
from_gpu
(
p
.
eval
(
m
));
std
::
cout
<<
p
<<
std
::
endl
;
return
migraph
::
gpu
::
from_gpu
(
out
);
return
migraph
x
::
gpu
::
from_gpu
(
out
);
}
template
<
class
F
>
...
...
@@ -50,12 +51,12 @@ void verify_program(const std::string& name, F f, double tolerance = 100)
{
auto
x
=
run_cpu
(
f
);
auto
y
=
run_gpu
(
f
);
migraph
::
verify_args
(
name
,
x
,
y
,
tolerance
);
migraph
x
::
verify_args
(
name
,
x
,
y
,
tolerance
);
// std::cout << "cpu: " << x << std::endl;
// std::cout << "gpu: " << y << std::endl;
}
void
verify_instructions
(
const
migraph
::
program
&
prog
,
double
tolerance
=
80
)
void
verify_instructions
(
const
migraph
x
::
program
&
prog
,
double
tolerance
=
80
)
{
for
(
auto
&&
ins
:
prog
)
{
...
...
@@ -68,8 +69,8 @@ void verify_instructions(const migraph::program& prog, double tolerance = 80)
if
(
ins
.
name
()
==
"reshape"
)
continue
;
auto
create_program
=
[
&
]
{
migraph
::
program
p
;
std
::
vector
<
migraph
::
instruction_ref
>
inputs
;
migraph
x
::
program
p
;
std
::
vector
<
migraph
x
::
instruction_ref
>
inputs
;
for
(
auto
&&
arg
:
ins
.
inputs
())
{
if
(
arg
->
name
()
==
"@literal"
)
...
...
@@ -100,8 +101,8 @@ void verify_reduced(F f, int n, double tolerance = 80)
{
auto
create_program
=
[
&
]
{
migraph
::
program
p
=
f
();
auto
last
=
std
::
prev
(
p
.
end
(),
n
+
1
);
migraph
x
::
program
p
=
f
();
auto
last
=
std
::
prev
(
p
.
end
(),
n
+
1
);
p
.
remove_instructions
(
last
,
p
.
end
());
return
p
;
};
...
...
@@ -113,9 +114,9 @@ void verify_reduced(F f, int n, double tolerance = 80)
template
<
class
F
>
void
verify_reduced_program
(
F
f
,
double
tolerance
=
80
)
{
migraph
::
program
p
=
f
();
auto
n
=
std
::
distance
(
p
.
begin
(),
p
.
end
());
for
(
in
t
i
=
0
;
i
<
n
;
i
++
)
migraph
x
::
program
p
=
f
();
auto
n
=
std
::
distance
(
p
.
begin
(),
p
.
end
());
for
(
std
::
size_
t
i
=
0
;
i
<
n
;
i
++
)
{
verify_reduced
(
f
,
i
,
tolerance
);
}
...
...
@@ -127,7 +128,7 @@ int main(int argc, char const* argv[])
if
(
not
args
.
empty
())
{
std
::
string
file
=
args
.
front
();
auto
p
=
migraph
::
parse_onnx
(
file
);
auto
p
=
migraph
x
::
parse_onnx
(
file
);
std
::
cout
<<
p
<<
std
::
endl
;
if
(
std
::
any_of
(
args
.
begin
(),
args
.
end
(),
[](
const
auto
&
s
)
{
return
s
==
"-i"
;
}))
...
...
@@ -136,11 +137,11 @@ int main(int argc, char const* argv[])
}
else
if
(
std
::
any_of
(
args
.
begin
(),
args
.
end
(),
[](
const
auto
&
s
)
{
return
s
==
"-r"
;
}))
{
verify_reduced_program
([
&
]
{
return
migraph
::
parse_onnx
(
file
);
});
verify_reduced_program
([
&
]
{
return
migraph
x
::
parse_onnx
(
file
);
});
}
else
{
verify_program
(
file
,
[
&
]
{
return
migraph
::
parse_onnx
(
file
);
});
verify_program
(
file
,
[
&
]
{
return
migraph
x
::
parse_onnx
(
file
);
});
}
}
}
src/opt/common_header.hpp
View file @
bc5d7f75
#ifndef MIGRAPH_GUARD_RTGLIB_COMMON_HEADER_HPP
#define MIGRAPH_GUARD_RTGLIB_COMMON_HEADER_HPP
#include <migraph/program.hpp>
#include <migraph/stringutils.hpp>
#include <migraph/instruction.hpp>
#include <migraph/operators.hpp>
#include <migraph/iterator_for.hpp>
#include <migraph/pass_config.hpp>
#include <migraph/config.hpp>
#ifndef MIGRAPH
X
_GUARD_RTGLIB_COMMON_HEADER_HPP
#define MIGRAPH
X
_GUARD_RTGLIB_COMMON_HEADER_HPP
#include <migraph
x
/program.hpp>
#include <migraph
x
/stringutils.hpp>
#include <migraph
x
/instruction.hpp>
#include <migraph
x
/operators.hpp>
#include <migraph
x
/iterator_for.hpp>
#include <migraph
x
/pass_config.hpp>
#include <migraph
x
/config.hpp>
#include <set>
#include <list>
#include <vector>
#include <queue>
namespace
migraph
{
inline
namespace
MIGRAPH_INLINE_NS
{
namespace
migraph
x
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
//#define MIGRAPH_DEBUG_OPT
//#define MIGRAPH
X
_DEBUG_OPT
#ifdef MIGRAPH_DEBUG_OPT
#define MIGRAPH_DEBUG(s) s
#ifdef MIGRAPH
X
_DEBUG_OPT
#define MIGRAPH
X
_DEBUG(s) s
#else
#define MIGRAPH_DEBUG(s)
#endif // MIGRAPH_DEBUG_OPT
#define MIGRAPH
X
_DEBUG(s)
#endif // MIGRAPH
X
_DEBUG_OPT
}
// namespace MIGRAPH_INLINE_NS
}
// namespace migraph
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraph
x
#endif // MIGRAPH_GUARD_RTGLIB_COMMON_HEADER_HPP
#endif // MIGRAPH
X
_GUARD_RTGLIB_COMMON_HEADER_HPP
src/opt/memory_coloring.cpp
View file @
bc5d7f75
#include <migraph/memory_coloring.hpp>
#include <migraph
x
/memory_coloring.hpp>
#include "memory_coloring_impl.hpp"
namespace
migraph
{
inline
namespace
MIGRAPH_INLINE_NS
{
namespace
migraph
x
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
void
memory_coloring
::
apply
(
program
&
p
)
const
{
if
(
!
enabled
(
MIGRAPH_DISABLE_MEMORY_COLORING
{}))
if
(
!
enabled
(
MIGRAPH
X
_DISABLE_MEMORY_COLORING
{}))
{
memory_coloring_impl
opt
(
&
p
,
allocation_op
,
verify
);
opt
.
run
();
}
}
}
// namespace MIGRAPH_INLINE_NS
}
// namespace migraph
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraph
x
src/opt/memory_coloring_impl.cpp
View file @
bc5d7f75
#include "memory_coloring_impl.hpp"
namespace
migraph
{
inline
namespace
MIGRAPH_INLINE_NS
{
namespace
migraph
x
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
void
memory_coloring_impl
::
run
()
{
MIGRAPH_DEBUG
(
dump
(
"---Before memory coloring---"
));
MIGRAPH_DEBUG
(
dump_program
());
MIGRAPH
X
_DEBUG
(
dump
(
"---Before memory coloring---"
));
MIGRAPH
X
_DEBUG
(
dump_program
());
build
();
if
(
num_of_lives
!=
0
)
{
MIGRAPH_DEBUG
(
dump_intervals
());
MIGRAPH
X
_DEBUG
(
dump_intervals
());
// Coloring
while
(
!
alloc_queue
.
empty
())
{
...
...
@@ -85,7 +85,7 @@ bool memory_coloring_impl::allocate(interval_ptr interval)
conflict_queue
.
pop
();
}
segment
.
offset
=
offset
;
MIGRAPH_DEBUG
(
segment
.
dump
());
MIGRAPH
X
_DEBUG
(
segment
.
dump
());
required_bytes
=
std
::
max
(
required_bytes
,
offset
+
segment
.
size
);
return
true
;
}
...
...
@@ -118,11 +118,11 @@ void memory_coloring_impl::build()
live_range
&
range
=
def_interval
->
segment
;
def_interval
->
result
=
iter
->
get_shape
();
def_interval
->
is_literal
=
is_lit
;
range
.
begin
=
cur_points
;
def_interval
->
def_point
=
cur_points
;
range
.
size
=
(
iter
->
get_shape
()).
bytes
();
if
(
!
is_lit
||
unify_literals
)
alloc_queue
.
push
(
def_interval
);
range
.
begin
=
cur_points
;
def_interval
->
def_point
=
cur_points
;
range
.
size
=
(
iter
->
get_shape
()).
bytes
();
live_set
.
erase
(
range
.
vn
);
}
}
...
...
@@ -217,8 +217,8 @@ void memory_coloring_impl::rewrite()
}
}
}
MIGRAPH_DEBUG
(
dump
(
"---After rewrite---"
));
MIGRAPH_DEBUG
(
dump_program
());
MIGRAPH
X
_DEBUG
(
dump
(
"---After rewrite---"
));
MIGRAPH
X
_DEBUG
(
dump_program
());
}
void
memory_coloring_impl
::
verify
()
...
...
@@ -232,9 +232,8 @@ void memory_coloring_impl::verify()
if
(
segment
.
begin
==
invalid_offset
)
{
// TODO: This check breaks on the tests
// if(!interval.is_live_on_entry)
// MIGRAPH_THROW("interval is not live on entry");
if
(
!
interval
.
is_live_on_entry
)
MIGRAPHX_THROW
(
"interval is not live on entry"
);
continue
;
}
...
...
@@ -252,14 +251,14 @@ void memory_coloring_impl::verify()
if
(
range
->
offset
==
invalid_offset
)
continue
;
if
(
!
is_disjoin
(
*
range
,
segment
))
MIGRAPH_THROW
(
"range and segment is not disjoined"
);
MIGRAPH
X
_THROW
(
"range and segment is not disjoined"
);
}
}
}
}
}
#ifdef MIGRAPH_DEBUG_OPT
#ifdef MIGRAPH
X
_DEBUG_OPT
void
memory_coloring_impl
::
dump
(
const
std
::
string
&
str
)
{
std
::
cout
<<
str
<<
std
::
endl
;
}
...
...
@@ -333,5 +332,5 @@ void live_interval::dump()
#endif
}
// namespace MIGRAPH_INLINE_NS
}
// namespace migraph
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraph
x
src/opt/memory_coloring_impl.hpp
View file @
bc5d7f75
#ifndef MIGRAPH_GUARD_RTGLIB_MEMORY_COLORING_IMPL_HPP
#define MIGRAPH_GUARD_RTGLIB_MEMORY_COLORING_IMPL_HPP
#ifndef MIGRAPH
X
_GUARD_RTGLIB_MEMORY_COLORING_IMPL_HPP
#define MIGRAPH
X
_GUARD_RTGLIB_MEMORY_COLORING_IMPL_HPP
#include "common_header.hpp"
#include <migraph/config.hpp>
#include <migraph
x
/config.hpp>
namespace
migraph
{
inline
namespace
MIGRAPH_INLINE_NS
{
namespace
migraph
x
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
static
const
int
invalid_offset
=
-
1
;
...
...
@@ -15,7 +15,7 @@ struct live_range
long
long
offset
;
// offset to base pointer of allocated memory trunk.
int
vn
;
// value number that identifies this live_range.
long
long
size
;
// size of required memory in bytes
#ifdef MIGRAPH_DEBUG_OPT
#ifdef MIGRAPH
X
_DEBUG_OPT
void
dump
();
#endif
};
...
...
@@ -35,7 +35,7 @@ struct live_interval
int
get_end
()
const
{
return
segment
.
end
;
}
long
long
get_offset
()
const
{
return
segment
.
offset
;
}
#ifdef MIGRAPH_DEBUG_OPT
#ifdef MIGRAPH
X
_DEBUG_OPT
void
dump
();
#endif
...
...
@@ -84,7 +84,7 @@ struct memory_coloring_impl
{
return
is_param
(
ins
)
&&
any_cast
<
builtin
::
param
>
(
ins
->
get_operator
()).
parameter
==
"output"
;
}
bool
is_allocate
(
const
instruction_ref
ins
)
{
return
ins
->
name
()
==
allocation_op
;
}
bool
is_allocate
(
const
instruction_ref
ins
)
const
{
return
ins
->
name
()
==
allocation_op
;
}
static
bool
is_outline
(
const
instruction_ref
ins
)
{
return
ins
->
name
()
==
"@outline"
;
}
static
bool
is_literal
(
const
instruction_ref
ins
)
{
return
ins
->
name
()
==
"@literal"
;
}
static
bool
is_check_context
(
const
instruction_ref
ins
)
...
...
@@ -101,7 +101,7 @@ struct memory_coloring_impl
return
((
end1
<
range2
.
offset
)
||
(
end2
<
range1
.
offset
));
}
void
verify
();
#ifdef MIGRAPH_DEBUG_OPT
#ifdef MIGRAPH
X
_DEBUG_OPT
void
dump
(
const
std
::
string
&
);
void
dump_program
();
void
dump_intervals
();
...
...
@@ -154,6 +154,6 @@ struct memory_coloring_impl
bool
enable_verify
;
};
}
// namespace MIGRAPH_INLINE_NS
}
// namespace migraph
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraph
x
#endif
src/program.cpp
View file @
bc5d7f75
#include <migraph/program.hpp>
#include <migraph/stringutils.hpp>
#include <migraph/instruction.hpp>
#include <migraph/env.hpp>
#include <migraph/ranges.hpp>
#include <migraph/time.hpp>
#include <migraph/iterator_for.hpp>
#include <migraphx/program.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/env.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/time.hpp>
#include <migraphx/iterator_for.hpp>
#include <iostream>
#include <sstream>
#include <algorithm>
#include <utility>
namespace
migraph
{
inline
namespace
MIGRAPH_INLINE_NS
{
namespace
migraph
x
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
MIGRAPH_DECLARE_ENV_VAR
(
MIGRAPH_TRACE_COMPILE
)
MIGRAPH_DECLARE_ENV_VAR
(
MIGRAPH_TRACE_EVAL
)
MIGRAPH
X
_DECLARE_ENV_VAR
(
MIGRAPH
X
_TRACE_COMPILE
)
MIGRAPH
X
_DECLARE_ENV_VAR
(
MIGRAPH
X
_TRACE_EVAL
)
struct
program_impl
{
...
...
@@ -134,6 +135,12 @@ instruction_ref program::replace_instruction(instruction_ref ins, instruction_re
assert
(
has_instruction
(
ins
));
assert
(
has_instruction
(
rep
));
assert
(
ins
!=
rep
);
if
(
ins
==
std
::
prev
(
this
->
end
()))
{
return
replace_instruction
(
ins
,
op
::
identity
{},
rep
);
}
// TODO: Should it be an error if the output is empty?
if
(
ins
->
outputs
().
empty
())
{
...
...
@@ -271,6 +278,8 @@ instruction_ref program::end() const { return impl->instructions.end(); }
shape
program
::
get_shape
()
const
{
return
impl
->
instructions
.
back
().
get_shape
();
}
context
&
program
::
get_context
()
const
{
return
impl
->
ctx
;
}
instruction_ref
program
::
validate
()
const
{
return
std
::
find_if
(
impl
->
instructions
.
begin
(),
...
...
@@ -282,7 +291,7 @@ void program::compile(const target& t, tracer trace)
{
assert
(
this
->
validate
()
==
impl
->
instructions
.
end
());
this
->
impl
->
ctx
=
t
.
get_context
();
if
(
enabled
(
MIGRAPH_TRACE_COMPILE
{}))
if
(
enabled
(
MIGRAPH
X
_TRACE_COMPILE
{}))
trace
=
tracer
{
std
::
cout
};
trace
(
*
this
);
trace
();
...
...
@@ -297,8 +306,8 @@ void program::compile(const target& t, tracer trace)
if
(
invalid
!=
impl
->
instructions
.
end
())
{
auto
index
=
std
::
distance
(
impl
->
instructions
.
begin
(),
invalid
);
MIGRAPH_THROW
(
p
.
name
()
+
" pass produces invalid program at instruction "
+
std
::
to_string
(
index
)
+
": "
+
invalid
->
name
());
MIGRAPH
X
_THROW
(
p
.
name
()
+
" pass produces invalid program at instruction "
+
std
::
to_string
(
index
)
+
": "
+
invalid
->
name
());
}
trace
();
#endif
...
...
@@ -307,7 +316,16 @@ void program::compile(const target& t, tracer trace)
if
(
invalid
!=
impl
->
instructions
.
end
())
{
auto
index
=
std
::
distance
(
impl
->
instructions
.
begin
(),
invalid
);
MIGRAPH_THROW
(
"Invalid program from compilation at instruction "
+
std
::
to_string
(
index
));
MIGRAPHX_THROW
(
"Invalid program from compilation at instruction "
+
std
::
to_string
(
index
));
}
this
->
finalize
();
}
void
program
::
finalize
()
{
for
(
auto
ins
:
iterator_for
(
*
this
))
{
ins
->
finalize
(
this
->
impl
->
ctx
);
}
}
...
...
@@ -334,7 +352,7 @@ argument generic_eval(const program& p,
auto
param_name
=
any_cast
<
builtin
::
param
>
(
ins
->
get_operator
()).
parameter
;
if
(
not
contains
(
params
,
param_name
))
MIGRAPH_THROW
(
"Parameter not found: "
+
param_name
);
MIGRAPH
X
_THROW
(
"Parameter not found: "
+
param_name
);
return
params
.
at
(
param_name
);
}));
}
...
...
@@ -361,20 +379,31 @@ argument generic_eval(const program& p,
argument
program
::
eval
(
std
::
unordered_map
<
std
::
string
,
argument
>
params
)
const
{
if
(
enabled
(
MIGRAPH_TRACE_EVAL
{}))
auto
&
ctx
=
this
->
impl
->
ctx
;
#ifndef NDEBUG
auto
sctx
=
ctx
;
auto
check_context
=
[
&
](
auto
f
)
{
assert
(
is_shared
(
ctx
,
sctx
));
auto
x
=
f
();
sctx
=
ctx
;
return
x
;
};
#else
auto
check_context
=
[](
auto
f
)
{
return
f
();
};
#endif
if
(
enabled
(
MIGRAPHX_TRACE_EVAL
{}))
{
auto
&
ctx
=
this
->
impl
->
ctx
;
return
generic_eval
(
*
this
,
this
->
impl
->
ctx
,
std
::
move
(
params
),
[
&
](
auto
&
ins
,
auto
f
)
{
return
generic_eval
(
*
this
,
ctx
,
std
::
move
(
params
),
[
&
](
auto
&
ins
,
auto
f
)
{
ctx
.
finish
();
std
::
cout
<<
"Run instruction: "
;
this
->
debug_print
(
ins
);
return
f
(
);
return
check_context
(
f
);
});
}
else
{
return
generic_eval
(
*
this
,
this
->
impl
->
ctx
,
std
::
move
(
params
),
[](
auto
&
,
auto
f
)
{
return
f
(
);
});
*
this
,
ctx
,
std
::
move
(
params
),
[
&
](
auto
&
,
auto
f
)
{
return
check_context
(
f
);
});
}
}
...
...
@@ -428,8 +457,7 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params)
overhead_vec
.
reserve
(
n
);
for
(
std
::
size_t
i
=
0
;
i
<
n
;
i
++
)
{
overhead_vec
.
push_back
(
time
<
milliseconds
>
(
[
&
]
{
generic_eval
(
*
this
,
ctx
,
params
,
[](
auto
...)
{
return
argument
{};
});
}));
overhead_vec
.
push_back
(
time
<
milliseconds
>
([
&
]
{
dry_run
(
params
);
}));
}
double
total_time
=
common_average
(
total_vec
);
...
...
@@ -493,6 +521,12 @@ void program::debug_print(const std::vector<instruction_ref>& inss) const
std
::
cout
<<
std
::
endl
;
}
void
program
::
dry_run
(
std
::
unordered_map
<
std
::
string
,
argument
>
params
)
const
{
auto
&
ctx
=
this
->
impl
->
ctx
;
generic_eval
(
*
this
,
ctx
,
std
::
move
(
params
),
[](
auto
&&
...)
{
return
argument
{};
});
}
bool
operator
==
(
const
program
&
x
,
const
program
&
y
)
{
return
to_string
(
x
)
==
to_string
(
y
);
}
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
program
&
p
)
...
...
@@ -501,5 +535,5 @@ std::ostream& operator<<(std::ostream& os, const program& p)
return
os
;
}
}
// namespace MIGRAPH_INLINE_NS
}
// namespace migraph
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraph
x
src/py/CMakeLists.txt
0 → 100644
View file @
bc5d7f75
option
(
MIGRAPHX_ENABLE_PYTHON
"Enable python bindings"
ON
)
if
(
MIGRAPHX_ENABLE_PYTHON
)
find_program
(
DEFAULT_PYTHON_EXE python
)
if
(
DEFAULT_PYTHON_EXE
)
set
(
PYTHON_EXECUTABLE
${
DEFAULT_PYTHON_EXE
}
CACHE PATH
"Path to python executable"
)
endif
()
find_package
(
pybind11 REQUIRED
)
pybind11_add_module
(
migraphx_py migraphx_py.cpp
)
set_target_properties
(
migraphx_py PROPERTIES
OUTPUT_NAME migraphx
C_VISIBILITY_PRESET hidden
CXX_VISIBILITY_PRESET hidden
)
target_link_libraries
(
migraphx_py PRIVATE migraphx migraphx_onnx migraphx_cpu
)
if
(
MIGRAPHX_ENABLE_GPU
)
target_link_libraries
(
migraphx_py PRIVATE migraphx_gpu
)
target_compile_definitions
(
migraphx_py PRIVATE -DHAVE_GPU
)
endif
()
endif
()
src/py/migraphx_py.cpp
0 → 100644
View file @
bc5d7f75
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/cpu/target.hpp>
#include <migraphx/onnx.hpp>
#include <migraphx/stringutils.hpp>
#ifdef HAVE_GPU
#include <migraphx/gpu/target.hpp>
#include <migraphx/gpu/hip.hpp>
#endif
namespace
py
=
pybind11
;
template
<
class
F
>
struct
throw_half
{
F
f
;
template
<
class
A
>
void
operator
()(
A
a
)
const
{
f
(
a
);
}
void
operator
()(
migraphx
::
shape
::
as
<
migraphx
::
half
>
)
const
{
throw
std
::
runtime_error
(
"Half not supported in python yet."
);
}
};
template
<
class
F
>
struct
skip_half
{
F
f
;
template
<
class
A
>
void
operator
()(
A
a
)
const
{
f
(
a
);
}
void
operator
()(
migraphx
::
shape
::
as
<
migraphx
::
half
>
)
const
{}
};
template
<
class
F
>
void
visit_type
(
const
migraphx
::
shape
&
s
,
F
f
)
{
s
.
visit_type
(
throw_half
<
F
>
{
f
});
}
template
<
class
F
>
void
visit_types
(
F
f
)
{
migraphx
::
shape
::
visit_types
(
skip_half
<
F
>
{
f
});
}
template
<
class
T
>
py
::
buffer_info
to_buffer_info
(
T
&
x
)
{
migraphx
::
shape
s
=
x
.
get_shape
();
py
::
buffer_info
b
;
visit_type
(
s
,
[
&
](
auto
as
)
{
b
=
py
::
buffer_info
(
x
.
data
(),
as
.
size
(),
py
::
format_descriptor
<
decltype
(
as
())
>::
format
(),
s
.
lens
().
size
(),
s
.
lens
(),
s
.
strides
());
});
return
b
;
}
migraphx
::
shape
to_shape
(
const
py
::
buffer_info
&
info
)
{
migraphx
::
shape
::
type_t
t
;
visit_types
([
&
](
auto
as
)
{
if
(
info
.
format
==
py
::
format_descriptor
<
decltype
(
as
())
>::
format
())
t
=
as
.
type_enum
();
});
return
migraphx
::
shape
{
t
,
info
.
shape
,
info
.
strides
};
}
PYBIND11_MODULE
(
migraphx
,
m
)
{
py
::
class_
<
migraphx
::
shape
>
(
m
,
"shape"
)
.
def
(
py
::
init
<>
())
.
def
(
"type"
,
&
migraphx
::
shape
::
type
)
.
def
(
"lens"
,
&
migraphx
::
shape
::
lens
)
.
def
(
"strides"
,
&
migraphx
::
shape
::
strides
)
.
def
(
"elements"
,
&
migraphx
::
shape
::
elements
)
.
def
(
"bytes"
,
&
migraphx
::
shape
::
bytes
)
.
def
(
"type_size"
,
&
migraphx
::
shape
::
type_size
)
.
def
(
"packed"
,
&
migraphx
::
shape
::
packed
)
.
def
(
"transposed"
,
&
migraphx
::
shape
::
transposed
)
.
def
(
"broadcasted"
,
&
migraphx
::
shape
::
broadcasted
)
.
def
(
"standard"
,
&
migraphx
::
shape
::
standard
)
.
def
(
"scalar"
,
&
migraphx
::
shape
::
scalar
)
.
def
(
"__eq__"
,
std
::
equal_to
<
migraphx
::
shape
>
{})
.
def
(
"__ne__"
,
std
::
not_equal_to
<
migraphx
::
shape
>
{})
.
def
(
"__repr__"
,
[](
const
migraphx
::
shape
&
s
)
{
return
migraphx
::
to_string
(
s
);
});
py
::
class_
<
migraphx
::
argument
>
(
m
,
"argument"
,
py
::
buffer_protocol
())
.
def_buffer
([](
migraphx
::
argument
&
x
)
->
py
::
buffer_info
{
return
to_buffer_info
(
x
);
})
.
def
(
"__init__"
,
[](
migraphx
::
argument
&
x
,
py
::
buffer
b
)
{
py
::
buffer_info
info
=
b
.
request
();
new
(
&
x
)
migraphx
::
argument
(
to_shape
(
info
),
info
.
ptr
);
})
.
def
(
"__eq__"
,
std
::
equal_to
<
migraphx
::
argument
>
{})
.
def
(
"__ne__"
,
std
::
not_equal_to
<
migraphx
::
argument
>
{})
.
def
(
"__repr__"
,
[](
const
migraphx
::
argument
&
x
)
{
return
migraphx
::
to_string
(
x
);
});
py
::
class_
<
migraphx
::
target
>
(
m
,
"target"
);
py
::
class_
<
migraphx
::
program
>
(
m
,
"program"
)
.
def
(
"get_parameter_shapes"
,
&
migraphx
::
program
::
get_parameter_shapes
)
.
def
(
"get_shape"
,
&
migraphx
::
program
::
get_shape
)
.
def
(
"compile"
,
[](
migraphx
::
program
&
p
,
const
migraphx
::
target
&
t
)
{
p
.
compile
(
t
);
})
.
def
(
"run"
,
&
migraphx
::
program
::
eval
)
.
def
(
"__eq__"
,
std
::
equal_to
<
migraphx
::
program
>
{})
.
def
(
"__ne__"
,
std
::
not_equal_to
<
migraphx
::
program
>
{})
.
def
(
"__repr__"
,
[](
const
migraphx
::
program
&
p
)
{
return
migraphx
::
to_string
(
p
);
});
m
.
def
(
"parse_onnx"
,
&
migraphx
::
parse_onnx
);
m
.
def
(
"get_target"
,
[](
const
std
::
string
&
name
)
->
migraphx
::
target
{
if
(
name
==
"cpu"
)
return
migraphx
::
cpu
::
target
{};
#ifdef HAVE_GPU
if
(
name
==
"gpu"
)
return
migraphx
::
gpu
::
target
{};
#endif
throw
std
::
runtime_error
(
"Target not found: "
+
name
);
});
m
.
def
(
"generate_argument"
,
&
migraphx
::
generate_argument
,
py
::
arg
(
"s"
),
py
::
arg
(
"seed"
)
=
0
);
#ifdef HAVE_GPU
m
.
def
(
"allocate_gpu"
,
&
migraphx
::
gpu
::
allocate_gpu
,
py
::
arg
(
"s"
),
py
::
arg
(
"host"
)
=
false
);
m
.
def
(
"to_gpu"
,
&
migraphx
::
gpu
::
to_gpu
,
py
::
arg
(
"arg"
),
py
::
arg
(
"host"
)
=
false
);
m
.
def
(
"from_gpu"
,
&
migraphx
::
gpu
::
from_gpu
);
m
.
def
(
"gpu_sync"
,
&
migraphx
::
gpu
::
gpu_sync
);
m
.
def
(
"copy_to_gpu"
,
&
migraphx
::
gpu
::
copy_to_gpu
);
#endif
#ifdef VERSION_INFO
m
.
attr
(
"__version__"
)
=
VERSION_INFO
;
#else
m
.
attr
(
"__version__"
)
=
"dev"
;
#endif
}
src/rewrite_rnn.cpp
0 → 100644
View file @
bc5d7f75
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/dfor.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
void
rewrite_rnn
::
apply
(
program
&
prog
)
const
{
for
(
auto
ins
:
iterator_for
(
prog
))
{
if
(
ins
->
name
()
==
"rnn"
)
{
apply_vanilla_rnn
(
prog
,
ins
);
}
if
(
ins
->
name
()
==
"gru"
)
{
apply_gru
(
prog
,
ins
);
}
}
}
void
rewrite_rnn
::
apply_vanilla_rnn
(
program
&
prog
,
instruction_ref
ins
)
const
{
assert
(
ins
->
name
()
==
"rnn"
);
// could be 3 to 6 inputs, but the parse_rnn function will
// append undefined operators to make 6 arguments when parsing
// an onnx file. Another case is user can have num of arguments
// when writing their program.
auto
args
=
ins
->
inputs
();
shape
seq_shape
=
args
[
0
]
->
get_shape
();
std
::
size_t
hidden_size
=
args
[
1
]
->
get_shape
().
lens
()[
1
];
std
::
size_t
batch_size
=
seq_shape
.
lens
()[
1
];
shape
::
type_t
type
=
seq_shape
.
type
();
migraphx
::
shape
ih_shape
{
type
,
{
1
,
batch_size
,
hidden_size
}};
std
::
vector
<
float
>
data
(
ih_shape
.
elements
(),
0
);
auto
actv_funcs
=
vanilla_rnn_actv_funcs
(
ins
);
auto
rnn_op
=
any_cast
<
op
::
rnn
>
(
ins
->
get_operator
());
op
::
rnn_direction
dicrt
=
rnn_op
.
direction
;
instruction_ref
last_output
{};
if
(
dicrt
==
op
::
rnn_direction
::
bidirectional
)
{
// input weight matrix
auto
w_forward
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
1
}},
args
[
1
]);
auto
w_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
1
},
{
2
}},
args
[
1
]);
// hidden state weight matrix
auto
r_forward
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
1
}},
args
[
2
]);
auto
r_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
1
},
{
2
}},
args
[
2
]);
// process bias
instruction_ref
bias_forward
=
prog
.
end
();
instruction_ref
bias_reverse
=
prog
.
end
();
if
(
args
.
size
()
>=
4
&&
args
[
3
]
->
name
()
!=
"undefined"
)
{
bias_forward
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
1
}},
args
[
3
]);
bias_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
1
},
{
2
}},
args
[
3
]);
}
// process intial hidden state, it could be the 6th argument
// or the 5th one (if the sequence len argument is ignored)
instruction_ref
ih_forward
{};
instruction_ref
ih_reverse
{};
if
(
args
.
size
()
==
6
&&
args
[
5
]
->
name
()
!=
"undefined"
)
{
ih_forward
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
1
}},
args
[
5
]);
ih_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
1
},
{
2
}},
args
[
5
]);
}
else
{
ih_forward
=
prog
.
add_literal
(
migraphx
::
literal
{
ih_shape
,
data
});
ih_reverse
=
prog
.
add_literal
(
migraphx
::
literal
{
ih_shape
,
data
});
}
auto
ret_forward
=
vanilla_rnn_cell
(
true
,
prog
,
ins
,
args
[
0
],
w_forward
,
r_forward
,
bias_forward
,
ih_forward
,
actv_funcs
.
at
(
0
));
auto
ret_reverse
=
vanilla_rnn_cell
(
false
,
prog
,
ins
,
args
[
0
],
w_reverse
,
r_reverse
,
bias_reverse
,
ih_reverse
,
actv_funcs
.
at
(
1
));
auto
concat_output
=
prog
.
insert_instruction
(
ins
,
op
::
concat
{
1
},
ret_forward
[
1
],
ret_reverse
[
1
]);
last_output
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
concat_output
);
// The following logic is to ensure the last instruction rewritten from
// rnn operator is a concat instruction
// sequence len is 1
if
(
ret_forward
[
0
]
==
prog
.
end
())
{
prog
.
replace_instruction
(
ins
,
op
::
concat
{
1
},
ret_forward
[
1
],
ret_reverse
[
1
]);
}
else
{
ret_forward
[
0
]
=
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
ret_forward
[
0
],
ret_forward
[
1
]);
ret_reverse
[
0
]
=
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
ret_reverse
[
1
],
ret_reverse
[
0
]);
prog
.
replace_instruction
(
ins
,
op
::
concat
{
1
},
{
ret_forward
[
0
],
ret_reverse
[
0
]});
}
}
else
{
bool
is_forward
=
(
dicrt
==
op
::
rnn_direction
::
forward
);
// input weight matrix
auto
w
=
args
[
1
];
// hidden state weight matrix
auto
r
=
args
[
2
];
// process bias and initial hidden state
instruction_ref
bias
=
prog
.
end
();
if
(
args
.
size
()
>=
4
&&
args
[
3
]
->
name
()
!=
"undefined"
)
{
bias
=
args
[
3
];
}
// process intial hidden state
instruction_ref
ih
;
if
(
args
.
size
()
==
6
&&
args
[
5
]
->
name
()
!=
"undefined"
)
{
ih
=
args
[
5
];
}
else
{
ih
=
prog
.
add_literal
(
migraphx
::
literal
{
ih_shape
,
data
});
}
auto
ret
=
vanilla_rnn_cell
(
is_forward
,
prog
,
ins
,
args
[
0
],
w
,
r
,
bias
,
ih
,
actv_funcs
.
at
(
0
));
last_output
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ret
[
1
]);
// following logic is to ensure the last instruction is a
// concat instruction
// sequence len is 1
if
(
ret
[
0
]
==
prog
.
end
())
{
prog
.
replace_instruction
(
ins
,
op
::
concat
{
0
},
ret
[
1
]);
}
else
{
auto
concat_arg0
=
is_forward
?
ret
[
0
]
:
ret
[
1
];
auto
concat_arg1
=
is_forward
?
ret
[
1
]
:
ret
[
0
];
prog
.
replace_instruction
(
ins
,
op
::
concat
{
0
},
concat_arg0
,
concat_arg1
);
}
}
// search its output to find if there are rnn_last_output operator
// while loop to handle case of multiple rnn_last_output operators
auto
last_output_it
=
ins
->
outputs
().
begin
();
while
(
last_output_it
!=
ins
->
outputs
().
end
())
{
last_output_it
=
std
::
find_if
(
last_output_it
,
ins
->
outputs
().
end
(),
[](
auto
i
)
{
return
i
->
name
()
==
"rnn_last_output"
;
});
if
(
last_output_it
!=
ins
->
outputs
().
end
())
{
prog
.
replace_instruction
(
*
last_output_it
,
last_output
);
last_output_it
++
;
}
}
}
std
::
vector
<
instruction_ref
>
rewrite_rnn
::
vanilla_rnn_cell
(
bool
is_forward
,
program
&
prog
,
instruction_ref
ins
,
instruction_ref
input
,
instruction_ref
w
,
instruction_ref
r
,
instruction_ref
bias
,
instruction_ref
ih
,
operation
&
actv_func
)
const
{
// squeeze and transpose w
std
::
vector
<
int64_t
>
perm
{
1
,
0
};
auto
sw
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
w
);
auto
tran_sw
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
sw
);
// squeeze and transpose r
auto
sr
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
r
);
auto
tran_sr
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
sr
);
// initial hidden state
auto
sih
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ih
);
// bias
if
(
bias
!=
prog
.
end
())
{
long
hs
=
r
->
get_shape
().
lens
()[
2
];
auto
sbias
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
bias
);
auto
wb
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
sbias
);
auto
rb
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
sbias
);
auto
b
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
wb
,
rb
);
bias
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
sih
->
get_shape
()},
b
);
}
instruction_ref
hidden_out
=
prog
.
end
();
instruction_ref
last_out
{};
last_out
=
prog
.
insert_instruction
(
ins
,
op
::
unsqueeze
{{
0
,
1
}},
sih
);
std
::
size_t
seq_len
=
input
->
get_shape
().
lens
()[
0
];
for
(
std
::
size_t
i
=
0
;
i
<
seq_len
;
i
++
)
{
long
seq_index
=
is_forward
?
i
:
(
seq_len
-
1
-
i
);
auto
xt
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
seq_index
},
{
seq_index
+
1
}},
input
);
xt
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
xt
);
auto
xt_wi
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_sw
);
auto
ht_ri
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_sr
);
auto
xt_ht
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wi
,
ht_ri
);
instruction_ref
ht
;
if
(
bias
!=
prog
.
end
())
{
ht
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_ht
,
bias
);
}
else
{
ht
=
xt_ht
;
}
// apply activation function
ht
=
prog
.
insert_instruction
(
ins
,
actv_func
,
ht
);
sih
=
ht
;
// add the dimensions of sequence length (axis 0 for sequence length,
// axis 1 for num_directions
last_out
=
prog
.
insert_instruction
(
ins
,
op
::
unsqueeze
{{
0
,
1
}},
ht
);
// concatenation for the last last_out is performed in the apply()
// function to ensure the last instruction is concat, then we have
// output inserted
if
(
i
<
seq_len
-
1
)
{
if
(
is_forward
)
{
hidden_out
=
(
seq_index
==
0
)
?
last_out
:
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
hidden_out
,
last_out
);
}
else
{
hidden_out
=
(
seq_index
==
seq_len
-
1
)
?
last_out
:
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
last_out
,
hidden_out
);
}
}
}
return
{
hidden_out
,
last_out
};
}
std
::
vector
<
operation
>
rewrite_rnn
::
vanilla_rnn_actv_funcs
(
instruction_ref
ins
)
const
{
auto
rnn_op
=
any_cast
<
op
::
rnn
>
(
ins
->
get_operator
());
// could be 3 to 6 inputs, but the parse_gru function will
// append undefined operators to make 6 arguments when parsing
// an onnx file. Another case is user can have any num of arguments
// when writing their program.
if
(
rnn_op
.
direction
==
op
::
rnn_direction
::
bidirectional
)
{
if
(
rnn_op
.
actv_funcs
.
empty
())
{
// default is tanh
return
{
op
::
tanh
{},
op
::
tanh
{}};
}
else
if
(
rnn_op
.
actv_funcs
.
size
()
==
1
)
{
return
{
rnn_op
.
actv_funcs
.
at
(
0
),
rnn_op
.
actv_funcs
.
at
(
0
)};
}
else
{
return
rnn_op
.
actv_funcs
;
}
}
else
{
if
(
rnn_op
.
actv_funcs
.
empty
())
{
// default is tanh
return
{
op
::
tanh
{}};
}
else
{
return
rnn_op
.
actv_funcs
;
}
}
}
void
rewrite_rnn
::
apply_gru
(
program
&
prog
,
instruction_ref
ins
)
const
{
assert
(
ins
->
name
()
==
"gru"
);
const
auto
actv_funcs
=
gru_actv_funcs
(
ins
);
// could be 3 to 6 inputs, but the parse_gru function will
// append undefined operators to make 6 arguments when parsing
// an onnx file. Another case is user can have num of arguments
// when writing their program.
auto
args
=
ins
->
inputs
();
shape
seq_shape
=
args
[
0
]
->
get_shape
();
std
::
size_t
hidden_size
=
args
[
2
]
->
get_shape
().
lens
()[
2
];
std
::
size_t
batch_size
=
seq_shape
.
lens
()[
1
];
shape
::
type_t
type
=
seq_shape
.
type
();
migraphx
::
shape
ih_shape
{
type
,
{
1
,
batch_size
,
hidden_size
}};
std
::
vector
<
float
>
data
(
ih_shape
.
elements
(),
0.0
);
auto
gru_op
=
any_cast
<
op
::
gru
>
(
ins
->
get_operator
());
op
::
rnn_direction
dicrt
=
gru_op
.
direction
;
instruction_ref
last_output
{};
if
(
dicrt
==
op
::
rnn_direction
::
bidirectional
)
{
// w weight matrix
auto
w_forward
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
1
}},
args
[
1
]);
auto
w_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
1
},
{
2
}},
args
[
1
]);
// r weight matrix
auto
r_forward
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
1
}},
args
[
2
]);
auto
r_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
1
},
{
2
}},
args
[
2
]);
// bias
instruction_ref
bias_forward
=
prog
.
end
();
instruction_ref
bias_reverse
=
prog
.
end
();
if
(
args
.
size
()
>=
4
&&
args
[
3
]
->
name
()
!=
"undefined"
)
{
bias_forward
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
1
}},
args
[
3
]);
bias_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
1
},
{
2
}},
args
[
3
]);
}
// intial hidden state
instruction_ref
ih_forward
{};
instruction_ref
ih_reverse
{};
if
(
args
.
size
()
==
6
&&
args
[
5
]
->
name
()
!=
"undefined"
)
{
ih_forward
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
1
}},
args
[
5
]);
ih_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
1
},
{
2
}},
args
[
5
]);
}
else
{
ih_forward
=
prog
.
add_literal
(
migraphx
::
literal
{
ih_shape
,
data
});
ih_reverse
=
prog
.
add_literal
(
migraphx
::
literal
{
ih_shape
,
data
});
}
auto
ret_forward
=
gru_cell
(
true
,
prog
,
ins
,
{
args
[
0
],
w_forward
,
r_forward
,
bias_forward
,
ih_forward
},
gru_op
.
linear_before_reset
,
actv_funcs
.
at
(
0
),
actv_funcs
.
at
(
1
));
auto
ret_reverse
=
gru_cell
(
false
,
prog
,
ins
,
{
args
[
0
],
w_reverse
,
r_reverse
,
bias_reverse
,
ih_reverse
},
gru_op
.
linear_before_reset
,
actv_funcs
.
at
(
2
),
actv_funcs
.
at
(
3
));
auto
concat_output
=
prog
.
insert_instruction
(
ins
,
op
::
concat
{
1
},
ret_forward
[
1
],
ret_reverse
[
1
]);
last_output
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
concat_output
);
// The following logic is to ensure the last instruction rewritten
// from gru operator is a concat
if
(
ret_forward
[
0
]
==
prog
.
end
())
{
prog
.
replace_instruction
(
ins
,
op
::
concat
{
1
},
ret_forward
[
1
],
ret_reverse
[
1
]);
}
else
{
ret_forward
[
0
]
=
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
ret_forward
[
0
],
ret_forward
[
1
]);
ret_reverse
[
0
]
=
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
ret_reverse
[
1
],
ret_reverse
[
0
]);
prog
.
replace_instruction
(
ins
,
op
::
concat
{
1
},
{
ret_forward
[
0
],
ret_reverse
[
0
]});
}
}
else
{
bool
is_forward
=
(
dicrt
==
op
::
rnn_direction
::
forward
);
// weight matrix
auto
w
=
args
[
1
];
auto
r
=
args
[
2
];
// bias
instruction_ref
bias
=
prog
.
end
();
if
(
args
.
size
()
>=
4
&&
args
[
3
]
->
name
()
!=
"undefined"
)
{
bias
=
args
[
3
];
}
// intial hidden state
instruction_ref
ih
{};
if
(
args
.
size
()
==
6
&&
args
[
5
]
->
name
()
!=
"undefined"
)
{
ih
=
args
[
5
];
}
else
{
ih
=
prog
.
add_literal
(
migraphx
::
literal
{
ih_shape
,
data
});
}
auto
ret
=
gru_cell
(
is_forward
,
prog
,
ins
,
{
args
[
0
],
w
,
r
,
bias
,
ih
},
gru_op
.
linear_before_reset
,
actv_funcs
.
at
(
0
),
actv_funcs
.
at
(
1
));
last_output
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ret
[
1
]);
if
(
ret
[
0
]
==
prog
.
end
())
{
prog
.
replace_instruction
(
ins
,
op
::
concat
{
0
},
ret
[
1
]);
}
else
{
auto
concat_arg0
=
is_forward
?
ret
[
0
]
:
ret
[
1
];
auto
concat_arg1
=
is_forward
?
ret
[
1
]
:
ret
[
0
];
prog
.
replace_instruction
(
ins
,
op
::
concat
{
0
},
concat_arg0
,
concat_arg1
);
}
}
// replace the corresponding rnn_last_output instruction
// with the last_output, if rnn_last_output exists
// while loop to handle case of multiple rnn_last_output operators
auto
last_output_it
=
ins
->
outputs
().
begin
();
while
(
last_output_it
!=
ins
->
outputs
().
end
())
{
last_output_it
=
std
::
find_if
(
last_output_it
,
ins
->
outputs
().
end
(),
[](
auto
i
)
{
return
i
->
name
()
==
"rnn_last_output"
;
});
if
(
last_output_it
!=
ins
->
outputs
().
end
())
{
prog
.
replace_instruction
(
*
last_output_it
,
last_output
);
last_output_it
++
;
}
}
}
std
::
vector
<
instruction_ref
>
rewrite_rnn
::
gru_cell
(
bool
is_forward
,
program
&
prog
,
instruction_ref
ins
,
std
::
vector
<
instruction_ref
>
inputs
,
int
linear_before_reset
,
const
operation
&
actv_func1
,
const
operation
&
actv_func2
)
const
{
assert
(
inputs
.
size
()
==
5
);
auto
seq
=
inputs
.
at
(
0
);
auto
w
=
inputs
.
at
(
1
);
auto
r
=
inputs
.
at
(
2
);
auto
bias
=
inputs
.
at
(
3
);
auto
ih
=
inputs
.
at
(
4
);
instruction_ref
hidden_states
=
prog
.
end
();
instruction_ref
last_output
{};
migraphx
::
shape
seq_shape
=
seq
->
get_shape
();
migraphx
::
shape
r_shape
=
r
->
get_shape
();
long
seq_len
=
static_cast
<
long
>
(
seq_shape
.
lens
()[
0
]);
long
hs
=
static_cast
<
long
>
(
r_shape
.
lens
()[
2
]);
migraphx
::
shape
s
(
seq_shape
.
type
(),
{
seq_shape
.
lens
()[
1
],
r_shape
.
lens
()[
2
]});
std
::
vector
<
int
>
data
(
s
.
elements
(),
1
);
auto
l1
=
prog
.
add_literal
(
migraphx
::
literal
{
s
,
data
});
// weight matrix
std
::
vector
<
int64_t
>
perm
{
1
,
0
};
auto
sw
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
w
);
auto
wz
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
sw
);
auto
tran_wz
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
wz
);
auto
wr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
sw
);
auto
tran_wr
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
wr
);
auto
wh
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
2
*
hs
},
{
3
*
hs
}},
sw
);
auto
tran_wh
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
wh
);
auto
sr
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
r
);
auto
rz
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
sr
);
auto
tran_rz
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
rz
);
auto
rr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
sr
);
auto
tran_rr
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
rr
);
auto
rh
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
2
*
hs
},
{
3
*
hs
}},
sr
);
auto
tran_rh
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
rh
);
// initial states
auto
sih
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ih
);
// bias
instruction_ref
brcst_bz
{};
instruction_ref
brcst_br
{};
instruction_ref
brcst_wbh
{};
instruction_ref
brcst_rbh
{};
instruction_ref
brcst_bh
{};
if
(
bias
!=
prog
.
end
())
{
auto
sbias
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
bias
);
auto
wbz
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
sbias
);
auto
wbr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
sbias
);
auto
wbh
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
2
*
hs
},
{
3
*
hs
}},
sbias
);
brcst_wbh
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
sih
->
get_shape
()},
wbh
);
auto
rbz
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
3
*
hs
},
{
4
*
hs
}},
sbias
);
auto
rbr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
4
*
hs
},
{
5
*
hs
}},
sbias
);
auto
rbh
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
5
*
hs
},
{
6
*
hs
}},
sbias
);
brcst_rbh
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
sih
->
get_shape
()},
rbh
);
auto
bz
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
wbz
,
rbz
);
brcst_bz
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
sih
->
get_shape
()},
bz
);
auto
br
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
wbr
,
rbr
);
brcst_br
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
sih
->
get_shape
()},
br
);
auto
bh
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
wbh
,
rbh
);
brcst_bh
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
sih
->
get_shape
()},
bh
);
}
for
(
long
i
=
0
;
i
<
seq_len
;
i
++
)
{
long
seq_index
=
is_forward
?
i
:
(
seq_len
-
1
-
i
);
auto
xt
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
seq_index
},
{
seq_index
+
1
}},
seq
);
xt
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
xt
);
// equation f(xt*(Wz^T) + Ht-1 * (Rz^T) + Wbz + Rbz)
auto
xt_wz
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wz
);
auto
ht_rz
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_rz
);
auto
xht_z
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wz
,
ht_rz
);
if
(
bias
!=
prog
.
end
())
{
xht_z
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xht_z
,
brcst_bz
);
}
auto
zt
=
prog
.
insert_instruction
(
ins
,
actv_func1
,
xht_z
);
// equation f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr)
auto
xt_wr
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wr
);
auto
ht_rr
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_rr
);
auto
xht_r
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wr
,
ht_rr
);
if
(
bias
!=
prog
.
end
())
{
xht_r
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xht_r
,
brcst_br
);
}
auto
rt
=
prog
.
insert_instruction
(
ins
,
actv_func1
,
xht_r
);
instruction_ref
xht_h
;
if
(
linear_before_reset
==
0
)
{
// equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh)
auto
xt_wh
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wh
);
auto
rt_ht1
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
rt
,
sih
);
auto
rt_rh
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
rt_ht1
,
tran_rh
);
xht_h
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wh
,
rt_rh
);
if
(
bias
!=
prog
.
end
())
{
xht_h
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xht_h
,
brcst_bh
);
}
}
else
{
// equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh)
auto
xt_wh
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wh
);
auto
ht1_rh
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_rh
);
if
(
bias
!=
prog
.
end
())
{
ht1_rh
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
ht1_rh
,
brcst_rbh
);
}
auto
rt_rh
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
rt
,
ht1_rh
);
xht_h
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wh
,
rt_rh
);
if
(
bias
!=
prog
.
end
())
{
xht_h
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xht_h
,
brcst_wbh
);
}
}
auto
ht
=
prog
.
insert_instruction
(
ins
,
actv_func2
,
xht_h
);
// equation Ht = (1 - zt) (.) ht + zt (.) Ht-1
auto
one_minus_zt
=
prog
.
insert_instruction
(
ins
,
op
::
sub
{},
l1
,
zt
);
auto
one_minus_zt_ht
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
one_minus_zt
,
ht
);
auto
zt_ht1
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
zt
,
sih
);
sih
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
one_minus_zt_ht
,
zt_ht1
);
last_output
=
prog
.
insert_instruction
(
ins
,
op
::
unsqueeze
{{
0
,
1
}},
sih
);
if
(
i
<
seq_len
-
1
)
{
if
(
is_forward
)
{
hidden_states
=
(
seq_index
==
0
)
?
last_output
:
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
hidden_states
,
last_output
);
}
else
{
hidden_states
=
(
seq_index
==
seq_len
-
1
)
?
last_output
:
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
last_output
,
hidden_states
);
}
}
}
return
{
hidden_states
,
last_output
};
}
std
::
vector
<
operation
>
rewrite_rnn
::
gru_actv_funcs
(
instruction_ref
ins
)
const
{
auto
gru_op
=
any_cast
<
op
::
gru
>
(
ins
->
get_operator
());
// before rewrite the gru operator, need to ensure
// we have 4 actv funcs, even though a user does not
// specifiy any actv func. If less than 4, use the
// algorithm in parse_gru to make 4 actv functions
if
(
gru_op
.
direction
==
op
::
rnn_direction
::
bidirectional
)
{
if
(
gru_op
.
actv_funcs
.
empty
())
return
{
op
::
sigmoid
{},
op
::
tanh
{},
op
::
sigmoid
{},
op
::
tanh
{}};
else
if
(
gru_op
.
actv_funcs
.
size
()
==
1
)
return
{
gru_op
.
actv_funcs
.
at
(
0
),
gru_op
.
actv_funcs
.
at
(
0
),
gru_op
.
actv_funcs
.
at
(
0
),
gru_op
.
actv_funcs
.
at
(
0
)};
else
if
(
gru_op
.
actv_funcs
.
size
()
==
2
)
return
{
gru_op
.
actv_funcs
.
at
(
0
),
gru_op
.
actv_funcs
.
at
(
1
),
gru_op
.
actv_funcs
.
at
(
0
),
gru_op
.
actv_funcs
.
at
(
1
)};
else
if
(
gru_op
.
actv_funcs
.
size
()
==
3
)
return
{
gru_op
.
actv_funcs
.
at
(
0
),
gru_op
.
actv_funcs
.
at
(
1
),
gru_op
.
actv_funcs
.
at
(
2
),
gru_op
.
actv_funcs
.
at
(
0
)};
else
return
gru_op
.
actv_funcs
;
}
else
{
if
(
gru_op
.
actv_funcs
.
empty
())
return
{
op
::
sigmoid
{},
op
::
tanh
{}};
else
if
(
gru_op
.
actv_funcs
.
size
()
==
1
)
return
{
gru_op
.
actv_funcs
.
at
(
0
),
gru_op
.
actv_funcs
.
at
(
0
)};
else
return
gru_op
.
actv_funcs
;
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/shape.cpp
View file @
bc5d7f75
#include <migraph/shape.hpp>
#include <migraph/stringutils.hpp>
#include <migraph
x
/shape.hpp>
#include <migraph
x
/stringutils.hpp>
#include <numeric>
#include <algorithm>
#include <functional>
#include <iostream>
namespace
migraph
{
inline
namespace
MIGRAPH_INLINE_NS
{
namespace
migraph
x
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
struct
shape_impl
{
...
...
@@ -169,12 +169,12 @@ std::string shape::type_string() const
{
switch
(
this
->
type
())
{
#define MIGRAPH_SHAPE_TYPE_STRING_CASE(x, t) \
#define MIGRAPH
X
_SHAPE_
GENERATE_
TYPE_STRING_CASE(x, t) \
case x: return #x;
MIGRAPH_SHAPE_VISIT_TYPES
(
MIGRAPH_SHAPE_TYPE_STRING_CASE
)
#undef MIGRAPH_SHAPE_TYPE_STRING_CASE
MIGRAPH
X
_SHAPE_VISIT_TYPES
(
MIGRAPH
X
_SHAPE_
GENERATE_
TYPE_STRING_CASE
)
#undef MIGRAPH
X
_SHAPE_
GENERATE_
TYPE_STRING_CASE
}
MIGRAPH_THROW
(
"Invalid type"
);
MIGRAPH
X
_THROW
(
"Invalid type"
);
}
bool
operator
==
(
const
shape
&
x
,
const
shape
&
y
)
...
...
@@ -191,5 +191,5 @@ std::ostream& operator<<(std::ostream& os, const shape& x)
return
os
;
}
}
// namespace MIGRAPH_INLINE_NS
}
// namespace migraph
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraph
x
src/simplify_algebra.cpp
View file @
bc5d7f75
#include <migraph/simplify_algebra.hpp>
#include <migraph/program.hpp>
#include <migraph/operators.hpp>
#include <migraph/matcher.hpp>
#include <migraph/literal.hpp>
#include <migraph
x
/simplify_algebra.hpp>
#include <migraph
x
/program.hpp>
#include <migraph
x
/operators.hpp>
#include <migraph
x
/matcher.hpp>
#include <migraph
x
/literal.hpp>
namespace
migraph
{
inline
namespace
MIGRAPH_INLINE_NS
{
namespace
migraph
x
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
struct
find_add_lit_broadcast
{
...
...
@@ -61,5 +61,5 @@ struct find_add_lit_broadcast
void
simplify_algebra
::
apply
(
program
&
p
)
const
{
match
::
find_matches
(
p
,
find_add_lit_broadcast
{});
}
}
// namespace MIGRAPH_INLINE_NS
}
// namespace migraph
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraph
x
src/simplify_reshapes.cpp
View file @
bc5d7f75
#include <migraph/simplify_reshapes.hpp>
#include <migraph/program.hpp>
#include <migraph/instruction.hpp>
#include <migraph/operators.hpp>
#include <migraph/iterator_for.hpp>
#include <migraph/ranges.hpp>
#include <migraph
x
/simplify_reshapes.hpp>
#include <migraph
x
/program.hpp>
#include <migraph
x
/instruction.hpp>
#include <migraph
x
/operators.hpp>
#include <migraph
x
/iterator_for.hpp>
#include <migraph
x
/ranges.hpp>
#include <unordered_set>
namespace
migraph
{
inline
namespace
MIGRAPH_INLINE_NS
{
namespace
migraph
x
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
bool
is_reshaper
(
co
nst
std
::
string
&
name
)
bool
is_reshaper
(
i
nst
ruction_ref
ins
)
{
// clang-format off
static
const
std
::
unordered_set
<
std
::
string
>
names
=
{
"reshape"
,
"transpose"
,
// "broadcast",
"contiguous"
};
// clang-format on
return
contains
(
names
,
name
);
return
contains
(
names
,
ins
->
name
());
}
bool
is_transpose_output
(
instruction_ref
ins
)
{
if
(
ins
->
outputs
().
size
()
!=
1
)
return
false
;
if
(
ins
->
outputs
().
front
()
->
name
()
==
"contiguous"
)
return
is_transpose_output
(
ins
->
outputs
().
front
());
return
ins
->
outputs
().
front
()
->
name
()
==
"transpose"
;
}
instruction_ref
find_transpose_input
(
instruction_ref
ins
)
{
if
(
ins
->
inputs
().
size
()
!=
1
)
return
ins
;
if
(
ins
->
inputs
().
front
()
->
name
()
==
"contiguous"
)
return
find_transpose_input
(
ins
->
inputs
().
front
());
if
(
ins
->
inputs
().
front
()
->
name
()
==
"transpose"
)
return
ins
->
inputs
().
front
();
return
ins
;
}
void
simplify_reshapes
::
apply
(
program
&
p
)
const
{
auto
end
=
std
::
prev
(
p
.
end
());
for
(
auto
ins
:
iterator_for
(
p
))
{
if
(
not
is_reshaper
(
ins
->
name
()))
continue
;
if
(
ins
->
outputs
().
size
()
!=
1
)
continue
;
if
(
is_reshaper
(
ins
->
outputs
().
front
()
->
name
()))
if
(
ins
->
outputs
().
empty
()
and
ins
!=
end
)
continue
;
// Gather reshapes
std
::
vector
<
instruction_ref
>
reshapes
{
ins
};
while
(
is_reshaper
(
reshapes
.
back
()
->
name
()))
if
(
is_reshaper
(
ins
))
{
assert
(
!
reshapes
.
back
()
->
inputs
().
empty
());
assert
(
p
.
has_instruction
(
reshapes
.
back
()
->
inputs
().
front
()));
reshapes
.
push_back
(
reshapes
.
back
()
->
inputs
().
front
());
}
if
(
std
::
any_of
(
ins
->
outputs
().
begin
(),
ins
->
outputs
().
end
(),
&
is_reshaper
))
continue
;
// Gather reshapes
std
::
vector
<
instruction_ref
>
reshapes
{
ins
};
while
(
is_reshaper
(
reshapes
.
back
()))
{
assert
(
!
reshapes
.
back
()
->
inputs
().
empty
());
assert
(
p
.
has_instruction
(
reshapes
.
back
()
->
inputs
().
front
()));
auto
input
=
reshapes
.
back
()
->
inputs
().
front
();
reshapes
.
push_back
(
input
);
}
std
::
pair
<
instruction_ref
,
instruction_ref
>
r
{
p
.
end
(),
p
.
end
()};
for
(
auto
start
:
iterator_for
(
reshapes
))
{
auto
last
=
std
::
find_if
(
reshapes
.
rbegin
(),
reshapes
.
rend
(),
[
&
](
auto
&&
i
)
{
return
i
->
get_shape
()
==
(
*
start
)
->
get_shape
()
and
i
!=
(
*
start
);
});
if
(
last
!=
reshapes
.
rend
())
std
::
pair
<
instruction_ref
,
instruction_ref
>
r
{
p
.
end
(),
p
.
end
()};
for
(
auto
start
:
iterator_for
(
reshapes
))
{
r
=
std
::
make_pair
(
*
start
,
*
last
);
break
;
auto
last
=
std
::
find_if
(
reshapes
.
rbegin
(),
reshapes
.
rend
(),
[
&
](
auto
&&
i
)
{
return
i
->
get_shape
()
==
(
*
start
)
->
get_shape
()
and
i
!=
(
*
start
);
});
if
(
last
!=
reshapes
.
rend
())
{
r
=
std
::
make_pair
(
*
start
,
*
last
);
break
;
}
}
if
(
r
.
first
!=
r
.
second
)
{
p
.
replace_instruction
(
r
.
first
,
r
.
second
);
}
}
if
(
r
.
first
!=
r
.
second
)
else
if
(
ins
->
name
()
==
"transpose"
)
{
p
.
replace_instruction
(
r
.
first
,
r
.
second
);
if
(
is_transpose_output
(
ins
))
continue
;
auto
x
=
ins
;
auto
t
=
ins
;
do
{
x
=
t
;
t
=
find_transpose_input
(
x
);
}
while
(
x
!=
t
and
t
->
name
()
==
"transpose"
);
if
(
t
==
ins
or
t
->
name
()
!=
"transpose"
)
continue
;
p
.
replace_instruction
(
ins
,
t
->
inputs
().
front
());
}
}
// Replace all reshapes with as_shape
for
(
auto
ins
:
iterator_for
(
p
))
{
if
(
ins
->
name
()
!=
"reshape"
)
continue
;
p
.
replace_instruction
(
ins
,
op
::
as_shape
{
ins
->
get_shape
()},
ins
->
inputs
());
}
}
}
// namespace MIGRAPH_INLINE_NS
}
// namespace migraph
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraph
x
src/targets/cpu/CMakeLists.txt
View file @
bc5d7f75
add_library
(
migraph_cpu
add_library
(
migraph
x
_cpu
target.cpp
lowering.cpp
gemm.cpp
)
set_target_properties
(
migraph_cpu PROPERTIES EXPORT_NAME cpu
)
set_target_properties
(
migraph
x
_cpu PROPERTIES EXPORT_NAME cpu
)
find_path
(
BLAZE_INCLUDE blaze/Blaze.h
)
find_package
(
Threads
)
rocm_clang_tidy_check
(
migraph_cpu
)
target_link_libraries
(
migraph_cpu migraph Threads::Threads
)
target_include_directories
(
migraph_cpu PRIVATE
${
BLAZE_INCLUDE
}
)
target_compile_definitions
(
migraph_cpu PRIVATE -DBLAZE_USE_CPP_THREADS
)
rocm_clang_tidy_check
(
migraph
x
_cpu
)
target_link_libraries
(
migraph
x
_cpu migraph
x
Threads::Threads
)
target_include_directories
(
migraph
x
_cpu PRIVATE
${
BLAZE_INCLUDE
}
)
target_compile_definitions
(
migraph
x
_cpu PRIVATE -DBLAZE_USE_CPP_THREADS
)
rocm_install_targets
(
TARGETS migraph_cpu
TARGETS migraph
x
_cpu
INCLUDE
${
CMAKE_CURRENT_SOURCE_DIR
}
/include
)
...
...
src/targets/cpu/gemm.cpp
View file @
bc5d7f75
#include <migraph/cpu/gemm.hpp>
#include <migraph/dfor.hpp>
#include <migraph/requires.hpp>
#include <migraph
x
/cpu/gemm.hpp>
#include <migraph
x
/dfor.hpp>
#include <migraph
x
/requires.hpp>
#include <blaze/math/CustomMatrix.h>
namespace
migraph
{
inline
namespace
MIGRAPH_INLINE_NS
{
namespace
migraph
x
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
namespace
cpu
{
template
<
class
T
>
...
...
@@ -94,5 +94,5 @@ void migemm(
}
}
// namespace cpu
}
// namespace MIGRAPH_INLINE_NS
}
// namespace migraph
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraph
x
src/targets/cpu/include/migraph/cpu/context.hpp
deleted
100644 → 0
View file @
47c0854d
#ifndef MIGRAPH_GUARD_RTGLIB_CONTEXT_HPP
#define MIGRAPH_GUARD_RTGLIB_CONTEXT_HPP
#include <migraph/config.hpp>
namespace
migraph
{
inline
namespace
MIGRAPH_INLINE_NS
{
namespace
cpu
{
struct
context
{
void
finish
()
const
{}
};
}
// namespace cpu
}
// namespace MIGRAPH_INLINE_NS
}
// namespace migraph
#endif
src/targets/cpu/include/migraph/cpu/target.hpp
deleted
100644 → 0
View file @
47c0854d
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_CPU_TARGET_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_CPU_TARGET_HPP
#include <migraph/program.hpp>
#include <migraph/cpu/context.hpp>
#include <migraph/config.hpp>
namespace
migraph
{
inline
namespace
MIGRAPH_INLINE_NS
{
namespace
cpu
{
struct
target
{
std
::
string
name
()
const
;
std
::
vector
<
pass
>
get_passes
(
migraph
::
context
&
ctx
)
const
;
migraph
::
context
get_context
()
const
{
return
context
{};
}
};
}
// namespace cpu
}
// namespace MIGRAPH_INLINE_NS
}
// namespace migraph
#endif
src/targets/cpu/include/migraphx/cpu/context.hpp
0 → 100644
View file @
bc5d7f75
#ifndef MIGRAPHX_GUARD_RTGLIB_CONTEXT_HPP
#define MIGRAPHX_GUARD_RTGLIB_CONTEXT_HPP
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
cpu
{
struct
context
{
void
finish
()
const
{}
};
}
// namespace cpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/targets/cpu/include/migraph/cpu/gemm.hpp
→
src/targets/cpu/include/migraph
x
/cpu/gemm.hpp
View file @
bc5d7f75
#ifndef MIGRAPH_GUARD_RTGLIB_CPU_GEMM_HPP
#define MIGRAPH_GUARD_RTGLIB_CPU_GEMM_HPP
#ifndef MIGRAPH
X
_GUARD_RTGLIB_CPU_GEMM_HPP
#define MIGRAPH
X
_GUARD_RTGLIB_CPU_GEMM_HPP
#include <migraph/argument.hpp>
#include <migraph/config.hpp>
#include <migraph
x
/argument.hpp>
#include <migraph
x
/config.hpp>
namespace
migraph
{
inline
namespace
MIGRAPH_INLINE_NS
{
namespace
migraph
x
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
namespace
cpu
{
void
migemm
(
const
argument
&
c_arg
,
const
argument
&
a_arg
,
const
argument
&
b_arg
,
float
alpha
,
float
beta
);
}
// namespace cpu
}
// namespace MIGRAPH_INLINE_NS
}
// namespace migraph
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraph
x
#endif
Prev
1
2
3
4
5
6
7
8
9
10
…
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment