Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
4b7a267a
"vscode:/vscode.git/clone" did not exist on "201c8182bfe09287a1f2849bd4f76a4f682e78b8"
Commit
4b7a267a
authored
Apr 08, 2019
by
Paul
Browse files
Merge from develop
parents
92803edf
af00eea8
Changes
124
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
576 additions
and
1490 deletions
+576
-1490
src/include/migraphx/op/softmax.hpp
src/include/migraphx/op/softmax.hpp
+33
-0
src/include/migraphx/op/squeeze.hpp
src/include/migraphx/op/squeeze.hpp
+79
-0
src/include/migraphx/op/sub.hpp
src/include/migraphx/op/sub.hpp
+32
-0
src/include/migraphx/op/tan.hpp
src/include/migraphx/op/tan.hpp
+29
-0
src/include/migraphx/op/tanh.hpp
src/include/migraphx/op/tanh.hpp
+29
-0
src/include/migraphx/op/transpose.hpp
src/include/migraphx/op/transpose.hpp
+67
-0
src/include/migraphx/op/unary.hpp
src/include/migraphx/op/unary.hpp
+32
-0
src/include/migraphx/op/unsqueeze.hpp
src/include/migraphx/op/unsqueeze.hpp
+62
-0
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+57
-1437
src/include/migraphx/rewrite_rnn.hpp
src/include/migraphx/rewrite_rnn.hpp
+1
-1
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+108
-41
src/opt/memory_coloring_impl.cpp
src/opt/memory_coloring_impl.cpp
+1
-0
src/opt/memory_coloring_impl.hpp
src/opt/memory_coloring_impl.hpp
+0
-1
src/program.cpp
src/program.cpp
+1
-1
src/schedule.cpp
src/schedule.cpp
+1
-1
src/simplify_algebra.cpp
src/simplify_algebra.cpp
+1
-1
src/simplify_reshapes.cpp
src/simplify_reshapes.cpp
+1
-1
src/targets/cpu/gemm.cpp
src/targets/cpu/gemm.cpp
+9
-3
src/targets/cpu/lowering.cpp
src/targets/cpu/lowering.cpp
+33
-2
src/targets/gpu/eliminate_workspace.cpp
src/targets/gpu/eliminate_workspace.cpp
+0
-1
No files found.
src/include/migraphx/op/softmax.hpp
0 → 100644
View file @
4b7a267a
#ifndef MIGRAPHX_GUARD_OPERATORS_SOFTMAX_HPP
#define MIGRAPHX_GUARD_OPERATORS_SOFTMAX_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
softmax
{
std
::
string
name
()
const
{
return
"softmax"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
}.
has
(
1
).
only_dims
(
4
);
return
inputs
.
at
(
0
);
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/squeeze.hpp
0 → 100644
View file @
4b7a267a
#ifndef MIGRAPHX_GUARD_OPERATORS_SQUEEZE_HPP
#define MIGRAPHX_GUARD_OPERATORS_SQUEEZE_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
squeeze
{
std
::
vector
<
int64_t
>
axes
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
axes
,
"axes"
));
}
std
::
string
name
()
const
{
return
"squeeze"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
auto
input_shape
=
inputs
[
0
];
auto
type
=
input_shape
.
type
();
auto
old_lens
=
input_shape
.
lens
();
if
(
std
::
any_of
(
axes
.
begin
(),
axes
.
end
(),
[
&
](
auto
axis
)
{
return
input_shape
.
lens
()[
axis
]
!=
1
;
}))
{
MIGRAPHX_THROW
(
"squeeze axis dimension should be equal to 1"
);
}
std
::
vector
<
std
::
size_t
>
new_lens
;
if
(
axes
.
empty
())
{
std
::
copy_if
(
old_lens
.
begin
(),
old_lens
.
end
(),
std
::
back_inserter
(
new_lens
),
[](
auto
len
)
{
return
len
!=
1
;
});
}
else
{
for
(
std
::
size_t
i
=
0
;
i
<
old_lens
.
size
();
i
++
)
{
if
(
std
::
find
(
axes
.
begin
(),
axes
.
end
(),
i
)
==
axes
.
end
())
{
new_lens
.
push_back
(
old_lens
[
i
]);
}
}
}
if
(
new_lens
.
empty
())
{
return
shape
{
type
};
}
else
{
return
shape
{
type
,
new_lens
};
}
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
front
().
data
)};
}
int
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/sub.hpp
0 → 100644
View file @
4b7a267a
#ifndef MIGRAPHX_GUARD_OPERATORS_SUB_HPP
#define MIGRAPHX_GUARD_OPERATORS_SUB_HPP
#include <array>
#include <migraphx/op/binary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
sub
:
binary
<
sub
>
{
auto
apply
()
const
{
return
[](
auto
x
,
auto
y
)
{
return
x
-
y
;
};
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/tan.hpp
0 → 100644
View file @
4b7a267a
#ifndef MIGRAPHX_GUARD_OPERATORS_TAN_HPP
#define MIGRAPHX_GUARD_OPERATORS_TAN_HPP
#include <array>
#include <migraphx/op/unary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
tan
:
unary
{
std
::
string
name
()
const
{
return
"tan"
;
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/tanh.hpp
0 → 100644
View file @
4b7a267a
#ifndef MIGRAPHX_GUARD_OPERATORS_TANH_HPP
#define MIGRAPHX_GUARD_OPERATORS_TANH_HPP
#include <array>
#include <migraphx/op/unary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
tanh
:
unary
{
std
::
string
name
()
const
{
return
"tanh"
;
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/transpose.hpp
0 → 100644
View file @
4b7a267a
#ifndef MIGRAPHX_GUARD_OPERATORS_TRANSPOSE_HPP
#define MIGRAPHX_GUARD_OPERATORS_TRANSPOSE_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
transpose
{
std
::
vector
<
int64_t
>
dims
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
dims
,
"dims"
));
}
std
::
string
name
()
const
{
return
"transpose"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
auto
input
=
inputs
.
at
(
0
);
auto
input_lens
=
input
.
lens
();
auto
input_strides
=
input
.
strides
();
auto
t
=
input
.
type
();
if
(
dims
.
size
()
!=
input_lens
.
size
())
{
MIGRAPHX_THROW
(
"Permutation has wrong number of axes"
);
}
std
::
vector
<
int64_t
>
axes
(
dims
.
size
());
std
::
iota
(
axes
.
begin
(),
axes
.
end
(),
0
);
if
(
!
std
::
is_permutation
(
axes
.
begin
(),
axes
.
end
(),
dims
.
begin
()))
{
MIGRAPHX_THROW
(
"Invalid permutation"
);
}
std
::
vector
<
size_t
>
output_lens
(
input_lens
.
size
());
std
::
vector
<
size_t
>
output_strides
(
input_lens
.
size
());
for
(
std
::
size_t
i
=
0
;
i
<
output_lens
.
size
();
i
++
)
{
output_lens
[
i
]
=
input_lens
[
dims
[
i
]];
output_strides
[
i
]
=
input_strides
[
dims
[
i
]];
}
return
{
t
,
output_lens
,
output_strides
};
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
front
().
data
)};
}
int
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/unary.hpp
0 → 100644
View file @
4b7a267a
#ifndef MIGRAPHX_GUARD_OPERATORS_UNARY_HPP
#define MIGRAPHX_GUARD_OPERATORS_UNARY_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
unary
{
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
}.
has
(
1
);
return
inputs
.
at
(
0
);
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/unsqueeze.hpp
0 → 100644
View file @
4b7a267a
#ifndef MIGRAPHX_GUARD_OPERATORS_UNSQUEEZE_HPP
#define MIGRAPHX_GUARD_OPERATORS_UNSQUEEZE_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
unsqueeze
{
std
::
vector
<
int64_t
>
axes
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
axes
,
"axes"
));
}
std
::
string
name
()
const
{
return
"unsqueeze"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
auto
input_shape
=
inputs
[
0
];
auto
type
=
input_shape
.
type
();
auto
old_lens
=
input_shape
.
lens
();
std
::
size_t
new_size
=
old_lens
.
size
()
+
axes
.
size
();
std
::
vector
<
std
::
size_t
>
new_lens
(
new_size
);
std
::
size_t
p
=
0
;
for
(
std
::
size_t
i
=
0
;
i
<
new_size
;
i
++
)
{
if
(
std
::
find
(
axes
.
begin
(),
axes
.
end
(),
i
)
!=
axes
.
end
())
{
new_lens
[
i
]
=
1
;
}
else
{
new_lens
[
i
]
=
old_lens
[
p
++
];
}
}
return
shape
{
type
,
new_lens
};
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
front
().
data
)};
}
int
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/operators.hpp
View file @
4b7a267a
This diff is collapsed.
Click to expand it.
src/include/migraphx/rewrite_rnn.hpp
View file @
4b7a267a
...
...
@@ -4,7 +4,7 @@
#include <string>
#include <vector>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/operat
ors
.hpp>
#include <migraphx/operat
ion
.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
...
...
src/onnx/onnx.cpp
View file @
4b7a267a
...
...
@@ -36,7 +36,6 @@ struct onnx_parser
onnx_parser
()
{
add_generic_op
(
"MatMul"
,
op
::
dot
{});
add_generic_op
(
"Relu"
,
op
::
relu
{});
add_generic_op
(
"Sigmoid"
,
op
::
sigmoid
{});
add_generic_op
(
"Abs"
,
op
::
abs
{});
...
...
@@ -77,6 +76,7 @@ struct onnx_parser
add_mem_op
(
"Reshape"
,
&
onnx_parser
::
parse_reshape
);
add_mem_op
(
"Flatten"
,
&
onnx_parser
::
parse_flatten
);
add_mem_op
(
"Gemm"
,
&
onnx_parser
::
parse_gemm
);
add_mem_op
(
"MatMul"
,
&
onnx_parser
::
parse_matmul
);
add_mem_op
(
"BatchNormalization"
,
&
onnx_parser
::
parse_batchnorm
);
add_mem_op
(
"Softmax"
,
&
onnx_parser
::
parse_softmax
);
add_mem_op
(
"LogSoftmax"
,
&
onnx_parser
::
parse_logsoftmax
);
...
...
@@ -154,42 +154,48 @@ struct onnx_parser
});
}
std
::
vector
<
std
::
size_t
>
compute_broadcasted_lens
(
std
::
vector
<
std
::
size_t
>
s0
,
std
::
vector
<
std
::
size_t
>
s1
)
{
// Example:
// s0 = (3,2,4,5) and s1 = (2,1,1)
//
// In this case we need to broadcast (:,1,1) portion of
// s1 plus broadcast the 1st dimension of s1
// giving output_lens = (3,2,4,5)
//
// Another example:
// s0 = (3,2,1,5) and s1 = (2,7,5)
// In this case we need to broadcast the (:,:,1:,:) axis
// of s0 plus the 1st dimension of s1 giving
// output_lens = (3,2,7,5)
if
(
s0
.
size
()
>
s1
.
size
())
{
s0
.
swap
(
s1
);
}
std
::
vector
<
std
::
size_t
>
out_lens
(
s1
);
auto
offset
=
s1
.
size
()
-
s0
.
size
();
std
::
transform
(
s0
.
begin
(),
s0
.
end
(),
s1
.
begin
()
+
offset
,
out_lens
.
begin
()
+
offset
,
[](
auto
a
,
auto
b
)
{
return
std
::
max
(
a
,
b
);
});
return
out_lens
;
}
template
<
class
T
>
instruction_ref
add_broadcastable_binary_op
(
instruction_ref
arg0
,
instruction_ref
arg1
,
T
x
)
{
if
(
arg0
->
get_shape
().
lens
()
!=
arg1
->
get_shape
().
lens
())
{
// Example:
// s0 = (3,2,4,5) and s1 = (2,1,1)
//
// In this case we need to broadcast (:,1,1) portion of
// s1 plus broadcast the 1st dimension of s1
// giving output_lens = (3,2,4,5)
//
// Another example:
// s0 = (3,2,1,5) and s1 = (2,7,5)
// In this case we need to broadcast the (:,:,1:,:) axis
// of s0 plus the 1st dimension of s1 giving
// output_lens = (3,2,7,5)
//
// Get lengths for both arguments
const
std
::
vector
<
std
::
size_t
>*
s0
=
&
arg0
->
get_shape
().
lens
();
const
std
::
vector
<
std
::
size_t
>*
s1
=
&
arg1
->
get_shape
().
lens
();
// Make sure s0 is the smaller size
if
(
s0
->
size
()
>
s1
->
size
())
std
::
swap
(
s0
,
s1
);
std
::
vector
<
std
::
size_t
>
output_lens
(
*
s1
);
auto
offset
=
s1
->
size
()
-
s0
->
size
();
std
::
transform
(
s0
->
begin
(),
s0
->
end
(),
s1
->
begin
()
+
offset
,
output_lens
.
begin
()
+
offset
,
[](
auto
a
,
auto
b
)
{
return
std
::
max
(
a
,
b
);
});
auto
l0
=
prog
.
add_instruction
(
op
::
multibroadcast
{
output_lens
},
arg0
);
auto
l1
=
prog
.
add_instruction
(
op
::
multibroadcast
{
output_lens
},
arg1
);
auto
s0
=
arg0
->
get_shape
().
lens
();
auto
s1
=
arg1
->
get_shape
().
lens
();
auto
out_lens
=
compute_broadcasted_lens
(
s0
,
s1
);
auto
l0
=
prog
.
add_instruction
(
op
::
multibroadcast
{
out_lens
},
arg0
);
auto
l1
=
prog
.
add_instruction
(
op
::
multibroadcast
{
out_lens
},
arg1
);
return
prog
.
add_instruction
(
x
,
l0
,
l1
);
}
else
...
...
@@ -495,25 +501,86 @@ struct onnx_parser
auto
l2
=
(
transb
)
?
prog
.
add_instruction
(
op
::
transpose
{
perm
},
args
[
1
])
:
args
[
1
];
if
(
args
.
size
()
==
3
)
{
if
(
beta
!=
0.
f
)
if
(
beta
!=
0.
f
&&
args
[
2
]
->
get_shape
().
elements
()
>
0
)
{
auto
l3
=
prog
.
add_instruction
(
op
::
dot
{
alpha
},
l1
,
l2
);
a
ut
o
l4
=
args
[
2
]
;
if
(
l4
->
get_shape
().
scalar
())
// ignore args[2] (no C value added to alpha*A*B)
return
l3
;
if
(
beta
!=
1.
f
)
auto
out_lens
=
l1
->
get_shape
().
lens
(
);
o
ut
_lens
.
back
()
=
l2
->
get_shape
().
lens
().
back
()
;
auto
l3
=
args
[
2
];
auto
l3_lens
=
l3
->
get_shape
().
lens
()
;
if
(
!
std
::
equal
(
out_lens
.
begin
(),
out_lens
.
end
(),
l3_lens
.
begin
(),
l3_lens
.
end
())
)
{
auto
beta_val
=
prog
.
add_literal
(
beta
);
auto
l5
=
prog
.
add_instruction
(
op
::
scalar
{
args
[
2
]
->
get_shape
()},
beta_val
);
l4
=
prog
.
add_instruction
(
op
::
mul
{},
args
[
2
],
l5
);
l3
=
prog
.
add_instruction
(
op
::
multibroadcast
{
out_lens
},
args
[
2
]);
}
return
add_broadcastable_binary_op
(
l3
,
l
4
,
op
::
add
{}
);
return
prog
.
add_instruction
(
op
::
dot
{
alpha
,
beta
},
l1
,
l
2
,
l3
);
}
}
return
prog
.
add_instruction
(
op
::
dot
{
alpha
,
beta
},
l1
,
l2
);
}
instruction_ref
parse_matmul
(
const
std
::
string
&
,
const
attribute_map
&
,
std
::
vector
<
instruction_ref
>
args
)
{
auto
l0
=
args
[
0
];
auto
l1
=
args
[
1
];
auto
l0_lens
=
l0
->
get_shape
().
lens
();
auto
l1_lens
=
l1
->
get_shape
().
lens
();
// args[0] is a vector, prepend 1 to the shape
bool
is_a_prepended
=
false
;
if
(
l0_lens
.
size
()
==
1
)
{
is_a_prepended
=
true
;
l0_lens
.
insert
(
l0_lens
.
begin
(),
1
);
l0
=
prog
.
add_instruction
(
op
::
unsqueeze
{{
0
}},
args
[
0
]);
}
bool
is_b_appended
=
false
;
if
(
l1_lens
.
size
()
==
1
)
{
is_b_appended
=
true
;
l1_lens
.
push_back
(
1
);
l1
=
prog
.
add_instruction
(
op
::
unsqueeze
{{
1
}},
args
[
1
]);
}
instruction_ref
bl0
=
l0
;
instruction_ref
bl1
=
l1
;
if
(
!
std
::
equal
(
l0_lens
.
rbegin
()
+
2
,
l0_lens
.
rend
(),
l1_lens
.
rbegin
()
+
2
,
l1_lens
.
rend
()))
{
auto
l0_it
=
l0_lens
.
begin
()
+
l0_lens
.
size
()
-
2
;
std
::
vector
<
std
::
size_t
>
l0_broadcasted_lens
(
l0_lens
.
begin
(),
l0_it
);
auto
l1_it
=
l1_lens
.
begin
()
+
l1_lens
.
size
()
-
2
;
std
::
vector
<
std
::
size_t
>
l1_broadcasted_lens
(
l1_lens
.
begin
(),
l1_it
);
auto
output_lens
=
compute_broadcasted_lens
(
l0_broadcasted_lens
,
l1_broadcasted_lens
);
l0_broadcasted_lens
=
output_lens
;
l0_broadcasted_lens
.
insert
(
l0_broadcasted_lens
.
end
(),
l0_it
,
l0_lens
.
end
());
l1_broadcasted_lens
=
output_lens
;
l1_broadcasted_lens
.
insert
(
l1_broadcasted_lens
.
end
(),
l1_it
,
l1_lens
.
end
());
if
(
l0_lens
!=
l0_broadcasted_lens
)
{
bl0
=
prog
.
add_instruction
(
op
::
multibroadcast
{
l0_broadcasted_lens
},
l0
);
}
if
(
l1_lens
!=
l1_broadcasted_lens
)
{
bl1
=
prog
.
add_instruction
(
op
::
multibroadcast
{
l1_broadcasted_lens
},
l1
);
}
}
auto
dot_res
=
prog
.
add_instruction
(
op
::
dot
{
1.0
f
,
0.0
f
},
bl0
,
bl1
);
int64_t
num_axis
=
static_cast
<
int64_t
>
(
dot_res
->
get_shape
().
lens
().
size
());
if
(
is_a_prepended
)
{
dot_res
=
prog
.
add_instruction
(
op
::
squeeze
{{
num_axis
-
2
}},
dot_res
);
--
num_axis
;
}
if
(
is_b_appended
)
{
dot_res
=
prog
.
add_instruction
(
op
::
squeeze
{{
num_axis
-
1
}},
dot_res
);
}
return
dot_res
;
}
instruction_ref
parse_batchnorm
(
const
std
::
string
&
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
...
...
src/opt/memory_coloring_impl.cpp
View file @
4b7a267a
#include <migraphx/op/load.hpp>
#include "memory_coloring_impl.hpp"
namespace
migraphx
{
...
...
src/opt/memory_coloring_impl.hpp
View file @
4b7a267a
...
...
@@ -3,7 +3,6 @@
#include <migraphx/program.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/pass_config.hpp>
#include <migraphx/config.hpp>
...
...
src/program.cpp
View file @
4b7a267a
#include <migraphx/program.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/op
erators
.hpp>
#include <migraphx/op
/identity
.hpp>
#include <migraphx/target.hpp>
#include <migraphx/env.hpp>
#include <migraphx/ranges.hpp>
...
...
src/schedule.cpp
View file @
4b7a267a
#include <migraphx/schedule.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/op
erators
.hpp>
#include <migraphx/op
/identity
.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/functional.hpp>
...
...
src/simplify_algebra.cpp
View file @
4b7a267a
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/program.hpp>
#include <migraphx/op
erators
.hpp>
#include <migraphx/op
/add
.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/literal.hpp>
...
...
src/simplify_reshapes.cpp
View file @
4b7a267a
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/op
erators
.hpp>
#include <migraphx/op
/as_shape
.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <unordered_set>
...
...
src/targets/cpu/gemm.cpp
View file @
4b7a267a
...
...
@@ -55,7 +55,13 @@ void migemm_impl(tensor_view<T> cmat,
visit_mat
(
amat
,
[
&
](
const
auto
&
a
)
{
visit_mat
(
bmat
,
[
&
](
const
auto
&
b
)
{
auto
c
=
make_mat
(
cmat
);
c
=
(
a
*
b
)
*
alpha
+
beta
*
c
;
c
=
beta
*
c
;
// This is a simple optimization to avoid
// compute A * B if alpha is 0.0
if
(
alpha
!=
0.0
)
{
c
=
c
+
alpha
*
a
*
b
;
}
});
});
}
...
...
@@ -95,8 +101,8 @@ void migemm_impl(
{
auto
lens
=
amat
.
get_shape
().
lens
();
bool
batch_mul
=
std
::
accumulate
(
lens
.
begin
(),
lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
())
==
(
*
lens
.
rbegin
()
)
*
(
*
(
lens
.
rbegin
()
+
1
))
;
std
::
accumulate
(
lens
.
rbegin
()
+
2
,
lens
.
rend
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
())
==
1
;
if
(
batch_mul
)
{
migemm_impl
(
cmat
,
amat
,
bmat
,
alpha
,
beta
,
is_fast_gemm_type
<
T
>
{});
...
...
src/targets/cpu/lowering.cpp
View file @
4b7a267a
...
...
@@ -369,12 +369,43 @@ struct cpu_gemm
{
op
::
dot
op
;
std
::
string
name
()
const
{
return
"cpu::dot"
;
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
return
op
.
compute_shape
(
inputs
);
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
if
(
inputs
.
size
()
==
3
)
{
auto
c_shape
=
inputs
.
at
(
2
);
check_shapes
{{
c_shape
}}.
not_broadcasted
();
}
return
op
.
compute_shape
(
inputs
);
}
argument
compute
(
context
&
,
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
output_shape
};
migemm
(
result
,
args
[
0
],
args
[
1
],
op
.
alpha
,
op
.
beta
);
// 3 inputs, it is alpha * A * B + beta * C, then
// A and B are matrics, and C is broadcastable to A * B
if
(
args
.
size
()
==
3
)
{
// no need to consider the value of args[2]
if
(
op
.
beta
==
0.0
f
)
{
result
.
visit
([
&
](
auto
output
)
{
std
::
fill
(
output
.
begin
(),
output
.
end
(),
0
);
});
}
else
{
visit_all
(
result
,
args
[
2
])([
&
](
auto
output
,
auto
input
)
{
std
::
copy
(
input
.
begin
(),
input
.
end
(),
output
.
begin
());
});
}
migemm
(
result
,
args
[
0
],
args
[
1
],
op
.
alpha
,
op
.
beta
);
return
result
;
}
// 2 input arguments
migemm
(
result
,
args
[
0
],
args
[
1
],
op
.
alpha
,
0.0
f
);
return
result
;
}
};
...
...
src/targets/gpu/eliminate_workspace.cpp
View file @
4b7a267a
...
...
@@ -2,7 +2,6 @@
#include <migraphx/gpu/hip.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
...
...
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment