Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
b878f78f
Commit
b878f78f
authored
Aug 12, 2022
by
turneram
Browse files
Merge remote-tracking branch 'origin/develop' into rewrite-fast-gelu
parents
3b414cc2
55cb7d3a
Changes
197
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
206 additions
and
135 deletions
+206
-135
src/include/migraphx/op/common.hpp
src/include/migraphx/op/common.hpp
+3
-1
src/include/migraphx/op/concat.hpp
src/include/migraphx/op/concat.hpp
+0
-1
src/include/migraphx/op/contiguous.hpp
src/include/migraphx/op/contiguous.hpp
+1
-6
src/include/migraphx/op/convert.hpp
src/include/migraphx/op/convert.hpp
+1
-8
src/include/migraphx/op/convolution.hpp
src/include/migraphx/op/convolution.hpp
+122
-32
src/include/migraphx/op/cos.hpp
src/include/migraphx/op/cos.hpp
+1
-8
src/include/migraphx/op/cosh.hpp
src/include/migraphx/op/cosh.hpp
+1
-8
src/include/migraphx/op/deconvolution.hpp
src/include/migraphx/op/deconvolution.hpp
+3
-6
src/include/migraphx/op/dequantizelinear.hpp
src/include/migraphx/op/dequantizelinear.hpp
+1
-9
src/include/migraphx/op/div.hpp
src/include/migraphx/op/div.hpp
+1
-9
src/include/migraphx/op/dot.hpp
src/include/migraphx/op/dot.hpp
+1
-7
src/include/migraphx/op/elu.hpp
src/include/migraphx/op/elu.hpp
+1
-7
src/include/migraphx/op/equal.hpp
src/include/migraphx/op/equal.hpp
+1
-4
src/include/migraphx/op/exp.hpp
src/include/migraphx/op/exp.hpp
+1
-8
src/include/migraphx/op/flatten.hpp
src/include/migraphx/op/flatten.hpp
+1
-8
src/include/migraphx/op/fmod.hpp
src/include/migraphx/op/fmod.hpp
+63
-0
src/include/migraphx/op/gathernd.hpp
src/include/migraphx/op/gathernd.hpp
+1
-0
src/include/migraphx/op/get_tuple_elem.hpp
src/include/migraphx/op/get_tuple_elem.hpp
+1
-3
src/include/migraphx/op/identity.hpp
src/include/migraphx/op/identity.hpp
+2
-9
src/include/migraphx/op/if_op.hpp
src/include/migraphx/op/if_op.hpp
+0
-1
No files found.
src/include/migraphx/op/common.hpp
View file @
b878f78f
...
...
@@ -37,7 +37,9 @@ enum padding_mode_t
{
default_
,
// NOLINT
same
,
valid
valid
,
same_lower
,
same_upper
};
// The pooling modes must correspond 1-1 to the operators defined for struct parse_pooling.
...
...
src/include/migraphx/op/concat.hpp
View file @
b878f78f
...
...
@@ -36,7 +36,6 @@
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath>
#include <utility>
#include <migraphx/tune_axis.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
src/include/migraphx/op/contiguous.hpp
View file @
b878f78f
...
...
@@ -24,15 +24,10 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_CONTIGUOUS_HPP
#define MIGRAPHX_GUARD_OPERATORS_CONTIGUOUS_HPP
#include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
src/include/migraphx/op/convert.hpp
View file @
b878f78f
...
...
@@ -24,16 +24,9 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_CONVERT_HPP
#define MIGRAPHX_GUARD_OPERATORS_CONVERT_HPP
#include <array>
#include <migraphx/op/unary.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/op/unary.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
src/include/migraphx/op/convolution.hpp
View file @
b878f78f
...
...
@@ -24,16 +24,10 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_CONVOLUTION_HPP
#define MIGRAPHX_GUARD_OPERATORS_CONVOLUTION_HPP
#include <array>
#include <migraphx/op/common.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath>
#include <utility>
...
...
@@ -47,8 +41,9 @@ struct convolution
std
::
vector
<
std
::
size_t
>
stride
=
{
1
,
1
};
std
::
vector
<
std
::
size_t
>
dilation
=
{
1
,
1
};
int
group
=
1
;
padding_mode_t
padding_mode
=
default_
;
int
group
=
1
;
padding_mode_t
padding_mode
=
default_
;
bool
use_dynamic_same_auto_pad
=
false
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
...
...
@@ -57,7 +52,8 @@ struct convolution
f
(
self
.
stride
,
"stride"
),
f
(
self
.
dilation
,
"dilation"
),
f
(
self
.
group
,
"group"
),
f
(
self
.
padding_mode
,
"padding_mode"
));
f
(
self
.
padding_mode
,
"padding_mode"
),
f
(
self
.
use_dynamic_same_auto_pad
,
"use_dynamic_same_auto_pad"
));
}
std
::
string
name
()
const
{
return
"convolution"
;
}
...
...
@@ -75,43 +71,137 @@ struct convolution
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
2
).
same_type
().
same_ndims
().
min_ndims
(
3
);
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
2
).
same_type
().
same_ndims
().
min_ndims
(
3
);
check_attribute_size
();
//
dim num
of input and attribute should match
auto
input_size
=
inputs
[
0
].
lens
().
size
();
auto
padding_size
=
padding
.
size
();
//
num of dims
of input and attribute should match
const
auto
input_size
=
inputs
[
0
].
max_
lens
().
size
();
const
auto
padding_size
=
padding
.
size
();
if
(
not
(
input_size
==
padding_size
/
2
+
2
or
input_size
==
padding_size
+
2
))
{
MIGRAPHX_THROW
(
"CONVOLUTION: input and attribute size mismatch!"
);
}
const
shape
&
input
=
inputs
.
at
(
0
);
const
shape
&
w
eights
=
inputs
.
at
(
1
);
size_t
kdims
=
input_size
-
2
;
if
(
k
dims
!=
this
->
kdims
())
const
shape
&
x_shape
=
inputs
.
at
(
0
);
const
shape
&
w
_shape
=
inputs
.
at
(
1
);
const
size_t
num_spatial_dims
=
input_size
-
2
;
if
(
num_spatial_
dims
!=
this
->
kdims
())
{
MIGRAPHX_THROW
(
"
convolution
: input k-dims does not match attribute size"
);
MIGRAPHX_THROW
(
"
CONVOLUTION
: input k-dims does not match attribute size"
);
}
if
(
input
.
lens
().
at
(
1
)
!=
(
weights
.
lens
().
at
(
1
)
*
group
))
MIGRAPHX_THROW
(
"CONVOLUTION: Mismatch channel numbers"
);
if
(
not
x_shape
.
dynamic
()
and
not
w_shape
.
dynamic
()
and
x_shape
.
lens
().
at
(
1
)
!=
(
w_shape
.
lens
().
at
(
1
)
*
group
))
MIGRAPHX_THROW
(
"CONVOLUTION: mismatched channel numbers"
);
std
::
vector
<
size_t
>
output_lens
{
input
.
lens
()[
0
],
weights
.
lens
()[
0
]};
std
::
vector
<
op
::
padding_mode_t
>
dyn_pad_modes
=
{
op
::
padding_mode_t
::
same_upper
,
op
::
padding_mode_t
::
same_lower
};
if
(
use_dynamic_same_auto_pad
and
not
contains
(
dyn_pad_modes
,
padding_mode
))
{
MIGRAPHX_THROW
(
"CONVOLUTION: use_dynamic_same_auto_pad set with invalid padding mode"
);
}
if
(
x_shape
.
dynamic
()
or
w_shape
.
dynamic
())
{
return
dynamic_compute_shape
(
x_shape
,
w_shape
);
}
else
{
return
fixed_compute_shape
(
x_shape
,
w_shape
);
}
}
std
::
vector
<
std
::
size_t
>
calc_conv_lens
(
std
::
vector
<
std
::
size_t
>
x_lens
,
std
::
vector
<
std
::
size_t
>
w_lens
)
const
{
const
size_t
num_spatial_dims
=
x_lens
.
size
()
-
2
;
std
::
vector
<
size_t
>
ret
=
{};
// calculate the output shape of the convolution: ((W - K + 2P) / S) + 1
for
(
size_t
i
=
0
;
i
<
num_spatial_dims
;
i
++
)
{
if
(
x_lens
[
i
]
==
0
or
w_lens
[
i
]
==
0
)
{
// for handling when a dimension = 0 (opt of dynamic_dimension)
ret
.
push_back
(
0
);
}
else
{
auto
padding_factor
=
2
*
padding
[
i
];
if
(
padding
.
size
()
==
2
*
num_spatial_dims
)
{
// when padding is {x0_begin, x1_begin, ... x0_end , x1_end, ...}
padding_factor
=
padding
[
i
]
+
padding
[
i
+
num_spatial_dims
];
}
ret
.
push_back
(
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
1
,
(
x_lens
[
i
+
2
]
-
(
1
+
dilation
[
i
]
*
(
w_lens
[
i
+
2
]
-
1
))
+
padding_factor
)
/
stride
[
i
]
+
1
)));
}
}
return
ret
;
}
for
(
size_t
i
=
0
;
i
<
kdims
;
i
++
)
shape
dynamic_compute_shape
(
shape
x_shape
,
shape
w_shape
)
const
{
std
::
vector
<
shape
::
dynamic_dimension
>
output_dyn_dims
=
{};
auto
dynamic_shape_push_back
=
[
&
](
const
shape
&
input_shape
)
{
if
(
input_shape
.
dynamic
())
{
output_dyn_dims
.
push_back
(
input_shape
.
dyn_dims
().
at
(
0
));
}
else
{
auto
l
=
input_shape
.
lens
().
at
(
0
);
output_dyn_dims
.
push_back
({
l
,
l
,
0
});
}
};
dynamic_shape_push_back
(
x_shape
);
dynamic_shape_push_back
(
w_shape
);
const
size_t
num_spatial_dims
=
x_shape
.
max_lens
().
size
()
-
2
;
if
(
use_dynamic_same_auto_pad
)
{
auto
padding_factor
=
2
*
padding
[
i
];
if
(
padding_size
==
2
*
kdims
)
padding_factor
=
padding
[
i
]
+
padding
[
i
+
kdims
];
output_lens
.
push_back
(
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
1
,
(
input
.
lens
()[
i
+
2
]
-
(
1
+
dilation
[
i
]
*
(
weights
.
lens
()[
i
+
2
]
-
1
))
+
padding_factor
)
/
stride
[
i
]
+
1
)));
for
(
std
::
size_t
i
=
0
;
i
<
num_spatial_dims
;
++
i
)
{
auto
ceil_div
=
[](
std
::
size_t
x
,
std
::
size_t
y
)
{
return
(
x
+
y
-
1
)
/
y
;
};
auto
s
=
stride
[
i
];
if
(
x_shape
.
dynamic
())
{
auto
x
=
x_shape
.
dyn_dims
()[
i
+
2
];
output_dyn_dims
.
push_back
(
shape
::
dynamic_dimension
{
ceil_div
(
x
.
min
,
s
),
ceil_div
(
x
.
max
,
s
),
ceil_div
(
x
.
opt
,
s
)});
}
else
{
auto
od
=
ceil_div
(
x_shape
.
lens
()[
i
+
2
],
s
);
output_dyn_dims
.
push_back
(
shape
::
dynamic_dimension
{
od
,
od
,
0
});
}
}
}
else
{
auto
min_spatial_dims
=
calc_conv_lens
(
x_shape
.
min_lens
(),
w_shape
.
max_lens
());
auto
max_spatial_dims
=
calc_conv_lens
(
x_shape
.
max_lens
(),
w_shape
.
min_lens
());
auto
opt_spatial_dims
=
calc_conv_lens
(
x_shape
.
opt_lens
(),
w_shape
.
opt_lens
());
for
(
size_t
i
=
0
;
i
<
num_spatial_dims
;
++
i
)
{
output_dyn_dims
.
push_back
(
shape
::
dynamic_dimension
{
min_spatial_dims
[
i
],
max_spatial_dims
[
i
],
opt_spatial_dims
[
i
]});
}
}
return
shape
{
x_shape
.
type
(),
output_dyn_dims
};
}
return
inputs
[
0
].
with_lens
(
output_lens
);
shape
fixed_compute_shape
(
shape
x_shape
,
shape
w_shape
)
const
{
std
::
vector
<
size_t
>
output_lens
{
x_shape
.
lens
()[
0
],
w_shape
.
lens
()[
0
]};
auto
spatial_lens
=
calc_conv_lens
(
x_shape
.
lens
(),
w_shape
.
lens
());
std
::
for_each
(
spatial_lens
.
begin
(),
spatial_lens
.
end
(),
[
&
output_lens
](
auto
x
)
{
output_lens
.
push_back
(
x
);
});
return
x_shape
.
with_lens
(
output_lens
);
}
size_t
kdims
()
const
...
...
src/include/migraphx/op/cos.hpp
View file @
b878f78f
...
...
@@ -24,16 +24,9 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_COS_HPP
#define MIGRAPHX_GUARD_OPERATORS_COS_HPP
#include <array>
#include <migraphx/op/unary.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/op/unary.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
src/include/migraphx/op/cosh.hpp
View file @
b878f78f
...
...
@@ -24,16 +24,9 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_COSH_HPP
#define MIGRAPHX_GUARD_OPERATORS_COSH_HPP
#include <array>
#include <migraphx/op/unary.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/op/unary.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
src/include/migraphx/op/deconvolution.hpp
View file @
b878f78f
...
...
@@ -24,16 +24,13 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_DECONVOLUTION_HPP
#define MIGRAPHX_GUARD_OPERATORS_DECONVOLUTION_HPP
#include <array>
#include <migraphx/op/common.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/value.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/par_dfor.hpp>
#include <migraphx/shape_for_each.hpp>
#include <cmath>
#include <utility>
...
...
src/include/migraphx/op/dequantizelinear.hpp
View file @
b878f78f
...
...
@@ -24,20 +24,12 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_DEQUANTIZE_LINEAR_HPP
#define MIGRAPHX_GUARD_OPERATORS_DEQUANTIZE_LINEAR_HPP
#include <array>
#include <migraphx/op/common.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/config.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/tune_axis.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
src/include/migraphx/op/div.hpp
View file @
b878f78f
...
...
@@ -24,16 +24,8 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_DIV_HPP
#define MIGRAPHX_GUARD_OPERATORS_DIV_HPP
#include <array>
#include <migraphx/op/binary.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
#include <migraphx/op/binary.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
src/include/migraphx/op/dot.hpp
View file @
b878f78f
...
...
@@ -24,16 +24,10 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_DOT_HPP
#define MIGRAPHX_GUARD_OPERATORS_DOT_HPP
#include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <migraphx/gemm.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
src/include/migraphx/op/elu.hpp
View file @
b878f78f
...
...
@@ -24,15 +24,9 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_ELU_HPP
#define MIGRAPHX_GUARD_OPERATORS_ELU_HPP
#include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/op/unary.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
src/include/migraphx/op/equal.hpp
View file @
b878f78f
...
...
@@ -24,11 +24,8 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_EQUAL_HPP
#define MIGRAPHX_GUARD_OPERATORS_EQUAL_HPP
#include <migraphx/op/binary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/float_equal.hpp>
#include <migraphx/config.hpp>
#include <migraphx/op/binary.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
src/include/migraphx/op/exp.hpp
View file @
b878f78f
...
...
@@ -24,16 +24,9 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_EXP_HPP
#define MIGRAPHX_GUARD_OPERATORS_EXP_HPP
#include <array>
#include <migraphx/op/unary.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/op/unary.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
src/include/migraphx/op/flatten.hpp
View file @
b878f78f
...
...
@@ -24,18 +24,11 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_FLATTEN_HPP
#define MIGRAPHX_GUARD_OPERATORS_FLATTEN_HPP
#include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/lifetime.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
src/include/migraphx/op/fmod.hpp
0 → 100644
View file @
b878f78f
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_OPERATORS_FMOD_HPP
#define MIGRAPHX_GUARD_OPERATORS_FMOD_HPP
#include <array>
#include <migraphx/op/binary.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
#include <type_traits>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
fmod
:
binary
<
fmod
>
{
std
::
string
name
()
const
{
return
"fmod"
;
}
value
attributes
()
const
{
auto
a
=
base_attributes
();
a
[
"commutative"
]
=
false
;
return
a
;
}
std
::
string
point_function
()
const
{
return
"fmod"
;
}
auto
apply
()
const
{
return
[](
auto
x
,
auto
y
)
{
return
std
::
fmod
(
x
,
y
);
};
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/gathernd.hpp
View file @
b878f78f
...
...
@@ -27,6 +27,7 @@
#include <migraphx/check_shapes.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/argument.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
src/include/migraphx/op/get_tuple_elem.hpp
View file @
b878f78f
...
...
@@ -26,10 +26,8 @@
#include "migraphx/errors.hpp"
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <migraphx/argument.hpp>
#include <utility>
namespace
migraphx
{
...
...
src/include/migraphx/op/identity.hpp
View file @
b878f78f
...
...
@@ -24,15 +24,8 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_IDENTITY_HPP
#define MIGRAPHX_GUARD_OPERATORS_IDENTITY_HPP
#include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
#include <migraphx/op/unary.hpp>
#include <migraphx/argument.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
src/include/migraphx/op/if_op.hpp
View file @
b878f78f
...
...
@@ -26,7 +26,6 @@
#include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/config.hpp>
#include <migraphx/module.hpp>
...
...
Prev
1
2
3
4
5
6
7
…
10
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment