Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
83aa9ac3
Commit
83aa9ac3
authored
Apr 23, 2018
by
Paul
Browse files
Enforce correct naming
parent
81778d3d
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
173 additions
and
163 deletions
+173
-163
.clang-tidy
.clang-tidy
+51
-49
include/rtg/argument.hpp
include/rtg/argument.hpp
+3
-3
include/rtg/literal.hpp
include/rtg/literal.hpp
+8
-8
include/rtg/operand.hpp
include/rtg/operand.hpp
+56
-48
include/rtg/shape.hpp
include/rtg/shape.hpp
+5
-5
include/rtg/tensor_view.hpp
include/rtg/tensor_view.hpp
+31
-31
src/shape.cpp
src/shape.cpp
+17
-17
test/test.hpp
test/test.hpp
+2
-2
No files found.
.clang-tidy
View file @
83aa9ac3
...
@@ -11,91 +11,93 @@ CheckOptions:
...
@@ -11,91 +11,93 @@ CheckOptions:
value: '10'
value: '10'
- key: readability-function-size.StatementThreshold
- key: readability-function-size.StatementThreshold
value: '150'
value: '150'
- key: readability-identifier-naming.Namespace
- key: readability-identifier-naming.Namespace
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.InlineNamespace
- key: readability-identifier-naming.InlineNamespace
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.EnumConstant
- key: readability-identifier-naming.EnumConstant
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.ConstexprVariable
- key: readability-identifier-naming.ConstexprVariable
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.ConstantMember
- key: readability-identifier-naming.ConstantMember
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.PrivateMember
- key: readability-identifier-naming.PrivateMember
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.ProtectedMember
- key: readability-identifier-naming.ProtectedMember
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.PublicMember
- key: readability-identifier-naming.PublicMember
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.Member
- key: readability-identifier-naming.Member
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.ClassConstant
- key: readability-identifier-naming.ClassConstant
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.ClassMember
- key: readability-identifier-naming.ClassMember
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.GlobalConstant
- key: readability-identifier-naming.GlobalConstant
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.GlobalVariable
- key: readability-identifier-naming.GlobalVariable
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.LocalConstant
- key: readability-identifier-naming.LocalConstant
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.LocalVariable
- key: readability-identifier-naming.LocalVariable
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.StaticConstant
- key: readability-identifier-naming.StaticConstant
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.StaticVariable
- key: readability-identifier-naming.StaticVariable
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.Constant
- key: readability-identifier-naming.Constant
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.Variable
- key: readability-identifier-naming.Variable
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.ConstantParameter
- key: readability-identifier-naming.ConstantParameter
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.ParameterPack
- key: readability-identifier-naming.ParameterPack
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.Parameter
- key: readability-identifier-naming.Parameter
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.AbstractClass
- key: readability-identifier-naming.AbstractClass
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.Struct
- key: readability-identifier-naming.Struct
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.Class
- key: readability-identifier-naming.Class
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.Union
- key: readability-identifier-naming.Union
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.Enum
- key: readability-identifier-naming.Enum
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.GlobalFunction
- key: readability-identifier-naming.GlobalFunction
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.ConstexprFunction
- key: readability-identifier-naming.ConstexprFunction
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.Function
- key: readability-identifier-naming.Function
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.ConstexprMethod
- key: readability-identifier-naming.ConstexprMethod
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.VirtualMethod
- key: readability-identifier-naming.VirtualMethod
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.ClassMethod
- key: readability-identifier-naming.ClassMethod
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.PrivateMethod
- key: readability-identifier-naming.PrivateMethod
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.ProtectedMethod
- key: readability-identifier-naming.ProtectedMethod
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.PublicMethod
- key: readability-identifier-naming.PublicMethod
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.Method
- key: readability-identifier-naming.Method
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.Typedef
- key: readability-identifier-naming.Typedef
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.TypeTemplateParameter
- key: readability-identifier-naming.TypeTemplateParameterCase
value: lower_case
value: CamelCase
- key: readability-identifier-naming.ValueTemplateParameter
- key: readability-identifier-naming.ValueTemplateParameterCase
value: lower_case
value: CamelCase
- key: readability-identifier-naming.TemplateTemplateParameter
- key: readability-identifier-naming.TemplateTemplateParameterCase
value: lower_case
value: CamelCase
- key: readability-identifier-naming.TemplateParameter
- key: readability-identifier-naming.TemplateParameterCase
value: lower_case
value: CamelCase
- key: readability-identifier-naming.TypeAlias
- key: readability-identifier-naming.TypeAliasCase
value: lower_case
- key: readability-identifier-naming.MacroDefinition
value: lower_case
value: lower_case
# - key: readability-identifier-naming.MacroDefinitionCase
# value: UPPER_CASE
# - key: readability-identifier-naming.MacroDefinitionPrefix
# value: RTG_
include/rtg/argument.hpp
View file @
83aa9ac3
...
@@ -11,16 +11,16 @@ struct argument : raw_data<argument>
...
@@ -11,16 +11,16 @@ struct argument : raw_data<argument>
{
{
argument
()
{}
argument
()
{}
argument
(
shape
s
,
std
::
function
<
char
*
()
>
d
)
:
data
(
d
),
shape
_
(
s
)
{}
argument
(
shape
s
,
std
::
function
<
char
*
()
>
d
)
:
data
(
d
),
m_
shape
(
s
)
{}
std
::
function
<
char
*
()
>
data
;
std
::
function
<
char
*
()
>
data
;
bool
empty
()
const
{
return
not
data
;
}
bool
empty
()
const
{
return
not
data
;
}
const
shape
&
get_shape
()
const
{
return
this
->
shape
_
;
}
const
shape
&
get_shape
()
const
{
return
this
->
m_
shape
;
}
private:
private:
shape
shape
_
;
shape
m_
shape
;
};
};
}
// namespace rtg
}
// namespace rtg
...
...
include/rtg/literal.hpp
View file @
83aa9ac3
...
@@ -13,14 +13,14 @@ struct literal : raw_data<literal>
...
@@ -13,14 +13,14 @@ struct literal : raw_data<literal>
literal
()
{}
literal
()
{}
template
<
class
T
>
template
<
class
T
>
literal
(
T
x
)
:
buffer
(
sizeof
(
T
),
0
),
shape
_
(
shape
::
get_type
<
T
>
{})
literal
(
T
x
)
:
buffer
(
sizeof
(
T
),
0
),
m_
shape
(
shape
::
get_type
<
T
>
{})
{
{
static_assert
(
std
::
is_trivial
<
T
>
{},
"Literals can only be trivial types"
);
static_assert
(
std
::
is_trivial
<
T
>
{},
"Literals can only be trivial types"
);
*
(
reinterpret_cast
<
T
*>
(
buffer
.
data
()))
=
x
;
*
(
reinterpret_cast
<
T
*>
(
buffer
.
data
()))
=
x
;
}
}
template
<
class
T
>
template
<
class
T
>
literal
(
shape
s
,
const
std
::
vector
<
T
>&
x
)
:
buffer
(
s
.
bytes
(),
0
),
shape
_
(
s
)
literal
(
shape
s
,
const
std
::
vector
<
T
>&
x
)
:
buffer
(
s
.
bytes
(),
0
),
m_
shape
(
s
)
{
{
assert
(
s
.
packed
());
assert
(
s
.
packed
());
static_assert
(
std
::
is_trivial
<
T
>
{},
"Literals can only be trivial types"
);
static_assert
(
std
::
is_trivial
<
T
>
{},
"Literals can only be trivial types"
);
...
@@ -28,7 +28,7 @@ struct literal : raw_data<literal>
...
@@ -28,7 +28,7 @@ struct literal : raw_data<literal>
}
}
template
<
class
T
>
template
<
class
T
>
literal
(
shape
s
,
const
std
::
initializer_list
<
T
>&
x
)
:
buffer
(
s
.
bytes
(),
0
),
shape
_
(
s
)
literal
(
shape
s
,
const
std
::
initializer_list
<
T
>&
x
)
:
buffer
(
s
.
bytes
(),
0
),
m_
shape
(
s
)
{
{
assert
(
s
.
packed
());
assert
(
s
.
packed
());
static_assert
(
std
::
is_trivial
<
T
>
{},
"Literals can only be trivial types"
);
static_assert
(
std
::
is_trivial
<
T
>
{},
"Literals can only be trivial types"
);
...
@@ -36,29 +36,29 @@ struct literal : raw_data<literal>
...
@@ -36,29 +36,29 @@ struct literal : raw_data<literal>
}
}
template
<
class
Iterator
>
template
<
class
Iterator
>
literal
(
shape
s
,
Iterator
start
,
Iterator
end
)
:
buffer
(
s
.
bytes
(),
0
),
shape
_
(
s
)
literal
(
shape
s
,
Iterator
start
,
Iterator
end
)
:
buffer
(
s
.
bytes
(),
0
),
m_
shape
(
s
)
{
{
assert
(
s
.
packed
());
assert
(
s
.
packed
());
s
.
visit_type
([
&
](
auto
as
)
{
std
::
copy
(
start
,
end
,
as
.
from
(
buffer
.
data
()));
});
s
.
visit_type
([
&
](
auto
as
)
{
std
::
copy
(
start
,
end
,
as
.
from
(
buffer
.
data
()));
});
}
}
literal
(
shape
s
,
const
char
*
x
)
:
buffer
(
x
,
x
+
s
.
bytes
()),
shape
_
(
s
)
{}
literal
(
shape
s
,
const
char
*
x
)
:
buffer
(
x
,
x
+
s
.
bytes
()),
m_
shape
(
s
)
{}
bool
empty
()
const
{
return
this
->
buffer
.
empty
();
}
bool
empty
()
const
{
return
this
->
buffer
.
empty
();
}
const
char
*
data
()
const
{
return
this
->
buffer
.
data
();
}
const
char
*
data
()
const
{
return
this
->
buffer
.
data
();
}
const
shape
&
get_shape
()
const
{
return
this
->
shape
_
;
}
const
shape
&
get_shape
()
const
{
return
this
->
m_
shape
;
}
argument
get_argument
()
const
argument
get_argument
()
const
{
{
auto
b
=
buffer
;
auto
b
=
buffer
;
return
{
shape
_
,
[
b
]()
mutable
{
return
b
.
data
();
}};
return
{
m_
shape
,
[
b
]()
mutable
{
return
b
.
data
();
}};
}
}
private:
private:
std
::
vector
<
char
>
buffer
;
std
::
vector
<
char
>
buffer
;
shape
shape
_
;
shape
m_
shape
;
};
};
}
// namespace rtg
}
// namespace rtg
...
...
include/rtg/operand.hpp
View file @
83aa9ac3
...
@@ -28,112 +28,120 @@ struct operand
...
@@ -28,112 +28,120 @@ struct operand
// Constructors
// Constructors
operand
()
=
default
;
operand
()
=
default
;
template
<
typename
TypeErased_T_
>
template
<
typename
PrivateDetailTypeErasedT
>
operand
(
TypeErased_T_
value
)
operand
(
PrivateDetailTypeErasedT
value
)
:
handle_mem_var_
(
:
private_detail_te_handle_mem_var
(
std
::
make_shared
<
handle_type_
<
typename
std
::
remove_reference
<
TypeErased_T_
>::
type
>>
(
std
::
make_shared
<
private_detail_te_handle_type
<
std
::
forward
<
TypeErased_T_
>
(
value
)))
typename
std
::
remove_reference
<
PrivateDetailTypeErasedT
>::
type
>>
(
std
::
forward
<
PrivateDetailTypeErasedT
>
(
value
)))
{
{
}
}
// Assignment
// Assignment
template
<
typename
TypeErased
_T_
>
template
<
typename
PrivateDetail
TypeErased
T
>
operand
&
operator
=
(
TypeErased
_T_
value
)
operand
&
operator
=
(
PrivateDetail
TypeErased
T
value
)
{
{
if
(
handle_mem_var_
.
unique
())
if
(
private_detail_te_handle_mem_var
.
unique
())
*
handle_mem_var_
=
std
::
forward
<
TypeErased_T_
>
(
value
);
*
private_detail_te_handle_mem_var
=
std
::
forward
<
PrivateDetailTypeErasedT
>
(
value
);
else
if
(
!
handle_mem_var_
)
else
if
(
!
private_detail_te_handle_mem_var
)
handle_mem_var_
=
std
::
make_shared
<
TypeErased_T_
>
(
std
::
forward
<
TypeErased_T_
>
(
value
));
private_detail_te_handle_mem_var
=
std
::
make_shared
<
PrivateDetailTypeErasedT
>
(
std
::
forward
<
PrivateDetailTypeErasedT
>
(
value
));
return
*
this
;
return
*
this
;
}
}
std
::
string
name
()
const
std
::
string
name
()
const
{
{
assert
(
handle_mem_var
_
);
assert
(
private_detail_te_
handle_mem_var
);
return
get_handle
_
().
name
();
return
private_detail_te_
get_handle
().
name
();
}
}
shape
compute_shape
(
std
::
vector
<
shape
>
input
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
input
)
const
{
{
assert
(
handle_mem_var
_
);
assert
(
private_detail_te_
handle_mem_var
);
return
get_handle
_
().
compute_shape
(
std
::
move
(
input
));
return
private_detail_te_
get_handle
().
compute_shape
(
std
::
move
(
input
));
}
}
argument
compute
(
std
::
vector
<
argument
>
input
)
const
argument
compute
(
std
::
vector
<
argument
>
input
)
const
{
{
assert
(
handle_mem_var
_
);
assert
(
private_detail_te_
handle_mem_var
);
return
get_handle
_
().
compute
(
std
::
move
(
input
));
return
private_detail_te_
get_handle
().
compute
(
std
::
move
(
input
));
}
}
private:
private:
struct
handle_base_type
_
struct
private_detail_te_
handle_base_type
{
{
virtual
~
handle_base_type
_
()
{}
virtual
~
private_detail_te_
handle_base_type
()
{}
virtual
std
::
shared_ptr
<
handle_base_type
_
>
clone
()
const
=
0
;
virtual
std
::
shared_ptr
<
private_detail_te_
handle_base_type
>
clone
()
const
=
0
;
virtual
std
::
string
name
()
const
=
0
;
virtual
std
::
string
name
()
const
=
0
;
virtual
shape
compute_shape
(
std
::
vector
<
shape
>
input
)
const
=
0
;
virtual
shape
compute_shape
(
std
::
vector
<
shape
>
input
)
const
=
0
;
virtual
argument
compute
(
std
::
vector
<
argument
>
input
)
const
=
0
;
virtual
argument
compute
(
std
::
vector
<
argument
>
input
)
const
=
0
;
};
};
template
<
typename
TypeErased
_T_
>
template
<
typename
PrivateDetail
TypeErased
T
>
struct
handle_type
_
:
handle_base_type
_
struct
private_detail_te_
handle_type
:
private_detail_te_
handle_base_type
{
{
template
<
typename
TypeErased_U_
=
TypeErased_T_
>
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
handle_type_
(
private_detail_te_handle_type
(
TypeErased_T_
value
,
PrivateDetailTypeErasedT
value
,
typename
std
::
enable_if
<
std
::
is_reference
<
TypeErased_U_
>::
value
>::
type
*
=
nullptr
)
typename
std
::
enable_if
<
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
>::
type
*
=
:
value_
(
value
)
nullptr
)
:
private_detail_te_value
(
value
)
{
{
}
}
template
<
typename
TypeErased_U_
=
TypeErased_T_
>
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
handle_type_
(
TypeErased_T_
value
,
private_detail_te_handle_type
(
typename
std
::
enable_if
<!
std
::
is_reference
<
TypeErased_U_
>::
value
,
int
>::
type
*
=
PrivateDetailTypeErasedT
value
,
nullptr
)
noexcept
typename
std
::
enable_if
<!
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
:
value_
(
std
::
move
(
value
))
int
>::
type
*
=
nullptr
)
noexcept
:
private_detail_te_value
(
std
::
move
(
value
))
{
{
}
}
std
::
shared_ptr
<
handle_base_type
_
>
clone
()
const
override
std
::
shared_ptr
<
private_detail_te_
handle_base_type
>
clone
()
const
override
{
{
return
std
::
make_shared
<
handle_type
_
>
(
value
_
);
return
std
::
make_shared
<
private_detail_te_
handle_type
>
(
private_detail_te_
value
);
}
}
std
::
string
name
()
const
override
{
return
value
_
.
name
();
}
std
::
string
name
()
const
override
{
return
private_detail_te_
value
.
name
();
}
shape
compute_shape
(
std
::
vector
<
shape
>
input
)
const
override
shape
compute_shape
(
std
::
vector
<
shape
>
input
)
const
override
{
{
return
value
_
.
compute_shape
(
std
::
move
(
input
));
return
private_detail_te_
value
.
compute_shape
(
std
::
move
(
input
));
}
}
argument
compute
(
std
::
vector
<
argument
>
input
)
const
override
argument
compute
(
std
::
vector
<
argument
>
input
)
const
override
{
{
return
value
_
.
compute
(
std
::
move
(
input
));
return
private_detail_te_
value
.
compute
(
std
::
move
(
input
));
}
}
TypeErased_T_
value_
;
PrivateDetailTypeErasedT
private_detail_te_value
;
};
};
template
<
typename
TypeErased_T_
>
template
<
typename
PrivateDetailTypeErasedT
>
struct
handle_type_
<
std
::
reference_wrapper
<
TypeErased_T_
>>
:
handle_type_
<
TypeErased_T_
&>
struct
private_detail_te_handle_type
<
std
::
reference_wrapper
<
PrivateDetailTypeErasedT
>>
:
private_detail_te_handle_type
<
PrivateDetailTypeErasedT
&>
{
{
handle_type
_
(
std
::
reference_wrapper
<
TypeErased
_T_
>
ref
)
private_detail_te_
handle_type
(
std
::
reference_wrapper
<
PrivateDetail
TypeErased
T
>
ref
)
:
handle_type_
<
TypeErased
_T_
&>
(
ref
.
get
())
:
private_detail_te_handle_type
<
PrivateDetail
TypeErased
T
&>
(
ref
.
get
())
{
{
}
}
};
};
const
handle_base_type_
&
get_handle_
()
const
{
return
*
handle_mem_var_
;
}
const
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
const
{
return
*
private_detail_te_handle_mem_var
;
}
handle_base_type
_
&
get_handle
_
()
private_detail_te_
handle_base_type
&
private_detail_te_
get_handle
()
{
{
if
(
!
handle_mem_var
_
.
unique
())
if
(
!
private_detail_te_
handle_mem_var
.
unique
())
handle_mem_var
_
=
handle_mem_var
_
->
clone
();
private_detail_te_
handle_mem_var
=
private_detail_te_
handle_mem_var
->
clone
();
return
*
handle_mem_var
_
;
return
*
private_detail_te_
handle_mem_var
;
}
}
std
::
shared_ptr
<
handle_base_type
_
>
handle_mem_var
_
;
std
::
shared_ptr
<
private_detail_te_
handle_base_type
>
private_detail_te_
handle_mem_var
;
};
};
}
// namespace rtg
}
// namespace rtg
...
...
include/rtg/shape.hpp
View file @
83aa9ac3
...
@@ -110,7 +110,7 @@ struct shape
...
@@ -110,7 +110,7 @@ struct shape
template
<
class
Visitor
>
template
<
class
Visitor
>
void
visit_type
(
Visitor
v
)
const
void
visit_type
(
Visitor
v
)
const
{
{
switch
(
this
->
type
_
)
switch
(
this
->
m_
type
)
{
{
#define RTG_SHAPE_VISITOR_CASE(x, t) \
#define RTG_SHAPE_VISITOR_CASE(x, t) \
case x: v(as<t>()); return;
case x: v(as<t>()); return;
...
@@ -121,10 +121,10 @@ struct shape
...
@@ -121,10 +121,10 @@ struct shape
}
}
private:
private:
type_t
type
_
;
type_t
m_
type
;
std
::
vector
<
std
::
size_t
>
lens
_
;
std
::
vector
<
std
::
size_t
>
m_
lens
;
std
::
vector
<
std
::
size_t
>
strides
_
;
std
::
vector
<
std
::
size_t
>
m_
strides
;
bool
packed
_
;
bool
m_
packed
;
void
calculate_strides
();
void
calculate_strides
();
std
::
size_t
element_space
()
const
;
std
::
size_t
element_space
()
const
;
...
...
include/rtg/tensor_view.hpp
View file @
83aa9ac3
...
@@ -11,103 +11,103 @@ namespace rtg {
...
@@ -11,103 +11,103 @@ namespace rtg {
template
<
class
T
>
template
<
class
T
>
struct
tensor_view
struct
tensor_view
{
{
tensor_view
()
:
data
_
(
nullptr
)
{}
tensor_view
()
:
m_
data
(
nullptr
)
{}
tensor_view
(
shape
s
,
T
*
d
)
:
data
_
(
d
),
shape
_
(
s
)
{}
tensor_view
(
shape
s
,
T
*
d
)
:
m_
data
(
d
),
m_
shape
(
s
)
{}
const
shape
&
get_shape
()
const
{
return
this
->
shape
_
;
}
const
shape
&
get_shape
()
const
{
return
this
->
m_
shape
;
}
bool
empty
()
const
{
return
data
_
==
nullptr
||
shape
_
.
lens
().
empty
();
}
bool
empty
()
const
{
return
m_
data
==
nullptr
||
m_
shape
.
lens
().
empty
();
}
std
::
size_t
size
()
const
{
return
shape
_
.
elements
();
}
std
::
size_t
size
()
const
{
return
m_
shape
.
elements
();
}
T
*
data
()
{
return
this
->
data
_
;
}
T
*
data
()
{
return
this
->
m_
data
;
}
const
T
*
data
()
const
{
return
this
->
data
_
;
}
const
T
*
data
()
const
{
return
this
->
m_
data
;
}
template
<
class
...
Ts
>
template
<
class
...
Ts
>
const
T
&
operator
()(
Ts
...
xs
)
const
const
T
&
operator
()(
Ts
...
xs
)
const
{
{
return
data
_
[
shape
_
.
index
({
xs
...})];
return
m_
data
[
m_
shape
.
index
({
xs
...})];
}
}
template
<
class
...
Ts
>
template
<
class
...
Ts
>
T
&
operator
()(
Ts
...
xs
)
T
&
operator
()(
Ts
...
xs
)
{
{
return
data
_
[
shape
_
.
index
({
xs
...})];
return
m_
data
[
m_
shape
.
index
({
xs
...})];
}
}
T
&
operator
[](
std
::
size_t
i
)
T
&
operator
[](
std
::
size_t
i
)
{
{
assert
(
!
this
->
empty
()
&&
i
<
this
->
size
());
assert
(
!
this
->
empty
()
&&
i
<
this
->
size
());
return
data
_
[
shape
_
.
index
(
i
)];
return
m_
data
[
m_
shape
.
index
(
i
)];
}
}
const
T
&
operator
[](
std
::
size_t
i
)
const
const
T
&
operator
[](
std
::
size_t
i
)
const
{
{
assert
(
!
this
->
empty
()
&&
i
<
this
->
size
());
assert
(
!
this
->
empty
()
&&
i
<
this
->
size
());
return
data
_
[
shape
_
.
index
(
i
)];
return
m_
data
[
m_
shape
.
index
(
i
)];
}
}
T
&
front
()
T
&
front
()
{
{
assert
(
!
this
->
empty
());
assert
(
!
this
->
empty
());
return
data
_
[
0
];
return
m_
data
[
0
];
}
}
const
T
&
front
()
const
const
T
&
front
()
const
{
{
assert
(
!
this
->
empty
());
assert
(
!
this
->
empty
());
return
data
_
[
0
];
return
m_
data
[
0
];
}
}
T
&
back
()
T
&
back
()
{
{
assert
(
!
this
->
empty
());
assert
(
!
this
->
empty
());
return
data
_
[
shape
_
.
index
(
this
->
size
()
-
1
)];
return
m_
data
[
m_
shape
.
index
(
this
->
size
()
-
1
)];
}
}
const
T
&
back
()
const
const
T
&
back
()
const
{
{
assert
(
!
this
->
empty
());
assert
(
!
this
->
empty
());
return
data
_
[
shape
_
.
index
(
this
->
size
()
-
1
)];
return
m_
data
[
m_
shape
.
index
(
this
->
size
()
-
1
)];
}
}
// TODO: Add iterators so it can handle nonpacked tensors
// TODO: Add iterators so it can handle nonpacked tensors
T
*
begin
()
T
*
begin
()
{
{
assert
(
this
->
shape
_
.
packed
());
assert
(
this
->
m_
shape
.
packed
());
return
data
_
;
return
m_
data
;
}
}
T
*
end
()
T
*
end
()
{
{
assert
(
this
->
shape
_
.
packed
());
assert
(
this
->
m_
shape
.
packed
());
if
(
this
->
empty
())
if
(
this
->
empty
())
return
data
_
;
return
m_
data
;
else
else
return
data
_
+
this
->
size
();
return
m_
data
+
this
->
size
();
}
}
const
T
*
begin
()
const
const
T
*
begin
()
const
{
{
assert
(
this
->
shape
_
.
packed
());
assert
(
this
->
m_
shape
.
packed
());
return
data
_
;
return
m_
data
;
}
}
const
T
*
end
()
const
const
T
*
end
()
const
{
{
assert
(
this
->
shape
_
.
packed
());
assert
(
this
->
m_
shape
.
packed
());
if
(
this
->
empty
())
if
(
this
->
empty
())
return
data
_
;
return
m_
data
;
else
else
return
data
_
+
this
->
size
();
return
m_
data
+
this
->
size
();
}
}
friend
bool
operator
==
(
const
tensor_view
<
T
>&
x
,
const
tensor_view
<
T
>&
y
)
friend
bool
operator
==
(
const
tensor_view
<
T
>&
x
,
const
tensor_view
<
T
>&
y
)
{
{
if
(
x
.
shape
_
==
y
.
shape
_
)
if
(
x
.
m_
shape
==
y
.
m_
shape
)
{
{
for
(
std
::
size_t
i
=
0
;
i
<
x
.
shape
_
.
elements
();
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
x
.
m_
shape
.
elements
();
i
++
)
{
{
if
(
!
float_equal
(
x
[
i
],
y
[
i
]))
if
(
!
float_equal
(
x
[
i
],
y
[
i
]))
return
false
;
return
false
;
...
@@ -124,17 +124,17 @@ struct tensor_view
...
@@ -124,17 +124,17 @@ struct tensor_view
if
(
!
x
.
empty
())
if
(
!
x
.
empty
())
{
{
os
<<
x
.
front
();
os
<<
x
.
front
();
for
(
std
::
size_t
i
=
1
;
i
<
x
.
shape
_
.
elements
();
i
++
)
for
(
std
::
size_t
i
=
1
;
i
<
x
.
m_
shape
.
elements
();
i
++
)
{
{
os
<<
", "
<<
x
.
data
_
[
x
.
shape
_
.
index
(
i
)];
os
<<
", "
<<
x
.
m_
data
[
x
.
m_
shape
.
index
(
i
)];
}
}
}
}
return
os
;
return
os
;
}
}
private:
private:
T
*
data
_
;
T
*
m_
data
;
shape
shape
_
;
shape
m_
shape
;
};
};
template
<
class
T
>
template
<
class
T
>
...
...
src/shape.cpp
View file @
83aa9ac3
...
@@ -7,35 +7,35 @@
...
@@ -7,35 +7,35 @@
namespace
rtg
{
namespace
rtg
{
shape
::
shape
()
:
type
_
(
float_type
),
packed
_
(
false
)
{}
shape
::
shape
()
:
m_
type
(
float_type
),
m_
packed
(
false
)
{}
shape
::
shape
(
type_t
t
)
:
type
_
(
t
),
lens
_
({
1
}),
strides
_
({
1
}),
packed
_
(
true
)
{}
shape
::
shape
(
type_t
t
)
:
m_
type
(
t
),
m_
lens
({
1
}),
m_
strides
({
1
}),
m_
packed
(
true
)
{}
shape
::
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
)
:
type
_
(
t
),
lens
_
(
std
::
move
(
l
)),
packed
_
(
true
)
shape
::
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
)
:
m_
type
(
t
),
m_
lens
(
std
::
move
(
l
)),
m_
packed
(
true
)
{
{
this
->
calculate_strides
();
this
->
calculate_strides
();
assert
(
lens
_
.
size
()
==
strides
_
.
size
());
assert
(
m_
lens
.
size
()
==
m_
strides
.
size
());
}
}
shape
::
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
,
std
::
vector
<
std
::
size_t
>
s
)
shape
::
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
,
std
::
vector
<
std
::
size_t
>
s
)
:
type
_
(
t
),
lens
_
(
std
::
move
(
l
)),
strides
_
(
std
::
move
(
s
))
:
m_
type
(
t
),
m_
lens
(
std
::
move
(
l
)),
m_
strides
(
std
::
move
(
s
))
{
{
assert
(
lens
_
.
size
()
==
strides
_
.
size
());
assert
(
m_
lens
.
size
()
==
m_
strides
.
size
());
packed
_
=
this
->
elements
()
==
this
->
element_space
();
m_
packed
=
this
->
elements
()
==
this
->
element_space
();
}
}
void
shape
::
calculate_strides
()
void
shape
::
calculate_strides
()
{
{
strides
_
.
clear
();
m_
strides
.
clear
();
strides
_
.
resize
(
lens
_
.
size
(),
0
);
m_
strides
.
resize
(
m_
lens
.
size
(),
0
);
if
(
strides
_
.
empty
())
if
(
m_
strides
.
empty
())
return
;
return
;
strides
_
.
back
()
=
1
;
m_
strides
.
back
()
=
1
;
std
::
partial_sum
(
std
::
partial_sum
(
lens
_
.
rbegin
(),
lens
_
.
rend
()
-
1
,
strides
_
.
rbegin
()
+
1
,
std
::
multiplies
<
std
::
size_t
>
());
m_
lens
.
rbegin
(),
m_
lens
.
rend
()
-
1
,
m_
strides
.
rbegin
()
+
1
,
std
::
multiplies
<
std
::
size_t
>
());
}
}
shape
::
type_t
shape
::
type
()
const
{
return
this
->
type
_
;
}
shape
::
type_t
shape
::
type
()
const
{
return
this
->
m_
type
;
}
const
std
::
vector
<
std
::
size_t
>&
shape
::
lens
()
const
{
return
this
->
lens
_
;
}
const
std
::
vector
<
std
::
size_t
>&
shape
::
lens
()
const
{
return
this
->
m_
lens
;
}
const
std
::
vector
<
std
::
size_t
>&
shape
::
strides
()
const
{
return
this
->
strides
_
;
}
const
std
::
vector
<
std
::
size_t
>&
shape
::
strides
()
const
{
return
this
->
m_
strides
;
}
std
::
size_t
shape
::
elements
()
const
std
::
size_t
shape
::
elements
()
const
{
{
assert
(
this
->
lens
().
size
()
==
this
->
strides
().
size
());
assert
(
this
->
lens
().
size
()
==
this
->
strides
().
size
());
...
@@ -71,7 +71,7 @@ std::size_t shape::index(std::size_t i) const
...
@@ -71,7 +71,7 @@ std::size_t shape::index(std::size_t i) const
std
::
plus
<
std
::
size_t
>
{},
std
::
plus
<
std
::
size_t
>
{},
[
&
](
std
::
size_t
len
,
std
::
size_t
stride
)
{
return
((
i
/
stride
)
%
len
)
*
stride
;
});
[
&
](
std
::
size_t
len
,
std
::
size_t
stride
)
{
return
((
i
/
stride
)
%
len
)
*
stride
;
});
}
}
bool
shape
::
packed
()
const
{
return
this
->
packed
_
;
}
bool
shape
::
packed
()
const
{
return
this
->
m_
packed
;
}
std
::
size_t
shape
::
element_space
()
const
std
::
size_t
shape
::
element_space
()
const
{
{
// TODO: Get rid of intermediate vector
// TODO: Get rid of intermediate vector
...
@@ -89,7 +89,7 @@ std::size_t shape::element_space() const
...
@@ -89,7 +89,7 @@ std::size_t shape::element_space() const
std
::
string
shape
::
type_string
()
const
std
::
string
shape
::
type_string
()
const
{
{
switch
(
this
->
type
_
)
switch
(
this
->
m_
type
)
{
{
#define RTG_SHAPE_TYPE_STRING_CASE(x, t) \
#define RTG_SHAPE_TYPE_STRING_CASE(x, t) \
case x: return #x;
case x: return #x;
...
...
test/test.hpp
View file @
83aa9ac3
...
@@ -4,8 +4,8 @@
...
@@ -4,8 +4,8 @@
#include <cstdlib>
#include <cstdlib>
#include <iostream>
#include <iostream>
#ifndef GUARD_TEST_TEST_HPP
_
#ifndef GUARD_TEST_TEST_HPP
#define GUARD_TEST_TEST_HPP
_
#define GUARD_TEST_TEST_HPP
inline
void
failed
(
const
char
*
msg
,
const
char
*
file
,
int
line
)
inline
void
failed
(
const
char
*
msg
,
const
char
*
file
,
int
line
)
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment