Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
dfa79e73
Commit
dfa79e73
authored
Apr 02, 2019
by
Paul
Browse files
Add evaluation of binary operators
parent
cfbdef6b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
38 additions
and
12 deletions
+38
-12
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+38
-12
No files found.
src/include/migraphx/operators.hpp
View file @
dfa79e73
...
...
@@ -8,6 +8,7 @@
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/type_name.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
...
...
@@ -1117,8 +1118,14 @@ struct scalar
int
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
template
<
class
Derived
>
struct
binary
{
std
::
string
name
()
const
{
static
const
std
::
string
&
name
=
get_type_name
<
Derived
>
();
return
name
.
substr
(
name
.
rfind
(
"::"
)
+
2
);
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
}.
has
(
2
).
same_type
().
same_dims
();
...
...
@@ -1126,36 +1133,55 @@ struct binary
auto
lens
=
inputs
.
at
(
0
).
lens
();
return
{
t
,
lens
};
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
output_shape
};
visit_all
(
result
,
args
[
0
],
args
[
1
])([
&
](
auto
output
,
auto
input1
,
auto
input2
)
{
if
(
input1
.
get_shape
().
standard
()
and
input2
.
get_shape
().
standard
())
{
std
::
transform
(
input1
.
begin
(),
input1
.
end
(),
input2
.
begin
(),
output
.
begin
(),
static_cast
<
const
Derived
&>
(
*
this
).
apply
());
}
else
{
shape_for_each
(
output
.
get_shape
(),
[
&
](
const
auto
&
idx
)
{
output
(
idx
.
begin
(),
idx
.
end
())
=
static_cast
<
const
Derived
&>
(
*
this
).
apply
()(
input1
(
idx
.
begin
(),
idx
.
end
()),
input2
(
idx
.
begin
(),
idx
.
end
()));
});
}
});
return
result
;
}
};
struct
add
:
binary
struct
add
:
binary
<
add
>
{
std
::
string
name
()
const
{
return
"add"
;
}
auto
apply
()
const
{
return
[](
auto
x
,
auto
y
)
{
return
x
+
y
;
}
;
}
};
struct
sub
:
binary
struct
sub
:
binary
<
sub
>
{
std
::
string
name
()
const
{
return
"sub"
;
}
auto
apply
()
const
{
return
[](
auto
x
,
auto
y
)
{
return
x
-
y
;
}
;
}
};
struct
mul
:
binary
struct
mul
:
binary
<
mul
>
{
std
::
string
name
()
const
{
return
"mul"
;
}
auto
apply
()
const
{
return
[](
auto
x
,
auto
y
)
{
return
x
*
y
;
}
;
}
};
struct
div
:
binary
struct
div
:
binary
<
div
>
{
std
::
string
name
()
const
{
return
"div"
;
}
auto
apply
()
const
{
return
[](
auto
x
,
auto
y
)
{
return
x
/
y
;
}
;
}
};
struct
max
:
binary
struct
max
:
binary
<
max
>
{
std
::
string
name
()
const
{
return
"max"
;
}
auto
apply
()
const
{
return
[](
auto
x
,
auto
y
)
{
return
std
::
max
(
x
,
y
);
}
;
}
};
struct
min
:
binary
struct
min
:
binary
<
min
>
{
std
::
string
name
()
const
{
return
"min"
;
}
auto
apply
()
const
{
return
[](
auto
x
,
auto
y
)
{
return
std
::
min
(
x
,
y
);
}
;
}
};
struct
load
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment