Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
31065c7d
Commit
31065c7d
authored
Oct 31, 2022
by
charlie
Browse files
Merge branch 'dyn_squeeze' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into dyn_model_test
parents
6bec381f
6acbd4e4
Changes
482
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
420 additions
and
212 deletions
+420
-212
src/include/migraphx/op/mod.hpp
src/include/migraphx/op/mod.hpp
+1
-10
src/include/migraphx/op/multibroadcast.hpp
src/include/migraphx/op/multibroadcast.hpp
+66
-25
src/include/migraphx/op/nonmaxsuppression.hpp
src/include/migraphx/op/nonmaxsuppression.hpp
+104
-38
src/include/migraphx/op/pooling.hpp
src/include/migraphx/op/pooling.hpp
+3
-3
src/include/migraphx/op/quant_convolution.hpp
src/include/migraphx/op/quant_convolution.hpp
+5
-7
src/include/migraphx/op/quant_dot.hpp
src/include/migraphx/op/quant_dot.hpp
+3
-2
src/include/migraphx/op/slice.hpp
src/include/migraphx/op/slice.hpp
+2
-2
src/include/migraphx/op/squeeze.hpp
src/include/migraphx/op/squeeze.hpp
+68
-29
src/include/migraphx/op/transpose.hpp
src/include/migraphx/op/transpose.hpp
+1
-1
src/include/migraphx/op/unary.hpp
src/include/migraphx/op/unary.hpp
+5
-4
src/include/migraphx/op/unsqueeze.hpp
src/include/migraphx/op/unsqueeze.hpp
+76
-41
src/include/migraphx/operation.hpp
src/include/migraphx/operation.hpp
+40
-18
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+0
-1
src/include/migraphx/pad_calc.hpp
src/include/migraphx/pad_calc.hpp
+15
-11
src/include/migraphx/pass.hpp
src/include/migraphx/pass.hpp
+2
-2
src/include/migraphx/program.hpp
src/include/migraphx/program.hpp
+4
-3
src/include/migraphx/raw_data.hpp
src/include/migraphx/raw_data.hpp
+4
-4
src/include/migraphx/reflect.hpp
src/include/migraphx/reflect.hpp
+15
-5
src/include/migraphx/requires.hpp
src/include/migraphx/requires.hpp
+1
-1
src/include/migraphx/rewrite_gelu.hpp
src/include/migraphx/rewrite_gelu.hpp
+5
-5
No files found.
src/include/migraphx/op/mod.hpp
View file @
31065c7d
...
@@ -24,17 +24,8 @@
...
@@ -24,17 +24,8 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_MOD_HPP
#ifndef MIGRAPHX_GUARD_OPERATORS_MOD_HPP
#define MIGRAPHX_GUARD_OPERATORS_MOD_HPP
#define MIGRAPHX_GUARD_OPERATORS_MOD_HPP
#include <array>
#include <migraphx/op/binary.hpp>
#include <migraphx/op/binary.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <cmath>
#include <utility>
#include <type_traits>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -47,9 +38,9 @@ struct mod : binary<mod>
...
@@ -47,9 +38,9 @@ struct mod : binary<mod>
{
{
auto
a
=
base_attributes
();
auto
a
=
base_attributes
();
a
[
"commutative"
]
=
false
;
a
[
"commutative"
]
=
false
;
a
[
"point_op"
]
=
"${function:fmod}((${function:remainder}(${0}, ${1})) + ${1}, ${1})"
;
return
a
;
return
a
;
}
}
std
::
string
point_function
()
const
{
return
"mod"
;
}
auto
apply
()
const
auto
apply
()
const
{
{
return
[](
auto
x
,
auto
y
)
{
return
std
::
fmod
((
std
::
remainder
(
x
,
y
))
+
y
,
y
);
};
return
[](
auto
x
,
auto
y
)
{
return
std
::
fmod
((
std
::
remainder
(
x
,
y
))
+
y
,
y
);
};
...
...
src/include/migraphx/op/multibroadcast.hpp
View file @
31065c7d
...
@@ -26,64 +26,105 @@
...
@@ -26,64 +26,105 @@
#include <migraphx/check_shapes.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/common.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
namespace
op
{
/**
* Broadcast multiple dimensions between two tensors.
* Two versions of this operator: one input and two inputs.
* One input version uses output_lens attribute and broadcasts to it.
* Two inputs version broadcasts both inputs to the common shape at evaluation time.
*/
struct
multibroadcast
struct
multibroadcast
{
{
std
::
vector
<
std
::
size_t
>
output_lens
;
std
::
vector
<
std
::
size_t
>
output_lens
=
{};
// optional attribute
std
::
vector
<
shape
::
dynamic_dimension
>
output_dyn_dims
=
{};
template
<
class
Self
,
class
F
>
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
static
auto
reflect
(
Self
&
self
,
F
f
)
{
{
return
pack
(
f
(
self
.
output_lens
,
"out_lens"
));
return
pack
(
f
(
self
.
output_lens
,
"out_lens"
)
,
f
(
self
.
output_dyn_dims
,
"out_dyn_dims"
)
);
}
}
std
::
string
name
()
const
{
return
"multibroadcast"
;
}
std
::
string
name
()
const
{
return
"multibroadcast"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
,
2
);
auto
t
=
inputs
.
at
(
0
).
type
();
auto
input
=
inputs
.
at
(
0
);
if
(
input
.
lens
().
empty
())
auto
t
=
inputs
.
at
(
0
).
type
();
{
auto
s0
=
inputs
.
at
(
0
);
MIGRAPHX_THROW
(
"MULTIBROADCAST: inputs dimensions should be > 0"
);
}
if
(
input
.
lens
().
size
()
>
output_lens
.
size
())
if
(
s0
.
max_lens
().
empty
())
{
{
MIGRAPHX_THROW
(
"MULTIBROADCAST: input
s
dimensions should
<= output size
"
);
MIGRAPHX_THROW
(
"MULTIBROADCAST: input dimensions should
be > 0
"
);
}
}
auto
offset
=
output_lens
.
size
()
-
input
.
lens
().
size
();
auto
make_bcast_strides
=
[
&
](
std
::
vector
<
std
::
size_t
>
bcast_lens
,
std
::
size_t
offset
)
{
for
(
std
::
ptrdiff_t
i
=
input
.
lens
().
size
()
-
1
;
i
>=
0
;
i
--
)
std
::
vector
<
size_t
>
bcast_strides
(
bcast_lens
.
size
(),
0
);
for
(
std
::
ptrdiff_t
i
=
s0
.
lens
().
size
()
-
1
;
i
>=
0
;
i
--
)
{
if
(
bcast_lens
[
i
+
offset
]
==
s0
.
lens
()[
i
])
{
bcast_strides
[
i
+
offset
]
=
s0
.
strides
()[
i
];
}
}
return
bcast_strides
;
};
if
(
inputs
.
size
()
==
1
)
{
{
if
(
output_lens
[
i
+
offset
]
!=
input
.
lens
()[
i
]
and
in
put
.
lens
()[
i
]
!=
1
)
if
(
s0
.
lens
().
size
()
>
out
put
_
lens
.
size
()
)
{
{
MIGRAPHX_THROW
(
"MULTIBROADCAST: input shape {"
+
to_string_range
(
input
.
lens
())
+
MIGRAPHX_THROW
(
"MULTIBROADCAST: input dimensions should <= output size"
);
"} cannot be broadcasted to {"
+
to_string_range
(
output_lens
)
+
"}!"
);
}
}
}
std
::
vector
<
size_t
>
bcast_strides
(
output_lens
.
size
(),
0
);
auto
offset
=
output_lens
.
size
()
-
s0
.
lens
().
size
();
for
(
std
::
ptrdiff_t
i
=
input
.
lens
().
size
()
-
1
;
i
>=
0
;
i
--
)
for
(
std
::
ptrdiff_t
i
=
s0
.
lens
().
size
()
-
1
;
i
>=
0
;
i
--
)
{
if
(
output_lens
[
i
+
offset
]
!=
s0
.
lens
()[
i
]
and
s0
.
lens
()[
i
]
!=
1
)
{
MIGRAPHX_THROW
(
"MULTIBROADCAST: input shape {"
+
to_string_range
(
s0
.
lens
())
+
"} cannot be broadcasted to {"
+
to_string_range
(
output_lens
)
+
"}!"
);
}
}
auto
bcast_strides
=
make_bcast_strides
(
output_lens
,
offset
);
return
{
t
,
output_lens
,
std
::
move
(
bcast_strides
)};
}
else
{
{
if
(
output_lens
[
i
+
offset
]
==
input
.
lens
()[
i
])
// two inputs
auto
s1
=
inputs
.
at
(
1
);
if
(
s0
.
dynamic
()
or
s1
.
dynamic
())
{
{
bcast_strides
[
i
+
offset
]
=
input
.
strides
()[
i
];
if
(
not
output_dyn_dims
.
empty
())
{
return
{
t
,
output_dyn_dims
};
}
return
{
t
,
compute_broadcasted_dyn_dims
(
s0
,
s1
)};
}
else
{
auto
bcast_lens
=
compute_broadcasted_lens
(
s0
.
lens
(),
s1
.
lens
());
auto
offset
=
bcast_lens
.
size
()
-
s0
.
lens
().
size
();
auto
bcast_strides
=
make_bcast_strides
(
bcast_lens
,
offset
);
return
{
t
,
std
::
move
(
bcast_lens
),
std
::
move
(
bcast_strides
)};
}
}
}
}
return
{
t
,
output_lens
,
bcast_strides
};
}
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
{
return
args
[
0
].
reshape
(
out
put_shape
);
return
args
[
0
].
reshape
(
dyn_out
.
com
put
ed
_shape
);
}
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
};
...
...
src/include/migraphx/op/nonmaxsuppression.hpp
View file @
31065c7d
...
@@ -45,11 +45,13 @@ namespace op {
...
@@ -45,11 +45,13 @@ namespace op {
struct
nonmaxsuppression
struct
nonmaxsuppression
{
{
bool
center_point_box
=
false
;
bool
center_point_box
=
false
;
bool
use_dyn_output
=
false
;
template
<
class
Self
,
class
F
>
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
static
auto
reflect
(
Self
&
self
,
F
f
)
{
{
return
pack
(
f
(
self
.
center_point_box
,
"center_point_box"
));
return
pack
(
f
(
self
.
center_point_box
,
"center_point_box"
),
f
(
self
.
use_dyn_output
,
"use_dyn_output"
));
}
}
std
::
string
name
()
const
{
return
"nonmaxsuppression"
;
}
std
::
string
name
()
const
{
return
"nonmaxsuppression"
;
}
...
@@ -57,27 +59,81 @@ struct nonmaxsuppression
...
@@ -57,27 +59,81 @@ struct nonmaxsuppression
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
// requires at least 2 inputs
// requires at least 2 inputs
check_shapes
{{
inputs
.
at
(
0
),
inputs
.
at
(
1
)},
*
this
}.
only_dims
(
3
);
check_shapes
{{
inputs
.
at
(
0
),
inputs
.
at
(
1
)},
*
this
,
true
}.
only_dims
(
3
).
same_ndims
();
auto
lens
=
inputs
.
front
().
lens
();
auto
boxes_max_lens
=
inputs
.
at
(
0
).
max_lens
();
// num batches * num boxes
const
auto
max_num_boxes
=
boxes_max_lens
.
at
(
0
)
*
boxes_max_lens
.
at
(
1
);
// check input shape
auto
fixed_shape_error_check
=
[
&
]()
{
if
(
lens
[
1
]
!=
inputs
.
at
(
1
).
lens
()[
2
])
auto
lens
=
inputs
.
front
().
lens
();
if
(
lens
[
1
]
!=
inputs
.
at
(
1
).
lens
()[
2
])
{
MIGRAPHX_THROW
(
"NonMaxSuppression: spatial dimension mismatch between boxes and scores input"
);
}
if
(
lens
[
0
]
!=
inputs
.
at
(
1
).
lens
()[
0
])
{
MIGRAPHX_THROW
(
"NonMaxSuppression: number of batches mismatch between boxes and scores input"
);
}
};
if
(
use_dyn_output
)
{
{
MIGRAPHX_THROW
(
if
(
inputs
.
at
(
0
).
dynamic
())
"NonMaxSuppression: spatial dimension mismatch between boxes and scores input"
);
{
// both boxes and scores should be dynamic
// check dynamic dimensions are consistent
const
auto
boxes_dims
=
inputs
.
at
(
0
).
dyn_dims
();
const
auto
scores_dims
=
inputs
.
at
(
1
).
dyn_dims
();
if
(
boxes_dims
.
at
(
1
)
!=
scores_dims
.
at
(
2
))
{
MIGRAPHX_THROW
(
"NonMaxSuppression: dynamic spatial dimension mismatch between "
"boxes and scores input"
);
}
if
(
boxes_dims
.
at
(
0
)
!=
scores_dims
.
at
(
0
))
{
MIGRAPHX_THROW
(
"NonMaxSuppression: dynamic number of batches mismatch between "
"boxes and scores input"
);
}
}
else
if
(
inputs
.
at
(
1
).
dynamic
())
{
// scores has dynamic shape, boxes fixed shape
// check that it is only a dynamic number of classes
const
auto
scores_dims
=
inputs
.
at
(
1
).
dyn_dims
();
const
auto
boxes_lens
=
inputs
.
at
(
0
).
lens
();
if
(
not
scores_dims
.
at
(
0
).
is_fixed
()
or
scores_dims
.
at
(
0
).
max
!=
boxes_lens
.
at
(
0
))
{
MIGRAPHX_THROW
(
"NonMaxSuppression: scores dynamic num_classes; num_batches not "
"fixed or mismatched"
);
}
if
(
not
scores_dims
.
at
(
2
).
is_fixed
()
or
scores_dims
.
at
(
2
).
max
!=
boxes_lens
.
at
(
1
))
{
MIGRAPHX_THROW
(
"NonMaxSuppression: scores dynamic num_classes; "
"spatial_dimension not fixed or mismatches"
);
}
}
else
{
fixed_shape_error_check
();
}
std
::
vector
<
shape
::
dynamic_dimension
>
out_lens
=
{};
out_lens
.
push_back
({
0
,
max_num_boxes
,
0
});
out_lens
.
push_back
({
3
,
3
,
0
});
return
{
shape
::
int64_type
,
out_lens
};
}
}
else
// check batch sizes
if
(
lens
[
0
]
!=
inputs
.
at
(
1
).
lens
()[
0
])
{
{
MIGRAPHX_THROW
(
if
(
inputs
.
at
(
0
).
dynamic
()
or
inputs
.
at
(
1
).
dynamic
())
"NonMaxSuppression: number of batches mismatch between boxes and scores input"
);
{
MIGRAPHX_THROW
(
"NonMaxSuppression: dynamic input shape with use_dyn_output set to false"
);
}
fixed_shape_error_check
();
std
::
vector
<
std
::
size_t
>
out_lens
=
{
max_num_boxes
,
3
};
return
{
shape
::
int64_type
,
out_lens
};
}
}
std
::
vector
<
int64_t
>
out_lens
(
2
);
out_lens
.
at
(
0
)
=
lens
.
at
(
1
);
out_lens
.
at
(
1
)
=
3
;
return
{
shape
::
int64_type
,
out_lens
};
}
}
struct
box
struct
box
...
@@ -181,13 +237,13 @@ struct nonmaxsuppression
...
@@ -181,13 +237,13 @@ struct nonmaxsuppression
}
}
template
<
class
Output
,
class
Boxes
,
class
Scores
>
template
<
class
Output
,
class
Boxes
,
class
Scores
>
void
compute_nms
(
Output
output
,
std
::
size_t
compute_nms
(
Output
output
,
Boxes
boxes
,
Boxes
boxes
,
Scores
scores
,
Scores
scores
,
const
shape
&
output_shape
,
const
shape
&
max_
output_shape
,
std
::
size_t
max_output_boxes_per_class
,
std
::
size_t
max_output_boxes_per_class
,
double
iou_threshold
,
double
iou_threshold
,
double
score_threshold
)
const
double
score_threshold
)
const
{
{
std
::
fill
(
output
.
begin
(),
output
.
end
(),
0
);
std
::
fill
(
output
.
begin
(),
output
.
end
(),
0
);
const
auto
&
lens
=
scores
.
get_shape
().
lens
();
const
auto
&
lens
=
scores
.
get_shape
().
lens
();
...
@@ -197,7 +253,7 @@ struct nonmaxsuppression
...
@@ -197,7 +253,7 @@ struct nonmaxsuppression
// boxes of a class with NMS applied [score, index]
// boxes of a class with NMS applied [score, index]
std
::
vector
<
std
::
pair
<
double
,
int64_t
>>
selected_boxes_inside_class
;
std
::
vector
<
std
::
pair
<
double
,
int64_t
>>
selected_boxes_inside_class
;
std
::
vector
<
int64_t
>
selected_indices
;
std
::
vector
<
int64_t
>
selected_indices
;
selected_boxes_inside_class
.
reserve
(
output_shape
.
elements
());
selected_boxes_inside_class
.
reserve
(
max_
output_shape
.
elements
());
// iterate over batches and classes
// iterate over batches and classes
shape
comp_s
{
shape
::
double_type
,
{
num_batches
,
num_classes
}};
shape
comp_s
{
shape
::
double_type
,
{
num_batches
,
num_classes
}};
shape_for_each
(
comp_s
,
[
&
](
auto
idx
)
{
shape_for_each
(
comp_s
,
[
&
](
auto
idx
)
{
...
@@ -210,7 +266,7 @@ struct nonmaxsuppression
...
@@ -210,7 +266,7 @@ struct nonmaxsuppression
auto
boxes_heap
=
filter_boxes_by_score
(
scores_start
,
num_boxes
,
score_threshold
);
auto
boxes_heap
=
filter_boxes_by_score
(
scores_start
,
num_boxes
,
score_threshold
);
selected_boxes_inside_class
.
clear
();
selected_boxes_inside_class
.
clear
();
// Get the next box with top score, filter by iou_threshold
// Get the next box with top score, filter by iou_threshold
while
(
!
boxes_heap
.
empty
()
&&
while
(
not
boxes_heap
.
empty
()
&&
selected_boxes_inside_class
.
size
()
<
max_output_boxes_per_class
)
selected_boxes_inside_class
.
size
()
<
max_output_boxes_per_class
)
{
{
// Check with existing selected boxes for this class, remove box if it
// Check with existing selected boxes for this class, remove box if it
...
@@ -237,11 +293,14 @@ struct nonmaxsuppression
...
@@ -237,11 +293,14 @@ struct nonmaxsuppression
}
}
});
});
std
::
copy
(
selected_indices
.
begin
(),
selected_indices
.
end
(),
output
.
begin
());
std
::
copy
(
selected_indices
.
begin
(),
selected_indices
.
end
(),
output
.
begin
());
return
selected_indices
.
size
()
/
3
;
}
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
argument
result
{
output_shape
};
// make buffer of maximum size
shape
max_output_shape
=
{
output_shape
.
type
(),
output_shape
.
max_lens
()};
argument
result
{
max_output_shape
};
std
::
size_t
max_output_boxes_per_class
=
std
::
size_t
max_output_boxes_per_class
=
(
args
.
size
()
>
2
)
?
(
args
.
at
(
2
).
at
<
std
::
size_t
>
())
:
0
;
(
args
.
size
()
>
2
)
?
(
args
.
at
(
2
).
at
<
std
::
size_t
>
())
:
0
;
...
@@ -249,22 +308,29 @@ struct nonmaxsuppression
...
@@ -249,22 +308,29 @@ struct nonmaxsuppression
{
{
return
result
;
return
result
;
}
}
double
iou_threshold
=
(
args
.
size
()
>
3
)
?
(
args
.
at
(
3
).
at
<
double
>
())
:
0.0
f
;
double
iou_threshold
=
(
args
.
size
()
>
3
)
?
(
args
.
at
(
3
).
at
<
double
>
())
:
0.0
f
;
double
score_threshold
=
(
args
.
size
()
>
4
)
?
(
args
.
at
(
4
).
at
<
double
>
())
:
0.0
f
;
double
score_threshold
=
(
args
.
size
()
>
4
)
?
(
args
.
at
(
4
).
at
<
double
>
())
:
0.0
f
;
std
::
size_t
num_selected
=
0
;
result
.
visit
([
&
](
auto
output
)
{
result
.
visit
([
&
](
auto
output
)
{
visit_all
(
args
[
0
],
args
[
1
])([
&
](
auto
boxes
,
auto
scores
)
{
visit_all
(
args
[
0
],
args
[
1
])([
&
](
auto
boxes
,
auto
scores
)
{
compute_nms
(
output
,
num_selected
=
compute_nms
(
output
,
boxes
,
boxes
,
scores
,
scores
,
output_shape
,
max_
output_shape
,
max_output_boxes_per_class
,
max_output_boxes_per_class
,
iou_threshold
,
iou_threshold
,
score_threshold
);
score_threshold
);
});
});
});
});
if
(
use_dyn_output
)
return
result
;
{
return
result
.
reshape
({
output_shape
.
type
(),
{
num_selected
,
3
}});
}
else
{
return
result
;
}
}
}
};
};
...
...
src/include/migraphx/op/pooling.hpp
View file @
31065c7d
...
@@ -64,8 +64,8 @@ struct pooling
...
@@ -64,8 +64,8 @@ struct pooling
void
check_attribute_size
()
const
void
check_attribute_size
()
const
{
{
if
(
not
(
(
padding
.
size
()
=
=
stride
.
size
()
or
(
padding
.
size
()
/
2
)
=
=
stride
.
size
())
and
if
((
padding
.
size
()
!
=
stride
.
size
()
and
(
padding
.
size
()
/
2
)
!
=
stride
.
size
())
or
stride
.
size
()
=
=
lengths
.
size
())
)
stride
.
size
()
!
=
lengths
.
size
())
{
{
MIGRAPHX_THROW
(
"POOLING: inconsistent attribute sizes"
);
MIGRAPHX_THROW
(
"POOLING: inconsistent attribute sizes"
);
}
}
...
@@ -83,7 +83,7 @@ struct pooling
...
@@ -83,7 +83,7 @@ struct pooling
size_t
kdims
=
input_lens
.
size
()
-
2
;
size_t
kdims
=
input_lens
.
size
()
-
2
;
auto
input_size
=
inputs
[
0
].
lens
().
size
();
auto
input_size
=
inputs
[
0
].
lens
().
size
();
auto
padding_size
=
padding
.
size
();
auto
padding_size
=
padding
.
size
();
if
(
not
(
input_size
=
=
padding_size
/
2
+
2
or
input_size
=
=
padding_size
+
2
)
)
if
(
input_size
!
=
padding_size
/
2
+
2
and
input_size
!
=
padding_size
+
2
)
{
{
MIGRAPHX_THROW
(
"POOLING: input and attribute size mismatch!"
);
MIGRAPHX_THROW
(
"POOLING: input and attribute size mismatch!"
);
}
}
...
...
src/include/migraphx/op/quant_convolution.hpp
View file @
31065c7d
...
@@ -41,9 +41,8 @@ struct quant_convolution
...
@@ -41,9 +41,8 @@ struct quant_convolution
std
::
vector
<
std
::
size_t
>
stride
=
{
1
,
1
};
std
::
vector
<
std
::
size_t
>
stride
=
{
1
,
1
};
std
::
vector
<
std
::
size_t
>
dilation
=
{
1
,
1
};
std
::
vector
<
std
::
size_t
>
dilation
=
{
1
,
1
};
padding_mode_t
padding_mode
=
default_
;
padding_mode_t
padding_mode
=
default_
;
int
group
=
1
;
int
group
=
1
;
bool
use_dynamic_same_auto_pad
=
false
;
template
<
class
Self
,
class
F
>
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
static
auto
reflect
(
Self
&
self
,
F
f
)
...
@@ -52,8 +51,7 @@ struct quant_convolution
...
@@ -52,8 +51,7 @@ struct quant_convolution
f
(
self
.
stride
,
"stride"
),
f
(
self
.
stride
,
"stride"
),
f
(
self
.
dilation
,
"dilation"
),
f
(
self
.
dilation
,
"dilation"
),
f
(
self
.
padding_mode
,
"padding_mode"
),
f
(
self
.
padding_mode
,
"padding_mode"
),
f
(
self
.
group
,
"group"
),
f
(
self
.
group
,
"group"
));
f
(
self
.
use_dynamic_same_auto_pad
,
"use_dynamic_same_auto_pad"
));
}
}
value
attributes
()
const
value
attributes
()
const
...
@@ -65,8 +63,8 @@ struct quant_convolution
...
@@ -65,8 +63,8 @@ struct quant_convolution
void
check_attribute_size
()
const
void
check_attribute_size
()
const
{
{
if
(
not
(
(
padding
.
size
()
=
=
stride
.
size
()
or
(
padding
.
size
()
/
2
)
=
=
stride
.
size
())
and
if
((
padding
.
size
()
!
=
stride
.
size
()
and
(
padding
.
size
()
/
2
)
!
=
stride
.
size
())
or
stride
.
size
()
=
=
dilation
.
size
())
)
stride
.
size
()
!
=
dilation
.
size
())
{
{
MIGRAPHX_THROW
(
"QUANT_CONVOLUTION: inconsistent attribute sizes"
);
MIGRAPHX_THROW
(
"QUANT_CONVOLUTION: inconsistent attribute sizes"
);
}
}
...
...
src/include/migraphx/op/quant_dot.hpp
View file @
31065c7d
...
@@ -49,13 +49,14 @@ struct quant_dot
...
@@ -49,13 +49,14 @@ struct quant_dot
MIGRAPHX_THROW
(
"QUANT_DOT: only support data type int8_t"
);
MIGRAPHX_THROW
(
"QUANT_DOT: only support data type int8_t"
);
}
}
if
(
!
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[](
auto
s
)
{
return
s
.
lens
().
size
()
>=
2
;
}))
if
(
not
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[](
auto
s
)
{
return
s
.
lens
().
size
()
>=
2
;
}))
{
{
MIGRAPHX_THROW
(
"QUANT_DOT: dot only accept 2 or more dims operands"
);
MIGRAPHX_THROW
(
"QUANT_DOT: dot only accept 2 or more dims operands"
);
}
}
// only handle the case that the batch size of a and b are the same
// only handle the case that the batch size of a and b are the same
if
(
!
std
::
equal
(
if
(
not
std
::
equal
(
a
.
lens
().
rbegin
()
+
2
,
a
.
lens
().
rend
(),
b
.
lens
().
rbegin
()
+
2
,
b
.
lens
().
rend
()))
a
.
lens
().
rbegin
()
+
2
,
a
.
lens
().
rend
(),
b
.
lens
().
rbegin
()
+
2
,
b
.
lens
().
rend
()))
{
{
MIGRAPHX_THROW
(
"QUANT_DOT: batch size of A and B mismatch: {"
+
MIGRAPHX_THROW
(
"QUANT_DOT: batch size of A and B mismatch: {"
+
...
...
src/include/migraphx/op/slice.hpp
View file @
31065c7d
...
@@ -78,7 +78,7 @@ struct slice
...
@@ -78,7 +78,7 @@ struct slice
const
std
::
vector
<
std
::
size_t
>&
lens
=
s
.
lens
();
const
std
::
vector
<
std
::
size_t
>&
lens
=
s
.
lens
();
const
std
::
vector
<
std
::
size_t
>&
strides
=
s
.
strides
();
const
std
::
vector
<
std
::
size_t
>&
strides
=
s
.
strides
();
auto
offset
=
0
;
auto
offset
=
0
;
if
(
!
axes
.
empty
())
if
(
not
axes
.
empty
())
{
{
for
(
std
::
size_t
i
=
0
;
i
<
axes
.
size
();
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
axes
.
size
();
i
++
)
{
{
...
@@ -109,7 +109,7 @@ struct slice
...
@@ -109,7 +109,7 @@ struct slice
MIGRAPHX_THROW
(
"SLICE: input axis "
+
to_string_range
(
axes
)
+
" out of range"
);
MIGRAPHX_THROW
(
"SLICE: input axis "
+
to_string_range
(
axes
)
+
" out of range"
);
}
}
if
(
starts
.
size
()
!=
axes
.
size
()
||
axes
.
size
()
!=
ends
.
size
())
if
(
starts
.
size
()
!=
axes
.
size
()
or
axes
.
size
()
!=
ends
.
size
())
{
{
MIGRAPHX_THROW
(
"SLICE: inconsistent sizes"
);
MIGRAPHX_THROW
(
"SLICE: inconsistent sizes"
);
}
}
...
...
src/include/migraphx/op/squeeze.hpp
View file @
31065c7d
...
@@ -29,6 +29,7 @@
...
@@ -29,6 +29,7 @@
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/dyn_output.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -54,52 +55,90 @@ struct squeeze
...
@@ -54,52 +55,90 @@ struct squeeze
std
::
string
name
()
const
{
return
"squeeze"
;
}
std
::
string
name
()
const
{
return
"squeeze"
;
}
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
);
auto
input_shape
=
inputs
[
0
];
auto
input_shape
=
inputs
[
0
];
auto
type
=
input_shape
.
type
();
if
(
input_shape
.
dynamic
())
auto
old_lens
=
input_shape
.
lens
();
auto
old_strides
=
input_shape
.
strides
();
if
(
std
::
any_of
(
axes
.
begin
(),
axes
.
end
(),
[
&
](
auto
axis
)
{
return
old_lens
[
axis
]
!=
1
;
}))
{
{
MIGRAPHX_THROW
(
"squeeze axis dimension should be equal to 1"
);
std
::
vector
<
shape
::
dynamic_dimension
>
one_dyn_dims
{{
1
,
1
,
0
},
{
1
,
1
,
1
}};
}
if
(
std
::
any_of
(
axes
.
begin
(),
axes
.
end
(),
[
&
](
auto
axis
)
{
std
::
vector
<
std
::
size_t
>
new_lens
;
return
not
contains
(
one_dyn_dims
,
input_shape
.
dyn_dims
()[
axis
]);
std
::
vector
<
std
::
size_t
>
new_strides
;
}))
if
(
axes
.
empty
())
{
{
MIGRAPHX_THROW
(
for
(
auto
i
:
range
(
old_lens
.
size
()))
"SQUEEZE: dynamic axis dimension should be equal to {1, 1, 0} or {1, 1, 1}"
);
}
std
::
vector
<
shape
::
dynamic_dimension
>
dyn_dims
=
{};
if
(
axes
.
empty
())
{
{
if
(
old_lens
[
i
]
!=
1
)
for
(
auto
i
:
range
(
input_shape
.
ndim
())
)
{
{
new_lens
.
push_back
(
old_lens
[
i
]);
auto
dd
=
input_shape
.
dyn_dims
()[
i
];
new_strides
.
push_back
(
old_strides
[
i
]);
if
(
not
contains
(
one_dyn_dims
,
dd
))
{
dyn_dims
.
push_back
(
dd
);
}
}
}
}
}
}
else
else
{
for
(
auto
i
:
range
(
old_lens
.
size
()))
{
{
if
(
std
::
find
(
axes
.
begin
(),
axes
.
end
(),
i
)
==
axes
.
end
(
))
for
(
auto
i
:
range
(
input_shape
.
ndim
()
))
{
{
new_lens
.
push_back
(
old_lens
[
i
]);
if
(
std
::
find
(
axes
.
begin
(),
axes
.
end
(),
i
)
==
axes
.
end
())
new_strides
.
push_back
(
old_strides
[
i
]);
{
dyn_dims
.
push_back
(
input_shape
.
dyn_dims
()[
i
]);
}
}
}
}
}
}
return
{
input_shape
.
type
(),
dyn_dims
};
if
(
new_lens
.
empty
())
{
return
shape
{
type
};
}
}
else
else
{
{
return
shape
{
type
,
new_lens
,
new_strides
};
auto
type
=
input_shape
.
type
();
auto
old_lens
=
input_shape
.
lens
();
auto
old_strides
=
input_shape
.
strides
();
if
(
std
::
any_of
(
axes
.
begin
(),
axes
.
end
(),
[
&
](
auto
axis
)
{
return
old_lens
[
axis
]
!=
1
;
}))
{
MIGRAPHX_THROW
(
"SQUEEZE: static axis dimension should be equal to 1"
);
}
std
::
vector
<
std
::
size_t
>
new_lens
;
std
::
vector
<
std
::
size_t
>
new_strides
;
if
(
axes
.
empty
())
{
for
(
auto
i
:
range
(
old_lens
.
size
()))
{
if
(
old_lens
[
i
]
!=
1
)
{
new_lens
.
push_back
(
old_lens
[
i
]);
new_strides
.
push_back
(
old_strides
[
i
]);
}
}
}
else
{
for
(
auto
i
:
range
(
old_lens
.
size
()))
{
if
(
std
::
find
(
axes
.
begin
(),
axes
.
end
(),
i
)
==
axes
.
end
())
{
new_lens
.
push_back
(
old_lens
[
i
]);
new_strides
.
push_back
(
old_strides
[
i
]);
}
}
}
if
(
new_lens
.
empty
())
{
return
shape
{
type
};
}
else
{
return
shape
{
type
,
new_lens
,
new_strides
};
}
}
}
}
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
{
return
args
[
0
].
reshape
(
out
put_shape
);
return
args
[
0
].
reshape
(
dyn_out
.
com
put
ed
_shape
);
}
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
};
...
...
src/include/migraphx/op/transpose.hpp
View file @
31065c7d
...
@@ -59,7 +59,7 @@ struct transpose
...
@@ -59,7 +59,7 @@ struct transpose
}
}
std
::
vector
<
int64_t
>
axes
(
dims
.
size
());
std
::
vector
<
int64_t
>
axes
(
dims
.
size
());
std
::
iota
(
axes
.
begin
(),
axes
.
end
(),
0
);
std
::
iota
(
axes
.
begin
(),
axes
.
end
(),
0
);
if
(
!
std
::
is_permutation
(
axes
.
begin
(),
axes
.
end
(),
dims
.
begin
()))
if
(
not
std
::
is_permutation
(
axes
.
begin
(),
axes
.
end
(),
dims
.
begin
()))
{
{
MIGRAPHX_THROW
(
"TRANSPOSE: Invalid permutation"
);
MIGRAPHX_THROW
(
"TRANSPOSE: Invalid permutation"
);
}
}
...
...
src/include/migraphx/op/unary.hpp
View file @
31065c7d
...
@@ -30,6 +30,7 @@
...
@@ -30,6 +30,7 @@
#include <migraphx/argument.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/value.hpp>
#include <migraphx/value.hpp>
#include <migraphx/dyn_output.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -62,9 +63,9 @@ struct unary : op_name<Derived>
...
@@ -62,9 +63,9 @@ struct unary : op_name<Derived>
value
attributes
()
const
{
return
base_attributes
();
}
value
attributes
()
const
{
return
base_attributes
();
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
static_cast
<
const
Derived
&>
(
*
this
)}.
has
(
1
);
check_shapes
{
inputs
,
static_cast
<
const
Derived
&>
(
*
this
)
,
true
}.
has
(
1
);
auto
s
=
inputs
.
at
(
0
);
auto
s
=
inputs
.
at
(
0
);
if
(
s
.
scalar
())
if
(
s
.
dynamic
()
or
s
.
scalar
())
{
{
return
s
;
return
s
;
}
}
...
@@ -78,9 +79,9 @@ struct unary : op_name<Derived>
...
@@ -78,9 +79,9 @@ struct unary : op_name<Derived>
}
}
}
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
{
argument
result
{
out
put_shape
};
argument
result
{
dyn_out
.
com
put
ed
_shape
};
result
.
visit
([
&
](
auto
output
)
{
result
.
visit
([
&
](
auto
output
)
{
args
[
0
].
visit
([
&
](
auto
input
)
{
args
[
0
].
visit
([
&
](
auto
input
)
{
std
::
transform
(
input
.
begin
(),
std
::
transform
(
input
.
begin
(),
...
...
src/include/migraphx/op/unsqueeze.hpp
View file @
31065c7d
...
@@ -29,11 +29,20 @@
...
@@ -29,11 +29,20 @@
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/dyn_output.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
namespace
op
{
/**
* Adds dimensions to a tensor based on the axes attribute.
* `axes` are based on the number of output shape dimensions and should not contain duplicates.
* `steps` are for modifying dimensions added to the middle of the original shape.
* Each step must be a factor of the original dimension.
* ex: unsqueeze(shape = [3, 4, 10], axes = [2, 4, 5], steps = [2]) -> shape = [3, 4, 2, 5, 1, 1]
* Dynamic shape version does not handle `steps`.
*/
struct
unsqueeze
struct
unsqueeze
{
{
std
::
vector
<
int64_t
>
axes
;
std
::
vector
<
int64_t
>
axes
;
...
@@ -56,63 +65,89 @@ struct unsqueeze
...
@@ -56,63 +65,89 @@ struct unsqueeze
std
::
string
name
()
const
{
return
"unsqueeze"
;
}
std
::
string
name
()
const
{
return
"unsqueeze"
;
}
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
);
auto
input_shape
=
inputs
[
0
];
auto
input_shape
=
inputs
[
0
];
auto
type
=
input_shape
.
type
();
auto
old_lens
=
input_shape
.
lens
();
if
(
input_shape
.
dynamic
())
auto
old_strides
=
input_shape
.
strides
();
if
(
input_shape
.
scalar
())
{
{
if
(
old_lens
.
size
()
==
1
and
old_lens
.
front
()
==
1
)
if
(
not
steps
.
empty
())
return
shape
{
type
,
old_lens
};
{
else
MIGRAPHX_THROW
(
"UNSQUEEZE_dyn: nonempty steps attribute"
);
MIGRAPHX_THROW
(
"UNSQUEEZE: Input must be a scalar"
);
}
std
::
vector
<
shape
::
dynamic_dimension
>
dyn_dims
=
{};
auto
new_ndim
=
input_shape
.
ndim
()
+
axes
.
size
();
std
::
size_t
k
=
0
;
for
(
auto
i
:
range
(
new_ndim
))
{
if
(
std
::
find
(
axes
.
begin
(),
axes
.
end
(),
i
)
!=
axes
.
end
())
{
dyn_dims
.
push_back
({
1
,
1
,
0
});
}
else
{
dyn_dims
.
push_back
(
input_shape
.
dyn_dims
().
at
(
k
++
));
}
}
return
{
input_shape
.
type
(),
dyn_dims
};
}
}
else
{
auto
type
=
input_shape
.
type
();
auto
old_lens
=
input_shape
.
lens
();
auto
old_strides
=
input_shape
.
strides
();
if
(
input_shape
.
scalar
())
{
if
(
old_lens
.
size
()
==
1
and
old_lens
.
front
()
==
1
)
return
shape
{
type
,
old_lens
};
else
MIGRAPHX_THROW
(
"UNSQUEEZE: Input must be a scalar"
);
}
if
(
steps
.
size
()
>
axes
.
size
())
if
(
steps
.
size
()
>
axes
.
size
())
MIGRAPHX_THROW
(
"UNSQUEEZE: Steps provided with no axis"
);
MIGRAPHX_THROW
(
"UNSQUEEZE: Steps provided with no axis"
);
std
::
size_t
new_size
=
old_lens
.
size
()
+
axes
.
size
();
std
::
size_t
new_size
=
old_lens
.
size
()
+
axes
.
size
();
std
::
vector
<
std
::
size_t
>
new_lens
(
new_size
);
std
::
vector
<
std
::
size_t
>
new_lens
(
new_size
);
std
::
vector
<
std
::
size_t
>
new_strides
(
new_size
);
std
::
vector
<
std
::
size_t
>
new_strides
(
new_size
);
std
::
size_t
p
=
0
;
std
::
size_t
p
=
0
;
for
(
auto
i
:
range
(
new_size
))
for
(
auto
i
:
range
(
new_size
))
{
auto
axis_idx
=
std
::
find
(
axes
.
begin
(),
axes
.
end
(),
i
)
-
axes
.
begin
();
if
(
axis_idx
<
axes
.
size
())
{
{
std
::
int64_t
step
=
1
;
auto
axis_idx
=
std
::
find
(
axes
.
begin
(),
axes
.
end
(),
i
)
-
axes
.
begin
();
if
(
axis_idx
<
steps
.
size
())
if
(
axis_idx
<
axes
.
size
())
step
=
steps
[
axis_idx
];
if
(
step
==
0
)
MIGRAPHX_THROW
(
"UNSQUEEZE: step must be non-zero"
);
new_lens
[
i
]
=
step
;
if
(
p
<
old_strides
.
size
())
{
{
if
((
old_lens
[
p
]
%
step
)
!=
0
)
std
::
int64_t
step
=
1
;
MIGRAPHX_THROW
(
"UNSQUEEZE: Axis dimenstion is not divisible by step"
);
if
(
axis_idx
<
steps
.
size
())
old_lens
[
p
]
/=
step
;
step
=
steps
[
axis_idx
];
new_strides
[
i
]
=
old_strides
[
p
]
*
old_lens
[
p
];
if
(
step
==
0
)
MIGRAPHX_THROW
(
"UNSQUEEZE: step must be non-zero"
);
new_lens
[
i
]
=
step
;
if
(
p
<
old_strides
.
size
())
{
if
((
old_lens
[
p
]
%
step
)
!=
0
)
MIGRAPHX_THROW
(
"UNSQUEEZE: Axis dimenstion is not divisible by step"
);
old_lens
[
p
]
/=
step
;
new_strides
[
i
]
=
old_strides
[
p
]
*
old_lens
[
p
];
}
else
{
if
(
step
!=
1
)
MIGRAPHX_THROW
(
"UNSQUEEZE: Step must be 1 for extra axes"
);
new_strides
[
i
]
=
1
;
}
}
}
else
else
{
{
if
(
step
!=
1
)
new_lens
[
i
]
=
old_lens
[
p
];
MIGRAPHX_THROW
(
"UNSQUEEZE: Step must be 1 for extra axes"
);
new_strides
[
i
]
=
old_strides
[
p
++
];
new_strides
[
i
]
=
1
;
}
}
}
}
else
return
shape
{
type
,
new_lens
,
new_strides
};
{
new_lens
[
i
]
=
old_lens
[
p
];
new_strides
[
i
]
=
old_strides
[
p
++
];
}
}
}
return
shape
{
type
,
new_lens
,
new_strides
};
}
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
{
return
args
[
0
].
reshape
(
out
put_shape
);
return
args
[
0
].
reshape
(
dyn_out
.
com
put
ed
_shape
);
}
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
};
...
...
src/include/migraphx/operation.hpp
View file @
31065c7d
...
@@ -32,6 +32,8 @@
...
@@ -32,6 +32,8 @@
#include <utility>
#include <utility>
#include <unordered_map>
#include <unordered_map>
#include <migraphx/reflect.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/normalize_attributes.hpp>
#include <migraphx/normalize_attributes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/argument.hpp>
...
@@ -199,9 +201,12 @@ auto compute_op(rank<1>,
...
@@ -199,9 +201,12 @@ auto compute_op(rank<1>,
context
&
ctx
,
context
&
ctx
,
const
shape
&
output_shape
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
const
std
::
vector
<
argument
>&
input
)
->
decltype
(
x
.
compute
(
auto_any_cast
(
ctx
),
output_shape
,
input
))
->
decltype
(
x
.
compute
(
auto_any_cast
(
ctx
),
make_compute_output_shape
(
pack
(
x
,
output_shape
,
input
)),
input
))
{
{
return
x
.
compute
(
auto_any_cast
(
ctx
),
output_shape
,
input
);
return
x
.
compute
(
auto_any_cast
(
ctx
),
make_compute_output_shape
(
pack
(
x
,
output_shape
,
input
)),
input
);
}
}
template
<
class
T
>
template
<
class
T
>
...
@@ -220,9 +225,9 @@ compute_op(const T& x, context& ctx, const shape& output_shape, const std::vecto
...
@@ -220,9 +225,9 @@ compute_op(const T& x, context& ctx, const shape& output_shape, const std::vecto
template
<
class
T
>
template
<
class
T
>
auto
compute_op
(
rank
<
1
>
,
const
T
&
x
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
auto
compute_op
(
rank
<
1
>
,
const
T
&
x
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
->
decltype
(
x
.
compute
(
output_shape
,
input
))
->
decltype
(
x
.
compute
(
make_compute_output_shape
(
pack
(
x
,
output_shape
,
input
))
,
input
))
{
{
return
x
.
compute
(
output_shape
,
input
);
return
x
.
compute
(
make_compute_output_shape
(
pack
(
x
,
output_shape
,
input
))
,
input
);
}
}
template
<
class
T
>
template
<
class
T
>
...
@@ -244,9 +249,11 @@ auto compute_op(rank<1>,
...
@@ -244,9 +249,11 @@ auto compute_op(rank<1>,
const
shape
&
output
,
const
shape
&
output
,
const
std
::
vector
<
argument
>&
inputs
,
const
std
::
vector
<
argument
>&
inputs
,
const
std
::
vector
<
module_ref
>&
module_args
,
const
std
::
vector
<
module_ref
>&
module_args
,
F
f
)
->
decltype
(
x
.
compute
(
output
,
inputs
,
module_args
,
f
))
F
f
)
->
decltype
(
x
.
compute
(
make_compute_output_shape
(
pack
(
x
,
output
,
inputs
)),
inputs
,
module_args
,
f
))
{
{
return
x
.
compute
(
output
,
inputs
,
module_args
,
f
);
return
x
.
compute
(
make_compute_output_shape
(
pack
(
x
,
output
,
inputs
))
,
inputs
,
module_args
,
f
);
}
}
template
<
class
T
,
class
F
>
template
<
class
T
,
class
F
>
...
@@ -278,9 +285,17 @@ auto compute_op(rank<4>,
...
@@ -278,9 +285,17 @@ auto compute_op(rank<4>,
const
shape
&
output
,
const
shape
&
output
,
const
std
::
vector
<
argument
>&
inputs
,
const
std
::
vector
<
argument
>&
inputs
,
const
std
::
vector
<
module_ref
>&
module_args
,
const
std
::
vector
<
module_ref
>&
module_args
,
F
f
)
->
decltype
(
x
.
compute
(
auto_any_cast
(
ctx
),
output
,
inputs
,
module_args
,
f
))
F
f
)
->
decltype
(
x
.
compute
(
auto_any_cast
(
ctx
),
make_compute_output_shape
(
pack
(
x
,
output
,
inputs
)),
inputs
,
module_args
,
f
))
{
{
return
x
.
compute
(
auto_any_cast
(
ctx
),
output
,
inputs
,
module_args
,
f
);
return
x
.
compute
(
auto_any_cast
(
ctx
),
make_compute_output_shape
(
pack
(
x
,
output
,
inputs
)),
inputs
,
module_args
,
f
);
}
}
template
<
class
T
,
class
F
>
template
<
class
T
,
class
F
>
...
@@ -290,9 +305,11 @@ auto compute_op(rank<3>,
...
@@ -290,9 +305,11 @@ auto compute_op(rank<3>,
const
shape
&
output
,
const
shape
&
output
,
const
std
::
vector
<
argument
>&
inputs
,
const
std
::
vector
<
argument
>&
inputs
,
const
std
::
vector
<
module_ref
>&
module_args
,
const
std
::
vector
<
module_ref
>&
module_args
,
F
f
)
->
decltype
(
x
.
compute
(
output
,
inputs
,
module_args
,
f
))
F
f
)
->
decltype
(
x
.
compute
(
make_compute_output_shape
(
pack
(
x
,
output
,
inputs
)),
inputs
,
module_args
,
f
))
{
{
return
x
.
compute
(
output
,
inputs
,
module_args
,
f
);
return
x
.
compute
(
make_compute_output_shape
(
pack
(
x
,
output
,
inputs
))
,
inputs
,
module_args
,
f
);
}
}
template
<
class
T
,
class
F
>
template
<
class
T
,
class
F
>
...
@@ -302,9 +319,10 @@ auto compute_op(rank<2>,
...
@@ -302,9 +319,10 @@ auto compute_op(rank<2>,
const
shape
&
output
,
const
shape
&
output
,
const
std
::
vector
<
argument
>&
inputs
,
const
std
::
vector
<
argument
>&
inputs
,
const
std
::
vector
<
module_ref
>&
,
const
std
::
vector
<
module_ref
>&
,
F
)
->
decltype
(
x
.
compute
(
output
,
inputs
))
F
)
->
decltype
(
x
.
compute
(
make_compute_output_shape
(
pack
(
x
,
output
,
inputs
)),
inputs
))
{
{
return
x
.
compute
(
output
,
inputs
);
return
x
.
compute
(
make_compute_output_shape
(
pack
(
x
,
output
,
inputs
))
,
inputs
);
}
}
template
<
class
T
,
class
F
>
template
<
class
T
,
class
F
>
...
@@ -314,9 +332,12 @@ auto compute_op(rank<1>,
...
@@ -314,9 +332,12 @@ auto compute_op(rank<1>,
const
shape
&
output
,
const
shape
&
output
,
const
std
::
vector
<
argument
>&
inputs
,
const
std
::
vector
<
argument
>&
inputs
,
const
std
::
vector
<
module_ref
>&
,
const
std
::
vector
<
module_ref
>&
,
F
)
->
decltype
(
x
.
compute
(
auto_any_cast
(
ctx
),
output
,
inputs
))
F
)
->
decltype
(
x
.
compute
(
auto_any_cast
(
ctx
),
make_compute_output_shape
(
pack
(
x
,
output
,
inputs
)),
inputs
))
{
{
return
x
.
compute
(
auto_any_cast
(
ctx
),
output
,
inputs
);
return
x
.
compute
(
auto_any_cast
(
ctx
),
make_compute_output_shape
(
pack
(
x
,
output
,
inputs
)),
inputs
);
}
}
template
<
class
T
,
class
F
>
template
<
class
T
,
class
F
>
...
@@ -348,7 +369,8 @@ auto is_context_free_op(rank<1>,
...
@@ -348,7 +369,8 @@ auto is_context_free_op(rank<1>,
const
T
&
x
,
const
T
&
x
,
const
shape
&
output_shape
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
const
std
::
vector
<
argument
>&
input
)
->
decltype
(
x
.
compute
(
output_shape
,
input
),
std
::
true_type
{});
->
decltype
(
x
.
compute
(
make_compute_output_shape
(
pack
(
x
,
output_shape
,
input
)),
input
),
std
::
true_type
{});
template
<
class
T
>
template
<
class
T
>
auto
is_context_free_op
(
rank
<
0
>
,
const
T
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
auto
is_context_free_op
(
rank
<
0
>
,
const
T
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
...
@@ -1066,7 +1088,7 @@ struct operation
...
@@ -1066,7 +1088,7 @@ struct operation
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
private_detail_te_handle_type
(
private_detail_te_handle_type
(
PrivateDetailTypeErasedT
value
,
PrivateDetailTypeErasedT
value
,
typename
std
::
enable_if
<
!
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
typename
std
::
enable_if
<
not
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
int
>::
type
*
=
nullptr
)
noexcept
int
>::
type
*
=
nullptr
)
noexcept
:
private_detail_te_value
(
std
::
move
(
value
))
:
private_detail_te_value
(
std
::
move
(
value
))
{
{
...
@@ -1237,7 +1259,7 @@ struct operation
...
@@ -1237,7 +1259,7 @@ struct operation
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
{
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
if
(
!
private_detail_te_handle_mem_var
.
unique
())
if
(
not
private_detail_te_handle_mem_var
.
unique
())
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
return
*
private_detail_te_handle_mem_var
;
return
*
private_detail_te_handle_mem_var
;
}
}
...
@@ -1276,7 +1298,7 @@ inline const ValueType& any_cast(const operation& x)
...
@@ -1276,7 +1298,7 @@ inline const ValueType& any_cast(const operation& x)
}
}
#endif
#endif
inline
bool
operator
!=
(
const
operation
&
x
,
const
operation
&
y
)
{
return
!
(
x
==
y
);
}
inline
bool
operator
!=
(
const
operation
&
x
,
const
operation
&
y
)
{
return
not
(
x
==
y
);
}
inline
value
inline
value
compile
(
operation
&
op
,
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
shape
>&
input
)
compile
(
operation
&
op
,
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
shape
>&
input
)
...
...
src/include/migraphx/operators.hpp
View file @
31065c7d
...
@@ -35,7 +35,6 @@
...
@@ -35,7 +35,6 @@
#include <migraphx/op/as_shape.hpp>
#include <migraphx/op/as_shape.hpp>
#include <migraphx/op/atan.hpp>
#include <migraphx/op/atan.hpp>
#include <migraphx/op/atanh.hpp>
#include <migraphx/op/atanh.hpp>
#include <migraphx/op/batch_norm_inference.hpp>
#include <migraphx/op/binary.hpp>
#include <migraphx/op/binary.hpp>
#include <migraphx/op/broadcast.hpp>
#include <migraphx/op/broadcast.hpp>
#include <migraphx/op/capture.hpp>
#include <migraphx/op/capture.hpp>
...
...
src/include/migraphx/pad_calc.hpp
View file @
31065c7d
...
@@ -24,9 +24,10 @@
...
@@ -24,9 +24,10 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_PAD_CALC_HPP
#ifndef MIGRAPHX_GUARD_OPERATORS_PAD_CALC_HPP
#define MIGRAPHX_GUARD_OPERATORS_PAD_CALC_HPP
#define MIGRAPHX_GUARD_OPERATORS_PAD_CALC_HPP
#include <migraphx/config.hpp>
#include <cstdint>
#include <cstdint>
#include <vector>
#include <vector>
#include <migraphx/config.hpp>
#include <migraphx/shape.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -42,18 +43,21 @@ void calculate_padding(int64_t idx,
...
@@ -42,18 +43,21 @@ void calculate_padding(int64_t idx,
/*!
/*!
* Calculate the padding for auto_padding. Used for dynamic shapes
* Calculate the padding for auto_padding. Used for dynamic shapes
* where the padding calculation must be done at evaluation time.
* where the padding calculation must be done at evaluation time.
* \param tensor_lens input tensor image shape
* \param k_lens weights kernel shape
* \param strides strides for the kernel
* \param dilations dilations for the kernel
* \param use_upper put odd padding on upper or lower side
* \return padding in the form of {x0_begin, x1_begin, ... x0_end , x1_end, ...}
* \return padding in the form of {x0_begin, x1_begin, ... x0_end , x1_end, ...}
*/
*/
std
::
vector
<
std
::
size_t
>
calc_dyn_auto_pad
(
std
::
vector
<
std
::
size_t
>
tensor_lens
,
std
::
vector
<
std
::
size_t
>
calc_dyn_auto_pad
(
const
std
::
vector
<
std
::
size_t
>&
input_lens
,
std
::
vector
<
std
::
size_t
>
k_lens
,
const
std
::
vector
<
std
::
size_t
>&
wei_lens
,
std
::
vector
<
std
::
size_t
>
strides
,
const
std
::
vector
<
std
::
size_t
>&
strides
,
std
::
vector
<
std
::
size_t
>
dilations
,
const
std
::
vector
<
std
::
size_t
>&
dilations
,
bool
use_upper
=
true
);
bool
use_upper
);
// Used for dynamic auto padding of convolution operators since padding needs to be computed at
// evaulation time.
shape
compute_padded_shape
(
const
shape
&
input
,
const
shape
&
weights
,
const
std
::
vector
<
std
::
size_t
>&
padding
,
const
std
::
vector
<
std
::
size_t
>&
stride
,
const
std
::
vector
<
std
::
size_t
>&
dilation
);
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
...
...
src/include/migraphx/pass.hpp
View file @
31065c7d
...
@@ -238,7 +238,7 @@ struct pass
...
@@ -238,7 +238,7 @@ struct pass
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
private_detail_te_handle_type
(
private_detail_te_handle_type
(
PrivateDetailTypeErasedT
value
,
PrivateDetailTypeErasedT
value
,
typename
std
::
enable_if
<
!
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
typename
std
::
enable_if
<
not
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
int
>::
type
*
=
nullptr
)
noexcept
int
>::
type
*
=
nullptr
)
noexcept
:
private_detail_te_value
(
std
::
move
(
value
))
:
private_detail_te_value
(
std
::
move
(
value
))
{
{
...
@@ -292,7 +292,7 @@ struct pass
...
@@ -292,7 +292,7 @@ struct pass
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
{
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
if
(
!
private_detail_te_handle_mem_var
.
unique
())
if
(
not
private_detail_te_handle_mem_var
.
unique
())
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
return
*
private_detail_te_handle_mem_var
;
return
*
private_detail_te_handle_mem_var
;
}
}
...
...
src/include/migraphx/program.hpp
View file @
31065c7d
...
@@ -37,6 +37,7 @@
...
@@ -37,6 +37,7 @@
#include <migraphx/assignment_options.hpp>
#include <migraphx/assignment_options.hpp>
#include <migraphx/env.hpp>
#include <migraphx/env.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/execution_environment.hpp>
#include <algorithm>
#include <algorithm>
#include <iostream>
#include <iostream>
...
@@ -76,8 +77,8 @@ struct program
...
@@ -76,8 +77,8 @@ struct program
std
::
unordered_map
<
std
::
string
,
shape
>
get_parameter_shapes
()
const
;
std
::
unordered_map
<
std
::
string
,
shape
>
get_parameter_shapes
()
const
;
std
::
vector
<
argument
>
eval
(
parameter_map
params
)
const
;
std
::
vector
<
argument
>
eval
(
parameter_map
params
,
execution_environment
exec_env
=
execution_environment
{})
const
;
std
::
size_t
size
()
const
;
std
::
size_t
size
()
const
;
std
::
vector
<
shape
>
get_output_shapes
()
const
;
std
::
vector
<
shape
>
get_output_shapes
()
const
;
...
@@ -124,7 +125,7 @@ struct program
...
@@ -124,7 +125,7 @@ struct program
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
program
&
p
);
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
program
&
p
);
friend
bool
operator
==
(
const
program
&
x
,
const
program
&
y
);
friend
bool
operator
==
(
const
program
&
x
,
const
program
&
y
);
friend
bool
operator
!=
(
const
program
&
x
,
const
program
&
y
)
{
return
!
(
x
==
y
);
}
friend
bool
operator
!=
(
const
program
&
x
,
const
program
&
y
)
{
return
not
(
x
==
y
);
}
// module related api
// module related api
module
*
create_module
(
const
std
::
string
&
name
);
module
*
create_module
(
const
std
::
string
&
name
);
...
...
src/include/migraphx/raw_data.hpp
View file @
31065c7d
...
@@ -147,7 +147,7 @@ struct raw_data : raw_data_base
...
@@ -147,7 +147,7 @@ struct raw_data : raw_data_base
template
<
class
T
>
template
<
class
T
>
bool
matches
()
const
bool
matches
()
const
{
{
return
is_data_ptr
<
T
>
{}
||
return
is_data_ptr
<
T
>
{}
or
self
->
get_shape
().
type
()
==
migraphx
::
shape
::
get_type
<
get_data_type
<
T
>>
{};
self
->
get_shape
().
type
()
==
migraphx
::
shape
::
get_type
<
get_data_type
<
T
>>
{};
}
}
...
@@ -232,7 +232,7 @@ auto visit_all(T&& x, Ts&&... xs)
...
@@ -232,7 +232,7 @@ auto visit_all(T&& x, Ts&&... xs)
{
{
auto
&&
s
=
x
.
get_shape
();
auto
&&
s
=
x
.
get_shape
();
std
::
initializer_list
<
shape
::
type_t
>
types
=
{
xs
.
get_shape
().
type
()...};
std
::
initializer_list
<
shape
::
type_t
>
types
=
{
xs
.
get_shape
().
type
()...};
if
(
!
std
::
all_of
(
types
.
begin
(),
types
.
end
(),
[
&
](
shape
::
type_t
t
)
{
return
t
==
s
.
type
();
}))
if
(
not
std
::
all_of
(
types
.
begin
(),
types
.
end
(),
[
&
](
shape
::
type_t
t
)
{
return
t
==
s
.
type
();
}))
MIGRAPHX_THROW
(
"Types must be the same"
);
MIGRAPHX_THROW
(
"Types must be the same"
);
return
[
&
](
auto
...
vs
)
{
detail
::
visit_all_pack
(
s
,
vs
...)(
x
,
xs
...);
};
return
[
&
](
auto
...
vs
)
{
detail
::
visit_all_pack
(
s
,
vs
...)(
x
,
xs
...);
};
}
}
...
@@ -241,7 +241,7 @@ template <class T>
...
@@ -241,7 +241,7 @@ template <class T>
auto
visit_all
(
const
std
::
vector
<
T
>&
x
)
auto
visit_all
(
const
std
::
vector
<
T
>&
x
)
{
{
auto
&&
s
=
x
.
front
().
get_shape
();
auto
&&
s
=
x
.
front
().
get_shape
();
if
(
!
std
::
all_of
(
if
(
not
std
::
all_of
(
x
.
begin
(),
x
.
end
(),
[
&
](
const
T
&
y
)
{
return
y
.
get_shape
().
type
()
==
s
.
type
();
}))
x
.
begin
(),
x
.
end
(),
[
&
](
const
T
&
y
)
{
return
y
.
get_shape
().
type
()
==
s
.
type
();
}))
MIGRAPHX_THROW
(
"Types must be the same"
);
MIGRAPHX_THROW
(
"Types must be the same"
);
return
[
&
](
auto
v
)
{
return
[
&
](
auto
v
)
{
...
@@ -281,7 +281,7 @@ template <class T,
...
@@ -281,7 +281,7 @@ template <class T,
std
::
is_base_of
<
raw_data_base
,
U
>
{})
>
std
::
is_base_of
<
raw_data_base
,
U
>
{})
>
bool
operator
!=
(
const
T
&
x
,
const
U
&
y
)
bool
operator
!=
(
const
T
&
x
,
const
U
&
y
)
{
{
return
!
(
x
==
y
);
return
not
(
x
==
y
);
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/reflect.hpp
View file @
31065c7d
...
@@ -56,11 +56,11 @@ auto reflect_impl(rank<0>, T&, Selector)
...
@@ -56,11 +56,11 @@ auto reflect_impl(rank<0>, T&, Selector)
}
}
template
<
class
T
>
template
<
class
T
>
auto
reflectable_impl
(
rank
<
1
>
,
T
&
&
x
)
auto
reflectable_impl
(
rank
<
1
>
,
const
T
&
x
)
->
decltype
(
T
::
reflect
(
x
,
reflect_placeholder
{}),
std
::
true_type
{});
->
decltype
(
T
::
reflect
(
x
,
reflect_placeholder
{}),
std
::
true_type
{});
template
<
class
T
>
template
<
class
T
>
auto
reflectable_impl
(
rank
<
0
>
,
T
&
&
)
->
decltype
(
std
::
false_type
{});
auto
reflectable_impl
(
rank
<
0
>
,
const
T
&
)
->
decltype
(
std
::
false_type
{});
template
<
class
T
>
template
<
class
T
>
struct
remove_rvalue_reference
struct
remove_rvalue_reference
...
@@ -111,8 +111,18 @@ auto reflect(T& x, Selector f)
...
@@ -111,8 +111,18 @@ auto reflect(T& x, Selector f)
template
<
class
T
>
template
<
class
T
>
auto
reflect_tie
(
T
&
x
)
auto
reflect_tie
(
T
&
x
)
{
{
return
reflect
(
x
,
[](
auto
&&
y
,
auto
&&
...)
{
return
detail
::
wrap
<
decltype
(
y
)
>
(
y
);
})(
return
reflect
(
x
,
[](
auto
&&
y
,
auto
&&
...)
{
[](
auto
&&
...
xs
)
{
return
detail
::
auto_tuple
(
xs
.
get
()...);
});
// cppcheck-suppress UnnecessaryElseStatement
if
constexpr
(
is_reflectable
<
decltype
(
y
)
>
{})
{
auto
t
=
reflect_tie
(
y
);
return
detail
::
wrap
<
decltype
(
t
)
>
(
t
);
}
else
{
return
detail
::
wrap
<
decltype
(
y
)
>
(
y
);
}
})([](
auto
&&
...
xs
)
{
return
detail
::
auto_tuple
(
xs
.
get
()...);
});
}
}
template
<
class
T
,
class
F
>
template
<
class
T
,
class
F
>
...
@@ -129,7 +139,7 @@ template <class T>
...
@@ -129,7 +139,7 @@ template <class T>
struct
reflect_equality
struct
reflect_equality
{
{
friend
bool
operator
==
(
const
T
&
x
,
const
T
&
y
)
{
return
reflect_tie
(
x
)
==
reflect_tie
(
y
);
}
friend
bool
operator
==
(
const
T
&
x
,
const
T
&
y
)
{
return
reflect_tie
(
x
)
==
reflect_tie
(
y
);
}
friend
bool
operator
!=
(
const
T
&
x
,
const
T
&
y
)
{
return
!
(
x
==
y
);
}
friend
bool
operator
!=
(
const
T
&
x
,
const
T
&
y
)
{
return
not
(
x
==
y
);
}
};
};
template
<
class
T
>
template
<
class
T
>
...
...
src/include/migraphx/requires.hpp
View file @
31065c7d
...
@@ -31,7 +31,7 @@ namespace migraphx {
...
@@ -31,7 +31,7 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
template
<
bool
...
Bs
>
template
<
bool
...
Bs
>
struct
and_
:
std
::
is_same
<
and_
<
Bs
...
>
,
and_
<
(
Bs
||
true
)...
>>
// NOLINT
struct
and_
:
std
::
is_same
<
and_
<
Bs
...
>
,
and_
<
(
Bs
or
true
)...
>>
// NOLINT
{
{
};
};
...
...
src/include/migraphx/rewrite_
batchnorm
.hpp
→
src/include/migraphx/rewrite_
gelu
.hpp
View file @
31065c7d
...
@@ -21,8 +21,8 @@
...
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
* THE SOFTWARE.
*/
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_
FWD_CONV_BATCHNORM_
REWRITE_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_REWRITE_
GELU_
HPP
#define MIGRAPHX_GUARD_RTGLIB_
FWD_CONV_BATCHNORM_
REWRITE_HPP
#define MIGRAPHX_GUARD_RTGLIB_REWRITE_
GELU_
HPP
#include <string>
#include <string>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/instruction_ref.hpp>
...
@@ -34,11 +34,11 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -34,11 +34,11 @@ inline namespace MIGRAPHX_INLINE_NS {
struct
module
;
struct
module
;
/**
/**
* Rewrite
batchnorm to a multiply and add.
* Rewrite
gelu standard formula as the sigmoid approximation formula
*/
*/
struct
rewrite_
batchnorm
struct
rewrite_
gelu
{
{
std
::
string
name
()
const
{
return
"rewrite_
batchnorm
"
;
}
std
::
string
name
()
const
{
return
"rewrite_
gelu
"
;
}
void
apply
(
module
&
m
)
const
;
void
apply
(
module
&
m
)
const
;
};
};
...
...
Prev
1
2
3
4
5
6
7
8
…
25
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment