Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
ffcd5b35
Commit
ffcd5b35
authored
Aug 02, 2018
by
Paul
Browse files
Add standard attribute to shape
parent
bc70ef12
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
52 additions
and
25 deletions
+52
-25
src/auto_contiguous.cpp
src/auto_contiguous.cpp
+1
-1
src/include/migraph/literal.hpp
src/include/migraph/literal.hpp
+1
-7
src/include/migraph/shape.hpp
src/include/migraph/shape.hpp
+3
-1
src/include/migraph/tensor_view.hpp
src/include/migraph/tensor_view.hpp
+5
-5
src/shape.cpp
src/shape.cpp
+10
-7
test/auto_contiguous_test.cpp
test/auto_contiguous_test.cpp
+6
-3
test/shape_test.cpp
test/shape_test.cpp
+26
-1
No files found.
src/auto_contiguous.cpp
View file @
ffcd5b35
...
@@ -11,7 +11,7 @@ void auto_contigous::apply(program& p) const
...
@@ -11,7 +11,7 @@ void auto_contigous::apply(program& p) const
for
(
auto
ins
:
iterator_for
(
p
))
for
(
auto
ins
:
iterator_for
(
p
))
{
{
shape
s
=
ins
->
result
;
shape
s
=
ins
->
result
;
if
(
not
s
.
packed
()
or
s
.
broadcaste
d
())
if
(
not
s
.
standar
d
())
{
{
auto
prev
=
p
.
insert_instruction
(
ins
,
ins
->
op
,
ins
->
arguments
);
auto
prev
=
p
.
insert_instruction
(
ins
,
ins
->
op
,
ins
->
arguments
);
p
.
replace_instruction
(
ins
,
contiguous
{},
prev
);
p
.
replace_instruction
(
ins
,
contiguous
{},
prev
);
...
...
src/include/migraph/literal.hpp
View file @
ffcd5b35
...
@@ -68,7 +68,7 @@ struct literal : raw_data<literal>
...
@@ -68,7 +68,7 @@ struct literal : raw_data<literal>
template
<
class
Iterator
>
template
<
class
Iterator
>
void
fill
(
Iterator
start
,
Iterator
end
)
void
fill
(
Iterator
start
,
Iterator
end
)
{
{
if
(
m_shape
.
packe
d
())
if
(
m_shape
.
standar
d
())
{
{
m_shape
.
visit_type
([
&
](
auto
as
)
{
std
::
copy
(
start
,
end
,
as
.
from
(
buffer
.
data
()));
});
m_shape
.
visit_type
([
&
](
auto
as
)
{
std
::
copy
(
start
,
end
,
as
.
from
(
buffer
.
data
()));
});
}
}
...
@@ -82,12 +82,6 @@ struct literal : raw_data<literal>
...
@@ -82,12 +82,6 @@ struct literal : raw_data<literal>
output
(
idx
.
begin
(),
idx
.
end
())
=
*
it
;
output
(
idx
.
begin
(),
idx
.
end
())
=
*
it
;
});
});
});
});
// visit_all(*this)([&](auto output) {
// shape_for_each(output.get_shape(), [&](const auto& idx) {
// it++;
// output(idx.begin(), idx.end()) = *it;
// });
// });
}
}
}
}
};
};
...
...
src/include/migraph/shape.hpp
View file @
ffcd5b35
...
@@ -76,7 +76,9 @@ struct shape
...
@@ -76,7 +76,9 @@ struct shape
std
::
size_t
index
(
std
::
size_t
i
)
const
;
std
::
size_t
index
(
std
::
size_t
i
)
const
;
bool
packed
()
const
;
bool
packed
()
const
;
bool
transposed
()
const
;
bool
broadcasted
()
const
;
bool
broadcasted
()
const
;
bool
standard
()
const
;
friend
bool
operator
==
(
const
shape
&
x
,
const
shape
&
y
);
friend
bool
operator
==
(
const
shape
&
x
,
const
shape
&
y
);
friend
bool
operator
!=
(
const
shape
&
x
,
const
shape
&
y
);
friend
bool
operator
!=
(
const
shape
&
x
,
const
shape
&
y
);
...
@@ -139,7 +141,7 @@ struct shape
...
@@ -139,7 +141,7 @@ struct shape
type_t
m_type
;
type_t
m_type
;
std
::
vector
<
std
::
size_t
>
m_lens
;
std
::
vector
<
std
::
size_t
>
m_lens
;
std
::
vector
<
std
::
size_t
>
m_strides
;
std
::
vector
<
std
::
size_t
>
m_strides
;
bool
m_
packe
d
;
bool
m_
standar
d
;
void
calculate_strides
();
void
calculate_strides
();
std
::
size_t
element_space
()
const
;
std
::
size_t
element_space
()
const
;
...
...
src/include/migraph/tensor_view.hpp
View file @
ffcd5b35
...
@@ -88,16 +88,16 @@ struct tensor_view
...
@@ -88,16 +88,16 @@ struct tensor_view
return
m_data
[
m_shape
.
index
(
this
->
size
()
-
1
)];
return
m_data
[
m_shape
.
index
(
this
->
size
()
-
1
)];
}
}
// TODO: Add iterators so it can handle non
packe
d tensors
// TODO: Add iterators so it can handle non
standar
d tensors
T
*
begin
()
T
*
begin
()
{
{
assert
(
this
->
m_shape
.
packe
d
());
assert
(
this
->
m_shape
.
standar
d
());
return
m_data
;
return
m_data
;
}
}
T
*
end
()
T
*
end
()
{
{
assert
(
this
->
m_shape
.
packe
d
());
assert
(
this
->
m_shape
.
standar
d
());
if
(
this
->
empty
())
if
(
this
->
empty
())
return
m_data
;
return
m_data
;
else
else
...
@@ -106,13 +106,13 @@ struct tensor_view
...
@@ -106,13 +106,13 @@ struct tensor_view
const
T
*
begin
()
const
const
T
*
begin
()
const
{
{
assert
(
this
->
m_shape
.
packe
d
());
assert
(
this
->
m_shape
.
standar
d
());
return
m_data
;
return
m_data
;
}
}
const
T
*
end
()
const
const
T
*
end
()
const
{
{
assert
(
this
->
m_shape
.
packe
d
());
assert
(
this
->
m_shape
.
standar
d
());
if
(
this
->
empty
())
if
(
this
->
empty
())
return
m_data
;
return
m_data
;
else
else
...
...
src/shape.cpp
View file @
ffcd5b35
...
@@ -8,10 +8,10 @@
...
@@ -8,10 +8,10 @@
namespace
migraph
{
namespace
migraph
{
shape
::
shape
()
:
m_type
(
float_type
),
m_
packe
d
(
false
)
{}
shape
::
shape
()
:
m_type
(
float_type
),
m_
standar
d
(
false
)
{}
shape
::
shape
(
type_t
t
)
:
m_type
(
t
),
m_lens
({
1
}),
m_strides
({
1
}),
m_
packe
d
(
true
)
{}
shape
::
shape
(
type_t
t
)
:
m_type
(
t
),
m_lens
({
1
}),
m_strides
({
1
}),
m_
standar
d
(
true
)
{}
shape
::
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
)
:
m_type
(
t
),
m_lens
(
std
::
move
(
l
)),
m_
packe
d
(
true
)
shape
::
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
)
:
m_type
(
t
),
m_lens
(
std
::
move
(
l
)),
m_
standar
d
(
true
)
{
{
this
->
calculate_strides
();
this
->
calculate_strides
();
assert
(
m_lens
.
size
()
==
m_strides
.
size
());
assert
(
m_lens
.
size
()
==
m_strides
.
size
());
...
@@ -22,8 +22,7 @@ shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
...
@@ -22,8 +22,7 @@ shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
assert
(
m_lens
.
size
()
==
m_strides
.
size
());
assert
(
m_lens
.
size
()
==
m_strides
.
size
());
assert
(
std
::
any_of
(
m_strides
.
begin
(),
m_strides
.
end
(),
[](
auto
x
)
{
return
x
>
0
;
})
and
assert
(
std
::
any_of
(
m_strides
.
begin
(),
m_strides
.
end
(),
[](
auto
x
)
{
return
x
>
0
;
})
and
"At least one stride must be non-zero"
);
"At least one stride must be non-zero"
);
m_packed
=
this
->
elements
()
==
this
->
element_space
()
and
m_standard
=
this
->
packed
()
and
not
this
->
transposed
();
std
::
is_sorted
(
m_strides
.
rbegin
(),
m_strides
.
rend
());
}
}
void
shape
::
calculate_strides
()
void
shape
::
calculate_strides
()
...
@@ -67,7 +66,7 @@ std::size_t shape::index(const std::vector<std::size_t>& l) const
...
@@ -67,7 +66,7 @@ std::size_t shape::index(const std::vector<std::size_t>& l) const
std
::
size_t
shape
::
index
(
std
::
size_t
i
)
const
std
::
size_t
shape
::
index
(
std
::
size_t
i
)
const
{
{
assert
(
this
->
lens
().
size
()
==
this
->
strides
().
size
());
assert
(
this
->
lens
().
size
()
==
this
->
strides
().
size
());
if
(
this
->
packe
d
())
if
(
this
->
standar
d
())
return
i
;
return
i
;
else
else
return
std
::
inner_product
(
this
->
lens
().
begin
(),
return
std
::
inner_product
(
this
->
lens
().
begin
(),
...
@@ -80,7 +79,9 @@ std::size_t shape::index(std::size_t i) const
...
@@ -80,7 +79,9 @@ std::size_t shape::index(std::size_t i) const
return
((
i
/
stride
)
%
len
)
*
stride
;
return
((
i
/
stride
)
%
len
)
*
stride
;
});
});
}
}
bool
shape
::
packed
()
const
{
return
this
->
m_packed
;
}
bool
shape
::
packed
()
const
{
return
this
->
elements
()
==
this
->
element_space
();
}
bool
shape
::
transposed
()
const
{
return
not
std
::
is_sorted
(
this
->
strides
().
rbegin
(),
this
->
strides
().
rend
());
}
bool
shape
::
broadcasted
()
const
bool
shape
::
broadcasted
()
const
{
{
...
@@ -91,6 +92,8 @@ bool shape::broadcasted() const
...
@@ -91,6 +92,8 @@ bool shape::broadcasted() const
std
::
multiplies
<
std
::
size_t
>
())
==
0
;
std
::
multiplies
<
std
::
size_t
>
())
==
0
;
}
}
bool
shape
::
standard
()
const
{
return
this
->
m_standard
;
}
std
::
size_t
shape
::
element_space
()
const
std
::
size_t
shape
::
element_space
()
const
{
{
assert
(
this
->
lens
().
size
()
==
this
->
strides
().
size
());
assert
(
this
->
lens
().
size
()
==
this
->
strides
().
size
());
...
...
test/auto_contiguous_test.cpp
View file @
ffcd5b35
...
@@ -22,11 +22,14 @@ void after_literal_transpose()
...
@@ -22,11 +22,14 @@ void after_literal_transpose()
{
{
migraph
::
program
p
;
migraph
::
program
p
;
auto
l
=
p
.
add_literal
(
get_2x2
());
auto
l
=
p
.
add_literal
(
get_2x2
());
EXPECT
(
p
.
get_shape
().
packed
());
EXPECT
(
p
.
get_shape
().
standard
());
EXPECT
(
not
p
.
get_shape
().
transposed
());
p
.
add_instruction
(
migraph
::
transpose
{{
1
,
0
}},
l
);
p
.
add_instruction
(
migraph
::
transpose
{{
1
,
0
}},
l
);
EXPECT
(
not
p
.
get_shape
().
packed
());
EXPECT
(
not
p
.
get_shape
().
standard
());
EXPECT
(
p
.
get_shape
().
transposed
());
p
.
compile
(
contigous_target
{});
p
.
compile
(
contigous_target
{});
EXPECT
(
p
.
get_shape
().
packed
());
EXPECT
(
p
.
get_shape
().
standard
());
EXPECT
(
not
p
.
get_shape
().
transposed
());
}
}
int
main
()
{
after_literal_transpose
();
}
int
main
()
{
after_literal_transpose
();
}
test/shape_test.cpp
View file @
ffcd5b35
...
@@ -16,19 +16,37 @@ void test_shape_assign()
...
@@ -16,19 +16,37 @@ void test_shape_assign()
void
test_shape_packed_default
()
void
test_shape_packed_default
()
{
{
migraph
::
shape
s
{
migraph
::
shape
::
float_type
,
{
2
,
2
}};
migraph
::
shape
s
{
migraph
::
shape
::
float_type
,
{
2
,
2
}};
EXPECT
(
s
.
standard
());
EXPECT
(
s
.
packed
());
EXPECT
(
s
.
packed
());
EXPECT
(
not
s
.
transposed
());
EXPECT
(
not
s
.
broadcasted
());
}
}
void
test_shape_packed
()
void
test_shape_packed
()
{
{
migraph
::
shape
s
{
migraph
::
shape
::
float_type
,
{
2
,
2
},
{
2
,
1
}};
migraph
::
shape
s
{
migraph
::
shape
::
float_type
,
{
2
,
2
},
{
2
,
1
}};
EXPECT
(
s
.
standard
());
EXPECT
(
s
.
packed
());
EXPECT
(
s
.
packed
());
EXPECT
(
not
s
.
transposed
());
EXPECT
(
not
s
.
broadcasted
());
}
}
void
test_shape_transposed
()
void
test_shape_transposed
()
{
{
migraph
::
shape
s
{
migraph
::
shape
::
float_type
,
{
2
,
2
},
{
1
,
2
}};
migraph
::
shape
s
{
migraph
::
shape
::
float_type
,
{
2
,
2
},
{
1
,
2
}};
EXPECT
(
not
s
.
standard
());
EXPECT
(
s
.
packed
());
EXPECT
(
s
.
transposed
());
EXPECT
(
not
s
.
broadcasted
());
}
void
test_shape_broadcasted
()
{
migraph
::
shape
s
{
migraph
::
shape
::
float_type
,
{
2
,
2
},
{
1
,
0
}};
EXPECT
(
not
s
.
standard
());
EXPECT
(
not
s
.
packed
());
EXPECT
(
not
s
.
packed
());
EXPECT
(
not
s
.
transposed
());
EXPECT
(
s
.
broadcasted
());
}
}
void
test_shape_default
()
void
test_shape_default
()
...
@@ -42,7 +60,10 @@ void test_shape_default()
...
@@ -42,7 +60,10 @@ void test_shape_default()
void
test_shape4
()
void
test_shape4
()
{
{
migraph
::
shape
s
{
migraph
::
shape
::
float_type
,
{
100
,
32
,
8
,
8
}};
migraph
::
shape
s
{
migraph
::
shape
::
float_type
,
{
100
,
32
,
8
,
8
}};
EXPECT
(
s
.
standard
());
EXPECT
(
s
.
packed
());
EXPECT
(
s
.
packed
());
EXPECT
(
not
s
.
transposed
());
EXPECT
(
not
s
.
broadcasted
());
EXPECT
(
s
.
type
()
==
migraph
::
shape
::
float_type
);
EXPECT
(
s
.
type
()
==
migraph
::
shape
::
float_type
);
EXPECT
(
s
.
lens
()[
0
]
==
100
);
EXPECT
(
s
.
lens
()[
0
]
==
100
);
EXPECT
(
s
.
lens
()[
1
]
==
32
);
EXPECT
(
s
.
lens
()[
1
]
==
32
);
...
@@ -86,7 +107,10 @@ void test_shape4_nonpacked()
...
@@ -86,7 +107,10 @@ void test_shape4_nonpacked()
std
::
multiplies
<
std
::
size_t
>
());
std
::
multiplies
<
std
::
size_t
>
());
migraph
::
shape
s
{
migraph
::
shape
::
float_type
,
lens
,
strides
};
migraph
::
shape
s
{
migraph
::
shape
::
float_type
,
lens
,
strides
};
EXPECT
(
!
s
.
packed
());
EXPECT
(
not
s
.
standard
());
EXPECT
(
not
s
.
packed
());
EXPECT
(
not
s
.
transposed
());
EXPECT
(
not
s
.
broadcasted
());
EXPECT
(
s
.
type
()
==
migraph
::
shape
::
float_type
);
EXPECT
(
s
.
type
()
==
migraph
::
shape
::
float_type
);
EXPECT
(
s
.
lens
()[
0
]
==
100
);
EXPECT
(
s
.
lens
()[
0
]
==
100
);
EXPECT
(
s
.
lens
()[
1
]
==
32
);
EXPECT
(
s
.
lens
()[
1
]
==
32
);
...
@@ -116,6 +140,7 @@ int main()
...
@@ -116,6 +140,7 @@ int main()
test_shape_packed_default
();
test_shape_packed_default
();
test_shape_packed
();
test_shape_packed
();
test_shape_transposed
();
test_shape_transposed
();
test_shape_broadcasted
();
test_shape_default
();
test_shape_default
();
test_shape4
();
test_shape4
();
test_shape4_nonpacked
();
test_shape4_nonpacked
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment