Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
20b1d690
"vscode:/vscode.git/clone" did not exist on "bf336b27639952e01bfa9f6531e77da47223a274"
Commit
20b1d690
authored
Sep 20, 2019
by
Paul
Browse files
Merge branch 'develop' into tests
parents
17aaaa1e
ba729cfc
Changes
281
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1080 additions
and
156 deletions
+1080
-156
src/include/migraphx/array.hpp
src/include/migraphx/array.hpp
+75
-0
src/include/migraphx/eliminate_common_subexpression.hpp
src/include/migraphx/eliminate_common_subexpression.hpp
+2
-2
src/include/migraphx/functional.hpp
src/include/migraphx/functional.hpp
+60
-2
src/include/migraphx/generate.hpp
src/include/migraphx/generate.hpp
+12
-2
src/include/migraphx/literal.hpp
src/include/migraphx/literal.hpp
+1
-0
src/include/migraphx/matcher.hpp
src/include/migraphx/matcher.hpp
+245
-48
src/include/migraphx/op/argmax.hpp
src/include/migraphx/op/argmax.hpp
+81
-0
src/include/migraphx/op/argmin.hpp
src/include/migraphx/op/argmin.hpp
+81
-0
src/include/migraphx/op/binary.hpp
src/include/migraphx/op/binary.hpp
+16
-8
src/include/migraphx/op/capture.hpp
src/include/migraphx/op/capture.hpp
+52
-0
src/include/migraphx/op/convolution.hpp
src/include/migraphx/op/convolution.hpp
+18
-45
src/include/migraphx/op/erf.hpp
src/include/migraphx/op/erf.hpp
+23
-0
src/include/migraphx/op/logsoftmax.hpp
src/include/migraphx/op/logsoftmax.hpp
+1
-8
src/include/migraphx/op/multibroadcast.hpp
src/include/migraphx/op/multibroadcast.hpp
+17
-3
src/include/migraphx/op/pooling.hpp
src/include/migraphx/op/pooling.hpp
+8
-38
src/include/migraphx/op/pow.hpp
src/include/migraphx/op/pow.hpp
+23
-0
src/include/migraphx/op/quant_convolution.hpp
src/include/migraphx/op/quant_convolution.hpp
+79
-0
src/include/migraphx/op/quant_dot.hpp
src/include/migraphx/op/quant_dot.hpp
+92
-0
src/include/migraphx/op/reduce_mean.hpp
src/include/migraphx/op/reduce_mean.hpp
+30
-0
src/include/migraphx/op/reduce_op.hpp
src/include/migraphx/op/reduce_op.hpp
+164
-0
No files found.
src/include/migraphx/array.hpp
0 → 100644
View file @
20b1d690
#ifndef MIGRAPHX_GUARD_RTGLIB_ARRAY_HPP
#define MIGRAPHX_GUARD_RTGLIB_ARRAY_HPP
#include <migraphx/config.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/requires.hpp>
#include <type_traits>
#include <array>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
detail
{
template
<
class
R
,
class
...
>
struct
array_type
{
using
type
=
R
;
};
template
<
class
...
Ts
>
struct
array_type
<
void
,
Ts
...
>
:
std
::
common_type
<
Ts
...
>
{
};
template
<
class
R
,
class
...
Ts
>
using
array_type_t
=
typename
array_type
<
R
,
Ts
...
>::
type
;
template
<
class
T
,
std
::
size_t
N
,
std
::
size_t
...
I
>
constexpr
std
::
array
<
std
::
remove_cv_t
<
T
>
,
N
>
to_array_impl
(
T
(
&
a
)[
N
],
seq
<
I
...
>
)
{
return
{{
a
[
I
]...}};
}
}
// namespace detail
template
<
class
Result
=
void
,
class
...
Ts
,
MIGRAPHX_REQUIRES
((
sizeof
...(
Ts
)
>
0
))
>
constexpr
std
::
array
<
detail
::
array_type_t
<
Result
,
Ts
...
>
,
sizeof
...(
Ts
)
>
make_array
(
Ts
&&
...
xs
)
{
return
{
static_cast
<
detail
::
array_type_t
<
Result
,
Ts
...
>>
(
std
::
forward
<
Ts
>
(
xs
))...};
}
constexpr
std
::
array
<
int
,
0
>
make_array
()
{
return
{};
}
template
<
class
T
,
std
::
size_t
N
>
constexpr
auto
to_array
(
T
(
&
a
)[
N
])
{
return
detail
::
to_array_impl
(
a
,
detail
::
gens
<
N
>
{});
}
namespace
detail
{
template
<
std
::
size_t
Offset
=
0
,
class
Array
,
std
::
size_t
...
I
>
constexpr
auto
rearray_impl
(
Array
a
,
seq
<
I
...
>
)
{
return
make_array
(
a
[
I
+
Offset
]...);
}
}
// namespace detail
template
<
class
T
,
std
::
size_t
N
>
constexpr
auto
pop_front
(
std
::
array
<
T
,
N
>
a
)
{
return
detail
::
rearray_impl
(
a
,
detail
::
gens
<
N
-
1
>
{});
}
template
<
class
T
,
std
::
size_t
N
>
constexpr
auto
pop_back
(
std
::
array
<
T
,
N
>
a
)
{
return
detail
::
rearray_impl
<
1
>
(
a
,
detail
::
gens
<
N
-
1
>
{});
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/common_subexpression
_elimination
.hpp
→
src/include/migraphx/
eliminate_
common_subexpression.hpp
View file @
20b1d690
...
...
@@ -13,9 +13,9 @@ struct program;
/**
* Remove identical instructions.
*/
struct
common_subexpression
_elimination
struct
eliminate_
common_subexpression
{
std
::
string
name
()
const
{
return
"common_subexpression
_elimination
"
;
}
std
::
string
name
()
const
{
return
"
eliminate_
common_subexpression"
;
}
void
apply
(
program
&
p
)
const
;
};
...
...
src/include/migraphx/functional.hpp
View file @
20b1d690
...
...
@@ -15,6 +15,12 @@ struct swallow
}
};
template
<
class
T
>
auto
tuple_size
(
const
T
&
)
{
return
typename
std
::
tuple_size
<
T
>::
type
{};
}
namespace
detail
{
template
<
class
R
,
class
F
>
...
...
@@ -83,6 +89,12 @@ constexpr auto sequence_c(F&& f)
return
detail
::
sequence_c_impl
(
f
,
detail
::
gens
<
N
>
{});
}
template
<
class
IntegerConstant
,
class
F
>
constexpr
auto
sequence
(
IntegerConstant
ic
,
F
&&
f
)
{
return
sequence_c
<
ic
>
(
f
);
}
template
<
class
F
,
class
...
Ts
>
constexpr
void
each_args
(
F
f
,
Ts
&&
...
xs
)
{
...
...
@@ -95,9 +107,9 @@ constexpr void each_args(F)
}
template
<
class
F
,
class
T
>
auto
unpack
(
F
f
,
T
&
x
)
auto
unpack
(
F
f
,
T
&
&
x
)
{
return
sequence
_c
<
std
::
tuple_size
<
T
>
{}
>
(
[
&
](
auto
...
is
)
{
f
(
std
::
get
<
is
>
(
x
)...);
});
return
sequence
(
tuple_size
(
x
),
[
&
](
auto
...
is
)
{
f
(
std
::
get
<
is
>
(
static_cast
<
T
&&>
(
x
)
)...);
});
}
/// Implements a fix-point combinator
...
...
@@ -149,6 +161,52 @@ auto index_of(T& x)
return
[
&
](
auto
&&
y
)
{
return
x
[
y
];
};
}
template
<
class
T
,
class
...
Ts
>
decltype
(
auto
)
front_args
(
T
&&
x
,
Ts
&&
...)
{
return
static_cast
<
T
&&>
(
x
);
}
template
<
class
...
Ts
>
decltype
(
auto
)
back_args
(
Ts
&&
...
xs
)
{
return
std
::
get
<
sizeof
...(
Ts
)
-
1
>
(
std
::
tuple
<
Ts
&&
...
>
(
static_cast
<
Ts
&&>
(
xs
)...));
}
template
<
class
T
,
class
...
Ts
>
auto
pop_front_args
(
T
&&
,
Ts
&&
...
xs
)
{
return
[
&
](
auto
f
)
{
f
(
static_cast
<
Ts
&&>
(
xs
)...);
};
}
template
<
class
...
Ts
>
auto
pop_back_args
(
Ts
&&
...
xs
)
{
return
[
&
](
auto
f
)
{
using
tuple_type
=
std
::
tuple
<
Ts
&&
...
>
;
auto
t
=
tuple_type
(
static_cast
<
Ts
&&>
(
xs
)...);
return
sequence_c
<
sizeof
...(
Ts
)
-
1
>
(
[
&
](
auto
...
is
)
{
return
f
(
std
::
get
<
is
>
(
static_cast
<
tuple_type
&&>
(
t
))...);
});
};
}
template
<
class
T
>
struct
always_f
{
T
x
;
template
<
class
...
Ts
>
constexpr
T
operator
()(
Ts
&&
...)
const
{
return
x
;
}
};
template
<
class
T
>
auto
always
(
T
x
)
{
return
always_f
<
T
>
{
x
};
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
...
...
src/include/migraphx/generate.hpp
View file @
20b1d690
...
...
@@ -25,7 +25,7 @@ constexpr T normalize(unsigned long z)
template
<
class
T
,
MIGRAPHX_REQUIRES
(
is_signed
<
T
>{}
and
not
is_floating_point
<
T
>
{})
>
constexpr
T
normalize
(
unsigned
long
z
)
{
const
auto
max
=
std
::
numeric_limits
<
T
>::
max
();
const
auto
max
=
std
::
numeric_limits
<
T
>::
max
()
/
64
;
const
auto
half_max
=
max
/
2
;
return
half_max
-
(
z
%
max
);
}
...
...
@@ -33,7 +33,7 @@ constexpr T normalize(unsigned long z)
template
<
class
T
,
MIGRAPHX_REQUIRES
(
not
is_signed
<
T
>{}
and
std
::
is_integral
<
T
>
{})
>
constexpr
T
normalize
(
unsigned
long
z
)
{
const
auto
max
=
std
::
numeric_limits
<
T
>::
max
();
const
auto
max
=
std
::
numeric_limits
<
T
>::
max
()
/
64
;
return
z
%
max
;
}
...
...
@@ -87,6 +87,16 @@ std::vector<T> generate_tensor_data(const migraphx::shape& s, unsigned long seed
return
result
;
}
template
<
class
T
>
std
::
vector
<
T
>
fill_tensor_data
(
const
migraphx
::
shape
&
s
,
unsigned
long
value
=
0
)
{
std
::
vector
<
T
>
result
(
s
.
elements
());
std
::
generate
(
result
.
begin
(),
result
.
end
(),
[
=
]
{
return
value
;
});
return
result
;
}
argument
fill_argument
(
shape
s
,
unsigned
long
value
=
0
);
argument
generate_argument
(
shape
s
,
unsigned
long
seed
=
0
);
literal
generate_literal
(
shape
s
,
unsigned
long
seed
=
0
);
...
...
src/include/migraphx/literal.hpp
View file @
20b1d690
...
...
@@ -79,6 +79,7 @@ struct literal : raw_data<literal>
template
<
class
Iterator
>
void
fill
(
Iterator
start
,
Iterator
end
)
{
assert
(
std
::
distance
(
start
,
end
)
==
m_shape
.
elements
());
if
(
m_shape
.
standard
())
{
m_shape
.
visit_type
([
&
](
auto
as
)
{
std
::
copy
(
start
,
end
,
as
.
from
(
buffer
.
get
()));
});
...
...
src/include/migraphx/matcher.hpp
View file @
20b1d690
...
...
@@ -8,6 +8,7 @@
#include <migraphx/iterator_for.hpp>
#include <migraphx/config.hpp>
#include <unordered_map>
#include <unordered_set>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -20,6 +21,12 @@ struct matcher_context
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
instructions
;
instruction_ref
not_found
()
const
{
return
last
;
}
template
<
class
M
>
bool
matched
(
M
m
,
instruction_ref
ins
)
{
return
m
.
match
(
*
this
,
ins
)
!=
this
->
not_found
();
}
private:
instruction_ref
last
;
};
...
...
@@ -67,7 +74,7 @@ auto bind_match(M m, std::string name)
[
=
,
name
=
std
::
move
(
name
)
](
matcher_context
&
ctx
,
instruction_ref
ins
)
{
auto
result
=
m
.
match
(
ctx
,
ins
);
if
(
result
!=
ctx
.
not_found
())
ctx
.
instructions
.
emplace
(
name
,
ins
)
;
ctx
.
instructions
[
name
]
=
ins
;
return
result
;
});
}
...
...
@@ -205,82 +212,176 @@ matcher_result match_instruction(program& p, instruction_ref ins, M&& m)
return
result
;
}
/// Find matches for an instruction in the program
template
<
class
...
Ms
>
void
find_matches
(
program
&
p
,
instruction_ref
ins
,
Ms
&&
...
ms
)
{
bool
match
=
false
;
each_args
(
[
&
](
auto
&&
m
)
{
if
(
match
)
return
;
auto
r
=
match_instruction
(
p
,
ins
,
m
.
matcher
());
if
(
r
.
result
==
p
.
end
())
return
;
m
.
apply
(
p
,
r
);
match
=
true
;
},
ms
...);
}
/// Find matches in a program
template
<
class
...
Ms
>
void
find_matches
(
program
&
p
,
Ms
&&
...
ms
)
{
for
(
auto
ins
:
iterator_for
(
p
))
{
bool
match
=
false
;
each_args
(
[
&
](
auto
&&
m
)
{
if
(
match
)
return
;
auto
r
=
match_instruction
(
p
,
ins
,
m
.
matcher
());
if
(
r
.
result
==
p
.
end
())
return
;
m
.
apply
(
p
,
r
);
match
=
true
;
},
ms
...);
find_matches
(
p
,
ins
,
ms
...);
}
}
template
<
class
...
Ts
>
auto
all_of
(
Ts
...
ms
)
template
<
class
M
>
struct
find_skip
{
return
make_bf_matcher
([
=
](
matcher_context
&
ctx
,
instruction_ref
ins
)
{
bool
matches
=
fold
([
&
](
auto
x
,
auto
y
)
{
return
x
and
y
.
match
(
ctx
,
ins
)
!=
ctx
.
not_found
();
})(
true
,
ms
...);
if
(
matches
)
return
ins
;
return
ctx
.
not_found
();
});
M
m
;
M
matcher
()
const
{
return
m
;
}
void
apply
(
program
&
,
const
matcher_result
&
)
const
{}
};
template
<
class
M
>
find_skip
<
M
>
make_find_skip
(
M
m
)
{
return
{
m
};
}
template
<
class
...
Ts
>
auto
none_of
(
Ts
...
ms
)
struct
lazy_and
{
return
make_bf_matcher
([
=
](
matcher_context
&
ctx
,
instruction_ref
ins
)
{
bool
matches
=
fold
([
&
](
auto
x
,
auto
y
)
{
return
x
and
y
.
match
(
ctx
,
ins
)
==
ctx
.
not_found
();
})(
true
,
ms
...);
if
(
matches
)
return
ins
;
return
ctx
.
not_found
();
});
template
<
class
F
,
class
G
>
bool
operator
()(
F
f
,
G
g
)
const
{
return
f
()
and
g
();
}
};
struct
lazy_or
{
template
<
class
F
,
class
G
>
bool
operator
()(
F
f
,
G
g
)
const
{
return
f
()
or
g
();
}
};
template
<
class
Op
,
bool
Start
,
bool
Matches
>
struct
match_fold_f
{
template
<
class
...
Ms
>
static
bool
fold_matchers
(
matcher_context
&
ctx
,
instruction_ref
ins
,
Ms
...
ms
)
{
Op
op
;
auto
matched
=
[
&
](
auto
m
)
{
return
[
=
,
&
ctx
]
{
return
ctx
.
matched
(
m
,
ins
);
};
};
return
fold
([
&
](
auto
x
,
auto
y
)
{
return
op
(
always
(
x
),
matched
(
y
));
})(
Start
,
ms
...);
}
template
<
class
Pack
>
static
bool
fold_matchers_pack
(
matcher_context
&
ctx
,
instruction_ref
ins
,
Pack
p
)
{
return
p
([
&
](
auto
...
ms
)
{
return
match_fold_f
::
fold_matchers
(
ctx
,
ins
,
ms
...);
});
}
template
<
class
...
Ts
>
auto
operator
()(
Ts
...
ms
)
const
{
return
make_bf_matcher
([
=
](
matcher_context
&
ctx
,
instruction_ref
ins
)
{
bool
matches
=
match_fold_f
::
fold_matchers
(
ctx
,
ins
,
ms
...);
if
(
matches
==
Matches
)
return
ins
;
return
ctx
.
not_found
();
});
}
template
<
class
Selector
>
auto
operator
[](
Selector
select
)
const
{
return
[
=
](
auto
...
ms
)
{
// Workaround ICE on gcc by packing matchers into an object
auto
mpack
=
pack
(
ms
...);
return
make_bf_matcher
([
=
](
matcher_context
&
ctx
,
instruction_ref
start
)
{
Op
op
;
bool
matches
=
Start
;
select
(
start
,
[
&
](
auto
ins
)
{
auto
fm
=
[
&
]
{
return
match_fold_f
::
fold_matchers_pack
(
ctx
,
ins
,
mpack
);
};
matches
=
op
(
always
(
matches
),
fm
);
});
if
(
matches
==
Matches
)
return
start
;
return
ctx
.
not_found
();
});
};
}
};
const
constexpr
auto
all_of
=
match_fold_f
<
lazy_and
,
true
,
true
>
{};
const
constexpr
auto
any_of
=
match_fold_f
<
lazy_or
,
false
,
true
>
{};
const
constexpr
auto
none_of
=
match_fold_f
<
lazy_or
,
false
,
false
>
{};
template
<
class
...
Ms
>
auto
skip_matches
(
Ms
...
ms
)
{
return
make_find_skip
(
any_of
(
ms
...));
}
template
<
class
...
Ts
>
auto
any_of
(
Ts
...
ms
)
inline
auto
inputs
()
{
return
make_bf_matcher
([
=
](
matcher_context
&
ctx
,
instruction_ref
ins
)
{
bool
matches
=
fold
([
&
](
auto
x
,
auto
y
)
{
return
x
or
y
.
match
(
ctx
,
ins
)
!=
ctx
.
not_found
();
})(
false
,
ms
...);
if
(
matches
)
return
ins
;
return
ctx
.
not_found
();
});
return
[](
auto
ins
,
auto
f
)
{
for
(
auto
&&
x
:
ins
->
inputs
())
f
(
x
);
};
}
inline
auto
outputs
()
{
return
[](
auto
ins
,
auto
f
)
{
for
(
auto
&&
x
:
ins
->
outputs
())
f
(
x
);
};
}
MIGRAPHX_PRED_MATCHER
(
any
,
instruction_ref
)
{
return
true
;
}
MIGRAPHX_PRED_MATCHER
(
none
,
instruction_ref
)
{
return
false
;
}
MIGRAPHX_PRED_MATCHER
(
standard_shape
,
instruction_ref
ins
)
{
return
ins
->
get_shape
().
standard
();
}
MIGRAPHX_PRED_MATCHER
(
not_standard_shape
,
instruction_ref
ins
)
{
return
not
ins
->
get_shape
().
standard
();
}
MIGRAPHX_PRED_MATCHER
(
broadcast_shape
,
instruction_ref
ins
)
{
return
ins
->
get_shape
().
broadcasted
();
}
MIGRAPHX_BASIC_MATCHER
(
output
,
matcher_context
&
ctx
,
instruction_ref
ins
)
MIGRAPHX_PRED_MATCHER
(
transpose_shape
,
instruction_ref
ins
)
{
return
ins
->
get_shape
().
transposed
();
}
MIGRAPHX_PRED_MATCHER
(
same_input_shapes
,
instruction_ref
ins
)
{
if
(
ins
->
inputs
().
empty
())
return
false
;
auto
s
=
ins
->
inputs
().
front
()
->
get_shape
();
return
std
::
all_of
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
[
&
](
auto
x
)
{
return
x
->
get_shape
()
==
s
;
});
}
MIGRAPHX_BASIC_MATCHER
(
output
,
const
matcher_context
&
ctx
,
instruction_ref
ins
)
{
if
(
ins
->
outputs
().
size
()
==
1
)
return
ins
->
outputs
().
front
();
return
ctx
.
not_found
();
}
MIGRAPHX_BASIC_MATCHER
(
used_once
,
matcher_context
&
ctx
,
instruction_ref
ins
)
MIGRAPHX_BASIC_MATCHER
(
used_once
,
const
matcher_context
&
ctx
,
instruction_ref
ins
)
{
if
(
ins
->
outputs
().
size
()
==
1
)
return
ins
;
...
...
@@ -289,10 +390,89 @@ MIGRAPHX_BASIC_MATCHER(used_once, matcher_context& ctx, instruction_ref ins)
return
ctx
.
not_found
();
}
inline
auto
name
(
std
::
string
name
)
inline
auto
used_once_recursive
(
std
::
size_t
depth
)
{
return
make_basic_fun_matcher
([
=
](
const
matcher_context
&
ctx
,
instruction_ref
start
)
{
// Used once
if
(
start
->
outputs
().
size
()
==
1
)
return
start
;
// Unused
if
(
start
->
outputs
().
empty
())
{
if
(
std
::
next
(
start
)
==
ctx
.
not_found
())
return
start
;
else
return
ctx
.
not_found
();
}
// Check for dead instructions
auto
is_dead
=
fix
<
bool
>
([
&
](
auto
self
,
auto
ins
,
auto
n
)
{
if
(
n
==
0
)
return
false
;
if
(
ins
->
get_shape
().
elements
()
==
0
)
return
false
;
if
(
ins
->
outputs
().
empty
()
and
std
::
next
(
ins
)
!=
ctx
.
not_found
())
return
true
;
return
std
::
all_of
(
ins
->
outputs
().
begin
(),
ins
->
outputs
().
end
(),
[
&
](
auto
i
)
{
return
self
(
i
,
n
-
1
);
});
});
auto
dead
=
std
::
count_if
(
start
->
outputs
().
begin
(),
start
->
outputs
().
end
(),
[
&
](
auto
i
)
{
return
is_dead
(
i
,
depth
);
});
if
(
dead
+
1
==
start
->
outputs
().
size
())
return
start
;
return
ctx
.
not_found
();
});
}
MIGRAPHX_PRED_MATCHER
(
is_constant
,
instruction_ref
ins
)
{
return
ins
->
can_eval
();
}
MIGRAPHX_BASIC_MATCHER
(
is_unused
,
const
matcher_context
&
ctx
,
instruction_ref
ins
)
{
if
(
ins
->
outputs
().
empty
()
and
ins
!=
std
::
prev
(
ctx
.
not_found
()))
return
ins
;
return
ctx
.
not_found
();
}
template
<
class
...
Ms
>
auto
skip_output
(
Ms
...
ms
)
{
auto
m
=
any_of
(
ms
...);
return
make_basic_fun_matcher
([
=
](
matcher_context
&
ctx
,
instruction_ref
start
)
{
return
fix
<
instruction_ref
>
([
&
](
auto
self
,
auto
ins
)
{
if
(
ins
->
outputs
().
size
()
==
1
)
{
auto
next
=
ins
->
outputs
().
front
();
if
(
ctx
.
matched
(
m
,
next
))
{
auto
skipped_next
=
self
(
next
);
if
(
skipped_next
!=
ctx
.
not_found
())
return
skipped_next
;
}
return
next
;
}
return
ctx
.
not_found
();
})(
start
);
});
}
inline
auto
name
(
std
::
string
s
)
{
return
make_basic_pred_matcher
(
[
=
,
name
=
std
::
move
(
name
)
](
instruction_ref
ins
)
{
return
ins
->
name
()
==
name
;
});
[
=
,
s
=
std
::
move
(
s
)
](
instruction_ref
ins
)
{
return
ins
->
name
()
==
s
;
});
}
inline
auto
name
(
std
::
unordered_set
<
std
::
string
>
names
)
{
return
make_basic_pred_matcher
([
=
,
names
=
std
::
move
(
names
)
](
instruction_ref
ins
)
{
return
names
.
count
(
ins
->
name
())
>
0
;
});
}
template
<
class
...
Ts
>
inline
auto
name
(
std
::
string
s
,
Ts
...
xs
)
// NOLINT
{
return
name
(
std
::
unordered_set
<
std
::
string
>
{
std
::
move
(
s
),
std
::
move
(
xs
)...});
}
inline
auto
nargs
(
std
::
size_t
n
)
...
...
@@ -302,7 +482,7 @@ inline auto nargs(std::size_t n)
inline
auto
arg
(
std
::
size_t
i
)
{
return
make_basic_fun_matcher
([
=
](
matcher_context
&
ctx
,
instruction_ref
ins
)
{
return
make_basic_fun_matcher
([
=
](
const
matcher_context
&
ctx
,
instruction_ref
ins
)
{
if
(
i
<
ins
->
inputs
().
size
())
return
ins
->
inputs
()[
i
];
return
ctx
.
not_found
();
...
...
@@ -338,6 +518,23 @@ inline auto either_arg(std::size_t i, std::size_t j)
};
}
template
<
class
M
>
auto
same_shape
(
M
m
)
{
return
make_basic_fun_matcher
([
=
](
matcher_context
&
ctx
,
instruction_ref
ins
)
{
auto
i
=
m
.
match
(
ctx
,
ins
);
if
(
i
!=
ctx
.
not_found
()
and
i
->
get_shape
()
==
ins
->
get_shape
())
return
ins
;
return
ctx
.
not_found
();
});
}
template
<
class
...
Ms
>
auto
same_shape
(
Ms
...
ms
)
{
return
all_of
(
same_shape
(
ms
)...);
}
}
// namespace match
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
...
...
src/include/migraphx/op/argmax.hpp
0 → 100644
View file @
20b1d690
#ifndef MIGRAPHX_GUARD_OPERATORS_ARGMAX_HPP
#define MIGRAPHX_GUARD_OPERATORS_ARGMAX_HPP
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/par_dfor.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
argmax
{
int64_t
axis
=
0
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
axis
,
"axis"
));
}
std
::
string
name
()
const
{
return
"argmax"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
).
standard
();
auto
lens
=
inputs
[
0
].
lens
();
int64_t
n_dim
=
static_cast
<
int64_t
>
(
lens
.
size
());
if
(
axis
>=
n_dim
||
axis
<
0
)
{
MIGRAPHX_THROW
(
"ARGMAX: axis is out of range."
);
}
lens
[
axis
]
=
1
;
return
{
shape
::
int64_type
,
lens
};
}
template
<
class
T
>
int64_t
calc_argmax
(
T
&
input
,
std
::
vector
<
std
::
size_t
>&
indices
,
size_t
item_num
)
const
{
auto
max_val
=
input
(
indices
.
begin
(),
indices
.
end
());
int64_t
max_index
=
0
;
for
(
std
::
size_t
i
=
1
;
i
<
item_num
;
++
i
)
{
indices
[
axis
]
=
i
;
auto
cur_val
=
input
(
indices
.
begin
(),
indices
.
end
());
if
(
max_val
<
cur_val
)
{
max_val
=
cur_val
;
max_index
=
i
;
}
}
return
max_index
;
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
output_shape
};
auto
batch_item_num
=
args
.
front
().
get_shape
().
lens
()[
axis
];
result
.
visit
([
&
](
auto
output
)
{
args
[
0
].
visit
([
&
](
auto
input
)
{
par_for
(
output_shape
.
elements
(),
[
&
](
auto
i
)
{
auto
data_idx
=
output_shape
.
multi
(
i
);
output
[
i
]
=
this
->
calc_argmax
(
input
,
data_idx
,
batch_item_num
);
});
});
});
return
result
;
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/argmin.hpp
0 → 100644
View file @
20b1d690
#ifndef MIGRAPHX_GUARD_OPERATORS_ARGMIN_HPP
#define MIGRAPHX_GUARD_OPERATORS_ARGMIN_HPP
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/par_dfor.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
argmin
{
int64_t
axis
=
0
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
axis
,
"axis"
));
}
std
::
string
name
()
const
{
return
"argmin"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
).
standard
();
auto
lens
=
inputs
[
0
].
lens
();
int64_t
n_dim
=
static_cast
<
int64_t
>
(
lens
.
size
());
if
(
axis
>=
n_dim
||
axis
<
0
)
{
MIGRAPHX_THROW
(
"ARGMIN: axis is out of range."
);
}
lens
[
axis
]
=
1
;
return
{
shape
::
int64_type
,
lens
};
}
template
<
class
T
>
int64_t
calc_argmin
(
T
&
input
,
std
::
vector
<
std
::
size_t
>&
indices
,
size_t
item_num
)
const
{
auto
min_val
=
input
(
indices
.
begin
(),
indices
.
end
());
int64_t
min_index
=
0
;
for
(
std
::
size_t
i
=
1
;
i
<
item_num
;
++
i
)
{
indices
[
axis
]
=
i
;
auto
cur_val
=
input
(
indices
.
begin
(),
indices
.
end
());
if
(
min_val
>
cur_val
)
{
min_val
=
cur_val
;
min_index
=
i
;
}
}
return
min_index
;
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
output_shape
};
std
::
size_t
batch_item_num
=
args
.
front
().
get_shape
().
lens
()[
axis
];
result
.
visit
([
&
](
auto
output
)
{
args
[
0
].
visit
([
&
](
auto
input
)
{
par_for
(
output_shape
.
elements
(),
[
&
](
auto
i
)
{
auto
data_idx
=
output_shape
.
multi
(
i
);
output
[
i
]
=
this
->
calc_argmin
(
input
,
data_idx
,
batch_item_num
);
});
});
});
return
result
;
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/binary.hpp
View file @
20b1d690
...
...
@@ -28,23 +28,31 @@ struct binary : op_name<Derived>
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
output_shape
};
visit_all
(
result
,
args
[
0
],
args
[
1
])([
&
](
auto
output
,
auto
input1
,
auto
input2
)
{
if
(
input1
.
get_shape
().
packed
()
and
input2
.
get_shape
().
packed
())
{
auto
s1
=
args
[
0
].
get_shape
();
auto
s2
=
args
[
1
].
get_shape
();
if
(
s1
==
s2
and
s1
.
packed
())
{
shape
std_shape
{
s1
.
type
(),
s1
.
lens
()};
argument
std_result
{
std_shape
,
result
.
data
()};
argument
std_arg0
{
std_shape
,
args
[
0
].
data
()};
argument
std_arg1
{
std_shape
,
args
[
1
].
data
()};
visit_all
(
std_result
,
std_arg0
,
std_arg1
)([
&
](
auto
output
,
auto
input1
,
auto
input2
)
{
std
::
transform
(
input1
.
begin
(),
input1
.
end
(),
input2
.
begin
(),
output
.
begin
(),
static_cast
<
const
Derived
&>
(
*
this
).
apply
());
}
else
{
});
}
else
{
visit_all
(
result
,
args
[
0
],
args
[
1
])([
&
](
auto
output
,
auto
input1
,
auto
input2
)
{
shape_for_each
(
output
.
get_shape
(),
[
&
](
const
auto
&
idx
)
{
output
(
idx
.
begin
(),
idx
.
end
())
=
static_cast
<
const
Derived
&>
(
*
this
).
apply
()(
input1
(
idx
.
begin
(),
idx
.
end
()),
input2
(
idx
.
begin
(),
idx
.
end
()));
});
}
}
);
}
);
}
return
result
;
}
...
...
src/include/migraphx/op/capture.hpp
0 → 100644
View file @
20b1d690
#ifndef MIGRAPHX_GUARD_OPERATORS_CAPTURE_HPP
#define MIGRAPHX_GUARD_OPERATORS_CAPTURE_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
capture
{
std
::
size_t
ins_index
;
std
::
function
<
void
(
std
::
size_t
ins_index
,
std
::
vector
<
argument
>
)
>
f
{};
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
ins_index
,
"ins_index"
));
}
std
::
string
name
()
const
{
return
"capture"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
return
inputs
.
front
();
}
argument
compute
(
const
shape
&
,
std
::
vector
<
argument
>
args
)
const
{
if
(
f
)
{
f
(
ins_index
,
args
);
}
else
{
MIGRAPHX_THROW
(
"CAPTURE: callback function is not callable!"
);
}
return
args
.
front
();
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/convolution.hpp
View file @
20b1d690
...
...
@@ -44,51 +44,24 @@ struct convolution
const
shape
&
input
=
inputs
.
at
(
0
);
const
shape
&
weights
=
inputs
.
at
(
1
);
auto
t
=
input
.
type
();
if
(
padding_mode
==
default_
)
{
return
{
t
,
{
input
.
lens
()[
0
],
weights
.
lens
()[
0
],
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
1
,
(
input
.
lens
()[
2
]
-
(
1
+
dilation
[
0
]
*
(
weights
.
lens
()[
2
]
-
1
))
+
2
*
padding
[
0
])
/
stride
[
0
]
+
1
)),
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
1
,
(
input
.
lens
()[
3
]
-
(
1
+
dilation
[
1
]
*
(
weights
.
lens
()[
3
]
-
1
))
+
2
*
padding
[
1
])
/
stride
[
1
]
+
1
)),
}};
}
else
if
(
padding_mode
==
same
)
{
return
{
t
,
{
input
.
lens
()[
0
],
weights
.
lens
()[
0
],
static_cast
<
std
::
size_t
>
(
std
::
ceil
(
static_cast
<
double
>
(
input
.
lens
()[
2
])
/
stride
[
0
])),
static_cast
<
std
::
size_t
>
(
std
::
ceil
(
static_cast
<
double
>
(
input
.
lens
()[
3
])
/
stride
[
1
]))}};
}
else
if
(
padding_mode
==
valid
)
{
return
{
t
,
{
input
.
lens
()[
0
],
weights
.
lens
()[
0
],
static_cast
<
std
::
size_t
>
(
std
::
ceil
(
static_cast
<
double
>
(
input
.
lens
()[
2
]
-
weights
.
lens
()[
2
]
+
1
)
/
stride
[
0
])),
static_cast
<
std
::
size_t
>
(
std
::
ceil
(
static_cast
<
double
>
(
input
.
lens
()[
3
]
-
weights
.
lens
()[
3
]
+
1
)
/
stride
[
1
]))}};
}
else
{
MIGRAPHX_THROW
(
"Invalid padding mode"
);
}
return
{
t
,
{
input
.
lens
()[
0
],
weights
.
lens
()[
0
],
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
1
,
(
input
.
lens
()[
2
]
-
(
1
+
dilation
[
0
]
*
(
weights
.
lens
()[
2
]
-
1
))
+
2
*
padding
[
0
])
/
stride
[
0
]
+
1
)),
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
1
,
(
input
.
lens
()[
3
]
-
(
1
+
dilation
[
1
]
*
(
weights
.
lens
()[
3
]
-
1
))
+
2
*
padding
[
1
])
/
stride
[
1
]
+
1
)),
}};
}
};
...
...
src/include/migraphx/op/erf.hpp
0 → 100644
View file @
20b1d690
#ifndef MIGRAPHX_GUARD_OPERATORS_ERF_HPP
#define MIGRAPHX_GUARD_OPERATORS_ERF_HPP
#include <migraphx/op/unary.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
erf
:
unary
<
erf
>
{
auto
apply
()
const
{
return
[](
auto
x
)
{
return
std
::
erf
(
x
);
};
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/logsoftmax.hpp
View file @
20b1d690
#ifndef MIGRAPHX_GUARD_OPERATORS_LOGSOFTMAX_HPP
#define MIGRAPHX_GUARD_OPERATORS_LOGSOFTMAX_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -30,7 +23,7 @@ struct logsoftmax
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
}.
has
(
1
).
standard
();
if
(
axis
<
0
||
axis
>
inputs
[
0
].
lens
().
size
())
if
(
axis
<
0
||
axis
>
=
inputs
[
0
].
lens
().
size
())
{
MIGRAPHX_THROW
(
"LogSoftMax: input axis value "
+
std
::
to_string
(
axis
)
+
" is out of range"
);
...
...
src/include/migraphx/op/multibroadcast.hpp
View file @
20b1d690
...
...
@@ -35,14 +35,28 @@ struct multibroadcast
auto
input
=
inputs
.
at
(
0
);
if
(
input
.
lens
().
empty
())
MIGRAPHX_THROW
(
"inputs dimensions should be > 0"
);
{
MIGRAPHX_THROW
(
"MULTIBROADCAST: inputs dimensions should be > 0"
);
}
if
(
input
.
lens
().
size
()
>
output_lens
.
size
())
MIGRAPHX_THROW
(
"inputs dimensions should <= output size"
);
{
MIGRAPHX_THROW
(
"MULTIBROADCAST: inputs dimensions should <= output size"
);
}
std
::
vector
<
size_t
>
bcast_strides
(
output_lens
.
size
(),
0
);
auto
offset
=
output_lens
.
size
()
-
input
.
lens
().
size
();
for
(
std
::
ptrdiff_t
i
=
input
.
lens
().
size
()
-
1
;
i
>=
0
;
i
--
)
{
if
(
output_lens
[
i
+
offset
]
!=
input
.
lens
()[
i
]
and
input
.
lens
()[
i
]
!=
1
)
{
MIGRAPHX_THROW
(
"MULTIBROADCAST: input shape {"
+
to_string_range
(
input
.
lens
())
+
"} cannot be broadcasted to {"
+
to_string_range
(
output_lens
)
+
"}!"
);
}
}
std
::
vector
<
size_t
>
bcast_strides
(
output_lens
.
size
(),
0
);
for
(
std
::
ptrdiff_t
i
=
input
.
lens
().
size
()
-
1
;
i
>=
0
;
i
--
)
{
if
(
output_lens
[
i
+
offset
]
==
input
.
lens
()[
i
])
{
...
...
src/include/migraphx/op/pooling.hpp
View file @
20b1d690
...
...
@@ -31,7 +31,7 @@ struct pooling
{
return
pack
(
f
(
self
.
mode
,
"mode"
),
f
(
self
.
padding
,
"padding"
),
f
(
self
.
padding
,
"padding_mode"
),
f
(
self
.
padding
_mode
,
"padding_mode"
),
f
(
self
.
stride
,
"stride"
),
f
(
self
.
lengths
,
"lengths"
));
}
...
...
@@ -48,51 +48,21 @@ struct pooling
assert
(
lengths
[
0
]
<=
(
input
.
lens
()[
2
]
+
2
*
padding
[
0
]));
assert
(
lengths
[
1
]
<=
(
input
.
lens
()[
3
]
+
2
*
padding
[
1
]));
if
(
padding_mode
==
default_
)
{
return
{
t
,
{
input
.
lens
()[
0
],
input
.
lens
()[
1
],
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
1
,
floor_divide
<
std
::
ptrdiff_t
>
(
input
.
lens
()[
2
]
+
2
*
padding
[
0
]
-
lengths
[
0
],
stride
[
0
])
+
1
)),
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
1
,
floor_divide
<
std
::
ptrdiff_t
>
(
input
.
lens
()[
3
]
+
2
*
padding
[
1
]
-
lengths
[
1
],
stride
[
1
])
+
1
)),
}};
}
else
if
(
padding_mode
==
same
)
{
return
{
t
,
{
input
.
lens
()[
0
],
input
.
lens
()[
1
],
ceil_divide
<
std
::
size_t
>
(
input
.
lens
()[
2
],
stride
[
0
]),
ceil_divide
<
std
::
size_t
>
(
input
.
lens
()[
3
],
stride
[
1
])}};
}
else
if
(
padding_mode
==
valid
)
{
return
{
t
,
return
{
t
,
{
input
.
lens
()[
0
],
input
.
lens
()[
1
],
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
1
,
floor_divide
<
std
::
ptrdiff_t
>
(
input
.
lens
()[
2
]
-
lengths
[
0
],
stride
[
0
])
+
1
)),
floor_divide
<
std
::
ptrdiff_t
>
(
input
.
lens
()[
2
]
+
2
*
padding
[
0
]
-
lengths
[
0
],
stride
[
0
])
+
1
)),
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
1
,
floor_divide
<
std
::
ptrdiff_t
>
(
input
.
lens
()[
3
]
-
lengths
[
1
],
stride
[
1
])
+
1
)),
floor_divide
<
std
::
ptrdiff_t
>
(
input
.
lens
()[
3
]
+
2
*
padding
[
1
]
-
lengths
[
1
],
stride
[
1
])
+
1
)),
}};
}
else
{
MIGRAPHX_THROW
(
"Invalid padding mode"
);
}
}
};
...
...
src/include/migraphx/op/pow.hpp
0 → 100644
View file @
20b1d690
#ifndef MIGRAPHX_GUARD_OPERATORS_POW_HPP
#define MIGRAPHX_GUARD_OPERATORS_POW_HPP
#include <migraphx/op/binary.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
pow
:
binary
<
pow
>
{
auto
apply
()
const
{
return
[](
auto
x
,
auto
y
)
{
return
std
::
pow
(
x
,
y
);
};
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/quant_convolution.hpp
0 → 100644
View file @
20b1d690
#ifndef MIGRAPHX_GUARD_OPERATORS_QUANT_CONVOLUTION_HPP
#define MIGRAPHX_GUARD_OPERATORS_QUANT_CONVOLUTION_HPP
#include <array>
#include <migraphx/op/common.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
quant_convolution
{
std
::
array
<
std
::
size_t
,
2
>
padding
=
{{
0
,
0
}};
std
::
array
<
std
::
size_t
,
2
>
stride
=
{{
1
,
1
}};
std
::
array
<
std
::
size_t
,
2
>
dilation
=
{{
1
,
1
}};
padding_mode_t
padding_mode
=
default_
;
int
group
=
1
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
padding
,
"padding"
),
f
(
self
.
stride
,
"stride"
),
f
(
self
.
dilation
,
"dilation"
),
f
(
self
.
padding_mode
,
"padding_mode"
),
f
(
self
.
group
,
"group"
));
}
std
::
string
name
()
const
{
return
"quant_convolution"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
2
).
same_type
().
same_ndims
().
only_dims
(
4
);
const
shape
&
input
=
inputs
.
at
(
0
);
const
shape
&
weights
=
inputs
.
at
(
1
);
auto
t
=
input
.
type
();
// all input type must be int8_type and output is float_type
if
(
t
!=
shape
::
int8_type
)
{
MIGRAPHX_THROW
(
"QUANT_CONVOLUTION: only accept input and weights of type int8_t"
);
}
t
=
shape
::
int32_type
;
return
{
t
,
{
input
.
lens
()[
0
],
weights
.
lens
()[
0
],
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
1
,
(
input
.
lens
()[
2
]
-
(
1
+
dilation
[
0
]
*
(
weights
.
lens
()[
2
]
-
1
))
+
2
*
padding
[
0
])
/
stride
[
0
]
+
1
)),
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
1
,
(
input
.
lens
()[
3
]
-
(
1
+
dilation
[
1
]
*
(
weights
.
lens
()[
3
]
-
1
))
+
2
*
padding
[
1
])
/
stride
[
1
]
+
1
)),
}};
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/quant_dot.hpp
0 → 100644
View file @
20b1d690
#ifndef MIGRAPHX_GUARD_OPERATORS_QUANT_DOT_HPP
#define MIGRAPHX_GUARD_OPERATORS_QUANT_DOT_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
quant_dot
{
int32_t
alpha
=
1
;
int32_t
beta
=
1
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
as_number
(
self
.
alpha
),
"alpha"
),
f
(
as_number
(
self
.
beta
),
"beta"
));
}
std
::
string
name
()
const
{
return
"quant_dot"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{{
inputs
.
at
(
0
),
inputs
.
at
(
1
)},
*
this
}.
same_type
();
const
shape
&
a
=
inputs
.
at
(
0
);
const
shape
&
b
=
inputs
.
at
(
1
);
auto
t
=
a
.
type
();
if
(
t
!=
shape
::
int8_type
)
{
MIGRAPHX_THROW
(
"QUANT_DOT: only support data type int8_t"
);
}
if
(
!
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[](
auto
s
)
{
return
s
.
lens
().
size
()
>=
2
;
}))
{
MIGRAPHX_THROW
(
"QUANT_DOT: dot only accept 2 or more dims operands"
);
}
// only handle the case that the batch size of a and b are the same
if
(
!
std
::
equal
(
a
.
lens
().
rbegin
()
+
2
,
a
.
lens
().
rend
(),
b
.
lens
().
rbegin
()
+
2
,
b
.
lens
().
rend
()))
{
MIGRAPHX_THROW
(
"QUANT_DOT: batch size of A and B mismatch: {"
+
to_string_range
(
a
.
lens
())
+
"} x {"
+
to_string_range
(
b
.
lens
())
+
"}"
);
}
std
::
size_t
dim_0
=
a
.
lens
().
size
()
-
2
;
std
::
size_t
dim_1
=
a
.
lens
().
size
()
-
1
;
if
(
a
.
lens
()[
dim_1
]
!=
b
.
lens
()[
dim_0
])
{
MIGRAPHX_THROW
(
"QUANT_DOT: inner dimensions do not match: {"
+
to_string_range
(
a
.
lens
())
+
"} x {"
+
to_string_range
(
b
.
lens
())
+
"}"
);
}
// k be multiple of 4
if
((
a
.
lens
()[
dim_1
]
%
4
)
!=
0
)
{
MIGRAPHX_THROW
(
"QUANT_DOT: size of A {"
+
to_string_range
(
a
.
lens
())
+
"} and B {"
+
to_string_range
(
b
.
lens
())
+
"} must be multiple of 4 for int8 type"
);
}
auto
out_lens
=
a
.
lens
();
out_lens
[
dim_1
]
=
b
.
lens
()[
dim_1
];
if
(
inputs
.
size
()
==
3
&&
out_lens
!=
inputs
.
at
(
2
).
lens
())
{
MIGRAPHX_THROW
(
"QUANT_DOT: dimension mismatch, operand C: {"
+
to_string_range
(
inputs
.
at
(
2
).
lens
())
+
"}, cannot add to operand A * B: {"
+
to_string_range
(
out_lens
)
+
"}"
);
}
if
(
inputs
.
size
()
==
3
&&
inputs
.
at
(
2
).
type
()
!=
shape
::
int32_type
)
{
MIGRAPHX_THROW
(
"QUANT_DOT: operand C type must be int32"
);
}
return
{
shape
::
int32_type
,
out_lens
};
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/reduce_mean.hpp
0 → 100644
View file @
20b1d690
#ifndef MIGRAPHX_GUARD_OPERATORS_MEAN_HPP
#define MIGRAPHX_GUARD_OPERATORS_MEAN_HPP
#include <migraphx/op/reduce_op.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
reduce_mean
:
reduce_op
<
reduce_mean
>
{
reduce_mean
()
{}
reduce_mean
(
std
::
vector
<
int64_t
>
ax
)
:
reduce_op
(
std
::
move
(
ax
))
{}
auto
op
()
const
{
return
[
=
](
auto
x
,
auto
y
)
{
return
x
+
y
;
};
}
auto
output
(
const
shape
&
s
)
const
{
return
[
&
](
auto
val
)
{
return
val
/
s
.
elements
();
};
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/reduce_op.hpp
0 → 100644
View file @
20b1d690
#ifndef MIGRAPHX_GUARD_OPERATORS_OP_HPP
#define MIGRAPHX_GUARD_OPERATORS_OP_HPP
#include <migraphx/op/name.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/config.hpp>
#include <vector>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
lowest
{
template
<
class
T
>
operator
T
()
const
{
return
std
::
numeric_limits
<
T
>::
lowest
();
}
};
struct
highest
{
template
<
class
T
>
operator
T
()
const
{
return
std
::
numeric_limits
<
T
>::
max
();
}
};
struct
zero
{
template
<
class
T
>
operator
T
()
const
{
return
T
{
0
};
}
};
template
<
class
Derived
>
struct
reduce_op
:
op_name
<
Derived
>
{
std
::
vector
<
std
::
int64_t
>
axes
{};
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
axes
,
"axes"
));
}
std
::
vector
<
int64_t
>
tune_axes
(
std
::
size_t
n_dim
)
const
{
auto
tuned_axes
=
axes
;
if
(
tuned_axes
.
empty
())
{
tuned_axes
.
resize
(
n_dim
);
std
::
iota
(
tuned_axes
.
begin
(),
tuned_axes
.
end
(),
0
);
}
else
{
for
(
auto
&
axis
:
tuned_axes
)
{
int64_t
s_dim
=
static_cast
<
int64_t
>
(
n_dim
);
if
(
axis
>=
s_dim
or
axis
<
-
s_dim
)
{
MIGRAPHX_THROW
(
"REDUCE_OP: axis out of range"
);
}
if
(
axis
<
0
)
{
axis
+=
n_dim
;
}
}
}
return
tuned_axes
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
auto
s
=
inputs
.
at
(
0
);
auto
lens
=
s
.
lens
();
auto
tuned_axes
=
tune_axes
(
lens
.
size
());
for
(
auto
axis
:
tuned_axes
)
{
lens
[
axis
]
=
1
;
}
return
{
s
.
type
(),
lens
};
}
template
<
class
T
>
void
tune_dims
(
const
std
::
vector
<
int64_t
>&
tuned_axes
,
const
std
::
vector
<
T
>&
in_lens
,
std
::
vector
<
T
>&
out_lens
)
const
{
for
(
auto
axis
:
tuned_axes
)
{
out_lens
[
axis
]
=
in_lens
[
axis
];
}
}
template
<
class
T
>
void
reduce
(
tensor_view
<
T
>&
input
,
shape
&
batch_shape
,
std
::
vector
<
int64_t
>&
tuned_axes
,
std
::
vector
<
std
::
size_t
>&
out_idx
,
tensor_view
<
T
>&
output
)
const
{
auto
data_idx
=
out_idx
;
T
val
=
static_cast
<
const
Derived
&>
(
*
this
).
init
();
shape_for_each
(
batch_shape
,
[
&
](
auto
b_idx
)
{
this
->
tune_dims
(
tuned_axes
,
b_idx
,
data_idx
);
val
=
static_cast
<
const
Derived
&>
(
*
this
).
op
()(
static_cast
<
const
Derived
&>
(
*
this
).
input
()(
input
(
data_idx
.
begin
(),
data_idx
.
end
())),
val
);
});
output
(
out_idx
.
begin
(),
out_idx
.
end
())
=
static_cast
<
const
Derived
&>
(
*
this
).
output
(
batch_shape
)(
val
);
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
output_shape
};
auto
arg_lens
=
args
.
front
().
get_shape
().
lens
();
auto
tuned_axes
=
tune_axes
(
arg_lens
.
size
());
std
::
vector
<
std
::
size_t
>
batch_lens
(
output_shape
.
lens
().
size
(),
1
);
tune_dims
(
tuned_axes
,
arg_lens
,
batch_lens
);
shape
batch_shape
{
output_shape
.
type
(),
batch_lens
};
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
input
)
{
par_for
(
output_shape
.
elements
(),
[
&
](
auto
i
)
{
auto
out_idx
=
output_shape
.
multi
(
i
);
this
->
reduce
(
input
,
batch_shape
,
tuned_axes
,
out_idx
,
output
);
});
});
return
result
;
}
auto
init
()
const
{
return
zero
();
}
auto
input
()
const
{
return
[
&
](
auto
val
)
{
return
val
;
};
}
auto
output
(
const
shape
&
)
const
{
return
[
&
](
auto
val
)
{
return
val
;
};
}
reduce_op
()
{}
reduce_op
(
std
::
vector
<
int64_t
>
ax
)
:
axes
(
std
::
move
(
ax
))
{}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
Prev
1
2
3
4
5
6
…
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment