Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
717744ce
Commit
717744ce
authored
Mar 29, 2018
by
Paul
Browse files
Add tensor_view class
parent
2372171d
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
229 additions
and
27 deletions
+229
-27
include/rtg/argument.hpp
include/rtg/argument.hpp
+24
-3
include/rtg/literal.hpp
include/rtg/literal.hpp
+22
-12
include/rtg/shape.hpp
include/rtg/shape.hpp
+10
-4
include/rtg/tensor_view.hpp
include/rtg/tensor_view.hpp
+154
-0
src/program.cpp
src/program.cpp
+1
-1
src/shape.cpp
src/shape.cpp
+13
-2
test/eval_test.cpp
test/eval_test.cpp
+5
-5
No files found.
include/rtg/argument.hpp
View file @
717744ce
...
...
@@ -8,16 +8,37 @@ namespace rtg {
struct
argument
{
argument
()
{}
argument
(
shape
s
,
std
::
function
<
char
*
()
>
d
)
:
data
(
d
),
shape_
(
s
)
{}
std
::
function
<
char
*
()
>
data
;
shape
s
;
const
shape
&
get_shape
()
const
{
return
this
->
shape_
;
}
template
<
class
Visitor
>
void
visit_at
(
Visitor
v
,
std
::
size_t
n
=
0
)
const
{
shape_
.
visit_type
([
&
](
auto
as
)
{
v
(
*
(
as
.
from
(
this
->
data
())
+
shape_
.
index
(
n
)));
});
}
template
<
class
Visitor
>
void
visit
(
Visitor
v
)
const
{
s
.
visit_type
([
&
](
auto
as
)
{
v
(
as
.
from
(
data
()));
s
hape_
.
visit_type
([
&
](
auto
as
)
{
v
(
make_view
(
this
->
shape_
,
as
.
from
(
this
->
data
()))
)
;
});
}
private:
shape
shape_
;
};
}
...
...
include/rtg/literal.hpp
View file @
717744ce
...
...
@@ -3,6 +3,7 @@
#include <rtg/shape.hpp>
#include <rtg/argument.hpp>
#include <rtg/tensor_view.hpp>
namespace
rtg
{
...
...
@@ -37,12 +38,11 @@ struct literal
bool
result
=
x
.
buffer
.
empty
()
&&
y
.
buffer
.
empty
();
if
(
not
result
&&
x
.
shape_
==
y
.
shape_
and
x
.
buffer
.
size
()
==
y
.
buffer
.
size
())
{
// TODO: Dont use tensor view for single values
x
.
shape_
.
visit_type
([
&
](
auto
as
)
{
auto
space
=
x
.
shape_
.
bytes
()
/
sizeof
(
as
());
auto
*
xstart
=
&
as
.
from
(
x
.
buffer
.
data
());
auto
*
ystart
=
&
as
.
from
(
y
.
buffer
.
data
());
result
=
std
::
equal
(
xstart
,
xstart
+
space
,
ystart
,
ystart
+
space
);
auto
xview
=
make_view
(
x
.
shape_
,
as
.
from
(
x
.
buffer
.
data
()));
auto
yview
=
make_view
(
y
.
shape_
,
as
.
from
(
y
.
buffer
.
data
()));
result
=
xview
==
yview
;
});
}
return
result
;
...
...
@@ -54,10 +54,18 @@ struct literal
}
template
<
class
Visitor
>
void
visit
(
Visitor
v
,
std
::
size_t
n
=
0
)
const
void
visit_at
(
Visitor
v
,
std
::
size_t
n
=
0
)
const
{
shape_
.
visit_type
([
&
](
auto
as
)
{
v
(
*
(
as
.
from
(
this
->
buffer
.
data
())
+
shape_
.
index
(
n
)));
});
}
template
<
class
Visitor
>
void
visit
(
Visitor
v
)
const
{
shape_
.
visit_type
([
&
](
auto
as
)
{
v
(
as
.
from
(
this
->
buffer
.
data
()
,
n
));
v
(
make_view
(
this
->
shape_
,
as
.
from
(
this
->
buffer
.
data
()
)
));
});
}
...
...
@@ -66,11 +74,16 @@ struct literal
return
this
->
buffer
.
empty
();
}
bool
single
()
const
{
return
this
->
shape_
.
elements
()
==
1
;
}
template
<
class
T
>
T
at
(
std
::
size_t
n
=
0
)
const
{
T
result
;
this
->
visit
([
&
](
auto
x
)
{
this
->
visit
_at
([
&
](
auto
x
)
{
result
=
x
;
});
return
result
;
...
...
@@ -83,11 +96,8 @@ struct literal
argument
get_argument
()
const
{
argument
arg
;
auto
b
=
buffer
;
arg
.
s
=
shape_
;
arg
.
data
=
[
b
]()
mutable
{
return
b
.
data
();
};
return
arg
;
return
{
shape_
,
[
b
]()
mutable
{
return
b
.
data
();
}};
}
private:
...
...
include/rtg/shape.hpp
View file @
717744ce
...
...
@@ -46,6 +46,11 @@ struct shape
std
::
size_t
index
(
std
::
initializer_list
<
std
::
size_t
>
l
)
const
;
std
::
size_t
index
(
const
std
::
vector
<
std
::
size_t
>&
l
)
const
;
// Map element index to space index
std
::
size_t
index
(
std
::
size_t
i
)
const
;
bool
packed
()
const
;
friend
bool
operator
==
(
const
shape
&
x
,
const
shape
&
y
);
friend
bool
operator
!=
(
const
shape
&
x
,
const
shape
&
y
);
...
...
@@ -83,15 +88,15 @@ struct shape
}
template
<
class
U
>
T
&
from
(
U
*
buffer
,
std
::
size_t
n
=
0
)
const
T
*
from
(
U
*
buffer
,
std
::
size_t
n
=
0
)
const
{
return
*
(
reinterpret_cast
<
T
*>
(
buffer
)
+
n
)
;
return
reinterpret_cast
<
T
*>
(
buffer
)
+
n
;
}
template
<
class
U
>
const
T
&
from
(
const
U
*
buffer
,
std
::
size_t
n
=
0
)
const
const
T
*
from
(
const
U
*
buffer
,
std
::
size_t
n
=
0
)
const
{
return
*
(
reinterpret_cast
<
const
T
*>
(
buffer
)
+
n
)
;
return
reinterpret_cast
<
const
T
*>
(
buffer
)
+
n
;
}
};
...
...
@@ -113,6 +118,7 @@ private:
type_t
type_
;
std
::
vector
<
std
::
size_t
>
lens_
;
std
::
vector
<
std
::
size_t
>
strides_
;
bool
packed_
;
void
calculate_strides
();
std
::
size_t
element_space
()
const
;
...
...
include/rtg/tensor_view.hpp
0 → 100644
View file @
717744ce
#ifndef RTG_GUARD_TENSOR_VIEW_HPP
#define RTG_GUARD_TENSOR_VIEW_HPP
#include <rtg/shape.hpp>
#include <iostream>
namespace
rtg
{
template
<
class
T
>
struct
tensor_view
{
tensor_view
()
:
data_
(
nullptr
),
shape_
()
{}
tensor_view
(
shape
s
,
T
*
d
)
:
data_
(
d
),
shape_
(
s
)
{}
const
shape
&
get_shape
()
const
{
return
this
->
shape_
;
}
bool
empty
()
const
{
return
data_
==
nullptr
||
shape_
.
lens
().
size
()
==
0
;
}
std
::
size_t
size
()
const
{
return
shape_
.
elements
();
}
T
*
data
()
{
return
this
->
data_
;
}
const
T
*
data
()
const
{
return
this
->
data_
;
}
template
<
class
...
Ts
>
const
T
&
operator
()(
Ts
...
xs
)
const
{
return
data_
[
shape_
.
index
({
xs
...})];
}
template
<
class
...
Ts
>
T
&
operator
()(
Ts
...
xs
)
{
return
data_
[
shape_
.
index
({
xs
...})];
}
T
&
operator
[](
std
::
size_t
i
)
{
assert
(
!
this
->
empty
()
&&
i
<
this
->
size
());
return
data_
[
shape_
.
index
(
i
)];
}
const
T
&
operator
[](
std
::
size_t
i
)
const
{
assert
(
!
this
->
empty
()
&&
i
<
this
->
size
());
return
data_
[
shape_
.
index
(
i
)];
}
T
&
front
()
{
assert
(
!
this
->
empty
());
return
data_
[
0
];
}
const
T
&
front
()
const
{
assert
(
!
this
->
empty
());
return
data_
[
0
];
}
T
&
back
()
{
assert
(
!
this
->
empty
());
return
data_
[
shape_
.
index
(
this
->
size
()
-
1
)];
}
const
T
&
back
()
const
{
assert
(
!
this
->
empty
());
return
data_
[
shape_
.
index
(
this
->
size
()
-
1
)];
}
// TODO: Add iterators so it can handle nonpacked tensors
T
*
begin
()
{
assert
(
this
->
shape_
.
packed
());
return
data_
;
}
T
*
end
()
{
assert
(
this
->
shape_
.
packed
());
if
(
this
->
empty
())
return
data_
;
else
return
data_
+
this
->
size
();
}
const
T
*
begin
()
const
{
assert
(
this
->
shape_
.
packed
());
return
data_
;
}
const
T
*
end
()
const
{
assert
(
this
->
shape_
.
packed
());
if
(
this
->
empty
())
return
data_
;
else
return
data_
+
this
->
size
();
}
friend
bool
operator
==
(
const
tensor_view
<
T
>&
x
,
const
tensor_view
<
T
>&
y
)
{
if
(
x
.
shape_
==
y
.
shape_
)
{
for
(
std
::
size_t
i
=
0
;
i
<
x
.
shape_
.
elements
();
i
++
)
{
std
::
cout
<<
x
[
i
]
<<
" == "
<<
y
[
i
]
<<
std
::
endl
;
if
(
x
[
i
]
==
y
[
i
])
std
::
cout
<<
"true"
<<
std
::
endl
;;
if
(
x
[
i
]
!=
y
[
i
])
std
::
cout
<<
"true"
<<
std
::
endl
;;
if
(
x
[
i
]
!=
y
[
i
])
return
false
;
}
return
true
;
}
return
false
;
}
friend
bool
operator
!=
(
const
tensor_view
<
T
>&
x
,
const
tensor_view
<
T
>&
y
)
{
return
!
(
x
==
y
);
}
private:
T
*
data_
;
shape
shape_
;
};
template
<
class
T
>
tensor_view
<
T
>
make_view
(
shape
s
,
T
*
data
)
{
return
{
s
,
data
};
}
}
// namespace rtg
#endif
src/program.cpp
View file @
717744ce
...
...
@@ -24,7 +24,7 @@ literal program::eval() const
}
results
.
emplace
(
std
::
addressof
(
ins
),
result
);
}
return
literal
{
result
.
s
,
result
.
data
()};
return
literal
{
result
.
get_shape
()
,
result
.
data
()};
}
}
...
...
src/shape.cpp
View file @
717744ce
...
...
@@ -11,10 +11,10 @@ shape::shape()
{}
shape
::
shape
(
type_t
t
)
:
type_
(
t
),
lens_
({
1
}),
strides_
({
1
})
:
type_
(
t
),
lens_
({
1
}),
strides_
({
1
})
,
packed_
(
true
)
{}
shape
::
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
)
:
type_
(
t
),
lens_
(
std
::
move
(
l
))
:
type_
(
t
),
lens_
(
std
::
move
(
l
))
,
packed_
(
true
)
{
this
->
calculate_strides
();
assert
(
lens_
.
size
()
==
strides_
.
size
());
...
...
@@ -23,6 +23,7 @@ shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
:
type_
(
t
),
lens_
(
std
::
move
(
l
)),
strides_
(
std
::
move
(
s
))
{
assert
(
lens_
.
size
()
==
strides_
.
size
());
packed_
=
this
->
elements
()
==
this
->
element_space
();
}
void
shape
::
calculate_strides
()
...
...
@@ -72,6 +73,16 @@ std::size_t shape::index(const std::vector<std::size_t>& l) const
assert
(
this
->
lens
().
size
()
==
this
->
strides
().
size
());
return
std
::
inner_product
(
l
.
begin
(),
l
.
end
(),
this
->
strides
().
begin
(),
std
::
size_t
{
0
});
}
std
::
size_t
shape
::
index
(
std
::
size_t
i
)
const
{
assert
(
this
->
lens
().
size
()
==
this
->
strides
().
size
());
return
std
::
inner_product
(
this
->
lens
().
begin
(),
this
->
lens
().
end
(),
this
->
strides
().
begin
(),
std
::
size_t
{
0
},
std
::
plus
<
std
::
size_t
>
{},
[
&
](
std
::
size_t
len
,
std
::
size_t
stride
)
{
return
((
i
/
stride
)
%
len
)
*
stride
;
});
}
bool
shape
::
packed
()
const
{
return
this
->
packed_
;
}
std
::
size_t
shape
::
element_space
()
const
{
// TODO: Get rid of intermediate vector
...
...
test/eval_test.cpp
View file @
717744ce
...
...
@@ -11,12 +11,12 @@ int main() {
[](
std
::
vector
<
rtg
::
argument
>
args
)
{
rtg
::
argument
result
;
if
(
args
.
size
()
!=
2
)
throw
"Wrong args"
;
if
(
args
[
0
].
s
!=
args
[
1
].
s
)
throw
"Wrong args"
;
if
(
args
[
0
].
s
.
lens
().
size
()
!=
1
)
throw
"Wrong args"
;
if
(
args
[
0
].
s
.
lens
().
front
()
!=
1
)
throw
"Wrong args"
;
if
(
args
[
0
].
get_shape
()
!=
args
[
1
].
get_shape
()
)
throw
"Wrong args"
;
if
(
args
[
0
].
get_shape
()
.
lens
().
size
()
!=
1
)
throw
"Wrong args"
;
if
(
args
[
0
].
get_shape
()
.
lens
().
front
()
!=
1
)
throw
"Wrong args"
;
args
[
0
].
visit
([
&
](
auto
x
)
{
args
[
1
].
visit
([
&
](
auto
y
)
{
args
[
0
].
visit
_at
([
&
](
auto
x
)
{
args
[
1
].
visit
_at
([
&
](
auto
y
)
{
result
=
rtg
::
literal
{
x
+
y
}.
get_argument
();
});
});
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment