Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
f0c7f958
Commit
f0c7f958
authored
Mar 27, 2018
by
Paul
Browse files
Setup literal class
parent
303a1b53
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
173 additions
and
12 deletions
+173
-12
include/rtg/literal.hpp
include/rtg/literal.hpp
+70
-0
include/rtg/shape.hpp
include/rtg/shape.hpp
+48
-10
src/shape.cpp
src/shape.cpp
+20
-2
test/literal_test.cpp
test/literal_test.cpp
+24
-0
test/shape_test.cpp
test/shape_test.cpp
+11
-0
No files found.
include/rtg/literal.hpp
View file @
f0c7f958
...
@@ -7,6 +7,76 @@ namespace rtg {
...
@@ -7,6 +7,76 @@ namespace rtg {
struct
literal
struct
literal
{
{
literal
()
:
buffer
(),
shape_
()
{}
template
<
class
T
>
literal
(
T
x
)
:
buffer
(
sizeof
(
T
),
0
),
shape_
(
shape
::
get_type
<
T
>
{})
{
static_assert
(
std
::
is_trivial
<
T
>
{},
"Literals can only be trivial types"
);
*
(
reinterpret_cast
<
T
*>
(
buffer
.
data
()))
=
x
;
}
template
<
class
T
>
literal
(
shape
s
,
const
std
::
vector
<
T
>&
x
)
:
buffer
(
s
.
bytes
(),
0
),
shape_
(
s
)
{
static_assert
(
std
::
is_trivial
<
T
>
{},
"Literals can only be trivial types"
);
std
::
copy
(
x
.
begin
(),
x
.
end
(),
reinterpret_cast
<
T
*>
(
buffer
.
data
()));
}
friend
bool
operator
==
(
const
literal
&
x
,
const
literal
&
y
)
{
bool
result
=
x
.
buffer
.
empty
()
&&
y
.
buffer
.
empty
();
if
(
not
result
&&
x
.
shape_
==
y
.
shape_
and
x
.
buffer
.
size
()
==
y
.
buffer
.
size
())
{
x
.
shape_
.
visit_type
([
&
](
auto
as
)
{
auto
space
=
x
.
shape_
.
bytes
()
/
sizeof
(
as
());
auto
*
xstart
=
&
as
.
from
(
x
.
buffer
.
data
());
auto
*
ystart
=
&
as
.
from
(
y
.
buffer
.
data
());
result
=
std
::
equal
(
xstart
,
xstart
+
space
,
ystart
,
ystart
+
space
);
});
}
return
result
;
}
friend
bool
operator
!=
(
const
literal
&
x
,
const
literal
&
y
)
{
return
!
(
x
==
y
);
}
template
<
class
Visitor
>
void
visit
(
Visitor
v
,
std
::
size_t
n
=
0
)
const
{
shape_
.
visit_type
([
&
](
auto
as
)
{
v
(
as
.
from
(
this
->
buffer
.
data
(),
n
));
});
}
bool
empty
()
const
{
return
this
->
buffer
.
empty
();
}
template
<
class
T
>
T
at
(
std
::
size_t
n
=
0
)
const
{
T
result
;
this
->
visit
([
&
](
auto
x
)
{
result
=
x
;
});
return
result
;
}
const
shape
&
get_shape
()
const
{
return
this
->
shape_
;
}
private:
std
::
vector
<
char
>
buffer
;
std
::
vector
<
char
>
buffer
;
shape
shape_
;
shape
shape_
;
};
};
...
...
include/rtg/shape.hpp
View file @
f0c7f958
...
@@ -4,15 +4,32 @@
...
@@ -4,15 +4,32 @@
#include <vector>
#include <vector>
#include <cassert>
#include <cassert>
namespace
rtg
{
namespace
rtg
{
struct
shape
struct
shape
{
{
// Add new types here
#define RTG_SHAPE_VISIT_TYPES(m) \
m(float_type, float) \
m(int_type, int) \
#define RTG_SHAPE_ENUM_TYPES(x, t) x,
enum
type_t
enum
type_t
{
{
float_type
,
RTG_SHAPE_VISIT_TYPES
(
RTG_SHAPE_ENUM_TYPES
)
int_type
};
};
#undef RTG_SHAPE_ENUM_TYPES
template
<
class
T
,
class
=
void
>
struct
get_type
;
#define RTG_SHAPE_GET_TYPE(x, t) \
template<class T> \
struct get_type<t, T> : std::integral_constant<type_t, x> \
{};
RTG_SHAPE_VISIT_TYPES
(
RTG_SHAPE_GET_TYPE
)
#undef RTG_SHAPE_GET_TYPE
shape
();
shape
();
shape
(
type_t
t
);
shape
(
type_t
t
);
...
@@ -21,11 +38,14 @@ struct shape
...
@@ -21,11 +38,14 @@ struct shape
type_t
type
()
const
;
type_t
type
()
const
;
const
std
::
vector
<
std
::
size_t
>
lens
()
const
;
const
std
::
vector
<
std
::
size_t
>
&
lens
()
const
;
const
std
::
vector
<
std
::
size_t
>
strides
()
const
;
const
std
::
vector
<
std
::
size_t
>
&
strides
()
const
;
std
::
size_t
elements
()
const
;
std
::
size_t
elements
()
const
;
std
::
size_t
bytes
()
const
;
std
::
size_t
bytes
()
const
;
std
::
size_t
index
(
std
::
initializer_list
<
std
::
size_t
>
l
)
const
;
std
::
size_t
index
(
const
std
::
vector
<
std
::
size_t
>&
l
)
const
;
friend
bool
operator
==
(
const
shape
&
x
,
const
shape
&
y
);
friend
bool
operator
==
(
const
shape
&
x
,
const
shape
&
y
);
friend
bool
operator
!=
(
const
shape
&
x
,
const
shape
&
y
);
friend
bool
operator
!=
(
const
shape
&
x
,
const
shape
&
y
);
...
@@ -40,12 +60,24 @@ struct shape
...
@@ -40,12 +60,24 @@ struct shape
return
T
(
u
);
return
T
(
u
);
}
}
template
<
class
U
>
T
*
operator
()(
U
*
u
)
const
{
return
static_cast
<
T
*>
(
u
);
}
template
<
class
U
>
const
T
*
operator
()(
const
U
*
u
)
const
{
return
static_cast
<
T
*>
(
u
);
}
T
operator
()()
const
T
operator
()()
const
{
{
return
{};
return
{};
}
}
std
::
size_t
size
(
std
::
size_t
n
=
0
)
const
std
::
size_t
size
(
std
::
size_t
n
=
1
)
const
{
{
return
sizeof
(
T
)
*
n
;
return
sizeof
(
T
)
*
n
;
}
}
...
@@ -55,6 +87,12 @@ struct shape
...
@@ -55,6 +87,12 @@ struct shape
{
{
return
*
(
reinterpret_cast
<
T
*>
(
buffer
)
+
n
);
return
*
(
reinterpret_cast
<
T
*>
(
buffer
)
+
n
);
}
}
template
<
class
U
>
const
T
&
from
(
const
U
*
buffer
,
std
::
size_t
n
=
0
)
const
{
return
*
(
reinterpret_cast
<
const
T
*>
(
buffer
)
+
n
);
}
};
};
template
<
class
Visitor
>
template
<
class
Visitor
>
...
@@ -62,12 +100,12 @@ struct shape
...
@@ -62,12 +100,12 @@ struct shape
{
{
switch
(
this
->
type_
)
switch
(
this
->
type_
)
{
{
case
float_type
:
#define RTG_SHAPE_VISITOR_CASE(x, t) \
v
(
as
<
float
>
());
case x: \
return
;
v(as<t>()); \
case
int_type
:
v
(
as
<
int
>
());
return;
return;
RTG_SHAPE_VISIT_TYPES
(
RTG_SHAPE_VISITOR_CASE
)
#undef RTG_SHAPE_VISITOR_CASE
}
}
assert
(
true
);
assert
(
true
);
}
}
...
...
src/shape.cpp
View file @
f0c7f958
...
@@ -6,6 +6,10 @@
...
@@ -6,6 +6,10 @@
namespace
rtg
{
namespace
rtg
{
shape
::
shape
()
:
type_
(
float_type
),
lens_
(),
strides_
()
{}
shape
::
shape
(
type_t
t
)
shape
::
shape
(
type_t
t
)
:
type_
(
t
),
lens_
({
1
}),
strides_
({
1
})
:
type_
(
t
),
lens_
({
1
}),
strides_
({
1
})
{}
{}
...
@@ -36,16 +40,17 @@ shape::type_t shape::type() const
...
@@ -36,16 +40,17 @@ shape::type_t shape::type() const
{
{
return
this
->
type_
;
return
this
->
type_
;
}
}
const
std
::
vector
<
std
::
size_t
>
shape
::
lens
()
const
const
std
::
vector
<
std
::
size_t
>
&
shape
::
lens
()
const
{
{
return
this
->
lens_
;
return
this
->
lens_
;
}
}
const
std
::
vector
<
std
::
size_t
>
shape
::
strides
()
const
const
std
::
vector
<
std
::
size_t
>
&
shape
::
strides
()
const
{
{
return
this
->
strides_
;
return
this
->
strides_
;
}
}
std
::
size_t
shape
::
elements
()
const
std
::
size_t
shape
::
elements
()
const
{
{
assert
(
this
->
lens
().
size
()
==
this
->
strides
().
size
());
return
std
::
accumulate
(
return
std
::
accumulate
(
this
->
lens
().
begin
(),
this
->
lens
().
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
this
->
lens
().
begin
(),
this
->
lens
().
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
}
}
...
@@ -55,9 +60,22 @@ std::size_t shape::bytes() const
...
@@ -55,9 +60,22 @@ std::size_t shape::bytes() const
this
->
visit_type
([
&
](
auto
as
)
{
n
=
as
.
size
();
});
this
->
visit_type
([
&
](
auto
as
)
{
n
=
as
.
size
();
});
return
n
*
this
->
element_space
();
return
n
*
this
->
element_space
();
}
}
std
::
size_t
shape
::
index
(
std
::
initializer_list
<
std
::
size_t
>
l
)
const
{
assert
(
l
.
size
()
<=
this
->
lens
().
size
());
assert
(
this
->
lens
().
size
()
==
this
->
strides
().
size
());
return
std
::
inner_product
(
l
.
begin
(),
l
.
end
(),
this
->
strides
().
begin
(),
std
::
size_t
{
0
});
}
std
::
size_t
shape
::
index
(
const
std
::
vector
<
std
::
size_t
>&
l
)
const
{
assert
(
l
.
size
()
<=
this
->
lens
().
size
());
assert
(
this
->
lens
().
size
()
==
this
->
strides
().
size
());
return
std
::
inner_product
(
l
.
begin
(),
l
.
end
(),
this
->
strides
().
begin
(),
std
::
size_t
{
0
});
}
std
::
size_t
shape
::
element_space
()
const
std
::
size_t
shape
::
element_space
()
const
{
{
// TODO: Get rid of intermediate vector
// TODO: Get rid of intermediate vector
assert
(
this
->
lens
().
size
()
==
this
->
strides
().
size
());
std
::
vector
<
std
::
size_t
>
max_indices
(
this
->
lens
().
size
());
std
::
vector
<
std
::
size_t
>
max_indices
(
this
->
lens
().
size
());
std
::
transform
(
this
->
lens
().
begin
(),
std
::
transform
(
this
->
lens
().
begin
(),
this
->
lens
().
end
(),
this
->
lens
().
end
(),
...
...
test/literal_test.cpp
0 → 100644
View file @
f0c7f958
#include <rtg/literal.hpp>
#include "test.hpp"
int
main
()
{
EXPECT
(
rtg
::
literal
{
1
}
==
rtg
::
literal
{
1
});
EXPECT
(
rtg
::
literal
{
1
}
!=
rtg
::
literal
{
2
});
EXPECT
(
rtg
::
literal
{}
==
rtg
::
literal
{});
EXPECT
(
rtg
::
literal
{}
!=
rtg
::
literal
{
2
});
rtg
::
literal
l1
{
1
};
rtg
::
literal
l2
=
l1
;
EXPECT
(
l1
==
l2
);
EXPECT
(
l1
.
at
<
int
>
(
0
)
==
1
);
EXPECT
(
!
l1
.
empty
());
EXPECT
(
!
l2
.
empty
());
rtg
::
literal
l3
{};
rtg
::
literal
l4
{};
EXPECT
(
l3
==
l4
);
EXPECT
(
l3
.
empty
());
EXPECT
(
l4
.
empty
());
}
test/shape_test.cpp
View file @
f0c7f958
...
@@ -10,6 +10,14 @@ void test_shape_assign()
...
@@ -10,6 +10,14 @@ void test_shape_assign()
EXPECT
(
!
(
s1
!=
s2
));
EXPECT
(
!
(
s1
!=
s2
));
}
}
void
test_shape_default
()
{
rtg
::
shape
s1
{};
rtg
::
shape
s2
{};
EXPECT
(
s1
==
s2
);
EXPECT
(
!
(
s1
!=
s2
));
}
void
test_shape4
()
void
test_shape4
()
{
{
rtg
::
shape
s
{
rtg
::
shape
::
float_type
,
{
100
,
32
,
8
,
8
}};
rtg
::
shape
s
{
rtg
::
shape
::
float_type
,
{
100
,
32
,
8
,
8
}};
...
@@ -22,10 +30,13 @@ void test_shape4()
...
@@ -22,10 +30,13 @@ void test_shape4()
EXPECT
(
s
.
strides
()[
1
]
==
s
.
lens
()[
2
]
*
s
.
strides
()[
2
]);
EXPECT
(
s
.
strides
()[
1
]
==
s
.
lens
()[
2
]
*
s
.
strides
()[
2
]);
EXPECT
(
s
.
strides
()[
2
]
==
s
.
lens
()[
3
]
*
s
.
strides
()[
3
]);
EXPECT
(
s
.
strides
()[
2
]
==
s
.
lens
()[
3
]
*
s
.
strides
()[
3
]);
EXPECT
(
s
.
strides
()[
3
]
==
1
);
EXPECT
(
s
.
strides
()[
3
]
==
1
);
EXPECT
(
s
.
elements
()
==
100
*
32
*
8
*
8
);
EXPECT
(
s
.
bytes
()
==
100
*
32
*
8
*
8
*
sizeof
(
float
));
}
}
int
main
()
{
int
main
()
{
test_shape_assign
();
test_shape_assign
();
test_shape_default
();
test_shape4
();
test_shape4
();
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment