Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
83aa9ac3
Commit
83aa9ac3
authored
Apr 23, 2018
by
Paul
Browse files
Enforce correct naming
parent
81778d3d
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
173 additions
and
163 deletions
+173
-163
.clang-tidy
.clang-tidy
+51
-49
include/rtg/argument.hpp
include/rtg/argument.hpp
+3
-3
include/rtg/literal.hpp
include/rtg/literal.hpp
+8
-8
include/rtg/operand.hpp
include/rtg/operand.hpp
+56
-48
include/rtg/shape.hpp
include/rtg/shape.hpp
+5
-5
include/rtg/tensor_view.hpp
include/rtg/tensor_view.hpp
+31
-31
src/shape.cpp
src/shape.cpp
+17
-17
test/test.hpp
test/test.hpp
+2
-2
No files found.
.clang-tidy
View file @
83aa9ac3
...
...
@@ -11,91 +11,93 @@ CheckOptions:
value: '10'
- key: readability-function-size.StatementThreshold
value: '150'
- key: readability-identifier-naming.Namespace
- key: readability-identifier-naming.Namespace
Case
value: lower_case
- key: readability-identifier-naming.InlineNamespace
- key: readability-identifier-naming.InlineNamespace
Case
value: lower_case
- key: readability-identifier-naming.EnumConstant
- key: readability-identifier-naming.EnumConstant
Case
value: lower_case
- key: readability-identifier-naming.ConstexprVariable
- key: readability-identifier-naming.ConstexprVariable
Case
value: lower_case
- key: readability-identifier-naming.ConstantMember
- key: readability-identifier-naming.ConstantMember
Case
value: lower_case
- key: readability-identifier-naming.PrivateMember
- key: readability-identifier-naming.PrivateMember
Case
value: lower_case
- key: readability-identifier-naming.ProtectedMember
- key: readability-identifier-naming.ProtectedMember
Case
value: lower_case
- key: readability-identifier-naming.PublicMember
- key: readability-identifier-naming.PublicMember
Case
value: lower_case
- key: readability-identifier-naming.Member
- key: readability-identifier-naming.Member
Case
value: lower_case
- key: readability-identifier-naming.ClassConstant
- key: readability-identifier-naming.ClassConstant
Case
value: lower_case
- key: readability-identifier-naming.ClassMember
- key: readability-identifier-naming.ClassMember
Case
value: lower_case
- key: readability-identifier-naming.GlobalConstant
- key: readability-identifier-naming.GlobalConstant
Case
value: lower_case
- key: readability-identifier-naming.GlobalVariable
- key: readability-identifier-naming.GlobalVariable
Case
value: lower_case
- key: readability-identifier-naming.LocalConstant
- key: readability-identifier-naming.LocalConstant
Case
value: lower_case
- key: readability-identifier-naming.LocalVariable
- key: readability-identifier-naming.LocalVariable
Case
value: lower_case
- key: readability-identifier-naming.StaticConstant
- key: readability-identifier-naming.StaticConstant
Case
value: lower_case
- key: readability-identifier-naming.StaticVariable
- key: readability-identifier-naming.StaticVariable
Case
value: lower_case
- key: readability-identifier-naming.Constant
- key: readability-identifier-naming.Constant
Case
value: lower_case
- key: readability-identifier-naming.Variable
- key: readability-identifier-naming.Variable
Case
value: lower_case
- key: readability-identifier-naming.ConstantParameter
- key: readability-identifier-naming.ConstantParameter
Case
value: lower_case
- key: readability-identifier-naming.ParameterPack
- key: readability-identifier-naming.ParameterPack
Case
value: lower_case
- key: readability-identifier-naming.Parameter
- key: readability-identifier-naming.Parameter
Case
value: lower_case
- key: readability-identifier-naming.AbstractClass
- key: readability-identifier-naming.AbstractClass
Case
value: lower_case
- key: readability-identifier-naming.Struct
- key: readability-identifier-naming.Struct
Case
value: lower_case
- key: readability-identifier-naming.Class
- key: readability-identifier-naming.Class
Case
value: lower_case
- key: readability-identifier-naming.Union
- key: readability-identifier-naming.Union
Case
value: lower_case
- key: readability-identifier-naming.Enum
- key: readability-identifier-naming.Enum
Case
value: lower_case
- key: readability-identifier-naming.GlobalFunction
- key: readability-identifier-naming.GlobalFunction
Case
value: lower_case
- key: readability-identifier-naming.ConstexprFunction
- key: readability-identifier-naming.ConstexprFunction
Case
value: lower_case
- key: readability-identifier-naming.Function
- key: readability-identifier-naming.Function
Case
value: lower_case
- key: readability-identifier-naming.ConstexprMethod
- key: readability-identifier-naming.ConstexprMethod
Case
value: lower_case
- key: readability-identifier-naming.VirtualMethod
- key: readability-identifier-naming.VirtualMethod
Case
value: lower_case
- key: readability-identifier-naming.ClassMethod
- key: readability-identifier-naming.ClassMethod
Case
value: lower_case
- key: readability-identifier-naming.PrivateMethod
- key: readability-identifier-naming.PrivateMethod
Case
value: lower_case
- key: readability-identifier-naming.ProtectedMethod
- key: readability-identifier-naming.ProtectedMethod
Case
value: lower_case
- key: readability-identifier-naming.PublicMethod
- key: readability-identifier-naming.PublicMethod
Case
value: lower_case
- key: readability-identifier-naming.Method
- key: readability-identifier-naming.Method
Case
value: lower_case
- key: readability-identifier-naming.Typedef
- key: readability-identifier-naming.Typedef
Case
value: lower_case
- key: readability-identifier-naming.TypeTemplateParameter
value: lower_case
- key: readability-identifier-naming.ValueTemplateParameter
value: lower_case
- key: readability-identifier-naming.TemplateTemplateParameter
value: lower_case
- key: readability-identifier-naming.TemplateParameter
value: lower_case
- key: readability-identifier-naming.TypeAlias
value: lower_case
- key: readability-identifier-naming.MacroDefinition
- key: readability-identifier-naming.TypeTemplateParameterCase
value: CamelCase
- key: readability-identifier-naming.ValueTemplateParameterCase
value: CamelCase
- key: readability-identifier-naming.TemplateTemplateParameterCase
value: CamelCase
- key: readability-identifier-naming.TemplateParameterCase
value: CamelCase
- key: readability-identifier-naming.TypeAliasCase
value: lower_case
# - key: readability-identifier-naming.MacroDefinitionCase
# value: UPPER_CASE
# - key: readability-identifier-naming.MacroDefinitionPrefix
# value: RTG_
include/rtg/argument.hpp
View file @
83aa9ac3
...
...
@@ -11,16 +11,16 @@ struct argument : raw_data<argument>
{
argument
()
{}
argument
(
shape
s
,
std
::
function
<
char
*
()
>
d
)
:
data
(
d
),
shape
_
(
s
)
{}
argument
(
shape
s
,
std
::
function
<
char
*
()
>
d
)
:
data
(
d
),
m_
shape
(
s
)
{}
std
::
function
<
char
*
()
>
data
;
bool
empty
()
const
{
return
not
data
;
}
const
shape
&
get_shape
()
const
{
return
this
->
shape
_
;
}
const
shape
&
get_shape
()
const
{
return
this
->
m_
shape
;
}
private:
shape
shape
_
;
shape
m_
shape
;
};
}
// namespace rtg
...
...
include/rtg/literal.hpp
View file @
83aa9ac3
...
...
@@ -13,14 +13,14 @@ struct literal : raw_data<literal>
literal
()
{}
template
<
class
T
>
literal
(
T
x
)
:
buffer
(
sizeof
(
T
),
0
),
shape
_
(
shape
::
get_type
<
T
>
{})
literal
(
T
x
)
:
buffer
(
sizeof
(
T
),
0
),
m_
shape
(
shape
::
get_type
<
T
>
{})
{
static_assert
(
std
::
is_trivial
<
T
>
{},
"Literals can only be trivial types"
);
*
(
reinterpret_cast
<
T
*>
(
buffer
.
data
()))
=
x
;
}
template
<
class
T
>
literal
(
shape
s
,
const
std
::
vector
<
T
>&
x
)
:
buffer
(
s
.
bytes
(),
0
),
shape
_
(
s
)
literal
(
shape
s
,
const
std
::
vector
<
T
>&
x
)
:
buffer
(
s
.
bytes
(),
0
),
m_
shape
(
s
)
{
assert
(
s
.
packed
());
static_assert
(
std
::
is_trivial
<
T
>
{},
"Literals can only be trivial types"
);
...
...
@@ -28,7 +28,7 @@ struct literal : raw_data<literal>
}
template
<
class
T
>
literal
(
shape
s
,
const
std
::
initializer_list
<
T
>&
x
)
:
buffer
(
s
.
bytes
(),
0
),
shape
_
(
s
)
literal
(
shape
s
,
const
std
::
initializer_list
<
T
>&
x
)
:
buffer
(
s
.
bytes
(),
0
),
m_
shape
(
s
)
{
assert
(
s
.
packed
());
static_assert
(
std
::
is_trivial
<
T
>
{},
"Literals can only be trivial types"
);
...
...
@@ -36,29 +36,29 @@ struct literal : raw_data<literal>
}
template
<
class
Iterator
>
literal
(
shape
s
,
Iterator
start
,
Iterator
end
)
:
buffer
(
s
.
bytes
(),
0
),
shape
_
(
s
)
literal
(
shape
s
,
Iterator
start
,
Iterator
end
)
:
buffer
(
s
.
bytes
(),
0
),
m_
shape
(
s
)
{
assert
(
s
.
packed
());
s
.
visit_type
([
&
](
auto
as
)
{
std
::
copy
(
start
,
end
,
as
.
from
(
buffer
.
data
()));
});
}
literal
(
shape
s
,
const
char
*
x
)
:
buffer
(
x
,
x
+
s
.
bytes
()),
shape
_
(
s
)
{}
literal
(
shape
s
,
const
char
*
x
)
:
buffer
(
x
,
x
+
s
.
bytes
()),
m_
shape
(
s
)
{}
bool
empty
()
const
{
return
this
->
buffer
.
empty
();
}
const
char
*
data
()
const
{
return
this
->
buffer
.
data
();
}
const
shape
&
get_shape
()
const
{
return
this
->
shape
_
;
}
const
shape
&
get_shape
()
const
{
return
this
->
m_
shape
;
}
argument
get_argument
()
const
{
auto
b
=
buffer
;
return
{
shape
_
,
[
b
]()
mutable
{
return
b
.
data
();
}};
return
{
m_
shape
,
[
b
]()
mutable
{
return
b
.
data
();
}};
}
private:
std
::
vector
<
char
>
buffer
;
shape
shape
_
;
shape
m_
shape
;
};
}
// namespace rtg
...
...
include/rtg/operand.hpp
View file @
83aa9ac3
...
...
@@ -28,112 +28,120 @@ struct operand
// Constructors
operand
()
=
default
;
template
<
typename
TypeErased_T_
>
operand
(
TypeErased_T_
value
)
:
handle_mem_var_
(
std
::
make_shared
<
handle_type_
<
typename
std
::
remove_reference
<
TypeErased_T_
>::
type
>>
(
std
::
forward
<
TypeErased_T_
>
(
value
)))
template
<
typename
PrivateDetailTypeErasedT
>
operand
(
PrivateDetailTypeErasedT
value
)
:
private_detail_te_handle_mem_var
(
std
::
make_shared
<
private_detail_te_handle_type
<
typename
std
::
remove_reference
<
PrivateDetailTypeErasedT
>::
type
>>
(
std
::
forward
<
PrivateDetailTypeErasedT
>
(
value
)))
{
}
// Assignment
template
<
typename
TypeErased
_T_
>
operand
&
operator
=
(
TypeErased
_T_
value
)
template
<
typename
PrivateDetail
TypeErased
T
>
operand
&
operator
=
(
PrivateDetail
TypeErased
T
value
)
{
if
(
handle_mem_var_
.
unique
())
*
handle_mem_var_
=
std
::
forward
<
TypeErased_T_
>
(
value
);
else
if
(
!
handle_mem_var_
)
handle_mem_var_
=
std
::
make_shared
<
TypeErased_T_
>
(
std
::
forward
<
TypeErased_T_
>
(
value
));
if
(
private_detail_te_handle_mem_var
.
unique
())
*
private_detail_te_handle_mem_var
=
std
::
forward
<
PrivateDetailTypeErasedT
>
(
value
);
else
if
(
!
private_detail_te_handle_mem_var
)
private_detail_te_handle_mem_var
=
std
::
make_shared
<
PrivateDetailTypeErasedT
>
(
std
::
forward
<
PrivateDetailTypeErasedT
>
(
value
));
return
*
this
;
}
std
::
string
name
()
const
{
assert
(
handle_mem_var
_
);
return
get_handle
_
().
name
();
assert
(
private_detail_te_
handle_mem_var
);
return
private_detail_te_
get_handle
().
name
();
}
shape
compute_shape
(
std
::
vector
<
shape
>
input
)
const
{
assert
(
handle_mem_var
_
);
return
get_handle
_
().
compute_shape
(
std
::
move
(
input
));
assert
(
private_detail_te_
handle_mem_var
);
return
private_detail_te_
get_handle
().
compute_shape
(
std
::
move
(
input
));
}
argument
compute
(
std
::
vector
<
argument
>
input
)
const
{
assert
(
handle_mem_var
_
);
return
get_handle
_
().
compute
(
std
::
move
(
input
));
assert
(
private_detail_te_
handle_mem_var
);
return
private_detail_te_
get_handle
().
compute
(
std
::
move
(
input
));
}
private:
struct
handle_base_type
_
struct
private_detail_te_
handle_base_type
{
virtual
~
handle_base_type
_
()
{}
virtual
std
::
shared_ptr
<
handle_base_type
_
>
clone
()
const
=
0
;
virtual
~
private_detail_te_
handle_base_type
()
{}
virtual
std
::
shared_ptr
<
private_detail_te_
handle_base_type
>
clone
()
const
=
0
;
virtual
std
::
string
name
()
const
=
0
;
virtual
shape
compute_shape
(
std
::
vector
<
shape
>
input
)
const
=
0
;
virtual
argument
compute
(
std
::
vector
<
argument
>
input
)
const
=
0
;
};
template
<
typename
TypeErased
_T_
>
struct
handle_type
_
:
handle_base_type
_
template
<
typename
PrivateDetail
TypeErased
T
>
struct
private_detail_te_
handle_type
:
private_detail_te_
handle_base_type
{
template
<
typename
TypeErased_U_
=
TypeErased_T_
>
handle_type_
(
TypeErased_T_
value
,
typename
std
::
enable_if
<
std
::
is_reference
<
TypeErased_U_
>::
value
>::
type
*
=
nullptr
)
:
value_
(
value
)
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
private_detail_te_handle_type
(
PrivateDetailTypeErasedT
value
,
typename
std
::
enable_if
<
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
>::
type
*
=
nullptr
)
:
private_detail_te_value
(
value
)
{
}
template
<
typename
TypeErased_U_
=
TypeErased_T_
>
handle_type_
(
TypeErased_T_
value
,
typename
std
::
enable_if
<!
std
::
is_reference
<
TypeErased_U_
>::
value
,
int
>::
type
*
=
nullptr
)
noexcept
:
value_
(
std
::
move
(
value
))
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
private_detail_te_handle_type
(
PrivateDetailTypeErasedT
value
,
typename
std
::
enable_if
<!
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
int
>::
type
*
=
nullptr
)
noexcept
:
private_detail_te_value
(
std
::
move
(
value
))
{
}
std
::
shared_ptr
<
handle_base_type
_
>
clone
()
const
override
std
::
shared_ptr
<
private_detail_te_
handle_base_type
>
clone
()
const
override
{
return
std
::
make_shared
<
handle_type
_
>
(
value
_
);
return
std
::
make_shared
<
private_detail_te_
handle_type
>
(
private_detail_te_
value
);
}
std
::
string
name
()
const
override
{
return
value
_
.
name
();
}
std
::
string
name
()
const
override
{
return
private_detail_te_
value
.
name
();
}
shape
compute_shape
(
std
::
vector
<
shape
>
input
)
const
override
{
return
value
_
.
compute_shape
(
std
::
move
(
input
));
return
private_detail_te_
value
.
compute_shape
(
std
::
move
(
input
));
}
argument
compute
(
std
::
vector
<
argument
>
input
)
const
override
{
return
value
_
.
compute
(
std
::
move
(
input
));
return
private_detail_te_
value
.
compute
(
std
::
move
(
input
));
}
TypeErased_T_
value_
;
PrivateDetailTypeErasedT
private_detail_te_value
;
};
template
<
typename
TypeErased_T_
>
struct
handle_type_
<
std
::
reference_wrapper
<
TypeErased_T_
>>
:
handle_type_
<
TypeErased_T_
&>
template
<
typename
PrivateDetailTypeErasedT
>
struct
private_detail_te_handle_type
<
std
::
reference_wrapper
<
PrivateDetailTypeErasedT
>>
:
private_detail_te_handle_type
<
PrivateDetailTypeErasedT
&>
{
handle_type
_
(
std
::
reference_wrapper
<
TypeErased
_T_
>
ref
)
:
handle_type_
<
TypeErased
_T_
&>
(
ref
.
get
())
private_detail_te_
handle_type
(
std
::
reference_wrapper
<
PrivateDetail
TypeErased
T
>
ref
)
:
private_detail_te_handle_type
<
PrivateDetail
TypeErased
T
&>
(
ref
.
get
())
{
}
};
const
handle_base_type_
&
get_handle_
()
const
{
return
*
handle_mem_var_
;
}
const
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
const
{
return
*
private_detail_te_handle_mem_var
;
}
handle_base_type
_
&
get_handle
_
()
private_detail_te_
handle_base_type
&
private_detail_te_
get_handle
()
{
if
(
!
handle_mem_var
_
.
unique
())
handle_mem_var
_
=
handle_mem_var
_
->
clone
();
return
*
handle_mem_var
_
;
if
(
!
private_detail_te_
handle_mem_var
.
unique
())
private_detail_te_
handle_mem_var
=
private_detail_te_
handle_mem_var
->
clone
();
return
*
private_detail_te_
handle_mem_var
;
}
std
::
shared_ptr
<
handle_base_type
_
>
handle_mem_var
_
;
std
::
shared_ptr
<
private_detail_te_
handle_base_type
>
private_detail_te_
handle_mem_var
;
};
}
// namespace rtg
...
...
include/rtg/shape.hpp
View file @
83aa9ac3
...
...
@@ -110,7 +110,7 @@ struct shape
template
<
class
Visitor
>
void
visit_type
(
Visitor
v
)
const
{
switch
(
this
->
type
_
)
switch
(
this
->
m_
type
)
{
#define RTG_SHAPE_VISITOR_CASE(x, t) \
case x: v(as<t>()); return;
...
...
@@ -121,10 +121,10 @@ struct shape
}
private:
type_t
type
_
;
std
::
vector
<
std
::
size_t
>
lens
_
;
std
::
vector
<
std
::
size_t
>
strides
_
;
bool
packed
_
;
type_t
m_
type
;
std
::
vector
<
std
::
size_t
>
m_
lens
;
std
::
vector
<
std
::
size_t
>
m_
strides
;
bool
m_
packed
;
void
calculate_strides
();
std
::
size_t
element_space
()
const
;
...
...
include/rtg/tensor_view.hpp
View file @
83aa9ac3
...
...
@@ -11,103 +11,103 @@ namespace rtg {
template
<
class
T
>
struct
tensor_view
{
tensor_view
()
:
data
_
(
nullptr
)
{}
tensor_view
(
shape
s
,
T
*
d
)
:
data
_
(
d
),
shape
_
(
s
)
{}
tensor_view
()
:
m_
data
(
nullptr
)
{}
tensor_view
(
shape
s
,
T
*
d
)
:
m_
data
(
d
),
m_
shape
(
s
)
{}
const
shape
&
get_shape
()
const
{
return
this
->
shape
_
;
}
const
shape
&
get_shape
()
const
{
return
this
->
m_
shape
;
}
bool
empty
()
const
{
return
data
_
==
nullptr
||
shape
_
.
lens
().
empty
();
}
bool
empty
()
const
{
return
m_
data
==
nullptr
||
m_
shape
.
lens
().
empty
();
}
std
::
size_t
size
()
const
{
return
shape
_
.
elements
();
}
std
::
size_t
size
()
const
{
return
m_
shape
.
elements
();
}
T
*
data
()
{
return
this
->
data
_
;
}
T
*
data
()
{
return
this
->
m_
data
;
}
const
T
*
data
()
const
{
return
this
->
data
_
;
}
const
T
*
data
()
const
{
return
this
->
m_
data
;
}
template
<
class
...
Ts
>
const
T
&
operator
()(
Ts
...
xs
)
const
{
return
data
_
[
shape
_
.
index
({
xs
...})];
return
m_
data
[
m_
shape
.
index
({
xs
...})];
}
template
<
class
...
Ts
>
T
&
operator
()(
Ts
...
xs
)
{
return
data
_
[
shape
_
.
index
({
xs
...})];
return
m_
data
[
m_
shape
.
index
({
xs
...})];
}
T
&
operator
[](
std
::
size_t
i
)
{
assert
(
!
this
->
empty
()
&&
i
<
this
->
size
());
return
data
_
[
shape
_
.
index
(
i
)];
return
m_
data
[
m_
shape
.
index
(
i
)];
}
const
T
&
operator
[](
std
::
size_t
i
)
const
{
assert
(
!
this
->
empty
()
&&
i
<
this
->
size
());
return
data
_
[
shape
_
.
index
(
i
)];
return
m_
data
[
m_
shape
.
index
(
i
)];
}
T
&
front
()
{
assert
(
!
this
->
empty
());
return
data
_
[
0
];
return
m_
data
[
0
];
}
const
T
&
front
()
const
{
assert
(
!
this
->
empty
());
return
data
_
[
0
];
return
m_
data
[
0
];
}
T
&
back
()
{
assert
(
!
this
->
empty
());
return
data
_
[
shape
_
.
index
(
this
->
size
()
-
1
)];
return
m_
data
[
m_
shape
.
index
(
this
->
size
()
-
1
)];
}
const
T
&
back
()
const
{
assert
(
!
this
->
empty
());
return
data
_
[
shape
_
.
index
(
this
->
size
()
-
1
)];
return
m_
data
[
m_
shape
.
index
(
this
->
size
()
-
1
)];
}
// TODO: Add iterators so it can handle nonpacked tensors
T
*
begin
()
{
assert
(
this
->
shape
_
.
packed
());
return
data
_
;
assert
(
this
->
m_
shape
.
packed
());
return
m_
data
;
}
T
*
end
()
{
assert
(
this
->
shape
_
.
packed
());
assert
(
this
->
m_
shape
.
packed
());
if
(
this
->
empty
())
return
data
_
;
return
m_
data
;
else
return
data
_
+
this
->
size
();
return
m_
data
+
this
->
size
();
}
const
T
*
begin
()
const
{
assert
(
this
->
shape
_
.
packed
());
return
data
_
;
assert
(
this
->
m_
shape
.
packed
());
return
m_
data
;
}
const
T
*
end
()
const
{
assert
(
this
->
shape
_
.
packed
());
assert
(
this
->
m_
shape
.
packed
());
if
(
this
->
empty
())
return
data
_
;
return
m_
data
;
else
return
data
_
+
this
->
size
();
return
m_
data
+
this
->
size
();
}
friend
bool
operator
==
(
const
tensor_view
<
T
>&
x
,
const
tensor_view
<
T
>&
y
)
{
if
(
x
.
shape
_
==
y
.
shape
_
)
if
(
x
.
m_
shape
==
y
.
m_
shape
)
{
for
(
std
::
size_t
i
=
0
;
i
<
x
.
shape
_
.
elements
();
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
x
.
m_
shape
.
elements
();
i
++
)
{
if
(
!
float_equal
(
x
[
i
],
y
[
i
]))
return
false
;
...
...
@@ -124,17 +124,17 @@ struct tensor_view
if
(
!
x
.
empty
())
{
os
<<
x
.
front
();
for
(
std
::
size_t
i
=
1
;
i
<
x
.
shape
_
.
elements
();
i
++
)
for
(
std
::
size_t
i
=
1
;
i
<
x
.
m_
shape
.
elements
();
i
++
)
{
os
<<
", "
<<
x
.
data
_
[
x
.
shape
_
.
index
(
i
)];
os
<<
", "
<<
x
.
m_
data
[
x
.
m_
shape
.
index
(
i
)];
}
}
return
os
;
}
private:
T
*
data
_
;
shape
shape
_
;
T
*
m_
data
;
shape
m_
shape
;
};
template
<
class
T
>
...
...
src/shape.cpp
View file @
83aa9ac3
...
...
@@ -7,35 +7,35 @@
namespace
rtg
{
shape
::
shape
()
:
type
_
(
float_type
),
packed
_
(
false
)
{}
shape
::
shape
()
:
m_
type
(
float_type
),
m_
packed
(
false
)
{}
shape
::
shape
(
type_t
t
)
:
type
_
(
t
),
lens
_
({
1
}),
strides
_
({
1
}),
packed
_
(
true
)
{}
shape
::
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
)
:
type
_
(
t
),
lens
_
(
std
::
move
(
l
)),
packed
_
(
true
)
shape
::
shape
(
type_t
t
)
:
m_
type
(
t
),
m_
lens
({
1
}),
m_
strides
({
1
}),
m_
packed
(
true
)
{}
shape
::
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
)
:
m_
type
(
t
),
m_
lens
(
std
::
move
(
l
)),
m_
packed
(
true
)
{
this
->
calculate_strides
();
assert
(
lens
_
.
size
()
==
strides
_
.
size
());
assert
(
m_
lens
.
size
()
==
m_
strides
.
size
());
}
shape
::
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
,
std
::
vector
<
std
::
size_t
>
s
)
:
type
_
(
t
),
lens
_
(
std
::
move
(
l
)),
strides
_
(
std
::
move
(
s
))
:
m_
type
(
t
),
m_
lens
(
std
::
move
(
l
)),
m_
strides
(
std
::
move
(
s
))
{
assert
(
lens
_
.
size
()
==
strides
_
.
size
());
packed
_
=
this
->
elements
()
==
this
->
element_space
();
assert
(
m_
lens
.
size
()
==
m_
strides
.
size
());
m_
packed
=
this
->
elements
()
==
this
->
element_space
();
}
void
shape
::
calculate_strides
()
{
strides
_
.
clear
();
strides
_
.
resize
(
lens
_
.
size
(),
0
);
if
(
strides
_
.
empty
())
m_
strides
.
clear
();
m_
strides
.
resize
(
m_
lens
.
size
(),
0
);
if
(
m_
strides
.
empty
())
return
;
strides
_
.
back
()
=
1
;
m_
strides
.
back
()
=
1
;
std
::
partial_sum
(
lens
_
.
rbegin
(),
lens
_
.
rend
()
-
1
,
strides
_
.
rbegin
()
+
1
,
std
::
multiplies
<
std
::
size_t
>
());
m_
lens
.
rbegin
(),
m_
lens
.
rend
()
-
1
,
m_
strides
.
rbegin
()
+
1
,
std
::
multiplies
<
std
::
size_t
>
());
}
shape
::
type_t
shape
::
type
()
const
{
return
this
->
type
_
;
}
const
std
::
vector
<
std
::
size_t
>&
shape
::
lens
()
const
{
return
this
->
lens
_
;
}
const
std
::
vector
<
std
::
size_t
>&
shape
::
strides
()
const
{
return
this
->
strides
_
;
}
shape
::
type_t
shape
::
type
()
const
{
return
this
->
m_
type
;
}
const
std
::
vector
<
std
::
size_t
>&
shape
::
lens
()
const
{
return
this
->
m_
lens
;
}
const
std
::
vector
<
std
::
size_t
>&
shape
::
strides
()
const
{
return
this
->
m_
strides
;
}
std
::
size_t
shape
::
elements
()
const
{
assert
(
this
->
lens
().
size
()
==
this
->
strides
().
size
());
...
...
@@ -71,7 +71,7 @@ std::size_t shape::index(std::size_t i) const
std
::
plus
<
std
::
size_t
>
{},
[
&
](
std
::
size_t
len
,
std
::
size_t
stride
)
{
return
((
i
/
stride
)
%
len
)
*
stride
;
});
}
bool
shape
::
packed
()
const
{
return
this
->
packed
_
;
}
bool
shape
::
packed
()
const
{
return
this
->
m_
packed
;
}
std
::
size_t
shape
::
element_space
()
const
{
// TODO: Get rid of intermediate vector
...
...
@@ -89,7 +89,7 @@ std::size_t shape::element_space() const
std
::
string
shape
::
type_string
()
const
{
switch
(
this
->
type
_
)
switch
(
this
->
m_
type
)
{
#define RTG_SHAPE_TYPE_STRING_CASE(x, t) \
case x: return #x;
...
...
test/test.hpp
View file @
83aa9ac3
...
...
@@ -4,8 +4,8 @@
#include <cstdlib>
#include <iostream>
#ifndef GUARD_TEST_TEST_HPP
_
#define GUARD_TEST_TEST_HPP
_
#ifndef GUARD_TEST_TEST_HPP
#define GUARD_TEST_TEST_HPP
inline
void
failed
(
const
char
*
msg
,
const
char
*
file
,
int
line
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment