Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
523a78c7
Commit
523a78c7
authored
Sep 25, 2018
by
Scott Thornton
Browse files
Added slice w/ tests
parent
492bf901
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
108 additions
and
31 deletions
+108
-31
src/include/migraph/operators.hpp
src/include/migraph/operators.hpp
+50
-22
src/include/migraph/shape.hpp
src/include/migraph/shape.hpp
+1
-0
src/shape.cpp
src/shape.cpp
+6
-0
test/cpu_ops_test.cpp
test/cpu_ops_test.cpp
+37
-9
test/op_shape_test.cpp
test/op_shape_test.cpp
+14
-0
No files found.
src/include/migraph/operators.hpp
View file @
523a78c7
...
...
@@ -317,43 +317,71 @@ struct slice
std
::
vector
<
int64_t
>
starts
;
std
::
vector
<
int64_t
>
ends
;
std
::
string
name
()
const
{
return
"slice"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
auto
fix_index
(
const
std
::
vector
<
std
::
size_t
>&
lens
,
std
::
size_t
axis
,
int64_t
index
)
const
{
auto
input_shape
=
inputs
[
0
];
auto
t
=
input_shape
.
type
();
auto
old_lens
=
input_shape
.
lens
();
auto
old_strides
=
input_shape
.
strides
();
std
::
vector
<
int64_t
>
t_axes
(
old_lens
.
size
());
if
(
axes
.
size
()
==
0
)
std
::
size_t
r
=
std
::
min
(
index
,
static_cast
<
int64_t
>
(
lens
[
axis
]));
if
(
r
<
0
)
r
+=
lens
[
axis
];
return
r
;
}
auto
compute_offset
(
const
shape
&
s
)
const
{
const
std
::
vector
<
std
::
size_t
>&
lens
=
s
.
lens
();
const
std
::
vector
<
std
::
size_t
>&
strides
=
s
.
strides
();
auto
offset
=
0
;
if
(
axes
.
size
()
>
0
)
{
std
::
iota
(
t_axes
.
begin
(),
t_axes
.
end
(),
0
);
for
(
std
::
size_t
i
=
0
;
i
<
axes
.
size
();
i
++
)
{
auto
axis
=
axes
[
i
];
offset
+=
fix_index
(
lens
,
axis
,
starts
[
i
])
*
strides
[
axis
];
}
}
else
{
std
::
copy
(
axes
.
begin
(),
axes
.
end
(),
t_axes
.
begin
());
for
(
std
::
size_t
axis
=
0
;
axis
<
lens
.
size
();
axis
++
)
{
offset
+=
fix_index
(
lens
,
axis
,
starts
[
axis
])
*
strides
[
axis
];
}
}
if
(
starts
.
size
()
||
t_axes
.
size
()
!=
ends
.
size
())
return
offset
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
auto
input_shape
=
inputs
[
0
];
auto
t
=
input_shape
.
type
();
auto
old_lens
=
input_shape
.
lens
();
auto
old_strides
=
input_shape
.
strides
();
// std::vector<int64_t> t_axes(old_lens.size());
// if(axes.size() == 0)
// {
// std::iota(t_axes.begin(), t_axes.end(), 0);
// }
// else
// {
// std::copy(axes.begin(), axes.end(), t_axes.begin());
// }
if
(
starts
.
size
()
!=
axes
.
size
()
||
axes
.
size
()
!=
ends
.
size
())
{
MIGRAPH_THROW
(
"inconsistent sizes"
);
}
std
::
vector
<
std
::
size_t
>
new_lens
;
std
::
copy
(
old_lens
.
begin
(),
old_lens
.
end
(),
new_lens
.
begin
());
auto
fix_index
=
[
&
](
std
::
size_t
axis
,
int64_t
index
)
{
auto
r
=
std
::
min
(
index
,
static_cast
<
int64_t
>
(
old_lens
[
axis
]
-
1
));
if
(
r
<
0
)
r
+=
old_lens
[
axis
];
return
r
;
};
for
(
std
::
size_t
i
=
0
;
i
<
t_axes
.
size
();
i
++
)
std
::
vector
<
std
::
size_t
>
new_lens
=
old_lens
;
for
(
std
::
size_t
i
=
0
;
i
<
axes
.
size
();
i
++
)
{
auto
axis
=
t_axes
[
i
];
new_lens
[
axis
]
=
fix_index
(
axis
,
ends
[
i
])
-
fix_index
(
axis
,
starts
[
i
]);
auto
axis
=
axes
[
i
];
new_lens
[
axis
]
=
fix_index
(
old_lens
,
axis
,
ends
[
i
])
-
fix_index
(
old_lens
,
axis
,
starts
[
i
]);
}
return
shape
{
t
,
new_lens
,
old_strides
};
}
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
front
().
data
)};
auto
input
=
args
[
0
];
auto
offset
=
compute_offset
(
input
.
get_shape
())
*
output_shape
.
type_size
();
return
{
std
::
move
(
output_shape
),
[
=
]
{
return
input
.
data
()
+
offset
;
}};
}
};
...
...
src/include/migraph/shape.hpp
View file @
523a78c7
...
...
@@ -63,6 +63,7 @@ struct shape
const
std
::
vector
<
std
::
size_t
>&
strides
()
const
;
std
::
size_t
elements
()
const
;
std
::
size_t
bytes
()
const
;
std
::
size_t
type_size
()
const
;
/// Map multiple indices to space index
std
::
size_t
index
(
std
::
initializer_list
<
std
::
size_t
>
l
)
const
;
...
...
src/shape.cpp
View file @
523a78c7
...
...
@@ -98,6 +98,12 @@ std::size_t shape::bytes() const
this
->
visit_type
([
&
](
auto
as
)
{
n
=
as
.
size
();
});
return
n
*
this
->
element_space
();
}
std
::
size_t
shape
::
type_size
()
const
{
std
::
size_t
n
=
0
;
this
->
visit_type
([
&
](
auto
as
)
{
n
=
as
.
size
();
});
return
n
;
}
std
::
size_t
shape
::
index
(
std
::
initializer_list
<
std
::
size_t
>
l
)
const
{
assert
(
l
.
size
()
<=
this
->
lens
().
size
());
...
...
test/cpu_ops_test.cpp
View file @
523a78c7
...
...
@@ -8,15 +8,42 @@
void
slice_test
()
{
migraph
::
program
p
;
std
::
vector
<
float
>
data
(
4
*
3
*
2
);
std
::
iota
(
data
.
begin
(),
data
.
end
(),
0
);
migraph
::
shape
s
{
migraph
::
shape
::
float_type
,
{
4
,
2
,
3
}};
auto
l0
=
p
.
add_literal
(
migraph
::
literal
{
s
,
data
});
p
.
add_instruction
(
migraph
::
squeeze
{{
0
},
{
0
},
{
2
}},
l0
);
p
.
compile
(
migraph
::
cpu
::
cpu_target
{});
auto
result
=
p
.
eval
({});
EXPECT
(
result
.
get_shape
()
==
s2
);
{
migraph
::
program
p
;
std
::
vector
<
int
>
data
(
2
*
2
*
3
);
std
::
iota
(
data
.
begin
(),
data
.
end
(),
0
);
migraph
::
shape
s
{
migraph
::
shape
::
int32_type
,
{
2
,
2
,
3
}};
auto
l0
=
p
.
add_literal
(
migraph
::
literal
{
s
,
data
});
p
.
add_instruction
(
migraph
::
slice
{{
2
},
{
1
},
{
3
}},
l0
);
migraph
::
shape
s2
{
migraph
::
shape
::
int32_type
,
{
2
,
2
,
2
},
{
6
,
3
,
1
}};
EXPECT
(
p
.
get_shape
()
==
s2
);
p
.
compile
(
migraph
::
cpu
::
cpu_target
{});
migraph
::
shape
sresult
{
migraph
::
shape
::
int32_type
,
{
2
,
2
,
2
},
{
4
,
2
,
1
}};
auto
result
=
p
.
eval
({});
std
::
vector
<
int
>
gold
=
{
1
,
2
,
4
,
5
,
7
,
8
,
10
,
11
};
std
::
vector
<
int
>
results_vector
(
2
*
2
*
2
);
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraph
::
verify_range
(
results_vector
,
gold
));
EXPECT
(
result
.
get_shape
()
==
sresult
);
}
{
migraph
::
program
p
;
std
::
vector
<
int
>
data
(
2
*
2
*
3
);
std
::
iota
(
data
.
begin
(),
data
.
end
(),
0
);
migraph
::
shape
s
{
migraph
::
shape
::
int32_type
,
{
2
,
2
,
3
}};
auto
l0
=
p
.
add_literal
(
migraph
::
literal
{
s
,
data
});
p
.
add_instruction
(
migraph
::
slice
{{
0
,
1
,
2
},
{
0
,
0
,
0
},
{
2
,
2
,
2
}},
l0
);
migraph
::
shape
s2
{
migraph
::
shape
::
int32_type
,
{
2
,
2
,
2
},
{
6
,
3
,
1
}};
EXPECT
(
p
.
get_shape
()
==
s2
);
p
.
compile
(
migraph
::
cpu
::
cpu_target
{});
migraph
::
shape
sresult
{
migraph
::
shape
::
int32_type
,
{
2
,
2
,
2
},
{
4
,
2
,
1
}};
auto
result
=
p
.
eval
({});
std
::
vector
<
int
>
gold
=
{
0
,
1
,
3
,
4
,
6
,
7
,
9
,
10
};
std
::
vector
<
int
>
results_vector
(
2
*
2
*
2
);
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraph
::
verify_range
(
results_vector
,
gold
));
EXPECT
(
result
.
get_shape
()
==
sresult
);
}
}
void
squeeze_test
()
...
...
@@ -877,6 +904,7 @@ void contiguous_test()
int
main
()
{
slice_test
();
squeeze_test
();
unsqueeze_test
();
exp_test
();
...
...
test/op_shape_test.cpp
View file @
523a78c7
...
...
@@ -130,6 +130,19 @@ void flatten_shape()
throws_shape
(
migraph
::
flatten
{
5
},
input
);
}
void
slice_shape
()
{
migraph
::
shape
input
{
migraph
::
shape
::
int32_type
,
{
2
,
2
,
3
}};
expect_shape
(
migraph
::
shape
{
migraph
::
shape
::
int32_type
,
{
2
,
2
,
2
},
{
6
,
3
,
1
}},
migraph
::
slice
{{
2
},
{
1
},
{
3
}},
input
);
expect_shape
(
migraph
::
shape
{
migraph
::
shape
::
int32_type
,
{
2
,
2
,
2
},
{
6
,
3
,
1
}},
migraph
::
slice
{{
0
,
1
,
2
},
{
0
,
0
,
1
},
{
2
,
2
,
3
}},
input
);
expect_shape
(
migraph
::
shape
{
migraph
::
shape
::
int32_type
,
{
2
,
2
,
1
},
{
6
,
3
,
1
}},
migraph
::
slice
{{
2
},
{
2
},
{
10
}},
input
);
}
int
main
()
{
batch_norm_inference_shape
();
...
...
@@ -138,4 +151,5 @@ int main()
contiguous_shape
();
reshape_shape
();
flatten_shape
();
slice_shape
();
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment