Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
20b1d690
Commit
20b1d690
authored
Sep 20, 2019
by
Paul
Browse files
Merge branch 'develop' into tests
parents
17aaaa1e
ba729cfc
Changes
281
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1080 additions
and
156 deletions
+1080
-156
src/include/migraphx/array.hpp
src/include/migraphx/array.hpp
+75
-0
src/include/migraphx/eliminate_common_subexpression.hpp
src/include/migraphx/eliminate_common_subexpression.hpp
+2
-2
src/include/migraphx/functional.hpp
src/include/migraphx/functional.hpp
+60
-2
src/include/migraphx/generate.hpp
src/include/migraphx/generate.hpp
+12
-2
src/include/migraphx/literal.hpp
src/include/migraphx/literal.hpp
+1
-0
src/include/migraphx/matcher.hpp
src/include/migraphx/matcher.hpp
+245
-48
src/include/migraphx/op/argmax.hpp
src/include/migraphx/op/argmax.hpp
+81
-0
src/include/migraphx/op/argmin.hpp
src/include/migraphx/op/argmin.hpp
+81
-0
src/include/migraphx/op/binary.hpp
src/include/migraphx/op/binary.hpp
+16
-8
src/include/migraphx/op/capture.hpp
src/include/migraphx/op/capture.hpp
+52
-0
src/include/migraphx/op/convolution.hpp
src/include/migraphx/op/convolution.hpp
+18
-45
src/include/migraphx/op/erf.hpp
src/include/migraphx/op/erf.hpp
+23
-0
src/include/migraphx/op/logsoftmax.hpp
src/include/migraphx/op/logsoftmax.hpp
+1
-8
src/include/migraphx/op/multibroadcast.hpp
src/include/migraphx/op/multibroadcast.hpp
+17
-3
src/include/migraphx/op/pooling.hpp
src/include/migraphx/op/pooling.hpp
+8
-38
src/include/migraphx/op/pow.hpp
src/include/migraphx/op/pow.hpp
+23
-0
src/include/migraphx/op/quant_convolution.hpp
src/include/migraphx/op/quant_convolution.hpp
+79
-0
src/include/migraphx/op/quant_dot.hpp
src/include/migraphx/op/quant_dot.hpp
+92
-0
src/include/migraphx/op/reduce_mean.hpp
src/include/migraphx/op/reduce_mean.hpp
+30
-0
src/include/migraphx/op/reduce_op.hpp
src/include/migraphx/op/reduce_op.hpp
+164
-0
No files found.
src/include/migraphx/array.hpp
0 → 100644
View file @
20b1d690
#ifndef MIGRAPHX_GUARD_RTGLIB_ARRAY_HPP
#define MIGRAPHX_GUARD_RTGLIB_ARRAY_HPP
#include <migraphx/config.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/requires.hpp>
#include <type_traits>
#include <array>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
detail
{
template
<
class
R
,
class
...
>
struct
array_type
{
using
type
=
R
;
};
template
<
class
...
Ts
>
struct
array_type
<
void
,
Ts
...
>
:
std
::
common_type
<
Ts
...
>
{
};
template
<
class
R
,
class
...
Ts
>
using
array_type_t
=
typename
array_type
<
R
,
Ts
...
>::
type
;
template
<
class
T
,
std
::
size_t
N
,
std
::
size_t
...
I
>
constexpr
std
::
array
<
std
::
remove_cv_t
<
T
>
,
N
>
to_array_impl
(
T
(
&
a
)[
N
],
seq
<
I
...
>
)
{
return
{{
a
[
I
]...}};
}
}
// namespace detail
template
<
class
Result
=
void
,
class
...
Ts
,
MIGRAPHX_REQUIRES
((
sizeof
...(
Ts
)
>
0
))
>
constexpr
std
::
array
<
detail
::
array_type_t
<
Result
,
Ts
...
>
,
sizeof
...(
Ts
)
>
make_array
(
Ts
&&
...
xs
)
{
return
{
static_cast
<
detail
::
array_type_t
<
Result
,
Ts
...
>>
(
std
::
forward
<
Ts
>
(
xs
))...};
}
constexpr
std
::
array
<
int
,
0
>
make_array
()
{
return
{};
}
template
<
class
T
,
std
::
size_t
N
>
constexpr
auto
to_array
(
T
(
&
a
)[
N
])
{
return
detail
::
to_array_impl
(
a
,
detail
::
gens
<
N
>
{});
}
namespace
detail
{
template
<
std
::
size_t
Offset
=
0
,
class
Array
,
std
::
size_t
...
I
>
constexpr
auto
rearray_impl
(
Array
a
,
seq
<
I
...
>
)
{
return
make_array
(
a
[
I
+
Offset
]...);
}
}
// namespace detail
template
<
class
T
,
std
::
size_t
N
>
constexpr
auto
pop_front
(
std
::
array
<
T
,
N
>
a
)
{
return
detail
::
rearray_impl
(
a
,
detail
::
gens
<
N
-
1
>
{});
}
template
<
class
T
,
std
::
size_t
N
>
constexpr
auto
pop_back
(
std
::
array
<
T
,
N
>
a
)
{
return
detail
::
rearray_impl
<
1
>
(
a
,
detail
::
gens
<
N
-
1
>
{});
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/common_subexpression
_elimination
.hpp
→
src/include/migraphx/
eliminate_
common_subexpression.hpp
View file @
20b1d690
...
@@ -13,9 +13,9 @@ struct program;
...
@@ -13,9 +13,9 @@ struct program;
/**
/**
* Remove identical instructions.
* Remove identical instructions.
*/
*/
struct
common_subexpression
_elimination
struct
eliminate_
common_subexpression
{
{
std
::
string
name
()
const
{
return
"common_subexpression
_elimination
"
;
}
std
::
string
name
()
const
{
return
"
eliminate_
common_subexpression"
;
}
void
apply
(
program
&
p
)
const
;
void
apply
(
program
&
p
)
const
;
};
};
...
...
src/include/migraphx/functional.hpp
View file @
20b1d690
...
@@ -15,6 +15,12 @@ struct swallow
...
@@ -15,6 +15,12 @@ struct swallow
}
}
};
};
template
<
class
T
>
auto
tuple_size
(
const
T
&
)
{
return
typename
std
::
tuple_size
<
T
>::
type
{};
}
namespace
detail
{
namespace
detail
{
template
<
class
R
,
class
F
>
template
<
class
R
,
class
F
>
...
@@ -83,6 +89,12 @@ constexpr auto sequence_c(F&& f)
...
@@ -83,6 +89,12 @@ constexpr auto sequence_c(F&& f)
return
detail
::
sequence_c_impl
(
f
,
detail
::
gens
<
N
>
{});
return
detail
::
sequence_c_impl
(
f
,
detail
::
gens
<
N
>
{});
}
}
template
<
class
IntegerConstant
,
class
F
>
constexpr
auto
sequence
(
IntegerConstant
ic
,
F
&&
f
)
{
return
sequence_c
<
ic
>
(
f
);
}
template
<
class
F
,
class
...
Ts
>
template
<
class
F
,
class
...
Ts
>
constexpr
void
each_args
(
F
f
,
Ts
&&
...
xs
)
constexpr
void
each_args
(
F
f
,
Ts
&&
...
xs
)
{
{
...
@@ -95,9 +107,9 @@ constexpr void each_args(F)
...
@@ -95,9 +107,9 @@ constexpr void each_args(F)
}
}
template
<
class
F
,
class
T
>
template
<
class
F
,
class
T
>
auto
unpack
(
F
f
,
T
&
x
)
auto
unpack
(
F
f
,
T
&
&
x
)
{
{
return
sequence
_c
<
std
::
tuple_size
<
T
>
{}
>
(
[
&
](
auto
...
is
)
{
f
(
std
::
get
<
is
>
(
x
)...);
});
return
sequence
(
tuple_size
(
x
),
[
&
](
auto
...
is
)
{
f
(
std
::
get
<
is
>
(
static_cast
<
T
&&>
(
x
)
)...);
});
}
}
/// Implements a fix-point combinator
/// Implements a fix-point combinator
...
@@ -149,6 +161,52 @@ auto index_of(T& x)
...
@@ -149,6 +161,52 @@ auto index_of(T& x)
return
[
&
](
auto
&&
y
)
{
return
x
[
y
];
};
return
[
&
](
auto
&&
y
)
{
return
x
[
y
];
};
}
}
template
<
class
T
,
class
...
Ts
>
decltype
(
auto
)
front_args
(
T
&&
x
,
Ts
&&
...)
{
return
static_cast
<
T
&&>
(
x
);
}
template
<
class
...
Ts
>
decltype
(
auto
)
back_args
(
Ts
&&
...
xs
)
{
return
std
::
get
<
sizeof
...(
Ts
)
-
1
>
(
std
::
tuple
<
Ts
&&
...
>
(
static_cast
<
Ts
&&>
(
xs
)...));
}
template
<
class
T
,
class
...
Ts
>
auto
pop_front_args
(
T
&&
,
Ts
&&
...
xs
)
{
return
[
&
](
auto
f
)
{
f
(
static_cast
<
Ts
&&>
(
xs
)...);
};
}
template
<
class
...
Ts
>
auto
pop_back_args
(
Ts
&&
...
xs
)
{
return
[
&
](
auto
f
)
{
using
tuple_type
=
std
::
tuple
<
Ts
&&
...
>
;
auto
t
=
tuple_type
(
static_cast
<
Ts
&&>
(
xs
)...);
return
sequence_c
<
sizeof
...(
Ts
)
-
1
>
(
[
&
](
auto
...
is
)
{
return
f
(
std
::
get
<
is
>
(
static_cast
<
tuple_type
&&>
(
t
))...);
});
};
}
template
<
class
T
>
struct
always_f
{
T
x
;
template
<
class
...
Ts
>
constexpr
T
operator
()(
Ts
&&
...)
const
{
return
x
;
}
};
template
<
class
T
>
auto
always
(
T
x
)
{
return
always_f
<
T
>
{
x
};
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
...
...
src/include/migraphx/generate.hpp
View file @
20b1d690
...
@@ -25,7 +25,7 @@ constexpr T normalize(unsigned long z)
...
@@ -25,7 +25,7 @@ constexpr T normalize(unsigned long z)
template
<
class
T
,
MIGRAPHX_REQUIRES
(
is_signed
<
T
>{}
and
not
is_floating_point
<
T
>
{})
>
template
<
class
T
,
MIGRAPHX_REQUIRES
(
is_signed
<
T
>{}
and
not
is_floating_point
<
T
>
{})
>
constexpr
T
normalize
(
unsigned
long
z
)
constexpr
T
normalize
(
unsigned
long
z
)
{
{
const
auto
max
=
std
::
numeric_limits
<
T
>::
max
();
const
auto
max
=
std
::
numeric_limits
<
T
>::
max
()
/
64
;
const
auto
half_max
=
max
/
2
;
const
auto
half_max
=
max
/
2
;
return
half_max
-
(
z
%
max
);
return
half_max
-
(
z
%
max
);
}
}
...
@@ -33,7 +33,7 @@ constexpr T normalize(unsigned long z)
...
@@ -33,7 +33,7 @@ constexpr T normalize(unsigned long z)
template
<
class
T
,
MIGRAPHX_REQUIRES
(
not
is_signed
<
T
>{}
and
std
::
is_integral
<
T
>
{})
>
template
<
class
T
,
MIGRAPHX_REQUIRES
(
not
is_signed
<
T
>{}
and
std
::
is_integral
<
T
>
{})
>
constexpr
T
normalize
(
unsigned
long
z
)
constexpr
T
normalize
(
unsigned
long
z
)
{
{
const
auto
max
=
std
::
numeric_limits
<
T
>::
max
();
const
auto
max
=
std
::
numeric_limits
<
T
>::
max
()
/
64
;
return
z
%
max
;
return
z
%
max
;
}
}
...
@@ -87,6 +87,16 @@ std::vector<T> generate_tensor_data(const migraphx::shape& s, unsigned long seed
...
@@ -87,6 +87,16 @@ std::vector<T> generate_tensor_data(const migraphx::shape& s, unsigned long seed
return
result
;
return
result
;
}
}
template
<
class
T
>
std
::
vector
<
T
>
fill_tensor_data
(
const
migraphx
::
shape
&
s
,
unsigned
long
value
=
0
)
{
std
::
vector
<
T
>
result
(
s
.
elements
());
std
::
generate
(
result
.
begin
(),
result
.
end
(),
[
=
]
{
return
value
;
});
return
result
;
}
argument
fill_argument
(
shape
s
,
unsigned
long
value
=
0
);
argument
generate_argument
(
shape
s
,
unsigned
long
seed
=
0
);
argument
generate_argument
(
shape
s
,
unsigned
long
seed
=
0
);
literal
generate_literal
(
shape
s
,
unsigned
long
seed
=
0
);
literal
generate_literal
(
shape
s
,
unsigned
long
seed
=
0
);
...
...
src/include/migraphx/literal.hpp
View file @
20b1d690
...
@@ -79,6 +79,7 @@ struct literal : raw_data<literal>
...
@@ -79,6 +79,7 @@ struct literal : raw_data<literal>
template
<
class
Iterator
>
template
<
class
Iterator
>
void
fill
(
Iterator
start
,
Iterator
end
)
void
fill
(
Iterator
start
,
Iterator
end
)
{
{
assert
(
std
::
distance
(
start
,
end
)
==
m_shape
.
elements
());
if
(
m_shape
.
standard
())
if
(
m_shape
.
standard
())
{
{
m_shape
.
visit_type
([
&
](
auto
as
)
{
std
::
copy
(
start
,
end
,
as
.
from
(
buffer
.
get
()));
});
m_shape
.
visit_type
([
&
](
auto
as
)
{
std
::
copy
(
start
,
end
,
as
.
from
(
buffer
.
get
()));
});
...
...
src/include/migraphx/matcher.hpp
View file @
20b1d690
...
@@ -8,6 +8,7 @@
...
@@ -8,6 +8,7 @@
#include <migraphx/iterator_for.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <unordered_map>
#include <unordered_map>
#include <unordered_set>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -20,6 +21,12 @@ struct matcher_context
...
@@ -20,6 +21,12 @@ struct matcher_context
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
instructions
;
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
instructions
;
instruction_ref
not_found
()
const
{
return
last
;
}
instruction_ref
not_found
()
const
{
return
last
;
}
template
<
class
M
>
bool
matched
(
M
m
,
instruction_ref
ins
)
{
return
m
.
match
(
*
this
,
ins
)
!=
this
->
not_found
();
}
private:
private:
instruction_ref
last
;
instruction_ref
last
;
};
};
...
@@ -67,7 +74,7 @@ auto bind_match(M m, std::string name)
...
@@ -67,7 +74,7 @@ auto bind_match(M m, std::string name)
[
=
,
name
=
std
::
move
(
name
)
](
matcher_context
&
ctx
,
instruction_ref
ins
)
{
[
=
,
name
=
std
::
move
(
name
)
](
matcher_context
&
ctx
,
instruction_ref
ins
)
{
auto
result
=
m
.
match
(
ctx
,
ins
);
auto
result
=
m
.
match
(
ctx
,
ins
);
if
(
result
!=
ctx
.
not_found
())
if
(
result
!=
ctx
.
not_found
())
ctx
.
instructions
.
emplace
(
name
,
ins
)
;
ctx
.
instructions
[
name
]
=
ins
;
return
result
;
return
result
;
});
});
}
}
...
@@ -205,12 +212,10 @@ matcher_result match_instruction(program& p, instruction_ref ins, M&& m)
...
@@ -205,12 +212,10 @@ matcher_result match_instruction(program& p, instruction_ref ins, M&& m)
return
result
;
return
result
;
}
}
/// Find matches
in a
program
/// Find matches
for an instruction in the
program
template
<
class
...
Ms
>
template
<
class
...
Ms
>
void
find_matches
(
program
&
p
,
Ms
&&
...
ms
)
void
find_matches
(
program
&
p
,
instruction_ref
ins
,
Ms
&&
...
ms
)
{
{
for
(
auto
ins
:
iterator_for
(
p
))
{
bool
match
=
false
;
bool
match
=
false
;
each_args
(
each_args
(
[
&
](
auto
&&
m
)
{
[
&
](
auto
&&
m
)
{
...
@@ -223,64 +228,160 @@ void find_matches(program& p, Ms&&... ms)
...
@@ -223,64 +228,160 @@ void find_matches(program& p, Ms&&... ms)
match
=
true
;
match
=
true
;
},
},
ms
...);
ms
...);
}
/// Find matches in a program
template
<
class
...
Ms
>
void
find_matches
(
program
&
p
,
Ms
&&
...
ms
)
{
for
(
auto
ins
:
iterator_for
(
p
))
{
find_matches
(
p
,
ins
,
ms
...);
}
}
}
}
template
<
class
...
Ts
>
template
<
class
M
>
auto
all_of
(
Ts
...
ms
)
struct
find_skip
{
{
return
make_bf_matcher
([
=
](
matcher_context
&
ctx
,
instruction_ref
ins
)
{
M
m
;
bool
matches
=
fold
([
&
](
auto
x
,
auto
y
)
{
M
matcher
()
const
{
return
m
;
}
return
x
and
y
.
match
(
ctx
,
ins
)
!=
ctx
.
not_found
();
})(
true
,
ms
...);
void
apply
(
program
&
,
const
matcher_result
&
)
const
{}
if
(
matches
)
};
return
ins
;
return
ctx
.
not_found
();
template
<
class
M
>
});
find_skip
<
M
>
make_find_skip
(
M
m
)
{
return
{
m
};
}
}
template
<
class
...
Ts
>
struct
lazy_and
auto
none_of
(
Ts
...
ms
)
{
template
<
class
F
,
class
G
>
bool
operator
()(
F
f
,
G
g
)
const
{
return
f
()
and
g
();
}
};
struct
lazy_or
{
{
template
<
class
F
,
class
G
>
bool
operator
()(
F
f
,
G
g
)
const
{
return
f
()
or
g
();
}
};
template
<
class
Op
,
bool
Start
,
bool
Matches
>
struct
match_fold_f
{
template
<
class
...
Ms
>
static
bool
fold_matchers
(
matcher_context
&
ctx
,
instruction_ref
ins
,
Ms
...
ms
)
{
Op
op
;
auto
matched
=
[
&
](
auto
m
)
{
return
[
=
,
&
ctx
]
{
return
ctx
.
matched
(
m
,
ins
);
};
};
return
fold
([
&
](
auto
x
,
auto
y
)
{
return
op
(
always
(
x
),
matched
(
y
));
})(
Start
,
ms
...);
}
template
<
class
Pack
>
static
bool
fold_matchers_pack
(
matcher_context
&
ctx
,
instruction_ref
ins
,
Pack
p
)
{
return
p
([
&
](
auto
...
ms
)
{
return
match_fold_f
::
fold_matchers
(
ctx
,
ins
,
ms
...);
});
}
template
<
class
...
Ts
>
auto
operator
()(
Ts
...
ms
)
const
{
return
make_bf_matcher
([
=
](
matcher_context
&
ctx
,
instruction_ref
ins
)
{
return
make_bf_matcher
([
=
](
matcher_context
&
ctx
,
instruction_ref
ins
)
{
bool
matches
=
fold
([
&
](
auto
x
,
auto
y
)
{
bool
matches
=
match_fold_f
::
fold_matchers
(
ctx
,
ins
,
ms
...);
return
x
and
y
.
match
(
ctx
,
ins
)
==
ctx
.
not_found
();
if
(
matches
==
Matches
)
})(
true
,
ms
...);
if
(
matches
)
return
ins
;
return
ins
;
return
ctx
.
not_found
();
return
ctx
.
not_found
();
});
});
}
}
template
<
class
...
Ts
>
template
<
class
Selector
>
auto
any_of
(
Ts
...
ms
)
auto
operator
[](
Selector
select
)
const
{
{
return
make_bf_matcher
([
=
](
matcher_context
&
ctx
,
instruction_ref
ins
)
{
return
[
=
](
auto
...
ms
)
{
bool
matches
=
fold
([
&
](
auto
x
,
auto
y
)
{
// Workaround ICE on gcc by packing matchers into an object
return
x
or
y
.
match
(
ctx
,
ins
)
!=
ctx
.
not_found
();
auto
mpack
=
pack
(
ms
...);
})(
false
,
ms
...);
return
make_bf_matcher
([
=
](
matcher_context
&
ctx
,
instruction_ref
start
)
{
if
(
matches
)
Op
op
;
return
ins
;
bool
matches
=
Start
;
select
(
start
,
[
&
](
auto
ins
)
{
auto
fm
=
[
&
]
{
return
match_fold_f
::
fold_matchers_pack
(
ctx
,
ins
,
mpack
);
};
matches
=
op
(
always
(
matches
),
fm
);
});
if
(
matches
==
Matches
)
return
start
;
return
ctx
.
not_found
();
return
ctx
.
not_found
();
});
});
};
}
};
const
constexpr
auto
all_of
=
match_fold_f
<
lazy_and
,
true
,
true
>
{};
const
constexpr
auto
any_of
=
match_fold_f
<
lazy_or
,
false
,
true
>
{};
const
constexpr
auto
none_of
=
match_fold_f
<
lazy_or
,
false
,
false
>
{};
template
<
class
...
Ms
>
auto
skip_matches
(
Ms
...
ms
)
{
return
make_find_skip
(
any_of
(
ms
...));
}
inline
auto
inputs
()
{
return
[](
auto
ins
,
auto
f
)
{
for
(
auto
&&
x
:
ins
->
inputs
())
f
(
x
);
};
}
inline
auto
outputs
()
{
return
[](
auto
ins
,
auto
f
)
{
for
(
auto
&&
x
:
ins
->
outputs
())
f
(
x
);
};
}
}
MIGRAPHX_PRED_MATCHER
(
any
,
instruction_ref
)
{
return
true
;
}
MIGRAPHX_PRED_MATCHER
(
any
,
instruction_ref
)
{
return
true
;
}
MIGRAPHX_PRED_MATCHER
(
none
,
instruction_ref
)
{
return
false
;
}
MIGRAPHX_PRED_MATCHER
(
none
,
instruction_ref
)
{
return
false
;
}
MIGRAPHX_PRED_MATCHER
(
standard_shape
,
instruction_ref
ins
)
{
return
ins
->
get_shape
().
standard
();
}
MIGRAPHX_PRED_MATCHER
(
standard_shape
,
instruction_ref
ins
)
{
return
ins
->
get_shape
().
standard
();
}
MIGRAPHX_PRED_MATCHER
(
not_standard_shape
,
instruction_ref
ins
)
{
return
not
ins
->
get_shape
().
standard
();
}
MIGRAPHX_PRED_MATCHER
(
broadcast_shape
,
instruction_ref
ins
)
MIGRAPHX_PRED_MATCHER
(
broadcast_shape
,
instruction_ref
ins
)
{
{
return
ins
->
get_shape
().
broadcasted
();
return
ins
->
get_shape
().
broadcasted
();
}
}
MIGRAPHX_BASIC_MATCHER
(
output
,
matcher_context
&
ctx
,
instruction_ref
ins
)
MIGRAPHX_PRED_MATCHER
(
transpose_shape
,
instruction_ref
ins
)
{
return
ins
->
get_shape
().
transposed
();
}
MIGRAPHX_PRED_MATCHER
(
same_input_shapes
,
instruction_ref
ins
)
{
if
(
ins
->
inputs
().
empty
())
return
false
;
auto
s
=
ins
->
inputs
().
front
()
->
get_shape
();
return
std
::
all_of
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
[
&
](
auto
x
)
{
return
x
->
get_shape
()
==
s
;
});
}
MIGRAPHX_BASIC_MATCHER
(
output
,
const
matcher_context
&
ctx
,
instruction_ref
ins
)
{
{
if
(
ins
->
outputs
().
size
()
==
1
)
if
(
ins
->
outputs
().
size
()
==
1
)
return
ins
->
outputs
().
front
();
return
ins
->
outputs
().
front
();
return
ctx
.
not_found
();
return
ctx
.
not_found
();
}
}
MIGRAPHX_BASIC_MATCHER
(
used_once
,
matcher_context
&
ctx
,
instruction_ref
ins
)
MIGRAPHX_BASIC_MATCHER
(
used_once
,
const
matcher_context
&
ctx
,
instruction_ref
ins
)
{
{
if
(
ins
->
outputs
().
size
()
==
1
)
if
(
ins
->
outputs
().
size
()
==
1
)
return
ins
;
return
ins
;
...
@@ -289,10 +390,89 @@ MIGRAPHX_BASIC_MATCHER(used_once, matcher_context& ctx, instruction_ref ins)
...
@@ -289,10 +390,89 @@ MIGRAPHX_BASIC_MATCHER(used_once, matcher_context& ctx, instruction_ref ins)
return
ctx
.
not_found
();
return
ctx
.
not_found
();
}
}
inline
auto
name
(
std
::
string
name
)
inline
auto
used_once_recursive
(
std
::
size_t
depth
)
{
return
make_basic_fun_matcher
([
=
](
const
matcher_context
&
ctx
,
instruction_ref
start
)
{
// Used once
if
(
start
->
outputs
().
size
()
==
1
)
return
start
;
// Unused
if
(
start
->
outputs
().
empty
())
{
if
(
std
::
next
(
start
)
==
ctx
.
not_found
())
return
start
;
else
return
ctx
.
not_found
();
}
// Check for dead instructions
auto
is_dead
=
fix
<
bool
>
([
&
](
auto
self
,
auto
ins
,
auto
n
)
{
if
(
n
==
0
)
return
false
;
if
(
ins
->
get_shape
().
elements
()
==
0
)
return
false
;
if
(
ins
->
outputs
().
empty
()
and
std
::
next
(
ins
)
!=
ctx
.
not_found
())
return
true
;
return
std
::
all_of
(
ins
->
outputs
().
begin
(),
ins
->
outputs
().
end
(),
[
&
](
auto
i
)
{
return
self
(
i
,
n
-
1
);
});
});
auto
dead
=
std
::
count_if
(
start
->
outputs
().
begin
(),
start
->
outputs
().
end
(),
[
&
](
auto
i
)
{
return
is_dead
(
i
,
depth
);
});
if
(
dead
+
1
==
start
->
outputs
().
size
())
return
start
;
return
ctx
.
not_found
();
});
}
MIGRAPHX_PRED_MATCHER
(
is_constant
,
instruction_ref
ins
)
{
return
ins
->
can_eval
();
}
MIGRAPHX_BASIC_MATCHER
(
is_unused
,
const
matcher_context
&
ctx
,
instruction_ref
ins
)
{
if
(
ins
->
outputs
().
empty
()
and
ins
!=
std
::
prev
(
ctx
.
not_found
()))
return
ins
;
return
ctx
.
not_found
();
}
template
<
class
...
Ms
>
auto
skip_output
(
Ms
...
ms
)
{
auto
m
=
any_of
(
ms
...);
return
make_basic_fun_matcher
([
=
](
matcher_context
&
ctx
,
instruction_ref
start
)
{
return
fix
<
instruction_ref
>
([
&
](
auto
self
,
auto
ins
)
{
if
(
ins
->
outputs
().
size
()
==
1
)
{
auto
next
=
ins
->
outputs
().
front
();
if
(
ctx
.
matched
(
m
,
next
))
{
auto
skipped_next
=
self
(
next
);
if
(
skipped_next
!=
ctx
.
not_found
())
return
skipped_next
;
}
return
next
;
}
return
ctx
.
not_found
();
})(
start
);
});
}
inline
auto
name
(
std
::
string
s
)
{
{
return
make_basic_pred_matcher
(
return
make_basic_pred_matcher
(
[
=
,
name
=
std
::
move
(
name
)
](
instruction_ref
ins
)
{
return
ins
->
name
()
==
name
;
});
[
=
,
s
=
std
::
move
(
s
)
](
instruction_ref
ins
)
{
return
ins
->
name
()
==
s
;
});
}
inline
auto
name
(
std
::
unordered_set
<
std
::
string
>
names
)
{
return
make_basic_pred_matcher
([
=
,
names
=
std
::
move
(
names
)
](
instruction_ref
ins
)
{
return
names
.
count
(
ins
->
name
())
>
0
;
});
}
template
<
class
...
Ts
>
inline
auto
name
(
std
::
string
s
,
Ts
...
xs
)
// NOLINT
{
return
name
(
std
::
unordered_set
<
std
::
string
>
{
std
::
move
(
s
),
std
::
move
(
xs
)...});
}
}
inline
auto
nargs
(
std
::
size_t
n
)
inline
auto
nargs
(
std
::
size_t
n
)
...
@@ -302,7 +482,7 @@ inline auto nargs(std::size_t n)
...
@@ -302,7 +482,7 @@ inline auto nargs(std::size_t n)
inline
auto
arg
(
std
::
size_t
i
)
inline
auto
arg
(
std
::
size_t
i
)
{
{
return
make_basic_fun_matcher
([
=
](
matcher_context
&
ctx
,
instruction_ref
ins
)
{
return
make_basic_fun_matcher
([
=
](
const
matcher_context
&
ctx
,
instruction_ref
ins
)
{
if
(
i
<
ins
->
inputs
().
size
())
if
(
i
<
ins
->
inputs
().
size
())
return
ins
->
inputs
()[
i
];
return
ins
->
inputs
()[
i
];
return
ctx
.
not_found
();
return
ctx
.
not_found
();
...
@@ -338,6 +518,23 @@ inline auto either_arg(std::size_t i, std::size_t j)
...
@@ -338,6 +518,23 @@ inline auto either_arg(std::size_t i, std::size_t j)
};
};
}
}
template
<
class
M
>
auto
same_shape
(
M
m
)
{
return
make_basic_fun_matcher
([
=
](
matcher_context
&
ctx
,
instruction_ref
ins
)
{
auto
i
=
m
.
match
(
ctx
,
ins
);
if
(
i
!=
ctx
.
not_found
()
and
i
->
get_shape
()
==
ins
->
get_shape
())
return
ins
;
return
ctx
.
not_found
();
});
}
template
<
class
...
Ms
>
auto
same_shape
(
Ms
...
ms
)
{
return
all_of
(
same_shape
(
ms
)...);
}
}
// namespace match
}
// namespace match
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
...
...
src/include/migraphx/op/argmax.hpp
0 → 100644
View file @
20b1d690
#ifndef MIGRAPHX_GUARD_OPERATORS_ARGMAX_HPP
#define MIGRAPHX_GUARD_OPERATORS_ARGMAX_HPP
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/par_dfor.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
argmax
{
int64_t
axis
=
0
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
axis
,
"axis"
));
}
std
::
string
name
()
const
{
return
"argmax"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
).
standard
();
auto
lens
=
inputs
[
0
].
lens
();
int64_t
n_dim
=
static_cast
<
int64_t
>
(
lens
.
size
());
if
(
axis
>=
n_dim
||
axis
<
0
)
{
MIGRAPHX_THROW
(
"ARGMAX: axis is out of range."
);
}
lens
[
axis
]
=
1
;
return
{
shape
::
int64_type
,
lens
};
}
template
<
class
T
>
int64_t
calc_argmax
(
T
&
input
,
std
::
vector
<
std
::
size_t
>&
indices
,
size_t
item_num
)
const
{
auto
max_val
=
input
(
indices
.
begin
(),
indices
.
end
());
int64_t
max_index
=
0
;
for
(
std
::
size_t
i
=
1
;
i
<
item_num
;
++
i
)
{
indices
[
axis
]
=
i
;
auto
cur_val
=
input
(
indices
.
begin
(),
indices
.
end
());
if
(
max_val
<
cur_val
)
{
max_val
=
cur_val
;
max_index
=
i
;
}
}
return
max_index
;
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
output_shape
};
auto
batch_item_num
=
args
.
front
().
get_shape
().
lens
()[
axis
];
result
.
visit
([
&
](
auto
output
)
{
args
[
0
].
visit
([
&
](
auto
input
)
{
par_for
(
output_shape
.
elements
(),
[
&
](
auto
i
)
{
auto
data_idx
=
output_shape
.
multi
(
i
);
output
[
i
]
=
this
->
calc_argmax
(
input
,
data_idx
,
batch_item_num
);
});
});
});
return
result
;
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/argmin.hpp
0 → 100644
View file @
20b1d690
#ifndef MIGRAPHX_GUARD_OPERATORS_ARGMIN_HPP
#define MIGRAPHX_GUARD_OPERATORS_ARGMIN_HPP
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/par_dfor.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
argmin
{
int64_t
axis
=
0
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
axis
,
"axis"
));
}
std
::
string
name
()
const
{
return
"argmin"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
).
standard
();
auto
lens
=
inputs
[
0
].
lens
();
int64_t
n_dim
=
static_cast
<
int64_t
>
(
lens
.
size
());
if
(
axis
>=
n_dim
||
axis
<
0
)
{
MIGRAPHX_THROW
(
"ARGMIN: axis is out of range."
);
}
lens
[
axis
]
=
1
;
return
{
shape
::
int64_type
,
lens
};
}
template
<
class
T
>
int64_t
calc_argmin
(
T
&
input
,
std
::
vector
<
std
::
size_t
>&
indices
,
size_t
item_num
)
const
{
auto
min_val
=
input
(
indices
.
begin
(),
indices
.
end
());
int64_t
min_index
=
0
;
for
(
std
::
size_t
i
=
1
;
i
<
item_num
;
++
i
)
{
indices
[
axis
]
=
i
;
auto
cur_val
=
input
(
indices
.
begin
(),
indices
.
end
());
if
(
min_val
>
cur_val
)
{
min_val
=
cur_val
;
min_index
=
i
;
}
}
return
min_index
;
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
output_shape
};
std
::
size_t
batch_item_num
=
args
.
front
().
get_shape
().
lens
()[
axis
];
result
.
visit
([
&
](
auto
output
)
{
args
[
0
].
visit
([
&
](
auto
input
)
{
par_for
(
output_shape
.
elements
(),
[
&
](
auto
i
)
{
auto
data_idx
=
output_shape
.
multi
(
i
);
output
[
i
]
=
this
->
calc_argmin
(
input
,
data_idx
,
batch_item_num
);
});
});
});
return
result
;
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/binary.hpp
View file @
20b1d690
...
@@ -28,23 +28,31 @@ struct binary : op_name<Derived>
...
@@ -28,23 +28,31 @@ struct binary : op_name<Derived>
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
argument
result
{
output_shape
};
argument
result
{
output_shape
};
visit_all
(
result
,
args
[
0
],
args
[
1
])([
&
](
auto
output
,
auto
input1
,
auto
input2
)
{
auto
s1
=
args
[
0
].
get_shape
();
if
(
input1
.
get_shape
().
packed
()
and
input2
.
get_shape
().
packed
())
auto
s2
=
args
[
1
].
get_shape
();
if
(
s1
==
s2
and
s1
.
packed
())
{
{
shape
std_shape
{
s1
.
type
(),
s1
.
lens
()};
argument
std_result
{
std_shape
,
result
.
data
()};
argument
std_arg0
{
std_shape
,
args
[
0
].
data
()};
argument
std_arg1
{
std_shape
,
args
[
1
].
data
()};
visit_all
(
std_result
,
std_arg0
,
std_arg1
)([
&
](
auto
output
,
auto
input1
,
auto
input2
)
{
std
::
transform
(
input1
.
begin
(),
std
::
transform
(
input1
.
begin
(),
input1
.
end
(),
input1
.
end
(),
input2
.
begin
(),
input2
.
begin
(),
output
.
begin
(),
output
.
begin
(),
static_cast
<
const
Derived
&>
(
*
this
).
apply
());
static_cast
<
const
Derived
&>
(
*
this
).
apply
());
});
}
}
else
else
{
{
visit_all
(
result
,
args
[
0
],
args
[
1
])([
&
](
auto
output
,
auto
input1
,
auto
input2
)
{
shape_for_each
(
output
.
get_shape
(),
[
&
](
const
auto
&
idx
)
{
shape_for_each
(
output
.
get_shape
(),
[
&
](
const
auto
&
idx
)
{
output
(
idx
.
begin
(),
idx
.
end
())
=
static_cast
<
const
Derived
&>
(
*
this
).
apply
()(
output
(
idx
.
begin
(),
idx
.
end
())
=
static_cast
<
const
Derived
&>
(
*
this
).
apply
()(
input1
(
idx
.
begin
(),
idx
.
end
()),
input2
(
idx
.
begin
(),
idx
.
end
()));
input1
(
idx
.
begin
(),
idx
.
end
()),
input2
(
idx
.
begin
(),
idx
.
end
()));
});
});
}
});
});
}
return
result
;
return
result
;
}
}
...
...
src/include/migraphx/op/capture.hpp
0 → 100644
View file @
20b1d690
#ifndef MIGRAPHX_GUARD_OPERATORS_CAPTURE_HPP
#define MIGRAPHX_GUARD_OPERATORS_CAPTURE_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
capture
{
std
::
size_t
ins_index
;
std
::
function
<
void
(
std
::
size_t
ins_index
,
std
::
vector
<
argument
>
)
>
f
{};
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
ins_index
,
"ins_index"
));
}
std
::
string
name
()
const
{
return
"capture"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
return
inputs
.
front
();
}
argument
compute
(
const
shape
&
,
std
::
vector
<
argument
>
args
)
const
{
if
(
f
)
{
f
(
ins_index
,
args
);
}
else
{
MIGRAPHX_THROW
(
"CAPTURE: callback function is not callable!"
);
}
return
args
.
front
();
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/convolution.hpp
View file @
20b1d690
...
@@ -44,8 +44,7 @@ struct convolution
...
@@ -44,8 +44,7 @@ struct convolution
const
shape
&
input
=
inputs
.
at
(
0
);
const
shape
&
input
=
inputs
.
at
(
0
);
const
shape
&
weights
=
inputs
.
at
(
1
);
const
shape
&
weights
=
inputs
.
at
(
1
);
auto
t
=
input
.
type
();
auto
t
=
input
.
type
();
if
(
padding_mode
==
default_
)
{
return
{
t
,
return
{
t
,
{
{
input
.
lens
()[
0
],
input
.
lens
()[
0
],
...
@@ -64,32 +63,6 @@ struct convolution
...
@@ -64,32 +63,6 @@ struct convolution
1
)),
1
)),
}};
}};
}
}
else
if
(
padding_mode
==
same
)
{
return
{
t
,
{
input
.
lens
()[
0
],
weights
.
lens
()[
0
],
static_cast
<
std
::
size_t
>
(
std
::
ceil
(
static_cast
<
double
>
(
input
.
lens
()[
2
])
/
stride
[
0
])),
static_cast
<
std
::
size_t
>
(
std
::
ceil
(
static_cast
<
double
>
(
input
.
lens
()[
3
])
/
stride
[
1
]))}};
}
else
if
(
padding_mode
==
valid
)
{
return
{
t
,
{
input
.
lens
()[
0
],
weights
.
lens
()[
0
],
static_cast
<
std
::
size_t
>
(
std
::
ceil
(
static_cast
<
double
>
(
input
.
lens
()[
2
]
-
weights
.
lens
()[
2
]
+
1
)
/
stride
[
0
])),
static_cast
<
std
::
size_t
>
(
std
::
ceil
(
static_cast
<
double
>
(
input
.
lens
()[
3
]
-
weights
.
lens
()[
3
]
+
1
)
/
stride
[
1
]))}};
}
else
{
MIGRAPHX_THROW
(
"Invalid padding mode"
);
}
}
};
};
}
// namespace op
}
// namespace op
...
...
src/include/migraphx/op/erf.hpp
0 → 100644
View file @
20b1d690
#ifndef MIGRAPHX_GUARD_OPERATORS_ERF_HPP
#define MIGRAPHX_GUARD_OPERATORS_ERF_HPP
#include <migraphx/op/unary.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
erf
:
unary
<
erf
>
{
auto
apply
()
const
{
return
[](
auto
x
)
{
return
std
::
erf
(
x
);
};
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/logsoftmax.hpp
View file @
20b1d690
#ifndef MIGRAPHX_GUARD_OPERATORS_LOGSOFTMAX_HPP
#ifndef MIGRAPHX_GUARD_OPERATORS_LOGSOFTMAX_HPP
#define MIGRAPHX_GUARD_OPERATORS_LOGSOFTMAX_HPP
#define MIGRAPHX_GUARD_OPERATORS_LOGSOFTMAX_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -30,7 +23,7 @@ struct logsoftmax
...
@@ -30,7 +23,7 @@ struct logsoftmax
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
}.
has
(
1
).
standard
();
check_shapes
{
inputs
}.
has
(
1
).
standard
();
if
(
axis
<
0
||
axis
>
inputs
[
0
].
lens
().
size
())
if
(
axis
<
0
||
axis
>
=
inputs
[
0
].
lens
().
size
())
{
{
MIGRAPHX_THROW
(
"LogSoftMax: input axis value "
+
std
::
to_string
(
axis
)
+
MIGRAPHX_THROW
(
"LogSoftMax: input axis value "
+
std
::
to_string
(
axis
)
+
" is out of range"
);
" is out of range"
);
...
...
src/include/migraphx/op/multibroadcast.hpp
View file @
20b1d690
...
@@ -35,14 +35,28 @@ struct multibroadcast
...
@@ -35,14 +35,28 @@ struct multibroadcast
auto
input
=
inputs
.
at
(
0
);
auto
input
=
inputs
.
at
(
0
);
if
(
input
.
lens
().
empty
())
if
(
input
.
lens
().
empty
())
MIGRAPHX_THROW
(
"inputs dimensions should be > 0"
);
{
MIGRAPHX_THROW
(
"MULTIBROADCAST: inputs dimensions should be > 0"
);
}
if
(
input
.
lens
().
size
()
>
output_lens
.
size
())
if
(
input
.
lens
().
size
()
>
output_lens
.
size
())
MIGRAPHX_THROW
(
"inputs dimensions should <= output size"
);
{
MIGRAPHX_THROW
(
"MULTIBROADCAST: inputs dimensions should <= output size"
);
}
std
::
vector
<
size_t
>
bcast_strides
(
output_lens
.
size
(),
0
);
auto
offset
=
output_lens
.
size
()
-
input
.
lens
().
size
();
auto
offset
=
output_lens
.
size
()
-
input
.
lens
().
size
();
for
(
std
::
ptrdiff_t
i
=
input
.
lens
().
size
()
-
1
;
i
>=
0
;
i
--
)
for
(
std
::
ptrdiff_t
i
=
input
.
lens
().
size
()
-
1
;
i
>=
0
;
i
--
)
{
if
(
output_lens
[
i
+
offset
]
!=
input
.
lens
()[
i
]
and
input
.
lens
()[
i
]
!=
1
)
{
MIGRAPHX_THROW
(
"MULTIBROADCAST: input shape {"
+
to_string_range
(
input
.
lens
())
+
"} cannot be broadcasted to {"
+
to_string_range
(
output_lens
)
+
"}!"
);
}
}
std
::
vector
<
size_t
>
bcast_strides
(
output_lens
.
size
(),
0
);
for
(
std
::
ptrdiff_t
i
=
input
.
lens
().
size
()
-
1
;
i
>=
0
;
i
--
)
{
{
if
(
output_lens
[
i
+
offset
]
==
input
.
lens
()[
i
])
if
(
output_lens
[
i
+
offset
]
==
input
.
lens
()[
i
])
{
{
...
...
src/include/migraphx/op/pooling.hpp
View file @
20b1d690
...
@@ -31,7 +31,7 @@ struct pooling
...
@@ -31,7 +31,7 @@ struct pooling
{
{
return
pack
(
f
(
self
.
mode
,
"mode"
),
return
pack
(
f
(
self
.
mode
,
"mode"
),
f
(
self
.
padding
,
"padding"
),
f
(
self
.
padding
,
"padding"
),
f
(
self
.
padding
,
"padding_mode"
),
f
(
self
.
padding
_mode
,
"padding_mode"
),
f
(
self
.
stride
,
"stride"
),
f
(
self
.
stride
,
"stride"
),
f
(
self
.
lengths
,
"lengths"
));
f
(
self
.
lengths
,
"lengths"
));
}
}
...
@@ -48,52 +48,22 @@ struct pooling
...
@@ -48,52 +48,22 @@ struct pooling
assert
(
lengths
[
0
]
<=
(
input
.
lens
()[
2
]
+
2
*
padding
[
0
]));
assert
(
lengths
[
0
]
<=
(
input
.
lens
()[
2
]
+
2
*
padding
[
0
]));
assert
(
lengths
[
1
]
<=
(
input
.
lens
()[
3
]
+
2
*
padding
[
1
]));
assert
(
lengths
[
1
]
<=
(
input
.
lens
()[
3
]
+
2
*
padding
[
1
]));
if
(
padding_mode
==
default_
)
{
return
{
t
,
return
{
t
,
{
{
input
.
lens
()[
0
],
input
.
lens
()[
0
],
input
.
lens
()[
1
],
input
.
lens
()[
1
],
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
1
,
1
,
floor_divide
<
std
::
ptrdiff_t
>
(
floor_divide
<
std
::
ptrdiff_t
>
(
input
.
lens
()[
2
]
+
2
*
padding
[
0
]
-
lengths
[
0
],
input
.
lens
()[
2
]
+
2
*
padding
[
0
]
-
lengths
[
0
],
stride
[
0
])
+
stride
[
0
])
+
1
)),
1
)),
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
1
,
1
,
floor_divide
<
std
::
ptrdiff_t
>
(
floor_divide
<
std
::
ptrdiff_t
>
(
input
.
lens
()[
3
]
+
2
*
padding
[
1
]
-
lengths
[
1
],
input
.
lens
()[
3
]
+
2
*
padding
[
1
]
-
lengths
[
1
],
stride
[
1
])
+
stride
[
1
])
+
1
)),
1
)),
}};
}};
}
}
else
if
(
padding_mode
==
same
)
{
return
{
t
,
{
input
.
lens
()[
0
],
input
.
lens
()[
1
],
ceil_divide
<
std
::
size_t
>
(
input
.
lens
()[
2
],
stride
[
0
]),
ceil_divide
<
std
::
size_t
>
(
input
.
lens
()[
3
],
stride
[
1
])}};
}
else
if
(
padding_mode
==
valid
)
{
return
{
t
,
{
input
.
lens
()[
0
],
input
.
lens
()[
1
],
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
1
,
floor_divide
<
std
::
ptrdiff_t
>
(
input
.
lens
()[
2
]
-
lengths
[
0
],
stride
[
0
])
+
1
)),
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
1
,
floor_divide
<
std
::
ptrdiff_t
>
(
input
.
lens
()[
3
]
-
lengths
[
1
],
stride
[
1
])
+
1
)),
}};
}
else
{
MIGRAPHX_THROW
(
"Invalid padding mode"
);
}
}
};
};
}
// namespace op
}
// namespace op
...
...
src/include/migraphx/op/pow.hpp
0 → 100644
View file @
20b1d690
#ifndef MIGRAPHX_GUARD_OPERATORS_POW_HPP
#define MIGRAPHX_GUARD_OPERATORS_POW_HPP
#include <migraphx/op/binary.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
pow
:
binary
<
pow
>
{
auto
apply
()
const
{
return
[](
auto
x
,
auto
y
)
{
return
std
::
pow
(
x
,
y
);
};
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/quant_convolution.hpp
0 → 100644
View file @
20b1d690
#ifndef MIGRAPHX_GUARD_OPERATORS_QUANT_CONVOLUTION_HPP
#define MIGRAPHX_GUARD_OPERATORS_QUANT_CONVOLUTION_HPP
#include <array>
#include <migraphx/op/common.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
quant_convolution
{
std
::
array
<
std
::
size_t
,
2
>
padding
=
{{
0
,
0
}};
std
::
array
<
std
::
size_t
,
2
>
stride
=
{{
1
,
1
}};
std
::
array
<
std
::
size_t
,
2
>
dilation
=
{{
1
,
1
}};
padding_mode_t
padding_mode
=
default_
;
int
group
=
1
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
padding
,
"padding"
),
f
(
self
.
stride
,
"stride"
),
f
(
self
.
dilation
,
"dilation"
),
f
(
self
.
padding_mode
,
"padding_mode"
),
f
(
self
.
group
,
"group"
));
}
std
::
string
name
()
const
{
return
"quant_convolution"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
2
).
same_type
().
same_ndims
().
only_dims
(
4
);
const
shape
&
input
=
inputs
.
at
(
0
);
const
shape
&
weights
=
inputs
.
at
(
1
);
auto
t
=
input
.
type
();
// all input type must be int8_type and output is float_type
if
(
t
!=
shape
::
int8_type
)
{
MIGRAPHX_THROW
(
"QUANT_CONVOLUTION: only accept input and weights of type int8_t"
);
}
t
=
shape
::
int32_type
;
return
{
t
,
{
input
.
lens
()[
0
],
weights
.
lens
()[
0
],
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
1
,
(
input
.
lens
()[
2
]
-
(
1
+
dilation
[
0
]
*
(
weights
.
lens
()[
2
]
-
1
))
+
2
*
padding
[
0
])
/
stride
[
0
]
+
1
)),
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
1
,
(
input
.
lens
()[
3
]
-
(
1
+
dilation
[
1
]
*
(
weights
.
lens
()[
3
]
-
1
))
+
2
*
padding
[
1
])
/
stride
[
1
]
+
1
)),
}};
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/quant_dot.hpp
0 → 100644
View file @
20b1d690
#ifndef MIGRAPHX_GUARD_OPERATORS_QUANT_DOT_HPP
#define MIGRAPHX_GUARD_OPERATORS_QUANT_DOT_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
quant_dot
{
int32_t
alpha
=
1
;
int32_t
beta
=
1
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
as_number
(
self
.
alpha
),
"alpha"
),
f
(
as_number
(
self
.
beta
),
"beta"
));
}
std
::
string
name
()
const
{
return
"quant_dot"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{{
inputs
.
at
(
0
),
inputs
.
at
(
1
)},
*
this
}.
same_type
();
const
shape
&
a
=
inputs
.
at
(
0
);
const
shape
&
b
=
inputs
.
at
(
1
);
auto
t
=
a
.
type
();
if
(
t
!=
shape
::
int8_type
)
{
MIGRAPHX_THROW
(
"QUANT_DOT: only support data type int8_t"
);
}
if
(
!
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[](
auto
s
)
{
return
s
.
lens
().
size
()
>=
2
;
}))
{
MIGRAPHX_THROW
(
"QUANT_DOT: dot only accept 2 or more dims operands"
);
}
// only handle the case that the batch size of a and b are the same
if
(
!
std
::
equal
(
a
.
lens
().
rbegin
()
+
2
,
a
.
lens
().
rend
(),
b
.
lens
().
rbegin
()
+
2
,
b
.
lens
().
rend
()))
{
MIGRAPHX_THROW
(
"QUANT_DOT: batch size of A and B mismatch: {"
+
to_string_range
(
a
.
lens
())
+
"} x {"
+
to_string_range
(
b
.
lens
())
+
"}"
);
}
std
::
size_t
dim_0
=
a
.
lens
().
size
()
-
2
;
std
::
size_t
dim_1
=
a
.
lens
().
size
()
-
1
;
if
(
a
.
lens
()[
dim_1
]
!=
b
.
lens
()[
dim_0
])
{
MIGRAPHX_THROW
(
"QUANT_DOT: inner dimensions do not match: {"
+
to_string_range
(
a
.
lens
())
+
"} x {"
+
to_string_range
(
b
.
lens
())
+
"}"
);
}
// k be multiple of 4
if
((
a
.
lens
()[
dim_1
]
%
4
)
!=
0
)
{
MIGRAPHX_THROW
(
"QUANT_DOT: size of A {"
+
to_string_range
(
a
.
lens
())
+
"} and B {"
+
to_string_range
(
b
.
lens
())
+
"} must be multiple of 4 for int8 type"
);
}
auto
out_lens
=
a
.
lens
();
out_lens
[
dim_1
]
=
b
.
lens
()[
dim_1
];
if
(
inputs
.
size
()
==
3
&&
out_lens
!=
inputs
.
at
(
2
).
lens
())
{
MIGRAPHX_THROW
(
"QUANT_DOT: dimension mismatch, operand C: {"
+
to_string_range
(
inputs
.
at
(
2
).
lens
())
+
"}, cannot add to operand A * B: {"
+
to_string_range
(
out_lens
)
+
"}"
);
}
if
(
inputs
.
size
()
==
3
&&
inputs
.
at
(
2
).
type
()
!=
shape
::
int32_type
)
{
MIGRAPHX_THROW
(
"QUANT_DOT: operand C type must be int32"
);
}
return
{
shape
::
int32_type
,
out_lens
};
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/reduce_mean.hpp
0 → 100644
View file @
20b1d690
#ifndef MIGRAPHX_GUARD_OPERATORS_MEAN_HPP
#define MIGRAPHX_GUARD_OPERATORS_MEAN_HPP
#include <migraphx/op/reduce_op.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
reduce_mean
:
reduce_op
<
reduce_mean
>
{
reduce_mean
()
{}
reduce_mean
(
std
::
vector
<
int64_t
>
ax
)
:
reduce_op
(
std
::
move
(
ax
))
{}
auto
op
()
const
{
return
[
=
](
auto
x
,
auto
y
)
{
return
x
+
y
;
};
}
auto
output
(
const
shape
&
s
)
const
{
return
[
&
](
auto
val
)
{
return
val
/
s
.
elements
();
};
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/reduce_op.hpp
0 → 100644
View file @
20b1d690
#ifndef MIGRAPHX_GUARD_OPERATORS_OP_HPP
#define MIGRAPHX_GUARD_OPERATORS_OP_HPP
#include <migraphx/op/name.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/config.hpp>
#include <vector>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
lowest
{
template
<
class
T
>
operator
T
()
const
{
return
std
::
numeric_limits
<
T
>::
lowest
();
}
};
struct
highest
{
template
<
class
T
>
operator
T
()
const
{
return
std
::
numeric_limits
<
T
>::
max
();
}
};
struct
zero
{
template
<
class
T
>
operator
T
()
const
{
return
T
{
0
};
}
};
template
<
class
Derived
>
struct
reduce_op
:
op_name
<
Derived
>
{
std
::
vector
<
std
::
int64_t
>
axes
{};
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
axes
,
"axes"
));
}
std
::
vector
<
int64_t
>
tune_axes
(
std
::
size_t
n_dim
)
const
{
auto
tuned_axes
=
axes
;
if
(
tuned_axes
.
empty
())
{
tuned_axes
.
resize
(
n_dim
);
std
::
iota
(
tuned_axes
.
begin
(),
tuned_axes
.
end
(),
0
);
}
else
{
for
(
auto
&
axis
:
tuned_axes
)
{
int64_t
s_dim
=
static_cast
<
int64_t
>
(
n_dim
);
if
(
axis
>=
s_dim
or
axis
<
-
s_dim
)
{
MIGRAPHX_THROW
(
"REDUCE_OP: axis out of range"
);
}
if
(
axis
<
0
)
{
axis
+=
n_dim
;
}
}
}
return
tuned_axes
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
auto
s
=
inputs
.
at
(
0
);
auto
lens
=
s
.
lens
();
auto
tuned_axes
=
tune_axes
(
lens
.
size
());
for
(
auto
axis
:
tuned_axes
)
{
lens
[
axis
]
=
1
;
}
return
{
s
.
type
(),
lens
};
}
template
<
class
T
>
void
tune_dims
(
const
std
::
vector
<
int64_t
>&
tuned_axes
,
const
std
::
vector
<
T
>&
in_lens
,
std
::
vector
<
T
>&
out_lens
)
const
{
for
(
auto
axis
:
tuned_axes
)
{
out_lens
[
axis
]
=
in_lens
[
axis
];
}
}
template
<
class
T
>
void
reduce
(
tensor_view
<
T
>&
input
,
shape
&
batch_shape
,
std
::
vector
<
int64_t
>&
tuned_axes
,
std
::
vector
<
std
::
size_t
>&
out_idx
,
tensor_view
<
T
>&
output
)
const
{
auto
data_idx
=
out_idx
;
T
val
=
static_cast
<
const
Derived
&>
(
*
this
).
init
();
shape_for_each
(
batch_shape
,
[
&
](
auto
b_idx
)
{
this
->
tune_dims
(
tuned_axes
,
b_idx
,
data_idx
);
val
=
static_cast
<
const
Derived
&>
(
*
this
).
op
()(
static_cast
<
const
Derived
&>
(
*
this
).
input
()(
input
(
data_idx
.
begin
(),
data_idx
.
end
())),
val
);
});
output
(
out_idx
.
begin
(),
out_idx
.
end
())
=
static_cast
<
const
Derived
&>
(
*
this
).
output
(
batch_shape
)(
val
);
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
output_shape
};
auto
arg_lens
=
args
.
front
().
get_shape
().
lens
();
auto
tuned_axes
=
tune_axes
(
arg_lens
.
size
());
std
::
vector
<
std
::
size_t
>
batch_lens
(
output_shape
.
lens
().
size
(),
1
);
tune_dims
(
tuned_axes
,
arg_lens
,
batch_lens
);
shape
batch_shape
{
output_shape
.
type
(),
batch_lens
};
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
input
)
{
par_for
(
output_shape
.
elements
(),
[
&
](
auto
i
)
{
auto
out_idx
=
output_shape
.
multi
(
i
);
this
->
reduce
(
input
,
batch_shape
,
tuned_axes
,
out_idx
,
output
);
});
});
return
result
;
}
auto
init
()
const
{
return
zero
();
}
auto
input
()
const
{
return
[
&
](
auto
val
)
{
return
val
;
};
}
auto
output
(
const
shape
&
)
const
{
return
[
&
](
auto
val
)
{
return
val
;
};
}
reduce_op
()
{}
reduce_op
(
std
::
vector
<
int64_t
>
ax
)
:
axes
(
std
::
move
(
ax
))
{}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
Prev
1
2
3
4
5
6
…
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment