Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
870a396b
Commit
870a396b
authored
Jan 23, 2023
by
Khalique Ahmed
Browse files
manual merge
parents
228b665c
d309e02f
Changes
473
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
379 additions
and
228 deletions
+379
-228
src/include/migraphx/match/layernorm.hpp
src/include/migraphx/match/layernorm.hpp
+2
-2
src/include/migraphx/module.hpp
src/include/migraphx/module.hpp
+6
-0
src/include/migraphx/op/argmax.hpp
src/include/migraphx/op/argmax.hpp
+19
-11
src/include/migraphx/op/batch_norm_inference.hpp
src/include/migraphx/op/batch_norm_inference.hpp
+0
-70
src/include/migraphx/op/binary.hpp
src/include/migraphx/op/binary.hpp
+14
-4
src/include/migraphx/op/broadcast.hpp
src/include/migraphx/op/broadcast.hpp
+90
-33
src/include/migraphx/op/common.hpp
src/include/migraphx/op/common.hpp
+2
-2
src/include/migraphx/op/contiguous.hpp
src/include/migraphx/op/contiguous.hpp
+18
-9
src/include/migraphx/op/convert.hpp
src/include/migraphx/op/convert.hpp
+1
-1
src/include/migraphx/op/convolution.hpp
src/include/migraphx/op/convolution.hpp
+8
-16
src/include/migraphx/op/deconvolution.hpp
src/include/migraphx/op/deconvolution.hpp
+2
-2
src/include/migraphx/op/dot.hpp
src/include/migraphx/op/dot.hpp
+51
-22
src/include/migraphx/op/elu.hpp
src/include/migraphx/op/elu.hpp
+9
-5
src/include/migraphx/op/flatten.hpp
src/include/migraphx/op/flatten.hpp
+39
-9
src/include/migraphx/op/fmod.hpp
src/include/migraphx/op/fmod.hpp
+0
-1
src/include/migraphx/op/layout.hpp
src/include/migraphx/op/layout.hpp
+23
-1
src/include/migraphx/op/leaky_relu.hpp
src/include/migraphx/op/leaky_relu.hpp
+7
-4
src/include/migraphx/op/mod.hpp
src/include/migraphx/op/mod.hpp
+1
-1
src/include/migraphx/op/multibroadcast.hpp
src/include/migraphx/op/multibroadcast.hpp
+66
-25
src/include/migraphx/op/pad.hpp
src/include/migraphx/op/pad.hpp
+21
-10
No files found.
src/include/migraphx/match/layernorm.hpp
View file @
870a396b
...
@@ -50,8 +50,8 @@ struct layernorm_matcher
...
@@ -50,8 +50,8 @@ struct layernorm_matcher
{
{
return
f
(
"div"
)(
arg
(
0
)(
x_minus_mean
()),
return
f
(
"div"
)(
arg
(
0
)(
x_minus_mean
()),
arg
(
1
)(
skip_broadcasts
(
f
(
"sqrt"
)(
arg
(
1
)(
skip_broadcasts
(
f
(
"sqrt"
)(
arg
(
0
)(
arg
(
0
)(
f
(
"add"
)(
either_arg
(
0
,
1
)(
variance
(),
has_value
(
1e-12
f
))))))));
f
(
"add"
)(
either_arg
(
0
,
1
)(
variance
(),
is_constant
().
bind
(
"eps"
))))))));
}
}
auto
matcher
()
const
{
return
layernorm_onnx
();
}
auto
matcher
()
const
{
return
layernorm_onnx
();
}
...
...
src/include/migraphx/module.hpp
View file @
870a396b
...
@@ -205,6 +205,12 @@ struct module
...
@@ -205,6 +205,12 @@ struct module
void
print_graph
(
std
::
ostream
&
os
,
bool
brief
=
false
)
const
;
void
print_graph
(
std
::
ostream
&
os
,
bool
brief
=
false
)
const
;
void
print_py
(
std
::
ostream
&
os
)
const
;
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
print_py
(
std
::
ostream
&
os
,
const
std
::
string
&
mname
,
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
names
)
const
;
void
print_cpp
(
std
::
ostream
&
os
)
const
;
void
print_cpp
(
std
::
ostream
&
os
)
const
;
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
print_cpp
(
std
::
ostream
&
os
,
print_cpp
(
std
::
ostream
&
os
,
...
...
src/include/migraphx/op/argmax.hpp
View file @
870a396b
...
@@ -30,6 +30,7 @@
...
@@ -30,6 +30,7 @@
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/dyn_output.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -56,12 +57,20 @@ struct argmax
...
@@ -56,12 +57,20 @@ struct argmax
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
);
auto
lens
=
inputs
[
0
].
lens
();
const
auto
&
s0
=
inputs
[
0
];
if
(
s0
.
dynamic
())
lens
[
axis
]
=
1
;
{
auto
dyn_dims
=
s0
.
dyn_dims
();
return
{
shape
::
int64_type
,
lens
};
dyn_dims
[
axis
]
=
{
1
,
1
,
0
};
return
{
shape
::
int64_type
,
dyn_dims
};
}
else
{
auto
lens
=
s0
.
lens
();
lens
[
axis
]
=
1
;
return
{
shape
::
int64_type
,
lens
};
}
}
}
template
<
class
T
>
template
<
class
T
>
...
@@ -79,19 +88,18 @@ struct argmax
...
@@ -79,19 +88,18 @@ struct argmax
max_index
=
i
;
max_index
=
i
;
}
}
}
}
return
max_index
;
return
max_index
;
}
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
{
argument
result
{
out
put_shape
};
argument
result
{
dyn_out
.
com
put
ed
_shape
};
auto
batch_item_num
=
args
.
front
().
get_shape
().
lens
()[
axis
];
auto
batch_item_num
=
args
.
front
().
get_shape
().
lens
()[
axis
];
result
.
visit
([
&
](
auto
output
)
{
result
.
visit
([
&
](
auto
output
)
{
args
[
0
].
visit
([
&
](
auto
input
)
{
args
[
0
].
visit
([
&
](
auto
input
)
{
par_for
(
out
put_shape
.
elements
(),
[
&
](
auto
i
)
{
par_for
(
dyn_out
.
com
put
ed
_shape
.
elements
(),
[
&
](
auto
i
)
{
auto
data_idx
=
out
put_shape
.
multi
(
i
);
auto
data_idx
=
dyn_out
.
com
put
ed
_shape
.
multi
(
i
);
output
[
i
]
=
this
->
calc_argmax
(
input
,
data_idx
,
batch_item_num
);
output
[
i
]
=
this
->
calc_argmax
(
input
,
data_idx
,
batch_item_num
);
});
});
});
});
...
...
src/include/migraphx/op/batch_norm_inference.hpp
deleted
100644 → 0
View file @
228b665c
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_OPERATORS_BATCH_NORM_HPP
#define MIGRAPHX_GUARD_OPERATORS_BATCH_NORM_HPP
#include <migraphx/check_shapes.hpp>
#include <migraphx/config.hpp>
#include <cmath>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
batch_norm_inference
{
float
epsilon
=
1.0e-6
f
;
float
momentum
=
0.9
f
;
std
::
string
name
()
const
{
return
"batch_norm_inference"
;
}
enum
bn_infer_mode_t
{
per_activation
,
spatial
,
};
bn_infer_mode_t
bn_mode
=
spatial
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
epsilon
,
"epsilon"
),
f
(
self
.
momentum
,
"momentum"
),
f
(
self
.
bn_mode
,
"bn_mode"
));
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
5
);
check_shapes
{
inputs
.
data
(),
inputs
.
data
()
+
1
,
*
this
}.
same_ndims
();
check_shapes
{
inputs
.
data
()
+
1
,
inputs
.
data
()
+
inputs
.
size
(),
*
this
}.
same_shape
();
return
inputs
.
front
();
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/binary.hpp
View file @
870a396b
...
@@ -28,6 +28,7 @@
...
@@ -28,6 +28,7 @@
#include <migraphx/check_shapes.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/value.hpp>
#include <migraphx/value.hpp>
#include <migraphx/dyn_output.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -60,10 +61,19 @@ struct binary : op_name<Derived>
...
@@ -60,10 +61,19 @@ struct binary : op_name<Derived>
value
attributes
()
const
{
return
base_attributes
();
}
value
attributes
()
const
{
return
base_attributes
();
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
static_cast
<
const
Derived
&>
(
*
this
)}.
has
(
2
).
same_type
().
same_dims
();
check_shapes
{
inputs
,
static_cast
<
const
Derived
&>
(
*
this
),
true
}
.
has
(
2
)
.
same_type
()
.
same_dims
();
auto
s0
=
inputs
.
at
(
0
);
auto
s0
=
inputs
.
at
(
0
);
auto
s1
=
inputs
.
at
(
1
);
auto
s1
=
inputs
.
at
(
1
);
if
(
s0
==
s1
and
s0
.
packed
())
if
(
s0
.
dynamic
()
or
s1
.
dynamic
())
{
if
(
s0
==
s1
)
return
s0
;
MIGRAPHX_THROW
(
"BINARY: "
+
point_function
()
+
": fixed-dyn shape for inputs"
);
}
else
if
(
s0
==
s1
and
s0
.
packed
())
{
{
return
s0
;
return
s0
;
}
}
...
@@ -81,9 +91,9 @@ struct binary : op_name<Derived>
...
@@ -81,9 +91,9 @@ struct binary : op_name<Derived>
}
}
}
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
{
argument
result
{
out
put_shape
};
argument
result
{
dyn_out
.
com
put
ed
_shape
};
visit_all
(
result
,
args
[
0
],
args
[
1
])([
&
](
auto
output
,
auto
input1
,
auto
input2
)
{
visit_all
(
result
,
args
[
0
],
args
[
1
])([
&
](
auto
output
,
auto
input1
,
auto
input2
)
{
std
::
transform
(
input1
.
begin
(),
std
::
transform
(
input1
.
begin
(),
input1
.
end
(),
input1
.
end
(),
...
...
src/include/migraphx/op/broadcast.hpp
View file @
870a396b
...
@@ -27,23 +27,30 @@
...
@@ -27,23 +27,30 @@
#include <migraphx/check_shapes.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/dyn_output.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
namespace
op
{
/// The broadcast operator performs the numpy-style broadcasting of an axis of a given tensor. This
/**
/// is achieved primarily by setting the stride of the broadcasted axis to zero. Linear indicies are
* 1 input version:
/// computed from multi-indicies by computing the inner product on the multi-index with the strides.
* Broadcasts a tensor from the original shape to the broadcast_lens by setting the stride of
/// For example, if we have a tensor A(2,3) it has lengths of (2,3) and strides of (3,1). If we want
* broadcasted dimensions to zero. `axis` attribute for a 1D input shape is the output dimension
/// to compute the linear offset that corresponds to the element on the 2nd row (i = 1) and 3rd
* that stays the same. ex: broadcasting shape [1024] -> [4, 1024, 3] has axis = 1 For higher rank
/// column (j = 2), we compute the following inner product (1,2) dot (3, 1) = 1*3 + 2*1 = 5. It is
* input shapes, axis is an offset parameter for the broadcasting. Such that this operator would
/// obvious from there that we can negate the effects of a given axis by setting the stride of that
* work in the opposite direction of NumPy broadcasting. ex: broadcasting shape [2, 2] -> [2, 2, 3]
/// axis to zero.
* with axis = 0
*
* 2 input version:
* Broadcast the first input 1D shape into the second input shape based on the axis parameter.
* Handles broadcasting a 1D static shape into a higher rank dynamic shape.
* broadcast_lens is not used
*/
struct
broadcast
struct
broadcast
{
{
uint64_t
axis
=
0
;
uint64_t
axis
=
0
;
std
::
vector
<
std
::
size_t
>
broadcast_lens
;
std
::
vector
<
std
::
size_t
>
broadcast_lens
=
{}
;
template
<
class
Self
,
class
F
>
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
static
auto
reflect
(
Self
&
self
,
F
f
)
...
@@ -54,36 +61,86 @@ struct broadcast
...
@@ -54,36 +61,86 @@ struct broadcast
std
::
string
name
()
const
{
return
"broadcast"
;
}
std
::
string
name
()
const
{
return
"broadcast"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
auto
input
=
inputs
.
at
(
0
);
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
,
2
);
auto
t
=
input
.
type
();
auto
s0
=
inputs
.
at
(
0
);
auto
t
=
s0
.
type
();
std
::
vector
<
size_t
>
bcast_strides
(
broadcast_lens
.
size
(),
0
);
if
(
inputs
.
size
()
==
1
)
// the broacast op is deprecated now, so not handling the negative
// value of axis anymore
if
(
axis
>=
broadcast_lens
.
size
())
{
{
MIGRAPHX_THROW
(
"BROADCAST : axis is out of range"
);
// the ONNX broadcast op is deprecated now, so not handling the negative
}
// value of axis anymore
if
(
axis
>=
broadcast_lens
.
size
())
{
MIGRAPHX_THROW
(
"BROADCAST : axis "
+
migraphx
::
to_string
(
axis
)
+
" is out of range"
);
}
if
(
broadcast_lens
.
size
()
-
axis
<
s0
.
lens
().
size
())
{
MIGRAPHX_THROW
(
"BROADCAST: (broadcast ndims - axis) is less than s0 ndims"
);
}
if
(
not
std
::
equal
(
s0
.
lens
().
begin
(),
s0
.
lens
().
end
(),
broadcast_lens
.
begin
()
+
axis
))
{
MIGRAPHX_THROW
(
"BROADCAST: when broadcasting, succeeding sizes must match"
);
}
if
(
broadcast_lens
.
size
()
-
axis
<
input
.
lens
().
size
())
std
::
vector
<
size_t
>
bcast_strides
(
broadcast_lens
.
size
(),
0
);
{
std
::
copy
(
s0
.
strides
().
begin
(),
s0
.
strides
().
end
(),
bcast_strides
.
begin
()
+
axis
);
MIGRAPHX_THROW
(
"BROADCAST: (broadcast ndims - axis) is less than input ndims"
);
shape
output
{
t
,
broadcast_lens
,
std
::
move
(
bcast_strides
)};
if
(
output
.
elements
()
<
s0
.
elements
())
{
// don't think this can occur?
MIGRAPHX_THROW
(
"BROADCAST: output size must be greater than or equal to s0 size"
);
}
return
output
;
}
}
else
if
(
not
std
::
equal
(
input
.
lens
().
begin
(),
input
.
lens
().
end
(),
broadcast_lens
.
begin
()
+
axis
))
{
{
MIGRAPHX_THROW
(
"BROADCAST: when broadcasting, succeeding sizes must match"
);
// two inputs
}
auto
s1
=
inputs
.
at
(
1
);
std
::
copy
(
input
.
strides
().
begin
(),
input
.
strides
().
end
(),
bcast_strides
.
begin
()
+
axis
);
if
(
s0
.
dynamic
())
{
MIGRAPHX_THROW
(
"BROADCAST_2in: s0 is a dynamic shape, does not handle broadcasting "
"a dynamic shape"
);
}
if
(
s0
.
ndim
()
!=
1
)
{
MIGRAPHX_THROW
(
"BROADCAST_2in: s0 has ndim "
+
migraphx
::
to_string
(
s0
.
ndim
())
+
", only handle ndim = 1"
);
}
if
(
axis
>=
s1
.
ndim
())
{
MIGRAPHX_THROW
(
"BROADCAST_2in: axis "
+
migraphx
::
to_string
(
axis
)
+
" is out of range"
);
}
if
(
s1
.
dynamic
())
{
s0
=
s0
.
to_dynamic
();
if
(
s0
.
dyn_dims
()[
0
]
!=
s1
.
dyn_dims
()[
axis
])
{
MIGRAPHX_THROW
(
"BROADCAST_2in: s0 length doesn't match with dynamic s1 axis "
"dimension length ("
+
migraphx
::
to_string
(
s0
.
dyn_dims
()[
0
])
+
" != "
+
migraphx
::
to_string
(
s1
.
dyn_dims
()[
axis
])
+
")"
);
}
return
s1
;
}
shape
output
{
t
,
broadcast_lens
,
std
::
move
(
bcast_strides
)};
if
(
s0
.
lens
()[
0
]
!=
s1
.
lens
()[
axis
])
if
(
output
.
elements
()
<
input
.
elements
())
{
MIGRAPHX_THROW
(
"BROADCAST: output size must be greater than or equal to input size"
);
MIGRAPHX_THROW
(
"BROADCAST_2in: s0 length doesn't match with static s1 axis "
return
output
;
"dimension length ("
+
migraphx
::
to_string
(
s0
.
lens
()[
0
])
+
" != "
+
migraphx
::
to_string
(
s1
.
lens
()[
axis
])
+
")"
);
}
std
::
vector
<
size_t
>
bcast_strides
(
s1
.
ndim
(),
0
);
std
::
copy
(
s0
.
strides
().
begin
(),
s0
.
strides
().
end
(),
bcast_strides
.
begin
()
+
axis
);
shape
output
{
t
,
s1
.
lens
(),
std
::
move
(
bcast_strides
)};
return
output
;
}
}
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
{
return
args
[
0
].
reshape
(
out
put_shape
);
return
args
[
0
].
reshape
(
dyn_out
.
com
put
ed
_shape
);
}
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
};
...
...
src/include/migraphx/op/common.hpp
View file @
870a396b
...
@@ -33,11 +33,11 @@ namespace migraphx {
...
@@ -33,11 +33,11 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
namespace
op
{
// Padding mode is default_ for fixed shape padding.
// same_lower and same_upper used for dynamic padding.
enum
padding_mode_t
enum
padding_mode_t
{
{
default_
,
// NOLINT
default_
,
// NOLINT
same
,
valid
,
same_lower
,
same_lower
,
same_upper
same_upper
};
};
...
...
src/include/migraphx/op/contiguous.hpp
View file @
870a396b
...
@@ -28,6 +28,7 @@
...
@@ -28,6 +28,7 @@
#include <migraphx/argument.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/dyn_output.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -42,19 +43,27 @@ namespace op {
...
@@ -42,19 +43,27 @@ namespace op {
struct
contiguous
struct
contiguous
{
{
std
::
string
name
()
const
{
return
"contiguous"
;
}
std
::
string
name
()
const
{
return
"contiguous"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
);
if
(
inputs
.
front
().
standard
())
auto
s0
=
inputs
.
front
();
return
inputs
.
front
();
if
(
s0
.
dynamic
()
or
s0
.
standard
())
auto
lens
=
inputs
.
at
(
0
).
lens
();
{
auto
t
=
inputs
.
at
(
0
).
type
();
return
s0
;
return
{
t
,
lens
};
}
else
{
const
auto
&
lens
=
s0
.
lens
();
auto
t
=
s0
.
type
();
return
{
t
,
lens
};
}
}
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
{
assert
(
out
put_shape
.
standard
());
assert
(
dyn_out
.
com
put
ed
_shape
.
standard
());
argument
result
{
out
put_shape
};
argument
result
{
dyn_out
.
com
put
ed
_shape
};
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
input
)
{
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
input
)
{
shape_for_each
(
output
.
get_shape
(),
[
&
](
const
auto
&
idx
)
{
shape_for_each
(
output
.
get_shape
(),
[
&
](
const
auto
&
idx
)
{
output
(
idx
.
begin
(),
idx
.
end
())
=
input
(
idx
.
begin
(),
idx
.
end
());
output
(
idx
.
begin
(),
idx
.
end
())
=
input
(
idx
.
begin
(),
idx
.
end
());
...
...
src/include/migraphx/op/convert.hpp
View file @
870a396b
...
@@ -44,7 +44,7 @@ struct convert : unary<convert>
...
@@ -44,7 +44,7 @@ struct convert : unary<convert>
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
);
auto
input
=
inputs
.
at
(
0
);
auto
input
=
inputs
.
at
(
0
);
if
(
input
.
dynamic
())
if
(
input
.
dynamic
())
{
{
...
...
src/include/migraphx/op/convolution.hpp
View file @
870a396b
...
@@ -41,9 +41,8 @@ struct convolution
...
@@ -41,9 +41,8 @@ struct convolution
std
::
vector
<
std
::
size_t
>
stride
=
{
1
,
1
};
std
::
vector
<
std
::
size_t
>
stride
=
{
1
,
1
};
std
::
vector
<
std
::
size_t
>
dilation
=
{
1
,
1
};
std
::
vector
<
std
::
size_t
>
dilation
=
{
1
,
1
};
int
group
=
1
;
int
group
=
1
;
padding_mode_t
padding_mode
=
default_
;
padding_mode_t
padding_mode
=
default_
;
bool
use_dynamic_same_auto_pad
=
false
;
template
<
class
Self
,
class
F
>
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
static
auto
reflect
(
Self
&
self
,
F
f
)
...
@@ -52,16 +51,15 @@ struct convolution
...
@@ -52,16 +51,15 @@ struct convolution
f
(
self
.
stride
,
"stride"
),
f
(
self
.
stride
,
"stride"
),
f
(
self
.
dilation
,
"dilation"
),
f
(
self
.
dilation
,
"dilation"
),
f
(
self
.
group
,
"group"
),
f
(
self
.
group
,
"group"
),
f
(
self
.
padding_mode
,
"padding_mode"
),
f
(
self
.
padding_mode
,
"padding_mode"
));
f
(
self
.
use_dynamic_same_auto_pad
,
"use_dynamic_same_auto_pad"
));
}
}
std
::
string
name
()
const
{
return
"convolution"
;
}
std
::
string
name
()
const
{
return
"convolution"
;
}
void
check_attribute_size
()
const
void
check_attribute_size
()
const
{
{
if
(
not
(
(
padding
.
size
()
=
=
stride
.
size
()
or
(
padding
.
size
()
/
2
)
=
=
stride
.
size
())
and
if
((
padding
.
size
()
!
=
stride
.
size
()
and
(
padding
.
size
()
/
2
)
!
=
stride
.
size
())
or
stride
.
size
()
=
=
dilation
.
size
())
)
stride
.
size
()
!
=
dilation
.
size
())
{
{
MIGRAPHX_THROW
(
"CONVOLUTION: inconsistent attribute sizes"
);
MIGRAPHX_THROW
(
"CONVOLUTION: inconsistent attribute sizes"
);
}
}
...
@@ -76,7 +74,8 @@ struct convolution
...
@@ -76,7 +74,8 @@ struct convolution
// num of dims of input and attribute should match
// num of dims of input and attribute should match
const
auto
input_size
=
inputs
[
0
].
max_lens
().
size
();
const
auto
input_size
=
inputs
[
0
].
max_lens
().
size
();
const
auto
padding_size
=
padding
.
size
();
const
auto
padding_size
=
padding
.
size
();
if
(
not
(
input_size
==
padding_size
/
2
+
2
or
input_size
==
padding_size
+
2
))
if
(
input_size
!=
padding_size
/
2
+
2
&&
input_size
!=
padding_size
+
2
)
{
{
MIGRAPHX_THROW
(
"CONVOLUTION: input and attribute size mismatch!"
);
MIGRAPHX_THROW
(
"CONVOLUTION: input and attribute size mismatch!"
);
}
}
...
@@ -93,13 +92,6 @@ struct convolution
...
@@ -93,13 +92,6 @@ struct convolution
x_shape
.
lens
().
at
(
1
)
!=
(
w_shape
.
lens
().
at
(
1
)
*
group
))
x_shape
.
lens
().
at
(
1
)
!=
(
w_shape
.
lens
().
at
(
1
)
*
group
))
MIGRAPHX_THROW
(
"CONVOLUTION: mismatched channel numbers"
);
MIGRAPHX_THROW
(
"CONVOLUTION: mismatched channel numbers"
);
std
::
vector
<
op
::
padding_mode_t
>
dyn_pad_modes
=
{
op
::
padding_mode_t
::
same_upper
,
op
::
padding_mode_t
::
same_lower
};
if
(
use_dynamic_same_auto_pad
and
not
contains
(
dyn_pad_modes
,
padding_mode
))
{
MIGRAPHX_THROW
(
"CONVOLUTION: use_dynamic_same_auto_pad set with invalid padding mode"
);
}
if
(
x_shape
.
dynamic
()
or
w_shape
.
dynamic
())
if
(
x_shape
.
dynamic
()
or
w_shape
.
dynamic
())
{
{
return
dynamic_compute_shape
(
x_shape
,
w_shape
);
return
dynamic_compute_shape
(
x_shape
,
w_shape
);
...
@@ -161,7 +153,7 @@ struct convolution
...
@@ -161,7 +153,7 @@ struct convolution
dynamic_shape_push_back
(
w_shape
);
dynamic_shape_push_back
(
w_shape
);
const
size_t
num_spatial_dims
=
x_shape
.
max_lens
().
size
()
-
2
;
const
size_t
num_spatial_dims
=
x_shape
.
max_lens
().
size
()
-
2
;
if
(
use_dynamic_same_auto_pad
)
if
(
padding_mode
!=
default_
)
{
{
for
(
std
::
size_t
i
=
0
;
i
<
num_spatial_dims
;
++
i
)
for
(
std
::
size_t
i
=
0
;
i
<
num_spatial_dims
;
++
i
)
{
{
...
...
src/include/migraphx/op/deconvolution.hpp
View file @
870a396b
...
@@ -61,8 +61,8 @@ struct deconvolution
...
@@ -61,8 +61,8 @@ struct deconvolution
void
check_attribute_size
()
const
void
check_attribute_size
()
const
{
{
if
(
not
(
(
padding
.
size
()
=
=
stride
.
size
()
or
(
padding
.
size
()
/
2
)
=
=
stride
.
size
())
and
if
((
padding
.
size
()
!
=
stride
.
size
()
and
(
padding
.
size
()
/
2
)
!
=
stride
.
size
())
or
stride
.
size
()
=
=
dilation
.
size
())
)
stride
.
size
()
!
=
dilation
.
size
())
{
{
MIGRAPHX_THROW
(
"deconvolution: inconsistent attribute sizes"
);
MIGRAPHX_THROW
(
"deconvolution: inconsistent attribute sizes"
);
}
}
...
...
src/include/migraphx/op/dot.hpp
View file @
870a396b
...
@@ -28,6 +28,7 @@
...
@@ -28,6 +28,7 @@
#include <migraphx/argument.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/gemm.hpp>
#include <migraphx/gemm.hpp>
#include <migraphx/dyn_output.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -38,41 +39,69 @@ struct dot
...
@@ -38,41 +39,69 @@ struct dot
std
::
string
name
()
const
{
return
"dot"
;
}
std
::
string
name
()
const
{
return
"dot"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
same_type
().
has
(
2
);
check_shapes
{
inputs
,
*
this
,
true
}.
same_type
().
same_ndims
().
has
(
2
);
const
shape
&
a
=
inputs
.
at
(
0
);
const
shape
&
a
=
inputs
.
at
(
0
);
const
shape
&
b
=
inputs
.
at
(
1
);
const
shape
&
b
=
inputs
.
at
(
1
);
auto
t
=
a
.
type
();
auto
t
=
a
.
type
();
if
(
not
std
::
all_of
(
if
(
not
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[](
auto
s
)
{
return
s
.
ndim
()
>=
2
;
}))
inputs
.
begin
(),
inputs
.
end
(),
[](
auto
s
)
{
return
s
.
lens
().
size
()
>=
2
;
}))
{
{
MIGRAPHX_THROW
(
"DOT: dot only accept 2 or more dim
s operands
"
);
MIGRAPHX_THROW
(
"DOT: dot only accept
s operands with
2 or more dim
ensions
"
);
}
}
if
(
a
.
dynamic
()
or
b
.
dynamic
())
// only handle the case that the batch size of a and b are the same
if
(
not
std
::
equal
(
a
.
lens
().
rbegin
()
+
2
,
a
.
lens
().
rend
(),
b
.
lens
().
rbegin
()
+
2
,
b
.
lens
().
rend
()))
{
{
MIGRAPHX_THROW
(
"DOT: batch size of A and B mismatch: {"
+
to_string_range
(
a
.
lens
())
+
auto
s0
=
a
.
to_dynamic
();
"} x {"
+
to_string_range
(
b
.
lens
())
+
"}"
);
auto
s1
=
b
.
to_dynamic
();
if
(
not
std
::
equal
(
s0
.
dyn_dims
().
rbegin
()
+
2
,
s0
.
dyn_dims
().
rend
(),
s1
.
dyn_dims
().
rbegin
()
+
2
,
s1
.
dyn_dims
().
rend
()))
{
MIGRAPHX_THROW
(
"DOT: dynamic outer dimensions of A and B mismatch: {"
+
to_string_range
(
s0
.
dyn_dims
())
+
"} x {"
+
to_string_range
(
s1
.
dyn_dims
())
+
"}"
);
}
std
::
size_t
dim_0
=
s0
.
ndim
()
-
2
;
std
::
size_t
dim_1
=
s0
.
ndim
()
-
1
;
if
(
s0
.
dyn_dims
()[
dim_1
]
!=
s1
.
dyn_dims
()[
dim_0
])
{
MIGRAPHX_THROW
(
"DOT: dynamic inner dimensions do not match: {"
+
to_string_range
(
s0
.
dyn_dims
())
+
"} x {"
+
to_string_range
(
s1
.
dyn_dims
())
+
"}"
);
}
auto
out_dyn_dims
=
s0
.
dyn_dims
();
out_dyn_dims
[
dim_1
]
=
s1
.
dyn_dims
()[
dim_1
];
return
{
t
,
out_dyn_dims
};
}
}
else
std
::
size_t
dim_0
=
a
.
lens
().
size
()
-
2
;
std
::
size_t
dim_1
=
a
.
lens
().
size
()
-
1
;
if
(
a
.
lens
()[
dim_1
]
!=
b
.
lens
()[
dim_0
])
{
{
MIGRAPHX_THROW
(
"DOT: inner dimensions do not match: {"
+
to_string_range
(
a
.
lens
())
+
// only handle the case that all the dimensions except the last two are the same
"} x {"
+
to_string_range
(
b
.
lens
())
+
"}"
);
if
(
not
std
::
equal
(
}
a
.
lens
().
rbegin
()
+
2
,
a
.
lens
().
rend
(),
b
.
lens
().
rbegin
()
+
2
,
b
.
lens
().
rend
()))
{
MIGRAPHX_THROW
(
"DOT: static outer dimensions of A and B mismatch: {"
+
to_string_range
(
a
.
lens
())
+
"} x {"
+
to_string_range
(
b
.
lens
())
+
"}"
);
}
auto
out_lens
=
a
.
lens
();
std
::
size_t
dim_0
=
a
.
ndim
()
-
2
;
out_lens
[
dim_1
]
=
b
.
lens
()[
dim_1
];
std
::
size_t
dim_1
=
a
.
ndim
()
-
1
;
return
{
t
,
out_lens
};
if
(
a
.
lens
()[
dim_1
]
!=
b
.
lens
()[
dim_0
])
{
MIGRAPHX_THROW
(
"DOT: static inner dimensions do not match: {"
+
to_string_range
(
a
.
lens
())
+
"} x {"
+
to_string_range
(
b
.
lens
())
+
"}"
);
}
auto
out_lens
=
a
.
lens
();
out_lens
[
dim_1
]
=
b
.
lens
()[
dim_1
];
return
{
t
,
out_lens
};
}
}
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
{
argument
result
=
argument
{
out
put_shape
};
argument
result
=
argument
{
dyn_out
.
com
put
ed
_shape
};
visit_all
(
result
,
args
[
0
],
args
[
1
])(
visit_all
(
result
,
args
[
0
],
args
[
1
])(
[
&
](
auto
cmat
,
auto
amat
,
auto
bmat
)
{
gemm
(
cmat
,
amat
,
bmat
,
1.0
f
,
0.0
f
);
});
[
&
](
auto
cmat
,
auto
amat
,
auto
bmat
)
{
gemm
(
cmat
,
amat
,
bmat
,
1.0
f
,
0.0
f
);
});
return
result
;
return
result
;
...
...
src/include/migraphx/op/elu.hpp
View file @
870a396b
...
@@ -32,14 +32,13 @@ namespace migraphx {
...
@@ -32,14 +32,13 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
namespace
op
{
struct
elu
struct
elu
:
unary
<
elu
>
{
{
std
::
string
name
()
const
{
return
"elu"
;
}
float
alpha
=
1
;
float
alpha
=
1
;
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
std
::
string
point_op
()
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
return
"${function:where}(${0} > 0, ${0}, ${alpha} * (${function:exp}(${0}) - 1))"
;
return
inputs
.
front
();
}
}
template
<
class
Self
,
class
F
>
template
<
class
Self
,
class
F
>
...
@@ -47,6 +46,11 @@ struct elu
...
@@ -47,6 +46,11 @@ struct elu
{
{
return
pack
(
f
(
self
.
alpha
,
"alpha"
));
return
pack
(
f
(
self
.
alpha
,
"alpha"
));
}
}
auto
apply
()
const
{
return
[
&
](
auto
x
)
{
return
x
>
0
?
x
:
alpha
*
std
::
expm1
(
x
);
};
}
};
};
}
// namespace op
}
// namespace op
...
...
src/include/migraphx/op/flatten.hpp
View file @
870a396b
...
@@ -55,17 +55,47 @@ struct flatten
...
@@ -55,17 +55,47 @@ struct flatten
std
::
string
name
()
const
{
return
"flatten"
;
}
std
::
string
name
()
const
{
return
"flatten"
;
}
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
).
standard
();
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
);
auto
&&
lens
=
inputs
.
front
().
lens
();
auto
s
=
inputs
[
0
];
auto
x
=
if
(
s
.
dynamic
())
std
::
accumulate
(
lens
.
begin
(),
lens
.
begin
()
+
axis
,
std
::
size_t
{
1
},
std
::
multiplies
<>
{});
{
auto
y
=
auto
min_lens
=
s
.
min_lens
();
std
::
accumulate
(
lens
.
begin
()
+
axis
,
lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<>
{});
auto
max_lens
=
s
.
max_lens
();
return
{
inputs
.
at
(
0
).
type
(),
{
x
,
y
}};
auto
opt_lens
=
s
.
opt_lens
();
// If any of the opt values is 0, output opt will be 0
shape
::
dynamic_dimension
x
=
{
std
::
accumulate
(
min_lens
.
begin
(),
min_lens
.
begin
()
+
axis
,
std
::
size_t
{
1
},
std
::
multiplies
<>
{}),
std
::
accumulate
(
max_lens
.
begin
(),
max_lens
.
begin
()
+
axis
,
std
::
size_t
{
1
},
std
::
multiplies
<>
{}),
std
::
accumulate
(
opt_lens
.
begin
(),
opt_lens
.
begin
()
+
axis
,
std
::
size_t
{
1
},
std
::
multiplies
<>
{})};
shape
::
dynamic_dimension
y
=
{
std
::
accumulate
(
min_lens
.
begin
()
+
axis
,
min_lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<>
{}),
std
::
accumulate
(
max_lens
.
begin
()
+
axis
,
max_lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<>
{}),
std
::
accumulate
(
opt_lens
.
begin
()
+
axis
,
opt_lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<>
{}),
};
return
{
s
.
type
(),
{
x
,
y
}};
}
else
{
check_shapes
{
inputs
,
*
this
}.
standard
();
auto
&&
lens
=
s
.
lens
();
auto
x
=
std
::
accumulate
(
lens
.
begin
(),
lens
.
begin
()
+
axis
,
std
::
size_t
{
1
},
std
::
multiplies
<>
{});
auto
y
=
std
::
accumulate
(
lens
.
begin
()
+
axis
,
lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<>
{});
return
{
s
.
type
(),
{
x
,
y
}};
}
}
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
{
return
args
[
0
].
reshape
(
out
put_shape
);
return
args
[
0
].
reshape
(
dyn_out
.
com
put
ed
_shape
);
}
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
};
...
...
src/include/migraphx/op/fmod.hpp
View file @
870a396b
...
@@ -40,7 +40,6 @@ struct fmod : binary<fmod>
...
@@ -40,7 +40,6 @@ struct fmod : binary<fmod>
a
[
"commutative"
]
=
false
;
a
[
"commutative"
]
=
false
;
return
a
;
return
a
;
}
}
std
::
string
point_function
()
const
{
return
"fmod"
;
}
auto
apply
()
const
auto
apply
()
const
{
{
return
[](
auto
x
,
auto
y
)
{
return
std
::
fmod
(
x
,
y
);
};
return
[](
auto
x
,
auto
y
)
{
return
std
::
fmod
(
x
,
y
);
};
...
...
src/include/migraphx/op/layout.hpp
100755 → 100644
View file @
870a396b
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_OP_LAYOUT_HPP
#ifndef MIGRAPHX_GUARD_OP_LAYOUT_HPP
#define MIGRAPHX_GUARD_OP_LAYOUT_HPP
#define MIGRAPHX_GUARD_OP_LAYOUT_HPP
...
@@ -8,7 +31,6 @@
...
@@ -8,7 +31,6 @@
#include <migraphx/streamutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/op/unary.hpp>
#include <migraphx/op/unary.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <cmath>
#include <utility>
#include <utility>
...
...
src/include/migraphx/op/leaky_relu.hpp
View file @
870a396b
...
@@ -26,12 +26,13 @@
...
@@ -26,12 +26,13 @@
#include <migraphx/check_shapes.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/op/unary.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
namespace
op
{
struct
leaky_relu
struct
leaky_relu
:
unary
<
leaky_relu
>
{
{
float
alpha
=
0.01
;
float
alpha
=
0.01
;
...
@@ -41,11 +42,13 @@ struct leaky_relu
...
@@ -41,11 +42,13 @@ struct leaky_relu
return
pack
(
f
(
self
.
alpha
,
"alpha"
));
return
pack
(
f
(
self
.
alpha
,
"alpha"
));
}
}
std
::
string
point_op
()
const
{
return
"${function:where}(${0} > 0, ${0}, ${alpha} * ${0})"
;
}
std
::
string
name
()
const
{
return
"leaky_relu"
;
}
std
::
string
name
()
const
{
return
"leaky_relu"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
auto
apply
()
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
return
[
&
](
auto
x
)
{
return
x
>
0
?
x
:
x
*
alpha
;
};
return
inputs
.
front
();
}
}
};
};
...
...
src/include/migraphx/op/mod.hpp
View file @
870a396b
...
@@ -38,9 +38,9 @@ struct mod : binary<mod>
...
@@ -38,9 +38,9 @@ struct mod : binary<mod>
{
{
auto
a
=
base_attributes
();
auto
a
=
base_attributes
();
a
[
"commutative"
]
=
false
;
a
[
"commutative"
]
=
false
;
a
[
"point_op"
]
=
"${function:fmod}((${function:remainder}(${0}, ${1})) + ${1}, ${1})"
;
return
a
;
return
a
;
}
}
std
::
string
point_function
()
const
{
return
"mod"
;
}
auto
apply
()
const
auto
apply
()
const
{
{
return
[](
auto
x
,
auto
y
)
{
return
std
::
fmod
((
std
::
remainder
(
x
,
y
))
+
y
,
y
);
};
return
[](
auto
x
,
auto
y
)
{
return
std
::
fmod
((
std
::
remainder
(
x
,
y
))
+
y
,
y
);
};
...
...
src/include/migraphx/op/multibroadcast.hpp
View file @
870a396b
...
@@ -26,64 +26,105 @@
...
@@ -26,64 +26,105 @@
#include <migraphx/check_shapes.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/common.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
namespace
op
{
/**
* Broadcast multiple dimensions between two tensors.
* Two versions of this operator: one input and two inputs.
* One input version uses output_lens attribute and broadcasts to it.
* Two inputs version broadcasts both inputs to the common shape at evaluation time.
*/
struct
multibroadcast
struct
multibroadcast
{
{
std
::
vector
<
std
::
size_t
>
output_lens
;
std
::
vector
<
std
::
size_t
>
output_lens
=
{};
// optional attribute
std
::
vector
<
shape
::
dynamic_dimension
>
output_dyn_dims
=
{};
template
<
class
Self
,
class
F
>
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
static
auto
reflect
(
Self
&
self
,
F
f
)
{
{
return
pack
(
f
(
self
.
output_lens
,
"out_lens"
));
return
pack
(
f
(
self
.
output_lens
,
"out_lens"
)
,
f
(
self
.
output_dyn_dims
,
"out_dyn_dims"
)
);
}
}
std
::
string
name
()
const
{
return
"multibroadcast"
;
}
std
::
string
name
()
const
{
return
"multibroadcast"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
,
2
);
auto
t
=
inputs
.
at
(
0
).
type
();
auto
input
=
inputs
.
at
(
0
);
if
(
input
.
lens
().
empty
())
auto
t
=
inputs
.
at
(
0
).
type
();
{
auto
s0
=
inputs
.
at
(
0
);
MIGRAPHX_THROW
(
"MULTIBROADCAST: inputs dimensions should be > 0"
);
}
if
(
input
.
lens
().
size
()
>
output_lens
.
size
())
if
(
s0
.
max_lens
().
empty
())
{
{
MIGRAPHX_THROW
(
"MULTIBROADCAST: input
s
dimensions should
<= output size
"
);
MIGRAPHX_THROW
(
"MULTIBROADCAST: input dimensions should
be > 0
"
);
}
}
auto
offset
=
output_lens
.
size
()
-
input
.
lens
().
size
();
auto
make_bcast_strides
=
[
&
](
std
::
vector
<
std
::
size_t
>
bcast_lens
,
std
::
size_t
offset
)
{
for
(
std
::
ptrdiff_t
i
=
input
.
lens
().
size
()
-
1
;
i
>=
0
;
i
--
)
std
::
vector
<
size_t
>
bcast_strides
(
bcast_lens
.
size
(),
0
);
for
(
std
::
ptrdiff_t
i
=
s0
.
lens
().
size
()
-
1
;
i
>=
0
;
i
--
)
{
if
(
bcast_lens
[
i
+
offset
]
==
s0
.
lens
()[
i
])
{
bcast_strides
[
i
+
offset
]
=
s0
.
strides
()[
i
];
}
}
return
bcast_strides
;
};
if
(
inputs
.
size
()
==
1
)
{
{
if
(
output_lens
[
i
+
offset
]
!=
input
.
lens
()[
i
]
and
in
put
.
lens
()[
i
]
!=
1
)
if
(
s0
.
lens
().
size
()
>
out
put
_
lens
.
size
()
)
{
{
MIGRAPHX_THROW
(
"MULTIBROADCAST: input shape {"
+
to_string_range
(
input
.
lens
())
+
MIGRAPHX_THROW
(
"MULTIBROADCAST: input dimensions should <= output size"
);
"} cannot be broadcasted to {"
+
to_string_range
(
output_lens
)
+
"}!"
);
}
}
}
std
::
vector
<
size_t
>
bcast_strides
(
output_lens
.
size
(),
0
);
auto
offset
=
output_lens
.
size
()
-
s0
.
lens
().
size
();
for
(
std
::
ptrdiff_t
i
=
input
.
lens
().
size
()
-
1
;
i
>=
0
;
i
--
)
for
(
std
::
ptrdiff_t
i
=
s0
.
lens
().
size
()
-
1
;
i
>=
0
;
i
--
)
{
if
(
output_lens
[
i
+
offset
]
!=
s0
.
lens
()[
i
]
and
s0
.
lens
()[
i
]
!=
1
)
{
MIGRAPHX_THROW
(
"MULTIBROADCAST: input shape {"
+
to_string_range
(
s0
.
lens
())
+
"} cannot be broadcasted to {"
+
to_string_range
(
output_lens
)
+
"}!"
);
}
}
auto
bcast_strides
=
make_bcast_strides
(
output_lens
,
offset
);
return
{
t
,
output_lens
,
std
::
move
(
bcast_strides
)};
}
else
{
{
if
(
output_lens
[
i
+
offset
]
==
input
.
lens
()[
i
])
// two inputs
auto
s1
=
inputs
.
at
(
1
);
if
(
s0
.
dynamic
()
or
s1
.
dynamic
())
{
{
bcast_strides
[
i
+
offset
]
=
input
.
strides
()[
i
];
if
(
not
output_dyn_dims
.
empty
())
{
return
{
t
,
output_dyn_dims
};
}
return
{
t
,
compute_broadcasted_dyn_dims
(
s0
,
s1
)};
}
else
{
auto
bcast_lens
=
compute_broadcasted_lens
(
s0
.
lens
(),
s1
.
lens
());
auto
offset
=
bcast_lens
.
size
()
-
s0
.
lens
().
size
();
auto
bcast_strides
=
make_bcast_strides
(
bcast_lens
,
offset
);
return
{
t
,
std
::
move
(
bcast_lens
),
std
::
move
(
bcast_strides
)};
}
}
}
}
return
{
t
,
output_lens
,
bcast_strides
};
}
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
{
return
args
[
0
].
reshape
(
out
put_shape
);
return
args
[
0
].
reshape
(
dyn_out
.
com
put
ed
_shape
);
}
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
};
...
...
src/include/migraphx/op/pad.hpp
View file @
870a396b
...
@@ -59,18 +59,29 @@ struct pad
...
@@ -59,18 +59,29 @@ struct pad
std
::
string
name
()
const
{
return
"pad"
;
}
std
::
string
name
()
const
{
return
"pad"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
);
auto
&&
idims
=
inputs
.
front
().
lens
();
const
auto
&
s0
=
inputs
.
front
();
std
::
vector
<
std
::
size_t
>
rdims
(
idims
.
begin
(),
idims
.
end
());
if
(
s0
.
dynamic
())
std
::
size_t
num_dims
=
rdims
.
size
();
for
(
std
::
size_t
i
=
0
;
i
<
num_dims
;
i
++
)
{
{
rdims
[
i
]
+=
pads
[
i
]
+
pads
[
i
+
num_dims
];
auto
out_dyn_dims
=
s0
.
dyn_dims
();
for
(
std
::
size_t
i
=
0
;
i
<
s0
.
ndim
();
++
i
)
{
out_dyn_dims
[
i
]
+=
pads
[
i
]
+
pads
[
i
+
s0
.
ndim
()];
}
return
{
s0
.
type
(),
out_dyn_dims
};
}
else
{
auto
&&
idims
=
s0
.
lens
();
std
::
vector
<
std
::
size_t
>
rdims
(
idims
.
begin
(),
idims
.
end
());
std
::
size_t
num_dims
=
rdims
.
size
();
for
(
std
::
size_t
i
=
0
;
i
<
num_dims
;
i
++
)
{
rdims
[
i
]
+=
pads
[
i
]
+
pads
[
i
+
num_dims
];
}
shape
s
{
s0
.
type
(),
rdims
};
return
s
;
}
}
shape
s
{
inputs
.
front
().
type
(),
rdims
};
return
s
;
}
}
std
::
size_t
pad_ndims
()
const
std
::
size_t
pad_ndims
()
const
...
...
Prev
1
2
3
4
5
6
7
…
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment