Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
1ce83fbe
Commit
1ce83fbe
authored
Aug 18, 2018
by
Paul
Browse files
Make shape ref counted
parent
0dd8ee4f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
83 additions
and
48 deletions
+83
-48
src/include/migraph/shape.hpp
src/include/migraph/shape.hpp
+6
-7
src/shape.cpp
src/shape.cpp
+77
-41
No files found.
src/include/migraph/shape.hpp
View file @
1ce83fbe
...
@@ -5,11 +5,14 @@
...
@@ -5,11 +5,14 @@
#include <cassert>
#include <cassert>
#include <ostream>
#include <ostream>
#include <numeric>
#include <numeric>
#include <memory>
#include <migraph/errors.hpp>
#include <migraph/errors.hpp>
namespace
migraph
{
namespace
migraph
{
struct
shape_impl
;
struct
shape
struct
shape
{
{
...
@@ -136,7 +139,7 @@ struct shape
...
@@ -136,7 +139,7 @@ struct shape
template
<
class
Visitor
>
template
<
class
Visitor
>
void
visit_type
(
Visitor
v
)
const
void
visit_type
(
Visitor
v
)
const
{
{
switch
(
this
->
m_
type
)
switch
(
this
->
type
()
)
{
{
#define MIGRAPH_SHAPE_VISITOR_CASE(x, t) \
#define MIGRAPH_SHAPE_VISITOR_CASE(x, t) \
case x: v(as<t>()); return;
case x: v(as<t>()); return;
...
@@ -147,12 +150,8 @@ struct shape
...
@@ -147,12 +150,8 @@ struct shape
}
}
private:
private:
type_t
m_type
;
std
::
shared_ptr
<
const
shape_impl
>
impl
;
std
::
vector
<
std
::
size_t
>
m_lens
;
std
::
vector
<
std
::
size_t
>
m_strides
;
bool
m_standard
;
void
calculate_strides
();
std
::
size_t
element_space
()
const
;
std
::
size_t
element_space
()
const
;
std
::
string
type_string
()
const
;
std
::
string
type_string
()
const
;
};
};
...
...
src/shape.cpp
View file @
1ce83fbe
...
@@ -8,45 +8,90 @@
...
@@ -8,45 +8,90 @@
namespace
migraph
{
namespace
migraph
{
shape
::
shape
()
:
m_type
(
float_type
),
m_standard
(
false
)
{}
struct
shape_impl
{
static
std
::
shared_ptr
<
shape_impl
>
default_shape
()
{
static
std
::
shared_ptr
<
shape_impl
>
result
=
std
::
make_shared
<
shape_impl
>
();
return
result
;
}
shape_impl
()
:
m_type
(
shape
::
float_type
),
m_standard
(
false
)
{}
shape_impl
(
shape
::
type_t
t
)
:
m_type
(
t
),
m_lens
({
1
}),
m_strides
({
1
}),
m_standard
(
true
)
{}
shape_impl
(
shape
::
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
)
:
m_type
(
t
),
m_lens
(
std
::
move
(
l
)),
m_standard
(
true
)
{
this
->
calculate_strides
();
assert
(
m_lens
.
size
()
==
m_strides
.
size
());
}
shape_impl
(
shape
::
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
,
std
::
vector
<
std
::
size_t
>
s
)
:
m_type
(
t
),
m_lens
(
std
::
move
(
l
)),
m_strides
(
std
::
move
(
s
))
{
assert
(
m_lens
.
size
()
==
m_strides
.
size
());
assert
(
std
::
any_of
(
m_strides
.
begin
(),
m_strides
.
end
(),
[](
auto
x
)
{
return
x
>
0
;
})
and
"At least one stride must be non-zero"
);
m_standard
=
this
->
elements
()
==
this
->
element_space
()
and
std
::
is_sorted
(
m_strides
.
rbegin
(),
m_strides
.
rend
());
}
shape
::
type_t
m_type
;
std
::
vector
<
std
::
size_t
>
m_lens
;
std
::
vector
<
std
::
size_t
>
m_strides
;
bool
m_standard
;
void
calculate_strides
()
{
m_strides
.
clear
();
m_strides
.
resize
(
m_lens
.
size
(),
0
);
if
(
m_strides
.
empty
())
return
;
m_strides
.
back
()
=
1
;
std
::
partial_sum
(
m_lens
.
rbegin
(),
m_lens
.
rend
()
-
1
,
m_strides
.
rbegin
()
+
1
,
std
::
multiplies
<
std
::
size_t
>
());
}
std
::
size_t
element_space
()
const
{
assert
(
m_lens
.
size
()
==
m_strides
.
size
());
if
(
m_lens
.
empty
())
return
0
;
return
std
::
inner_product
(
m_lens
.
begin
(),
m_lens
.
end
(),
m_strides
.
begin
(),
std
::
size_t
{
0
},
std
::
plus
<
std
::
size_t
>
{},
[](
std
::
size_t
l
,
std
::
size_t
s
)
{
return
(
l
-
1
)
*
s
;
})
+
1
;
}
std
::
size_t
elements
()
const
{
assert
(
m_lens
.
size
()
==
m_strides
.
size
());
if
(
m_lens
.
empty
())
return
0
;
return
std
::
accumulate
(
m_lens
.
begin
(),
m_lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
}
};
shape
::
shape
(
type_t
t
)
:
m_type
(
t
),
m_lens
({
1
}),
m_strides
({
1
}),
m_standard
(
true
)
{}
shape
::
shape
()
:
impl
(
shape_impl
::
default_shape
())
{}
shape
::
shape
(
type_t
t
)
:
impl
(
std
::
make_shared
<
shape_impl
>
(
t
))
{}
shape
::
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
)
shape
::
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
)
:
m_type
(
t
),
m_lens
(
std
::
move
(
l
)),
m_standard
(
true
)
:
impl
(
std
::
make_shared
<
shape_impl
>
(
t
,
std
::
move
(
l
))
)
{
{
this
->
calculate_strides
();
assert
(
m_lens
.
size
()
==
m_strides
.
size
());
}
}
shape
::
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
,
std
::
vector
<
std
::
size_t
>
s
)
shape
::
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
,
std
::
vector
<
std
::
size_t
>
s
)
:
m_type
(
t
),
m_lens
(
std
::
move
(
l
)),
m_strides
(
std
::
move
(
s
))
:
impl
(
std
::
make_shared
<
shape_impl
>
(
t
,
std
::
move
(
l
),
std
::
move
(
s
)))
{
assert
(
m_lens
.
size
()
==
m_strides
.
size
());
assert
(
std
::
any_of
(
m_strides
.
begin
(),
m_strides
.
end
(),
[](
auto
x
)
{
return
x
>
0
;
})
and
"At least one stride must be non-zero"
);
m_standard
=
this
->
packed
()
and
not
this
->
transposed
();
}
void
shape
::
calculate_strides
()
{
{
m_strides
.
clear
();
m_strides
.
resize
(
m_lens
.
size
(),
0
);
if
(
m_strides
.
empty
())
return
;
m_strides
.
back
()
=
1
;
std
::
partial_sum
(
m_lens
.
rbegin
(),
m_lens
.
rend
()
-
1
,
m_strides
.
rbegin
()
+
1
,
std
::
multiplies
<
std
::
size_t
>
());
}
}
shape
::
type_t
shape
::
type
()
const
{
return
this
->
m_type
;
}
shape
::
type_t
shape
::
type
()
const
{
return
impl
->
m_type
;
}
const
std
::
vector
<
std
::
size_t
>&
shape
::
lens
()
const
{
return
this
->
m_lens
;
}
const
std
::
vector
<
std
::
size_t
>&
shape
::
lens
()
const
{
return
impl
->
m_lens
;
}
const
std
::
vector
<
std
::
size_t
>&
shape
::
strides
()
const
{
return
this
->
m_strides
;
}
const
std
::
vector
<
std
::
size_t
>&
shape
::
strides
()
const
{
return
impl
->
m_strides
;
}
std
::
size_t
shape
::
elements
()
const
std
::
size_t
shape
::
elements
()
const
{
{
assert
(
this
->
lens
().
size
()
==
this
->
strides
().
size
());
return
impl
->
elements
();
if
(
this
->
lens
().
empty
())
return
0
;
return
std
::
accumulate
(
this
->
lens
().
begin
(),
this
->
lens
().
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
}
}
std
::
size_t
shape
::
bytes
()
const
std
::
size_t
shape
::
bytes
()
const
{
{
...
@@ -98,25 +143,16 @@ bool shape::broadcasted() const
...
@@ -98,25 +143,16 @@ bool shape::broadcasted() const
std
::
multiplies
<
std
::
size_t
>
())
==
0
;
std
::
multiplies
<
std
::
size_t
>
())
==
0
;
}
}
bool
shape
::
standard
()
const
{
return
this
->
m_standard
;
}
bool
shape
::
standard
()
const
{
return
impl
->
m_standard
;
}
std
::
size_t
shape
::
element_space
()
const
std
::
size_t
shape
::
element_space
()
const
{
{
assert
(
this
->
lens
().
size
()
==
this
->
strides
().
size
());
return
impl
->
element_space
();
if
(
this
->
lens
().
empty
())
return
0
;
return
std
::
inner_product
(
this
->
lens
().
begin
(),
this
->
lens
().
end
(),
this
->
strides
().
begin
(),
std
::
size_t
{
0
},
std
::
plus
<
std
::
size_t
>
{},
[](
std
::
size_t
l
,
std
::
size_t
s
)
{
return
(
l
-
1
)
*
s
;
})
+
1
;
}
}
std
::
string
shape
::
type_string
()
const
std
::
string
shape
::
type_string
()
const
{
{
switch
(
this
->
m_
type
)
switch
(
this
->
type
()
)
{
{
#define MIGRAPH_SHAPE_TYPE_STRING_CASE(x, t) \
#define MIGRAPH_SHAPE_TYPE_STRING_CASE(x, t) \
case x: return #x;
case x: return #x;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment