Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
dd26f1aa
Commit
dd26f1aa
authored
Apr 08, 2019
by
Shucai Xiao
Browse files
Merge branch 'develop' of
https://github.com/ROCmSoftwarePlatform/AMDMIGraphX
into rnn_optimization
parents
4e3d06ab
4a3e493c
Changes
125
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
676 additions
and
1618 deletions
+676
-1618
src/include/migraphx/op/sinh.hpp
src/include/migraphx/op/sinh.hpp
+29
-0
src/include/migraphx/op/slice.hpp
src/include/migraphx/op/slice.hpp
+96
-0
src/include/migraphx/op/softmax.hpp
src/include/migraphx/op/softmax.hpp
+33
-0
src/include/migraphx/op/squeeze.hpp
src/include/migraphx/op/squeeze.hpp
+79
-0
src/include/migraphx/op/sub.hpp
src/include/migraphx/op/sub.hpp
+29
-0
src/include/migraphx/op/tan.hpp
src/include/migraphx/op/tan.hpp
+29
-0
src/include/migraphx/op/tanh.hpp
src/include/migraphx/op/tanh.hpp
+29
-0
src/include/migraphx/op/transpose.hpp
src/include/migraphx/op/transpose.hpp
+67
-0
src/include/migraphx/op/unary.hpp
src/include/migraphx/op/unary.hpp
+32
-0
src/include/migraphx/op/unsqueeze.hpp
src/include/migraphx/op/unsqueeze.hpp
+62
-0
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+57
-1578
src/include/migraphx/ranges.hpp
src/include/migraphx/ranges.hpp
+24
-0
src/include/migraphx/rewrite_rnn.hpp
src/include/migraphx/rewrite_rnn.hpp
+1
-1
src/include/migraphx/schedule.hpp
src/include/migraphx/schedule.hpp
+1
-0
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+102
-35
src/opt/memory_coloring_impl.cpp
src/opt/memory_coloring_impl.cpp
+1
-0
src/opt/memory_coloring_impl.hpp
src/opt/memory_coloring_impl.hpp
+0
-1
src/program.cpp
src/program.cpp
+1
-1
src/schedule.cpp
src/schedule.cpp
+3
-1
src/simplify_algebra.cpp
src/simplify_algebra.cpp
+1
-1
No files found.
src/include/migraphx/op/sinh.hpp
0 → 100644
View file @
dd26f1aa
#ifndef MIGRAPHX_GUARD_OPERATORS_SINH_HPP
#define MIGRAPHX_GUARD_OPERATORS_SINH_HPP
#include <array>
#include <migraphx/op/unary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
sinh
:
unary
{
std
::
string
name
()
const
{
return
"sinh"
;
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/slice.hpp
0 → 100644
View file @
dd26f1aa
#ifndef MIGRAPHX_GUARD_OPERATORS_SLICE_HPP
#define MIGRAPHX_GUARD_OPERATORS_SLICE_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
slice
{
std
::
vector
<
int64_t
>
axes
;
std
::
vector
<
int64_t
>
starts
;
std
::
vector
<
int64_t
>
ends
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
axes
,
"axes"
),
f
(
self
.
starts
,
"starts"
),
f
(
self
.
ends
,
"ends"
));
}
std
::
string
name
()
const
{
return
"slice"
;
}
auto
fix_index
(
const
std
::
vector
<
std
::
size_t
>&
lens
,
std
::
size_t
axis
,
int64_t
index
)
const
{
int64_t
r
=
std
::
min
(
index
,
static_cast
<
int64_t
>
(
lens
[
axis
]));
if
(
r
<
0
)
r
+=
lens
[
axis
];
return
std
::
size_t
(
r
);
}
auto
compute_offset
(
const
shape
&
s
)
const
{
const
std
::
vector
<
std
::
size_t
>&
lens
=
s
.
lens
();
const
std
::
vector
<
std
::
size_t
>&
strides
=
s
.
strides
();
auto
offset
=
0
;
if
(
!
axes
.
empty
())
{
for
(
std
::
size_t
i
=
0
;
i
<
axes
.
size
();
i
++
)
{
auto
axis
=
axes
[
i
];
offset
+=
fix_index
(
lens
,
axis
,
starts
[
i
])
*
strides
[
axis
];
}
}
else
{
for
(
std
::
size_t
axis
=
0
;
axis
<
lens
.
size
();
axis
++
)
{
offset
+=
fix_index
(
lens
,
axis
,
starts
[
axis
])
*
strides
[
axis
];
}
}
return
offset
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
auto
input_shape
=
inputs
[
0
];
auto
t
=
input_shape
.
type
();
const
auto
&
old_lens
=
input_shape
.
lens
();
const
auto
&
old_strides
=
input_shape
.
strides
();
if
(
starts
.
size
()
!=
axes
.
size
()
||
axes
.
size
()
!=
ends
.
size
())
{
MIGRAPHX_THROW
(
"inconsistent sizes"
);
}
std
::
vector
<
std
::
size_t
>
new_lens
=
old_lens
;
for
(
std
::
size_t
i
=
0
;
i
<
axes
.
size
();
i
++
)
{
auto
axis
=
axes
[
i
];
new_lens
[
axis
]
=
fix_index
(
old_lens
,
axis
,
ends
[
i
])
-
fix_index
(
old_lens
,
axis
,
starts
[
i
]);
}
return
shape
{
t
,
new_lens
,
old_strides
};
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
auto
input
=
args
[
0
];
auto
offset
=
compute_offset
(
input
.
get_shape
())
*
output_shape
.
type_size
();
return
{
std
::
move
(
output_shape
),
[
=
]
{
return
input
.
data
()
+
offset
;
}};
}
int
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/softmax.hpp
0 → 100644
View file @
dd26f1aa
#ifndef MIGRAPHX_GUARD_OPERATORS_SOFTMAX_HPP
#define MIGRAPHX_GUARD_OPERATORS_SOFTMAX_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
softmax
{
std
::
string
name
()
const
{
return
"softmax"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
}.
has
(
1
).
only_dims
(
4
);
return
inputs
.
at
(
0
);
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/squeeze.hpp
0 → 100644
View file @
dd26f1aa
#ifndef MIGRAPHX_GUARD_OPERATORS_SQUEEZE_HPP
#define MIGRAPHX_GUARD_OPERATORS_SQUEEZE_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
squeeze
{
std
::
vector
<
int64_t
>
axes
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
axes
,
"axes"
));
}
std
::
string
name
()
const
{
return
"squeeze"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
auto
input_shape
=
inputs
[
0
];
auto
type
=
input_shape
.
type
();
auto
old_lens
=
input_shape
.
lens
();
if
(
std
::
any_of
(
axes
.
begin
(),
axes
.
end
(),
[
&
](
auto
axis
)
{
return
input_shape
.
lens
()[
axis
]
!=
1
;
}))
{
MIGRAPHX_THROW
(
"squeeze axis dimension should be equal to 1"
);
}
std
::
vector
<
std
::
size_t
>
new_lens
;
if
(
axes
.
empty
())
{
std
::
copy_if
(
old_lens
.
begin
(),
old_lens
.
end
(),
std
::
back_inserter
(
new_lens
),
[](
auto
len
)
{
return
len
!=
1
;
});
}
else
{
for
(
std
::
size_t
i
=
0
;
i
<
old_lens
.
size
();
i
++
)
{
if
(
std
::
find
(
axes
.
begin
(),
axes
.
end
(),
i
)
==
axes
.
end
())
{
new_lens
.
push_back
(
old_lens
[
i
]);
}
}
}
if
(
new_lens
.
empty
())
{
return
shape
{
type
};
}
else
{
return
shape
{
type
,
new_lens
};
}
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
front
().
data
)};
}
int
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/sub.hpp
0 → 100644
View file @
dd26f1aa
#ifndef MIGRAPHX_GUARD_OPERATORS_SUB_HPP
#define MIGRAPHX_GUARD_OPERATORS_SUB_HPP
#include <array>
#include <migraphx/op/binary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
sub
:
binary
{
std
::
string
name
()
const
{
return
"sub"
;
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/tan.hpp
0 → 100644
View file @
dd26f1aa
#ifndef MIGRAPHX_GUARD_OPERATORS_TAN_HPP
#define MIGRAPHX_GUARD_OPERATORS_TAN_HPP
#include <array>
#include <migraphx/op/unary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
tan
:
unary
{
std
::
string
name
()
const
{
return
"tan"
;
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/tanh.hpp
0 → 100644
View file @
dd26f1aa
#ifndef MIGRAPHX_GUARD_OPERATORS_TANH_HPP
#define MIGRAPHX_GUARD_OPERATORS_TANH_HPP
#include <array>
#include <migraphx/op/unary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
tanh
:
unary
{
std
::
string
name
()
const
{
return
"tanh"
;
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/transpose.hpp
0 → 100644
View file @
dd26f1aa
#ifndef MIGRAPHX_GUARD_OPERATORS_TRANSPOSE_HPP
#define MIGRAPHX_GUARD_OPERATORS_TRANSPOSE_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
transpose
{
std
::
vector
<
int64_t
>
dims
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
dims
,
"dims"
));
}
std
::
string
name
()
const
{
return
"transpose"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
auto
input
=
inputs
.
at
(
0
);
auto
input_lens
=
input
.
lens
();
auto
input_strides
=
input
.
strides
();
auto
t
=
input
.
type
();
if
(
dims
.
size
()
!=
input_lens
.
size
())
{
MIGRAPHX_THROW
(
"Permutation has wrong number of axes"
);
}
std
::
vector
<
int64_t
>
axes
(
dims
.
size
());
std
::
iota
(
axes
.
begin
(),
axes
.
end
(),
0
);
if
(
!
std
::
is_permutation
(
axes
.
begin
(),
axes
.
end
(),
dims
.
begin
()))
{
MIGRAPHX_THROW
(
"Invalid permutation"
);
}
std
::
vector
<
size_t
>
output_lens
(
input_lens
.
size
());
std
::
vector
<
size_t
>
output_strides
(
input_lens
.
size
());
for
(
std
::
size_t
i
=
0
;
i
<
output_lens
.
size
();
i
++
)
{
output_lens
[
i
]
=
input_lens
[
dims
[
i
]];
output_strides
[
i
]
=
input_strides
[
dims
[
i
]];
}
return
{
t
,
output_lens
,
output_strides
};
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
front
().
data
)};
}
int
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/unary.hpp
0 → 100644
View file @
dd26f1aa
#ifndef MIGRAPHX_GUARD_OPERATORS_UNARY_HPP
#define MIGRAPHX_GUARD_OPERATORS_UNARY_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
unary
{
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
}.
has
(
1
);
return
inputs
.
at
(
0
);
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/unsqueeze.hpp
0 → 100644
View file @
dd26f1aa
#ifndef MIGRAPHX_GUARD_OPERATORS_UNSQUEEZE_HPP
#define MIGRAPHX_GUARD_OPERATORS_UNSQUEEZE_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
unsqueeze
{
std
::
vector
<
int64_t
>
axes
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
axes
,
"axes"
));
}
std
::
string
name
()
const
{
return
"unsqueeze"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
auto
input_shape
=
inputs
[
0
];
auto
type
=
input_shape
.
type
();
auto
old_lens
=
input_shape
.
lens
();
std
::
size_t
new_size
=
old_lens
.
size
()
+
axes
.
size
();
std
::
vector
<
std
::
size_t
>
new_lens
(
new_size
);
std
::
size_t
p
=
0
;
for
(
std
::
size_t
i
=
0
;
i
<
new_size
;
i
++
)
{
if
(
std
::
find
(
axes
.
begin
(),
axes
.
end
(),
i
)
!=
axes
.
end
())
{
new_lens
[
i
]
=
1
;
}
else
{
new_lens
[
i
]
=
old_lens
[
p
++
];
}
}
return
shape
{
type
,
new_lens
};
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
front
().
data
)};
}
int
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/operators.hpp
View file @
dd26f1aa
This diff is collapsed.
Click to expand it.
src/include/migraphx/ranges.hpp
View file @
dd26f1aa
...
@@ -71,6 +71,30 @@ bool all_of(const std::initializer_list<T>& c, const Predicate& p)
...
@@ -71,6 +71,30 @@ bool all_of(const std::initializer_list<T>& c, const Predicate& p)
return
std
::
all_of
(
c
.
begin
(),
c
.
end
(),
p
);
return
std
::
all_of
(
c
.
begin
(),
c
.
end
(),
p
);
}
}
template
<
class
C
,
class
Predicate
>
bool
any_of
(
const
C
&
c
,
const
Predicate
&
p
)
{
return
std
::
any_of
(
c
.
begin
(),
c
.
end
(),
p
);
}
template
<
class
T
,
class
Predicate
>
bool
any_of
(
const
std
::
initializer_list
<
T
>&
c
,
const
Predicate
&
p
)
{
return
std
::
any_of
(
c
.
begin
(),
c
.
end
(),
p
);
}
template
<
class
C
,
class
Predicate
>
bool
none_of
(
const
C
&
c
,
const
Predicate
&
p
)
{
return
std
::
none_of
(
c
.
begin
(),
c
.
end
(),
p
);
}
template
<
class
T
,
class
Predicate
>
bool
none_of
(
const
std
::
initializer_list
<
T
>&
c
,
const
Predicate
&
p
)
{
return
std
::
none_of
(
c
.
begin
(),
c
.
end
(),
p
);
}
template
<
class
Range
,
class
Iterator
>
template
<
class
Range
,
class
Iterator
>
void
copy
(
Range
&&
r
,
Iterator
it
)
void
copy
(
Range
&&
r
,
Iterator
it
)
{
{
...
...
src/include/migraphx/rewrite_rnn.hpp
View file @
dd26f1aa
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
#include <string>
#include <string>
#include <vector>
#include <vector>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/operat
ors
.hpp>
#include <migraphx/operat
ion
.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
namespace
migraphx
{
...
...
src/include/migraphx/schedule.hpp
View file @
dd26f1aa
...
@@ -17,6 +17,7 @@ struct program;
...
@@ -17,6 +17,7 @@ struct program;
struct
schedule
struct
schedule
{
{
schedule_model
model
{};
schedule_model
model
{};
bool
enable
=
true
;
std
::
string
name
()
const
{
return
"schedule"
;
}
std
::
string
name
()
const
{
return
"schedule"
;
}
void
apply
(
program
&
p
)
const
;
void
apply
(
program
&
p
)
const
;
};
};
...
...
src/onnx/onnx.cpp
View file @
dd26f1aa
...
@@ -36,7 +36,6 @@ struct onnx_parser
...
@@ -36,7 +36,6 @@ struct onnx_parser
onnx_parser
()
onnx_parser
()
{
{
add_generic_op
(
"MatMul"
,
op
::
dot
{});
add_generic_op
(
"Relu"
,
op
::
relu
{});
add_generic_op
(
"Relu"
,
op
::
relu
{});
add_generic_op
(
"Sigmoid"
,
op
::
sigmoid
{});
add_generic_op
(
"Sigmoid"
,
op
::
sigmoid
{});
add_generic_op
(
"Abs"
,
op
::
abs
{});
add_generic_op
(
"Abs"
,
op
::
abs
{});
...
@@ -77,6 +76,7 @@ struct onnx_parser
...
@@ -77,6 +76,7 @@ struct onnx_parser
add_mem_op
(
"Reshape"
,
&
onnx_parser
::
parse_reshape
);
add_mem_op
(
"Reshape"
,
&
onnx_parser
::
parse_reshape
);
add_mem_op
(
"Flatten"
,
&
onnx_parser
::
parse_flatten
);
add_mem_op
(
"Flatten"
,
&
onnx_parser
::
parse_flatten
);
add_mem_op
(
"Gemm"
,
&
onnx_parser
::
parse_gemm
);
add_mem_op
(
"Gemm"
,
&
onnx_parser
::
parse_gemm
);
add_mem_op
(
"MatMul"
,
&
onnx_parser
::
parse_matmul
);
add_mem_op
(
"BatchNormalization"
,
&
onnx_parser
::
parse_batchnorm
);
add_mem_op
(
"BatchNormalization"
,
&
onnx_parser
::
parse_batchnorm
);
add_mem_op
(
"Softmax"
,
&
onnx_parser
::
parse_softmax
);
add_mem_op
(
"Softmax"
,
&
onnx_parser
::
parse_softmax
);
add_mem_op
(
"LogSoftmax"
,
&
onnx_parser
::
parse_logsoftmax
);
add_mem_op
(
"LogSoftmax"
,
&
onnx_parser
::
parse_logsoftmax
);
...
@@ -154,42 +154,48 @@ struct onnx_parser
...
@@ -154,42 +154,48 @@ struct onnx_parser
});
});
}
}
std
::
vector
<
std
::
size_t
>
compute_broadcasted_lens
(
std
::
vector
<
std
::
size_t
>
s0
,
std
::
vector
<
std
::
size_t
>
s1
)
{
// Example:
// s0 = (3,2,4,5) and s1 = (2,1,1)
//
// In this case we need to broadcast (:,1,1) portion of
// s1 plus broadcast the 1st dimension of s1
// giving output_lens = (3,2,4,5)
//
// Another example:
// s0 = (3,2,1,5) and s1 = (2,7,5)
// In this case we need to broadcast the (:,:,1:,:) axis
// of s0 plus the 1st dimension of s1 giving
// output_lens = (3,2,7,5)
if
(
s0
.
size
()
>
s1
.
size
())
{
s0
.
swap
(
s1
);
}
std
::
vector
<
std
::
size_t
>
out_lens
(
s1
);
auto
offset
=
s1
.
size
()
-
s0
.
size
();
std
::
transform
(
s0
.
begin
(),
s0
.
end
(),
s1
.
begin
()
+
offset
,
out_lens
.
begin
()
+
offset
,
[](
auto
a
,
auto
b
)
{
return
std
::
max
(
a
,
b
);
});
return
out_lens
;
}
template
<
class
T
>
template
<
class
T
>
instruction_ref
add_broadcastable_binary_op
(
instruction_ref
arg0
,
instruction_ref
arg1
,
T
x
)
instruction_ref
add_broadcastable_binary_op
(
instruction_ref
arg0
,
instruction_ref
arg1
,
T
x
)
{
{
if
(
arg0
->
get_shape
().
lens
()
!=
arg1
->
get_shape
().
lens
())
if
(
arg0
->
get_shape
().
lens
()
!=
arg1
->
get_shape
().
lens
())
{
{
// Example:
// s0 = (3,2,4,5) and s1 = (2,1,1)
//
// In this case we need to broadcast (:,1,1) portion of
// s1 plus broadcast the 1st dimension of s1
// giving output_lens = (3,2,4,5)
//
// Another example:
// s0 = (3,2,1,5) and s1 = (2,7,5)
// In this case we need to broadcast the (:,:,1:,:) axis
// of s0 plus the 1st dimension of s1 giving
// output_lens = (3,2,7,5)
//
// Get lengths for both arguments
// Get lengths for both arguments
const
std
::
vector
<
std
::
size_t
>*
s0
=
&
arg0
->
get_shape
().
lens
();
auto
s0
=
arg0
->
get_shape
().
lens
();
const
std
::
vector
<
std
::
size_t
>*
s1
=
&
arg1
->
get_shape
().
lens
();
auto
s1
=
arg1
->
get_shape
().
lens
();
auto
out_lens
=
compute_broadcasted_lens
(
s0
,
s1
);
// Make sure s0 is the smaller size
auto
l0
=
prog
.
add_instruction
(
op
::
multibroadcast
{
out_lens
},
arg0
);
if
(
s0
->
size
()
>
s1
->
size
())
auto
l1
=
prog
.
add_instruction
(
op
::
multibroadcast
{
out_lens
},
arg1
);
std
::
swap
(
s0
,
s1
);
std
::
vector
<
std
::
size_t
>
output_lens
(
*
s1
);
auto
offset
=
s1
->
size
()
-
s0
->
size
();
std
::
transform
(
s0
->
begin
(),
s0
->
end
(),
s1
->
begin
()
+
offset
,
output_lens
.
begin
()
+
offset
,
[](
auto
a
,
auto
b
)
{
return
std
::
max
(
a
,
b
);
});
auto
l0
=
prog
.
add_instruction
(
op
::
multibroadcast
{
output_lens
},
arg0
);
auto
l1
=
prog
.
add_instruction
(
op
::
multibroadcast
{
output_lens
},
arg1
);
return
prog
.
add_instruction
(
x
,
l0
,
l1
);
return
prog
.
add_instruction
(
x
,
l0
,
l1
);
}
}
else
else
...
@@ -500,19 +506,80 @@ struct onnx_parser
...
@@ -500,19 +506,80 @@ struct onnx_parser
auto
out_lens
=
l1
->
get_shape
().
lens
();
auto
out_lens
=
l1
->
get_shape
().
lens
();
out_lens
.
back
()
=
l2
->
get_shape
().
lens
().
back
();
out_lens
.
back
()
=
l2
->
get_shape
().
lens
().
back
();
auto
l3
=
args
[
2
];
auto
l3
=
args
[
2
];
if
(
!
std
::
equal
(
auto
l3_lens
=
l3
->
get_shape
().
lens
();
out_lens
.
begin
(),
out_lens
.
end
(),
args
[
2
]
->
get_shape
().
lens
().
begin
())
&&
if
(
!
std
::
equal
(
out_lens
.
begin
(),
out_lens
.
end
(),
l3_lens
.
begin
(),
l3_lens
.
end
()))
out_lens
.
size
()
>
2
)
{
{
l3
=
prog
.
add_instruction
(
op
::
multibroadcast
{
out_lens
},
args
[
2
]);
l3
=
prog
.
add_instruction
(
op
::
multibroadcast
{
out_lens
},
args
[
2
]);
}
}
return
prog
.
add_instruction
(
op
::
dot
{
alpha
,
beta
},
l1
,
l2
,
l3
);
return
prog
.
add_instruction
(
op
::
dot
{
alpha
,
beta
},
l1
,
l2
,
l3
);
}
}
}
}
return
prog
.
add_instruction
(
op
::
dot
{
alpha
},
l1
,
l2
);
return
prog
.
add_instruction
(
op
::
dot
{
alpha
},
l1
,
l2
);
}
}
instruction_ref
parse_matmul
(
const
std
::
string
&
,
const
attribute_map
&
,
std
::
vector
<
instruction_ref
>
args
)
{
auto
l0
=
args
[
0
];
auto
l1
=
args
[
1
];
auto
l0_lens
=
l0
->
get_shape
().
lens
();
auto
l1_lens
=
l1
->
get_shape
().
lens
();
// args[0] is a vector, prepend 1 to the shape
bool
is_a_prepended
=
false
;
if
(
l0_lens
.
size
()
==
1
)
{
is_a_prepended
=
true
;
l0_lens
.
insert
(
l0_lens
.
begin
(),
1
);
l0
=
prog
.
add_instruction
(
op
::
unsqueeze
{{
0
}},
args
[
0
]);
}
bool
is_b_appended
=
false
;
if
(
l1_lens
.
size
()
==
1
)
{
is_b_appended
=
true
;
l1_lens
.
push_back
(
1
);
l1
=
prog
.
add_instruction
(
op
::
unsqueeze
{{
1
}},
args
[
1
]);
}
instruction_ref
bl0
=
l0
;
instruction_ref
bl1
=
l1
;
if
(
!
std
::
equal
(
l0_lens
.
rbegin
()
+
2
,
l0_lens
.
rend
(),
l1_lens
.
rbegin
()
+
2
,
l1_lens
.
rend
()))
{
auto
l0_it
=
l0_lens
.
begin
()
+
l0_lens
.
size
()
-
2
;
std
::
vector
<
std
::
size_t
>
l0_broadcasted_lens
(
l0_lens
.
begin
(),
l0_it
);
auto
l1_it
=
l1_lens
.
begin
()
+
l1_lens
.
size
()
-
2
;
std
::
vector
<
std
::
size_t
>
l1_broadcasted_lens
(
l1_lens
.
begin
(),
l1_it
);
auto
output_lens
=
compute_broadcasted_lens
(
l0_broadcasted_lens
,
l1_broadcasted_lens
);
l0_broadcasted_lens
=
output_lens
;
l0_broadcasted_lens
.
insert
(
l0_broadcasted_lens
.
end
(),
l0_it
,
l0_lens
.
end
());
l1_broadcasted_lens
=
output_lens
;
l1_broadcasted_lens
.
insert
(
l1_broadcasted_lens
.
end
(),
l1_it
,
l1_lens
.
end
());
if
(
l0_lens
!=
l0_broadcasted_lens
)
{
bl0
=
prog
.
add_instruction
(
op
::
multibroadcast
{
l0_broadcasted_lens
},
l0
);
}
if
(
l1_lens
!=
l1_broadcasted_lens
)
{
bl1
=
prog
.
add_instruction
(
op
::
multibroadcast
{
l1_broadcasted_lens
},
l1
);
}
}
auto
dot_res
=
prog
.
add_instruction
(
op
::
dot
{
1.0
f
,
0.0
f
},
bl0
,
bl1
);
int64_t
num_axis
=
static_cast
<
int64_t
>
(
dot_res
->
get_shape
().
lens
().
size
());
if
(
is_a_prepended
)
{
dot_res
=
prog
.
add_instruction
(
op
::
squeeze
{{
num_axis
-
2
}},
dot_res
);
--
num_axis
;
}
if
(
is_b_appended
)
{
dot_res
=
prog
.
add_instruction
(
op
::
squeeze
{{
num_axis
-
1
}},
dot_res
);
}
return
dot_res
;
}
instruction_ref
instruction_ref
parse_batchnorm
(
const
std
::
string
&
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
parse_batchnorm
(
const
std
::
string
&
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
{
...
...
src/opt/memory_coloring_impl.cpp
View file @
dd26f1aa
#include <migraphx/op/load.hpp>
#include "memory_coloring_impl.hpp"
#include "memory_coloring_impl.hpp"
namespace
migraphx
{
namespace
migraphx
{
...
...
src/opt/memory_coloring_impl.hpp
View file @
dd26f1aa
...
@@ -3,7 +3,6 @@
...
@@ -3,7 +3,6 @@
#include <migraphx/program.hpp>
#include <migraphx/program.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/pass_config.hpp>
#include <migraphx/pass_config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
...
...
src/program.cpp
View file @
dd26f1aa
#include <migraphx/program.hpp>
#include <migraphx/program.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/op
erators
.hpp>
#include <migraphx/op
/identity
.hpp>
#include <migraphx/target.hpp>
#include <migraphx/target.hpp>
#include <migraphx/env.hpp>
#include <migraphx/env.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/ranges.hpp>
...
...
src/schedule.cpp
View file @
dd26f1aa
#include <migraphx/schedule.hpp>
#include <migraphx/schedule.hpp>
#include <migraphx/program.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/op
erators
.hpp>
#include <migraphx/op
/identity
.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/functional.hpp>
...
@@ -341,6 +341,8 @@ struct stream_info
...
@@ -341,6 +341,8 @@ struct stream_info
void
schedule
::
apply
(
program
&
p
)
const
void
schedule
::
apply
(
program
&
p
)
const
{
{
if
(
not
enable
)
return
;
stream_info
si
;
stream_info
si
;
auto
last
=
std
::
prev
(
p
.
end
());
auto
last
=
std
::
prev
(
p
.
end
());
si
.
accumulate_weights
(
last
,
model
);
si
.
accumulate_weights
(
last
,
model
);
...
...
src/simplify_algebra.cpp
View file @
dd26f1aa
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/program.hpp>
#include <migraphx/program.hpp>
#include <migraphx/op
erators
.hpp>
#include <migraphx/op
/add
.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/literal.hpp>
...
...
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment