Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
4b7a267a
Commit
4b7a267a
authored
Apr 08, 2019
by
Paul
Browse files
Merge from develop
parents
92803edf
af00eea8
Changes
124
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
576 additions
and
1490 deletions
+576
-1490
src/include/migraphx/op/softmax.hpp
src/include/migraphx/op/softmax.hpp
+33
-0
src/include/migraphx/op/squeeze.hpp
src/include/migraphx/op/squeeze.hpp
+79
-0
src/include/migraphx/op/sub.hpp
src/include/migraphx/op/sub.hpp
+32
-0
src/include/migraphx/op/tan.hpp
src/include/migraphx/op/tan.hpp
+29
-0
src/include/migraphx/op/tanh.hpp
src/include/migraphx/op/tanh.hpp
+29
-0
src/include/migraphx/op/transpose.hpp
src/include/migraphx/op/transpose.hpp
+67
-0
src/include/migraphx/op/unary.hpp
src/include/migraphx/op/unary.hpp
+32
-0
src/include/migraphx/op/unsqueeze.hpp
src/include/migraphx/op/unsqueeze.hpp
+62
-0
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+57
-1437
src/include/migraphx/rewrite_rnn.hpp
src/include/migraphx/rewrite_rnn.hpp
+1
-1
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+108
-41
src/opt/memory_coloring_impl.cpp
src/opt/memory_coloring_impl.cpp
+1
-0
src/opt/memory_coloring_impl.hpp
src/opt/memory_coloring_impl.hpp
+0
-1
src/program.cpp
src/program.cpp
+1
-1
src/schedule.cpp
src/schedule.cpp
+1
-1
src/simplify_algebra.cpp
src/simplify_algebra.cpp
+1
-1
src/simplify_reshapes.cpp
src/simplify_reshapes.cpp
+1
-1
src/targets/cpu/gemm.cpp
src/targets/cpu/gemm.cpp
+9
-3
src/targets/cpu/lowering.cpp
src/targets/cpu/lowering.cpp
+33
-2
src/targets/gpu/eliminate_workspace.cpp
src/targets/gpu/eliminate_workspace.cpp
+0
-1
No files found.
src/include/migraphx/op/softmax.hpp
0 → 100644
View file @
4b7a267a
#ifndef MIGRAPHX_GUARD_OPERATORS_SOFTMAX_HPP
#define MIGRAPHX_GUARD_OPERATORS_SOFTMAX_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
softmax
{
std
::
string
name
()
const
{
return
"softmax"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
}.
has
(
1
).
only_dims
(
4
);
return
inputs
.
at
(
0
);
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/squeeze.hpp
0 → 100644
View file @
4b7a267a
#ifndef MIGRAPHX_GUARD_OPERATORS_SQUEEZE_HPP
#define MIGRAPHX_GUARD_OPERATORS_SQUEEZE_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
squeeze
{
std
::
vector
<
int64_t
>
axes
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
axes
,
"axes"
));
}
std
::
string
name
()
const
{
return
"squeeze"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
auto
input_shape
=
inputs
[
0
];
auto
type
=
input_shape
.
type
();
auto
old_lens
=
input_shape
.
lens
();
if
(
std
::
any_of
(
axes
.
begin
(),
axes
.
end
(),
[
&
](
auto
axis
)
{
return
input_shape
.
lens
()[
axis
]
!=
1
;
}))
{
MIGRAPHX_THROW
(
"squeeze axis dimension should be equal to 1"
);
}
std
::
vector
<
std
::
size_t
>
new_lens
;
if
(
axes
.
empty
())
{
std
::
copy_if
(
old_lens
.
begin
(),
old_lens
.
end
(),
std
::
back_inserter
(
new_lens
),
[](
auto
len
)
{
return
len
!=
1
;
});
}
else
{
for
(
std
::
size_t
i
=
0
;
i
<
old_lens
.
size
();
i
++
)
{
if
(
std
::
find
(
axes
.
begin
(),
axes
.
end
(),
i
)
==
axes
.
end
())
{
new_lens
.
push_back
(
old_lens
[
i
]);
}
}
}
if
(
new_lens
.
empty
())
{
return
shape
{
type
};
}
else
{
return
shape
{
type
,
new_lens
};
}
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
front
().
data
)};
}
int
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/sub.hpp
0 → 100644
View file @
4b7a267a
#ifndef MIGRAPHX_GUARD_OPERATORS_SUB_HPP
#define MIGRAPHX_GUARD_OPERATORS_SUB_HPP
#include <array>
#include <migraphx/op/binary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
sub
:
binary
<
sub
>
{
auto
apply
()
const
{
return
[](
auto
x
,
auto
y
)
{
return
x
-
y
;
};
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/tan.hpp
0 → 100644
View file @
4b7a267a
#ifndef MIGRAPHX_GUARD_OPERATORS_TAN_HPP
#define MIGRAPHX_GUARD_OPERATORS_TAN_HPP
#include <array>
#include <migraphx/op/unary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
tan
:
unary
{
std
::
string
name
()
const
{
return
"tan"
;
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/tanh.hpp
0 → 100644
View file @
4b7a267a
#ifndef MIGRAPHX_GUARD_OPERATORS_TANH_HPP
#define MIGRAPHX_GUARD_OPERATORS_TANH_HPP
#include <array>
#include <migraphx/op/unary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
tanh
:
unary
{
std
::
string
name
()
const
{
return
"tanh"
;
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/transpose.hpp
0 → 100644
View file @
4b7a267a
#ifndef MIGRAPHX_GUARD_OPERATORS_TRANSPOSE_HPP
#define MIGRAPHX_GUARD_OPERATORS_TRANSPOSE_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
transpose
{
std
::
vector
<
int64_t
>
dims
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
dims
,
"dims"
));
}
std
::
string
name
()
const
{
return
"transpose"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
auto
input
=
inputs
.
at
(
0
);
auto
input_lens
=
input
.
lens
();
auto
input_strides
=
input
.
strides
();
auto
t
=
input
.
type
();
if
(
dims
.
size
()
!=
input_lens
.
size
())
{
MIGRAPHX_THROW
(
"Permutation has wrong number of axes"
);
}
std
::
vector
<
int64_t
>
axes
(
dims
.
size
());
std
::
iota
(
axes
.
begin
(),
axes
.
end
(),
0
);
if
(
!
std
::
is_permutation
(
axes
.
begin
(),
axes
.
end
(),
dims
.
begin
()))
{
MIGRAPHX_THROW
(
"Invalid permutation"
);
}
std
::
vector
<
size_t
>
output_lens
(
input_lens
.
size
());
std
::
vector
<
size_t
>
output_strides
(
input_lens
.
size
());
for
(
std
::
size_t
i
=
0
;
i
<
output_lens
.
size
();
i
++
)
{
output_lens
[
i
]
=
input_lens
[
dims
[
i
]];
output_strides
[
i
]
=
input_strides
[
dims
[
i
]];
}
return
{
t
,
output_lens
,
output_strides
};
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
front
().
data
)};
}
int
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/unary.hpp
0 → 100644
View file @
4b7a267a
#ifndef MIGRAPHX_GUARD_OPERATORS_UNARY_HPP
#define MIGRAPHX_GUARD_OPERATORS_UNARY_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
unary
{
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
}.
has
(
1
);
return
inputs
.
at
(
0
);
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/unsqueeze.hpp
0 → 100644
View file @
4b7a267a
#ifndef MIGRAPHX_GUARD_OPERATORS_UNSQUEEZE_HPP
#define MIGRAPHX_GUARD_OPERATORS_UNSQUEEZE_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
unsqueeze
{
std
::
vector
<
int64_t
>
axes
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
axes
,
"axes"
));
}
std
::
string
name
()
const
{
return
"unsqueeze"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
auto
input_shape
=
inputs
[
0
];
auto
type
=
input_shape
.
type
();
auto
old_lens
=
input_shape
.
lens
();
std
::
size_t
new_size
=
old_lens
.
size
()
+
axes
.
size
();
std
::
vector
<
std
::
size_t
>
new_lens
(
new_size
);
std
::
size_t
p
=
0
;
for
(
std
::
size_t
i
=
0
;
i
<
new_size
;
i
++
)
{
if
(
std
::
find
(
axes
.
begin
(),
axes
.
end
(),
i
)
!=
axes
.
end
())
{
new_lens
[
i
]
=
1
;
}
else
{
new_lens
[
i
]
=
old_lens
[
p
++
];
}
}
return
shape
{
type
,
new_lens
};
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
front
().
data
)};
}
int
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/operators.hpp
View file @
4b7a267a
This diff is collapsed.
Click to expand it.
src/include/migraphx/rewrite_rnn.hpp
View file @
4b7a267a
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
#include <string>
#include <string>
#include <vector>
#include <vector>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/operat
ors
.hpp>
#include <migraphx/operat
ion
.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
namespace
migraphx
{
...
...
src/onnx/onnx.cpp
View file @
4b7a267a
...
@@ -36,7 +36,6 @@ struct onnx_parser
...
@@ -36,7 +36,6 @@ struct onnx_parser
onnx_parser
()
onnx_parser
()
{
{
add_generic_op
(
"MatMul"
,
op
::
dot
{});
add_generic_op
(
"Relu"
,
op
::
relu
{});
add_generic_op
(
"Relu"
,
op
::
relu
{});
add_generic_op
(
"Sigmoid"
,
op
::
sigmoid
{});
add_generic_op
(
"Sigmoid"
,
op
::
sigmoid
{});
add_generic_op
(
"Abs"
,
op
::
abs
{});
add_generic_op
(
"Abs"
,
op
::
abs
{});
...
@@ -77,6 +76,7 @@ struct onnx_parser
...
@@ -77,6 +76,7 @@ struct onnx_parser
add_mem_op
(
"Reshape"
,
&
onnx_parser
::
parse_reshape
);
add_mem_op
(
"Reshape"
,
&
onnx_parser
::
parse_reshape
);
add_mem_op
(
"Flatten"
,
&
onnx_parser
::
parse_flatten
);
add_mem_op
(
"Flatten"
,
&
onnx_parser
::
parse_flatten
);
add_mem_op
(
"Gemm"
,
&
onnx_parser
::
parse_gemm
);
add_mem_op
(
"Gemm"
,
&
onnx_parser
::
parse_gemm
);
add_mem_op
(
"MatMul"
,
&
onnx_parser
::
parse_matmul
);
add_mem_op
(
"BatchNormalization"
,
&
onnx_parser
::
parse_batchnorm
);
add_mem_op
(
"BatchNormalization"
,
&
onnx_parser
::
parse_batchnorm
);
add_mem_op
(
"Softmax"
,
&
onnx_parser
::
parse_softmax
);
add_mem_op
(
"Softmax"
,
&
onnx_parser
::
parse_softmax
);
add_mem_op
(
"LogSoftmax"
,
&
onnx_parser
::
parse_logsoftmax
);
add_mem_op
(
"LogSoftmax"
,
&
onnx_parser
::
parse_logsoftmax
);
...
@@ -154,42 +154,48 @@ struct onnx_parser
...
@@ -154,42 +154,48 @@ struct onnx_parser
});
});
}
}
std
::
vector
<
std
::
size_t
>
compute_broadcasted_lens
(
std
::
vector
<
std
::
size_t
>
s0
,
std
::
vector
<
std
::
size_t
>
s1
)
{
// Example:
// s0 = (3,2,4,5) and s1 = (2,1,1)
//
// In this case we need to broadcast (:,1,1) portion of
// s1 plus broadcast the 1st dimension of s1
// giving output_lens = (3,2,4,5)
//
// Another example:
// s0 = (3,2,1,5) and s1 = (2,7,5)
// In this case we need to broadcast the (:,:,1:,:) axis
// of s0 plus the 1st dimension of s1 giving
// output_lens = (3,2,7,5)
if
(
s0
.
size
()
>
s1
.
size
())
{
s0
.
swap
(
s1
);
}
std
::
vector
<
std
::
size_t
>
out_lens
(
s1
);
auto
offset
=
s1
.
size
()
-
s0
.
size
();
std
::
transform
(
s0
.
begin
(),
s0
.
end
(),
s1
.
begin
()
+
offset
,
out_lens
.
begin
()
+
offset
,
[](
auto
a
,
auto
b
)
{
return
std
::
max
(
a
,
b
);
});
return
out_lens
;
}
template
<
class
T
>
template
<
class
T
>
instruction_ref
add_broadcastable_binary_op
(
instruction_ref
arg0
,
instruction_ref
arg1
,
T
x
)
instruction_ref
add_broadcastable_binary_op
(
instruction_ref
arg0
,
instruction_ref
arg1
,
T
x
)
{
{
if
(
arg0
->
get_shape
().
lens
()
!=
arg1
->
get_shape
().
lens
())
if
(
arg0
->
get_shape
().
lens
()
!=
arg1
->
get_shape
().
lens
())
{
{
// Example:
// s0 = (3,2,4,5) and s1 = (2,1,1)
//
// In this case we need to broadcast (:,1,1) portion of
// s1 plus broadcast the 1st dimension of s1
// giving output_lens = (3,2,4,5)
//
// Another example:
// s0 = (3,2,1,5) and s1 = (2,7,5)
// In this case we need to broadcast the (:,:,1:,:) axis
// of s0 plus the 1st dimension of s1 giving
// output_lens = (3,2,7,5)
//
// Get lengths for both arguments
// Get lengths for both arguments
const
std
::
vector
<
std
::
size_t
>*
s0
=
&
arg0
->
get_shape
().
lens
();
auto
s0
=
arg0
->
get_shape
().
lens
();
const
std
::
vector
<
std
::
size_t
>*
s1
=
&
arg1
->
get_shape
().
lens
();
auto
s1
=
arg1
->
get_shape
().
lens
();
auto
out_lens
=
compute_broadcasted_lens
(
s0
,
s1
);
// Make sure s0 is the smaller size
auto
l0
=
prog
.
add_instruction
(
op
::
multibroadcast
{
out_lens
},
arg0
);
if
(
s0
->
size
()
>
s1
->
size
())
auto
l1
=
prog
.
add_instruction
(
op
::
multibroadcast
{
out_lens
},
arg1
);
std
::
swap
(
s0
,
s1
);
std
::
vector
<
std
::
size_t
>
output_lens
(
*
s1
);
auto
offset
=
s1
->
size
()
-
s0
->
size
();
std
::
transform
(
s0
->
begin
(),
s0
->
end
(),
s1
->
begin
()
+
offset
,
output_lens
.
begin
()
+
offset
,
[](
auto
a
,
auto
b
)
{
return
std
::
max
(
a
,
b
);
});
auto
l0
=
prog
.
add_instruction
(
op
::
multibroadcast
{
output_lens
},
arg0
);
auto
l1
=
prog
.
add_instruction
(
op
::
multibroadcast
{
output_lens
},
arg1
);
return
prog
.
add_instruction
(
x
,
l0
,
l1
);
return
prog
.
add_instruction
(
x
,
l0
,
l1
);
}
}
else
else
...
@@ -495,25 +501,86 @@ struct onnx_parser
...
@@ -495,25 +501,86 @@ struct onnx_parser
auto
l2
=
(
transb
)
?
prog
.
add_instruction
(
op
::
transpose
{
perm
},
args
[
1
])
:
args
[
1
];
auto
l2
=
(
transb
)
?
prog
.
add_instruction
(
op
::
transpose
{
perm
},
args
[
1
])
:
args
[
1
];
if
(
args
.
size
()
==
3
)
if
(
args
.
size
()
==
3
)
{
{
if
(
beta
!=
0.
f
)
if
(
beta
!=
0.
f
&&
args
[
2
]
->
get_shape
().
elements
()
>
0
)
{
{
auto
l3
=
prog
.
add_instruction
(
op
::
dot
{
alpha
},
l1
,
l2
);
auto
out_lens
=
l1
->
get_shape
().
lens
(
);
a
ut
o
l4
=
args
[
2
]
;
o
ut
_lens
.
back
()
=
l2
->
get_shape
().
lens
().
back
()
;
if
(
l4
->
get_shape
().
scalar
())
// ignore args[2] (no C value added to alpha*A*B)
auto
l3
=
args
[
2
];
return
l3
;
auto
l3_lens
=
l3
->
get_shape
().
lens
()
;
if
(
beta
!=
1.
f
)
if
(
!
std
::
equal
(
out_lens
.
begin
(),
out_lens
.
end
(),
l3_lens
.
begin
(),
l3_lens
.
end
())
)
{
{
auto
beta_val
=
prog
.
add_literal
(
beta
);
l3
=
prog
.
add_instruction
(
op
::
multibroadcast
{
out_lens
},
args
[
2
]);
auto
l5
=
prog
.
add_instruction
(
op
::
scalar
{
args
[
2
]
->
get_shape
()},
beta_val
);
l4
=
prog
.
add_instruction
(
op
::
mul
{},
args
[
2
],
l5
);
}
}
return
add_broadcastable_binary_op
(
l3
,
l
4
,
op
::
add
{}
);
return
prog
.
add_instruction
(
op
::
dot
{
alpha
,
beta
},
l1
,
l
2
,
l3
);
}
}
}
}
return
prog
.
add_instruction
(
op
::
dot
{
alpha
,
beta
},
l1
,
l2
);
return
prog
.
add_instruction
(
op
::
dot
{
alpha
,
beta
},
l1
,
l2
);
}
}
instruction_ref
parse_matmul
(
const
std
::
string
&
,
const
attribute_map
&
,
std
::
vector
<
instruction_ref
>
args
)
{
auto
l0
=
args
[
0
];
auto
l1
=
args
[
1
];
auto
l0_lens
=
l0
->
get_shape
().
lens
();
auto
l1_lens
=
l1
->
get_shape
().
lens
();
// args[0] is a vector, prepend 1 to the shape
bool
is_a_prepended
=
false
;
if
(
l0_lens
.
size
()
==
1
)
{
is_a_prepended
=
true
;
l0_lens
.
insert
(
l0_lens
.
begin
(),
1
);
l0
=
prog
.
add_instruction
(
op
::
unsqueeze
{{
0
}},
args
[
0
]);
}
bool
is_b_appended
=
false
;
if
(
l1_lens
.
size
()
==
1
)
{
is_b_appended
=
true
;
l1_lens
.
push_back
(
1
);
l1
=
prog
.
add_instruction
(
op
::
unsqueeze
{{
1
}},
args
[
1
]);
}
instruction_ref
bl0
=
l0
;
instruction_ref
bl1
=
l1
;
if
(
!
std
::
equal
(
l0_lens
.
rbegin
()
+
2
,
l0_lens
.
rend
(),
l1_lens
.
rbegin
()
+
2
,
l1_lens
.
rend
()))
{
auto
l0_it
=
l0_lens
.
begin
()
+
l0_lens
.
size
()
-
2
;
std
::
vector
<
std
::
size_t
>
l0_broadcasted_lens
(
l0_lens
.
begin
(),
l0_it
);
auto
l1_it
=
l1_lens
.
begin
()
+
l1_lens
.
size
()
-
2
;
std
::
vector
<
std
::
size_t
>
l1_broadcasted_lens
(
l1_lens
.
begin
(),
l1_it
);
auto
output_lens
=
compute_broadcasted_lens
(
l0_broadcasted_lens
,
l1_broadcasted_lens
);
l0_broadcasted_lens
=
output_lens
;
l0_broadcasted_lens
.
insert
(
l0_broadcasted_lens
.
end
(),
l0_it
,
l0_lens
.
end
());
l1_broadcasted_lens
=
output_lens
;
l1_broadcasted_lens
.
insert
(
l1_broadcasted_lens
.
end
(),
l1_it
,
l1_lens
.
end
());
if
(
l0_lens
!=
l0_broadcasted_lens
)
{
bl0
=
prog
.
add_instruction
(
op
::
multibroadcast
{
l0_broadcasted_lens
},
l0
);
}
if
(
l1_lens
!=
l1_broadcasted_lens
)
{
bl1
=
prog
.
add_instruction
(
op
::
multibroadcast
{
l1_broadcasted_lens
},
l1
);
}
}
auto
dot_res
=
prog
.
add_instruction
(
op
::
dot
{
1.0
f
,
0.0
f
},
bl0
,
bl1
);
int64_t
num_axis
=
static_cast
<
int64_t
>
(
dot_res
->
get_shape
().
lens
().
size
());
if
(
is_a_prepended
)
{
dot_res
=
prog
.
add_instruction
(
op
::
squeeze
{{
num_axis
-
2
}},
dot_res
);
--
num_axis
;
}
if
(
is_b_appended
)
{
dot_res
=
prog
.
add_instruction
(
op
::
squeeze
{{
num_axis
-
1
}},
dot_res
);
}
return
dot_res
;
}
instruction_ref
instruction_ref
parse_batchnorm
(
const
std
::
string
&
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
parse_batchnorm
(
const
std
::
string
&
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
{
...
...
src/opt/memory_coloring_impl.cpp
View file @
4b7a267a
#include <migraphx/op/load.hpp>
#include "memory_coloring_impl.hpp"
#include "memory_coloring_impl.hpp"
namespace
migraphx
{
namespace
migraphx
{
...
...
src/opt/memory_coloring_impl.hpp
View file @
4b7a267a
...
@@ -3,7 +3,6 @@
...
@@ -3,7 +3,6 @@
#include <migraphx/program.hpp>
#include <migraphx/program.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/pass_config.hpp>
#include <migraphx/pass_config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
...
...
src/program.cpp
View file @
4b7a267a
#include <migraphx/program.hpp>
#include <migraphx/program.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/op
erators
.hpp>
#include <migraphx/op
/identity
.hpp>
#include <migraphx/target.hpp>
#include <migraphx/target.hpp>
#include <migraphx/env.hpp>
#include <migraphx/env.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/ranges.hpp>
...
...
src/schedule.cpp
View file @
4b7a267a
#include <migraphx/schedule.hpp>
#include <migraphx/schedule.hpp>
#include <migraphx/program.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/op
erators
.hpp>
#include <migraphx/op
/identity
.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/functional.hpp>
...
...
src/simplify_algebra.cpp
View file @
4b7a267a
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/program.hpp>
#include <migraphx/program.hpp>
#include <migraphx/op
erators
.hpp>
#include <migraphx/op
/add
.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/literal.hpp>
...
...
src/simplify_reshapes.cpp
View file @
4b7a267a
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/program.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/op
erators
.hpp>
#include <migraphx/op
/as_shape
.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/ranges.hpp>
#include <unordered_set>
#include <unordered_set>
...
...
src/targets/cpu/gemm.cpp
View file @
4b7a267a
...
@@ -55,7 +55,13 @@ void migemm_impl(tensor_view<T> cmat,
...
@@ -55,7 +55,13 @@ void migemm_impl(tensor_view<T> cmat,
visit_mat
(
amat
,
[
&
](
const
auto
&
a
)
{
visit_mat
(
amat
,
[
&
](
const
auto
&
a
)
{
visit_mat
(
bmat
,
[
&
](
const
auto
&
b
)
{
visit_mat
(
bmat
,
[
&
](
const
auto
&
b
)
{
auto
c
=
make_mat
(
cmat
);
auto
c
=
make_mat
(
cmat
);
c
=
(
a
*
b
)
*
alpha
+
beta
*
c
;
c
=
beta
*
c
;
// This is a simple optimization to avoid
// compute A * B if alpha is 0.0
if
(
alpha
!=
0.0
)
{
c
=
c
+
alpha
*
a
*
b
;
}
});
});
});
});
}
}
...
@@ -95,8 +101,8 @@ void migemm_impl(
...
@@ -95,8 +101,8 @@ void migemm_impl(
{
{
auto
lens
=
amat
.
get_shape
().
lens
();
auto
lens
=
amat
.
get_shape
().
lens
();
bool
batch_mul
=
bool
batch_mul
=
std
::
accumulate
(
lens
.
begin
(),
lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
())
==
std
::
accumulate
(
(
*
lens
.
rbegin
()
)
*
(
*
(
lens
.
rbegin
()
+
1
))
;
lens
.
rbegin
()
+
2
,
lens
.
rend
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
())
==
1
;
if
(
batch_mul
)
if
(
batch_mul
)
{
{
migemm_impl
(
cmat
,
amat
,
bmat
,
alpha
,
beta
,
is_fast_gemm_type
<
T
>
{});
migemm_impl
(
cmat
,
amat
,
bmat
,
alpha
,
beta
,
is_fast_gemm_type
<
T
>
{});
...
...
src/targets/cpu/lowering.cpp
View file @
4b7a267a
...
@@ -369,12 +369,43 @@ struct cpu_gemm
...
@@ -369,12 +369,43 @@ struct cpu_gemm
{
{
op
::
dot
op
;
op
::
dot
op
;
std
::
string
name
()
const
{
return
"cpu::dot"
;
}
std
::
string
name
()
const
{
return
"cpu::dot"
;
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
return
op
.
compute_shape
(
inputs
);
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
if
(
inputs
.
size
()
==
3
)
{
auto
c_shape
=
inputs
.
at
(
2
);
check_shapes
{{
c_shape
}}.
not_broadcasted
();
}
return
op
.
compute_shape
(
inputs
);
}
argument
compute
(
context
&
,
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
context
&
,
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
argument
result
{
output_shape
};
argument
result
{
output_shape
};
migemm
(
result
,
args
[
0
],
args
[
1
],
op
.
alpha
,
op
.
beta
);
// 3 inputs, it is alpha * A * B + beta * C, then
// A and B are matrics, and C is broadcastable to A * B
if
(
args
.
size
()
==
3
)
{
// no need to consider the value of args[2]
if
(
op
.
beta
==
0.0
f
)
{
result
.
visit
([
&
](
auto
output
)
{
std
::
fill
(
output
.
begin
(),
output
.
end
(),
0
);
});
}
else
{
visit_all
(
result
,
args
[
2
])([
&
](
auto
output
,
auto
input
)
{
std
::
copy
(
input
.
begin
(),
input
.
end
(),
output
.
begin
());
});
}
migemm
(
result
,
args
[
0
],
args
[
1
],
op
.
alpha
,
op
.
beta
);
return
result
;
}
// 2 input arguments
migemm
(
result
,
args
[
0
],
args
[
1
],
op
.
alpha
,
0.0
f
);
return
result
;
return
result
;
}
}
};
};
...
...
src/targets/gpu/eliminate_workspace.cpp
View file @
4b7a267a
...
@@ -2,7 +2,6 @@
...
@@ -2,7 +2,6 @@
#include <migraphx/gpu/hip.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/program.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/stringutils.hpp>
...
...
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment