Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
1ce83fbe
Commit
1ce83fbe
authored
Aug 18, 2018
by
Paul
Browse files
Make shape ref counted
parent
0dd8ee4f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
83 additions
and
48 deletions
+83
-48
src/include/migraph/shape.hpp
src/include/migraph/shape.hpp
+6
-7
src/shape.cpp
src/shape.cpp
+77
-41
No files found.
src/include/migraph/shape.hpp
View file @
1ce83fbe
...
...
@@ -5,11 +5,14 @@
#include <cassert>
#include <ostream>
#include <numeric>
#include <memory>
#include <migraph/errors.hpp>
namespace
migraph
{
struct
shape_impl
;
struct
shape
{
...
...
@@ -136,7 +139,7 @@ struct shape
template
<
class
Visitor
>
void
visit_type
(
Visitor
v
)
const
{
switch
(
this
->
m_
type
)
switch
(
this
->
type
()
)
{
#define MIGRAPH_SHAPE_VISITOR_CASE(x, t) \
case x: v(as<t>()); return;
...
...
@@ -147,12 +150,8 @@ struct shape
}
private:
type_t
m_type
;
std
::
vector
<
std
::
size_t
>
m_lens
;
std
::
vector
<
std
::
size_t
>
m_strides
;
bool
m_standard
;
void
calculate_strides
();
std
::
shared_ptr
<
const
shape_impl
>
impl
;
std
::
size_t
element_space
()
const
;
std
::
string
type_string
()
const
;
};
...
...
src/shape.cpp
View file @
1ce83fbe
...
...
@@ -8,45 +8,90 @@
namespace
migraph
{
shape
::
shape
()
:
m_type
(
float_type
),
m_standard
(
false
)
{}
struct
shape_impl
{
static
std
::
shared_ptr
<
shape_impl
>
default_shape
()
{
static
std
::
shared_ptr
<
shape_impl
>
result
=
std
::
make_shared
<
shape_impl
>
();
return
result
;
}
shape_impl
()
:
m_type
(
shape
::
float_type
),
m_standard
(
false
)
{}
shape_impl
(
shape
::
type_t
t
)
:
m_type
(
t
),
m_lens
({
1
}),
m_strides
({
1
}),
m_standard
(
true
)
{}
shape_impl
(
shape
::
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
)
:
m_type
(
t
),
m_lens
(
std
::
move
(
l
)),
m_standard
(
true
)
{
this
->
calculate_strides
();
assert
(
m_lens
.
size
()
==
m_strides
.
size
());
}
shape_impl
(
shape
::
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
,
std
::
vector
<
std
::
size_t
>
s
)
:
m_type
(
t
),
m_lens
(
std
::
move
(
l
)),
m_strides
(
std
::
move
(
s
))
{
assert
(
m_lens
.
size
()
==
m_strides
.
size
());
assert
(
std
::
any_of
(
m_strides
.
begin
(),
m_strides
.
end
(),
[](
auto
x
)
{
return
x
>
0
;
})
and
"At least one stride must be non-zero"
);
m_standard
=
this
->
elements
()
==
this
->
element_space
()
and
std
::
is_sorted
(
m_strides
.
rbegin
(),
m_strides
.
rend
());
}
shape
::
type_t
m_type
;
std
::
vector
<
std
::
size_t
>
m_lens
;
std
::
vector
<
std
::
size_t
>
m_strides
;
bool
m_standard
;
void
calculate_strides
()
{
m_strides
.
clear
();
m_strides
.
resize
(
m_lens
.
size
(),
0
);
if
(
m_strides
.
empty
())
return
;
m_strides
.
back
()
=
1
;
std
::
partial_sum
(
m_lens
.
rbegin
(),
m_lens
.
rend
()
-
1
,
m_strides
.
rbegin
()
+
1
,
std
::
multiplies
<
std
::
size_t
>
());
}
std
::
size_t
element_space
()
const
{
assert
(
m_lens
.
size
()
==
m_strides
.
size
());
if
(
m_lens
.
empty
())
return
0
;
return
std
::
inner_product
(
m_lens
.
begin
(),
m_lens
.
end
(),
m_strides
.
begin
(),
std
::
size_t
{
0
},
std
::
plus
<
std
::
size_t
>
{},
[](
std
::
size_t
l
,
std
::
size_t
s
)
{
return
(
l
-
1
)
*
s
;
})
+
1
;
}
std
::
size_t
elements
()
const
{
assert
(
m_lens
.
size
()
==
m_strides
.
size
());
if
(
m_lens
.
empty
())
return
0
;
return
std
::
accumulate
(
m_lens
.
begin
(),
m_lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
}
};
shape
::
shape
(
type_t
t
)
:
m_type
(
t
),
m_lens
({
1
}),
m_strides
({
1
}),
m_standard
(
true
)
{}
shape
::
shape
()
:
impl
(
shape_impl
::
default_shape
())
{}
shape
::
shape
(
type_t
t
)
:
impl
(
std
::
make_shared
<
shape_impl
>
(
t
))
{}
shape
::
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
)
:
m_type
(
t
),
m_lens
(
std
::
move
(
l
)),
m_standard
(
true
)
:
impl
(
std
::
make_shared
<
shape_impl
>
(
t
,
std
::
move
(
l
))
)
{
this
->
calculate_strides
();
assert
(
m_lens
.
size
()
==
m_strides
.
size
());
}
shape
::
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
,
std
::
vector
<
std
::
size_t
>
s
)
:
m_type
(
t
),
m_lens
(
std
::
move
(
l
)),
m_strides
(
std
::
move
(
s
))
{
assert
(
m_lens
.
size
()
==
m_strides
.
size
());
assert
(
std
::
any_of
(
m_strides
.
begin
(),
m_strides
.
end
(),
[](
auto
x
)
{
return
x
>
0
;
})
and
"At least one stride must be non-zero"
);
m_standard
=
this
->
packed
()
and
not
this
->
transposed
();
}
void
shape
::
calculate_strides
()
:
impl
(
std
::
make_shared
<
shape_impl
>
(
t
,
std
::
move
(
l
),
std
::
move
(
s
)))
{
m_strides
.
clear
();
m_strides
.
resize
(
m_lens
.
size
(),
0
);
if
(
m_strides
.
empty
())
return
;
m_strides
.
back
()
=
1
;
std
::
partial_sum
(
m_lens
.
rbegin
(),
m_lens
.
rend
()
-
1
,
m_strides
.
rbegin
()
+
1
,
std
::
multiplies
<
std
::
size_t
>
());
}
shape
::
type_t
shape
::
type
()
const
{
return
this
->
m_type
;
}
const
std
::
vector
<
std
::
size_t
>&
shape
::
lens
()
const
{
return
this
->
m_lens
;
}
const
std
::
vector
<
std
::
size_t
>&
shape
::
strides
()
const
{
return
this
->
m_strides
;
}
shape
::
type_t
shape
::
type
()
const
{
return
impl
->
m_type
;
}
const
std
::
vector
<
std
::
size_t
>&
shape
::
lens
()
const
{
return
impl
->
m_lens
;
}
const
std
::
vector
<
std
::
size_t
>&
shape
::
strides
()
const
{
return
impl
->
m_strides
;
}
std
::
size_t
shape
::
elements
()
const
{
assert
(
this
->
lens
().
size
()
==
this
->
strides
().
size
());
if
(
this
->
lens
().
empty
())
return
0
;
return
std
::
accumulate
(
this
->
lens
().
begin
(),
this
->
lens
().
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
return
impl
->
elements
();
}
std
::
size_t
shape
::
bytes
()
const
{
...
...
@@ -98,25 +143,16 @@ bool shape::broadcasted() const
std
::
multiplies
<
std
::
size_t
>
())
==
0
;
}
bool
shape
::
standard
()
const
{
return
this
->
m_standard
;
}
bool
shape
::
standard
()
const
{
return
impl
->
m_standard
;
}
std
::
size_t
shape
::
element_space
()
const
{
assert
(
this
->
lens
().
size
()
==
this
->
strides
().
size
());
if
(
this
->
lens
().
empty
())
return
0
;
return
std
::
inner_product
(
this
->
lens
().
begin
(),
this
->
lens
().
end
(),
this
->
strides
().
begin
(),
std
::
size_t
{
0
},
std
::
plus
<
std
::
size_t
>
{},
[](
std
::
size_t
l
,
std
::
size_t
s
)
{
return
(
l
-
1
)
*
s
;
})
+
1
;
return
impl
->
element_space
();
}
std
::
string
shape
::
type_string
()
const
{
switch
(
this
->
m_
type
)
switch
(
this
->
type
()
)
{
#define MIGRAPH_SHAPE_TYPE_STRING_CASE(x, t) \
case x: return #x;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment