Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
140fde0a
"vscode:/vscode.git/clone" did not exist on "77212cc1da5c79de8b1e502e8df4842c7115fc41"
Commit
140fde0a
authored
May 01, 2019
by
Shucai Xiao
Browse files
code refinement.
parent
43194a31
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
37 additions
and
26 deletions
+37
-26
src/include/migraphx/op/binary.hpp
src/include/migraphx/op/binary.hpp
+3
-1
src/include/migraphx/op/convert.hpp
src/include/migraphx/op/convert.hpp
+15
-12
src/include/migraphx/op/unary.hpp
src/include/migraphx/op/unary.hpp
+17
-11
test/type_conversion.cpp
test/type_conversion.cpp
+2
-2
No files found.
src/include/migraphx/op/binary.hpp
View file @
140fde0a
...
...
@@ -18,11 +18,12 @@ struct binary : op_name<Derived>
return
{
s
.
type
()};
return
{
s
.
type
(),
s
.
lens
()};
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
output_shape
};
visit_all
(
result
,
args
[
0
],
args
[
1
])([
&
](
auto
output
,
auto
input1
,
auto
input2
)
{
if
(
input1
.
get_shape
().
standar
d
()
and
input2
.
get_shape
().
standar
d
())
if
(
input1
.
get_shape
().
packe
d
()
and
input2
.
get_shape
().
packe
d
())
{
std
::
transform
(
input1
.
begin
(),
input1
.
end
(),
...
...
@@ -38,6 +39,7 @@ struct binary : op_name<Derived>
});
}
});
return
result
;
}
};
...
...
src/include/migraphx/op/convert.hpp
View file @
140fde0a
...
...
@@ -2,7 +2,7 @@
#define MIGRAPHX_GUARD_OPERATORS_CONVERT_HPP
#include <array>
#include <migraphx/op/
bi
nary.hpp>
#include <migraphx/op/
u
nary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
...
...
@@ -17,7 +17,7 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
convert
struct
convert
:
unary
<
convert
>
{
shape
::
type_t
target_type
=
shape
::
half_type
;
...
...
@@ -27,23 +27,26 @@ struct convert
return
pack
(
f
(
self
.
target_type
,
"target_type"
));
}
std
::
string
name
()
const
{
return
"convert"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
return
{
target_type
,
inputs
.
front
().
lens
(),
inputs
.
front
().
strides
()};
if
(
inputs
.
at
(
0
).
packed
())
{
return
{
target_type
,
inputs
.
at
(
0
).
lens
(),
inputs
.
at
(
0
).
strides
()};
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
else
{
argument
result
{
output_shape
};
result
.
visit
([
&
](
auto
output
)
{
args
.
front
().
visit
(
[
&
](
auto
input
)
{
std
::
copy
(
input
.
begin
(),
input
.
end
(),
output
.
begin
());
});
});
return
{
target_type
,
inputs
.
at
(
0
).
lens
()};
}
}
return
result
;
auto
apply
()
const
{
return
[](
auto
x
)
{
return
x
;
};
}
convert
(
shape
::
type_t
t
)
:
target_type
{
t
}
{
}
convert
()
{
}
};
}
// namespace op
...
...
src/include/migraphx/op/unary.hpp
View file @
140fde0a
...
...
@@ -15,25 +15,31 @@ struct unary : op_name<Derived>
check_shapes
{
inputs
}.
has
(
1
);
return
inputs
.
at
(
0
);
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
output_shape
};
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
input
)
{
if
(
input
.
get_shape
().
standard
())
result
.
visit
([
&
](
auto
output
)
{
args
[
0
].
visit
([
&
](
auto
input
)
{
if
(
input
.
get_shape
().
packed
())
{
std
::
transform
(
input
.
begin
(),
input
.
end
(),
output
.
begin
(),
static_cast
<
const
Derived
&>
(
*
this
).
apply
());
return
result
;
}
else
{
shape_for_each
(
output
.
get_shape
(),
[
&
](
const
auto
&
idx
)
{
output
(
idx
.
begin
(),
idx
.
end
())
=
static_cast
<
const
Derived
&>
(
*
this
).
apply
()(
input
(
idx
.
begin
(),
idx
.
end
()));
});
}
return
result
;
});
});
return
result
;
}
};
...
...
test/type_conversion.cpp
View file @
140fde0a
...
...
@@ -29,10 +29,10 @@ TEST_CASE(param_add)
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
2
,
3
}};
auto
p1
=
p
.
add_parameter
(
"x"
,
s
);
auto
hp1
=
p
.
insert_instruction
(
std
::
next
(
p1
),
migraphx
::
op
::
convert
{
migraphx
::
shape
::
half_type
},
p1
);
std
::
next
(
p1
),
migraphx
::
op
::
convert
{},
p1
);
auto
p2
=
p
.
add_parameter
(
"y"
,
s
);
auto
hp2
=
p
.
insert_instruction
(
std
::
next
(
p2
),
migraphx
::
op
::
convert
{
migraphx
::
shape
::
half_type
},
p2
);
std
::
next
(
p2
),
migraphx
::
op
::
convert
{},
p2
);
auto
hs
=
p
.
add_instruction
(
migraphx
::
op
::
add
{},
hp1
,
hp2
);
p
.
add_instruction
(
migraphx
::
op
::
convert
{
migraphx
::
shape
::
float_type
},
hs
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment