Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
83aa9ac3
Commit
83aa9ac3
authored
Apr 23, 2018
by
Paul
Browse files
Enforce correct naming
parent
81778d3d
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
173 additions
and
163 deletions
+173
-163
.clang-tidy
.clang-tidy
+51
-49
include/rtg/argument.hpp
include/rtg/argument.hpp
+3
-3
include/rtg/literal.hpp
include/rtg/literal.hpp
+8
-8
include/rtg/operand.hpp
include/rtg/operand.hpp
+56
-48
include/rtg/shape.hpp
include/rtg/shape.hpp
+5
-5
include/rtg/tensor_view.hpp
include/rtg/tensor_view.hpp
+31
-31
src/shape.cpp
src/shape.cpp
+17
-17
test/test.hpp
test/test.hpp
+2
-2
No files found.
.clang-tidy
View file @
83aa9ac3
...
@@ -11,91 +11,93 @@ CheckOptions:
...
@@ -11,91 +11,93 @@ CheckOptions:
value: '10'
value: '10'
- key: readability-function-size.StatementThreshold
- key: readability-function-size.StatementThreshold
value: '150'
value: '150'
- key: readability-identifier-naming.Namespace
- key: readability-identifier-naming.Namespace
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.InlineNamespace
- key: readability-identifier-naming.InlineNamespace
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.EnumConstant
- key: readability-identifier-naming.EnumConstant
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.ConstexprVariable
- key: readability-identifier-naming.ConstexprVariable
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.ConstantMember
- key: readability-identifier-naming.ConstantMember
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.PrivateMember
- key: readability-identifier-naming.PrivateMember
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.ProtectedMember
- key: readability-identifier-naming.ProtectedMember
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.PublicMember
- key: readability-identifier-naming.PublicMember
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.Member
- key: readability-identifier-naming.Member
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.ClassConstant
- key: readability-identifier-naming.ClassConstant
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.ClassMember
- key: readability-identifier-naming.ClassMember
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.GlobalConstant
- key: readability-identifier-naming.GlobalConstant
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.GlobalVariable
- key: readability-identifier-naming.GlobalVariable
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.LocalConstant
- key: readability-identifier-naming.LocalConstant
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.LocalVariable
- key: readability-identifier-naming.LocalVariable
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.StaticConstant
- key: readability-identifier-naming.StaticConstant
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.StaticVariable
- key: readability-identifier-naming.StaticVariable
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.Constant
- key: readability-identifier-naming.Constant
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.Variable
- key: readability-identifier-naming.Variable
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.ConstantParameter
- key: readability-identifier-naming.ConstantParameter
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.ParameterPack
- key: readability-identifier-naming.ParameterPack
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.Parameter
- key: readability-identifier-naming.Parameter
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.AbstractClass
- key: readability-identifier-naming.AbstractClass
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.Struct
- key: readability-identifier-naming.Struct
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.Class
- key: readability-identifier-naming.Class
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.Union
- key: readability-identifier-naming.Union
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.Enum
- key: readability-identifier-naming.Enum
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.GlobalFunction
- key: readability-identifier-naming.GlobalFunction
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.ConstexprFunction
- key: readability-identifier-naming.ConstexprFunction
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.Function
- key: readability-identifier-naming.Function
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.ConstexprMethod
- key: readability-identifier-naming.ConstexprMethod
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.VirtualMethod
- key: readability-identifier-naming.VirtualMethod
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.ClassMethod
- key: readability-identifier-naming.ClassMethod
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.PrivateMethod
- key: readability-identifier-naming.PrivateMethod
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.ProtectedMethod
- key: readability-identifier-naming.ProtectedMethod
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.PublicMethod
- key: readability-identifier-naming.PublicMethod
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.Method
- key: readability-identifier-naming.Method
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.Typedef
- key: readability-identifier-naming.Typedef
Case
value: lower_case
value: lower_case
- key: readability-identifier-naming.TypeTemplateParameter
- key: readability-identifier-naming.TypeTemplateParameterCase
value: lower_case
value: CamelCase
- key: readability-identifier-naming.ValueTemplateParameter
- key: readability-identifier-naming.ValueTemplateParameterCase
value: lower_case
value: CamelCase
- key: readability-identifier-naming.TemplateTemplateParameter
- key: readability-identifier-naming.TemplateTemplateParameterCase
value: lower_case
value: CamelCase
- key: readability-identifier-naming.TemplateParameter
- key: readability-identifier-naming.TemplateParameterCase
value: lower_case
value: CamelCase
- key: readability-identifier-naming.TypeAlias
- key: readability-identifier-naming.TypeAliasCase
value: lower_case
- key: readability-identifier-naming.MacroDefinition
value: lower_case
value: lower_case
# - key: readability-identifier-naming.MacroDefinitionCase
# value: UPPER_CASE
# - key: readability-identifier-naming.MacroDefinitionPrefix
# value: RTG_
include/rtg/argument.hpp
View file @
83aa9ac3
...
@@ -11,16 +11,16 @@ struct argument : raw_data<argument>
...
@@ -11,16 +11,16 @@ struct argument : raw_data<argument>
{
{
argument
()
{}
argument
()
{}
argument
(
shape
s
,
std
::
function
<
char
*
()
>
d
)
:
data
(
d
),
shape
_
(
s
)
{}
argument
(
shape
s
,
std
::
function
<
char
*
()
>
d
)
:
data
(
d
),
m_
shape
(
s
)
{}
std
::
function
<
char
*
()
>
data
;
std
::
function
<
char
*
()
>
data
;
bool
empty
()
const
{
return
not
data
;
}
bool
empty
()
const
{
return
not
data
;
}
const
shape
&
get_shape
()
const
{
return
this
->
shape
_
;
}
const
shape
&
get_shape
()
const
{
return
this
->
m_
shape
;
}
private:
private:
shape
shape
_
;
shape
m_
shape
;
};
};
}
// namespace rtg
}
// namespace rtg
...
...
include/rtg/literal.hpp
View file @
83aa9ac3
...
@@ -13,14 +13,14 @@ struct literal : raw_data<literal>
...
@@ -13,14 +13,14 @@ struct literal : raw_data<literal>
literal
()
{}
literal
()
{}
template
<
class
T
>
template
<
class
T
>
literal
(
T
x
)
:
buffer
(
sizeof
(
T
),
0
),
shape
_
(
shape
::
get_type
<
T
>
{})
literal
(
T
x
)
:
buffer
(
sizeof
(
T
),
0
),
m_
shape
(
shape
::
get_type
<
T
>
{})
{
{
static_assert
(
std
::
is_trivial
<
T
>
{},
"Literals can only be trivial types"
);
static_assert
(
std
::
is_trivial
<
T
>
{},
"Literals can only be trivial types"
);
*
(
reinterpret_cast
<
T
*>
(
buffer
.
data
()))
=
x
;
*
(
reinterpret_cast
<
T
*>
(
buffer
.
data
()))
=
x
;
}
}
template
<
class
T
>
template
<
class
T
>
literal
(
shape
s
,
const
std
::
vector
<
T
>&
x
)
:
buffer
(
s
.
bytes
(),
0
),
shape
_
(
s
)
literal
(
shape
s
,
const
std
::
vector
<
T
>&
x
)
:
buffer
(
s
.
bytes
(),
0
),
m_
shape
(
s
)
{
{
assert
(
s
.
packed
());
assert
(
s
.
packed
());
static_assert
(
std
::
is_trivial
<
T
>
{},
"Literals can only be trivial types"
);
static_assert
(
std
::
is_trivial
<
T
>
{},
"Literals can only be trivial types"
);
...
@@ -28,7 +28,7 @@ struct literal : raw_data<literal>
...
@@ -28,7 +28,7 @@ struct literal : raw_data<literal>
}
}
template
<
class
T
>
template
<
class
T
>
literal
(
shape
s
,
const
std
::
initializer_list
<
T
>&
x
)
:
buffer
(
s
.
bytes
(),
0
),
shape
_
(
s
)
literal
(
shape
s
,
const
std
::
initializer_list
<
T
>&
x
)
:
buffer
(
s
.
bytes
(),
0
),
m_
shape
(
s
)
{
{
assert
(
s
.
packed
());
assert
(
s
.
packed
());
static_assert
(
std
::
is_trivial
<
T
>
{},
"Literals can only be trivial types"
);
static_assert
(
std
::
is_trivial
<
T
>
{},
"Literals can only be trivial types"
);
...
@@ -36,29 +36,29 @@ struct literal : raw_data<literal>
...
@@ -36,29 +36,29 @@ struct literal : raw_data<literal>
}
}
template
<
class
Iterator
>
template
<
class
Iterator
>
literal
(
shape
s
,
Iterator
start
,
Iterator
end
)
:
buffer
(
s
.
bytes
(),
0
),
shape
_
(
s
)
literal
(
shape
s
,
Iterator
start
,
Iterator
end
)
:
buffer
(
s
.
bytes
(),
0
),
m_
shape
(
s
)
{
{
assert
(
s
.
packed
());
assert
(
s
.
packed
());
s
.
visit_type
([
&
](
auto
as
)
{
std
::
copy
(
start
,
end
,
as
.
from
(
buffer
.
data
()));
});
s
.
visit_type
([
&
](
auto
as
)
{
std
::
copy
(
start
,
end
,
as
.
from
(
buffer
.
data
()));
});
}
}
literal
(
shape
s
,
const
char
*
x
)
:
buffer
(
x
,
x
+
s
.
bytes
()),
shape
_
(
s
)
{}
literal
(
shape
s
,
const
char
*
x
)
:
buffer
(
x
,
x
+
s
.
bytes
()),
m_
shape
(
s
)
{}
bool
empty
()
const
{
return
this
->
buffer
.
empty
();
}
bool
empty
()
const
{
return
this
->
buffer
.
empty
();
}
const
char
*
data
()
const
{
return
this
->
buffer
.
data
();
}
const
char
*
data
()
const
{
return
this
->
buffer
.
data
();
}
const
shape
&
get_shape
()
const
{
return
this
->
shape
_
;
}
const
shape
&
get_shape
()
const
{
return
this
->
m_
shape
;
}
argument
get_argument
()
const
argument
get_argument
()
const
{
{
auto
b
=
buffer
;
auto
b
=
buffer
;
return
{
shape
_
,
[
b
]()
mutable
{
return
b
.
data
();
}};
return
{
m_
shape
,
[
b
]()
mutable
{
return
b
.
data
();
}};
}
}
private:
private:
std
::
vector
<
char
>
buffer
;
std
::
vector
<
char
>
buffer
;
shape
shape
_
;
shape
m_
shape
;
};
};
}
// namespace rtg
}
// namespace rtg
...
...
include/rtg/operand.hpp
View file @
83aa9ac3
...
@@ -28,112 +28,120 @@ struct operand
...
@@ -28,112 +28,120 @@ struct operand
// Constructors
// Constructors
operand
()
=
default
;
operand
()
=
default
;
template
<
typename
TypeErased_T_
>
template
<
typename
PrivateDetailTypeErasedT
>
operand
(
TypeErased_T_
value
)
operand
(
PrivateDetailTypeErasedT
value
)
:
handle_mem_var_
(
:
private_detail_te_handle_mem_var
(
std
::
make_shared
<
handle_type_
<
typename
std
::
remove_reference
<
TypeErased_T_
>::
type
>>
(
std
::
make_shared
<
private_detail_te_handle_type
<
std
::
forward
<
TypeErased_T_
>
(
value
)))
typename
std
::
remove_reference
<
PrivateDetailTypeErasedT
>::
type
>>
(
std
::
forward
<
PrivateDetailTypeErasedT
>
(
value
)))
{
{
}
}
// Assignment
// Assignment
template
<
typename
TypeErased
_T_
>
template
<
typename
PrivateDetail
TypeErased
T
>
operand
&
operator
=
(
TypeErased
_T_
value
)
operand
&
operator
=
(
PrivateDetail
TypeErased
T
value
)
{
{
if
(
handle_mem_var_
.
unique
())
if
(
private_detail_te_handle_mem_var
.
unique
())
*
handle_mem_var_
=
std
::
forward
<
TypeErased_T_
>
(
value
);
*
private_detail_te_handle_mem_var
=
std
::
forward
<
PrivateDetailTypeErasedT
>
(
value
);
else
if
(
!
handle_mem_var_
)
else
if
(
!
private_detail_te_handle_mem_var
)
handle_mem_var_
=
std
::
make_shared
<
TypeErased_T_
>
(
std
::
forward
<
TypeErased_T_
>
(
value
));
private_detail_te_handle_mem_var
=
std
::
make_shared
<
PrivateDetailTypeErasedT
>
(
std
::
forward
<
PrivateDetailTypeErasedT
>
(
value
));
return
*
this
;
return
*
this
;
}
}
std
::
string
name
()
const
std
::
string
name
()
const
{
{
assert
(
handle_mem_var
_
);
assert
(
private_detail_te_
handle_mem_var
);
return
get_handle
_
().
name
();
return
private_detail_te_
get_handle
().
name
();
}
}
shape
compute_shape
(
std
::
vector
<
shape
>
input
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
input
)
const
{
{
assert
(
handle_mem_var
_
);
assert
(
private_detail_te_
handle_mem_var
);
return
get_handle
_
().
compute_shape
(
std
::
move
(
input
));
return
private_detail_te_
get_handle
().
compute_shape
(
std
::
move
(
input
));
}
}
argument
compute
(
std
::
vector
<
argument
>
input
)
const
argument
compute
(
std
::
vector
<
argument
>
input
)
const
{
{
assert
(
handle_mem_var
_
);
assert
(
private_detail_te_
handle_mem_var
);
return
get_handle
_
().
compute
(
std
::
move
(
input
));
return
private_detail_te_
get_handle
().
compute
(
std
::
move
(
input
));
}
}
private:
private:
struct
handle_base_type
_
struct
private_detail_te_
handle_base_type
{
{
virtual
~
handle_base_type
_
()
{}
virtual
~
private_detail_te_
handle_base_type
()
{}
virtual
std
::
shared_ptr
<
handle_base_type
_
>
clone
()
const
=
0
;
virtual
std
::
shared_ptr
<
private_detail_te_
handle_base_type
>
clone
()
const
=
0
;
virtual
std
::
string
name
()
const
=
0
;
virtual
std
::
string
name
()
const
=
0
;
virtual
shape
compute_shape
(
std
::
vector
<
shape
>
input
)
const
=
0
;
virtual
shape
compute_shape
(
std
::
vector
<
shape
>
input
)
const
=
0
;
virtual
argument
compute
(
std
::
vector
<
argument
>
input
)
const
=
0
;
virtual
argument
compute
(
std
::
vector
<
argument
>
input
)
const
=
0
;
};
};
template
<
typename
TypeErased
_T_
>
template
<
typename
PrivateDetail
TypeErased
T
>
struct
handle_type
_
:
handle_base_type
_
struct
private_detail_te_
handle_type
:
private_detail_te_
handle_base_type
{
{
template
<
typename
TypeErased_U_
=
TypeErased_T_
>
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
handle_type_
(
private_detail_te_handle_type
(
TypeErased_T_
value
,
PrivateDetailTypeErasedT
value
,
typename
std
::
enable_if
<
std
::
is_reference
<
TypeErased_U_
>::
value
>::
type
*
=
nullptr
)
typename
std
::
enable_if
<
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
>::
type
*
=
:
value_
(
value
)
nullptr
)
:
private_detail_te_value
(
value
)
{
{
}
}
template
<
typename
TypeErased_U_
=
TypeErased_T_
>
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
handle_type_
(
TypeErased_T_
value
,
private_detail_te_handle_type
(
typename
std
::
enable_if
<!
std
::
is_reference
<
TypeErased_U_
>::
value
,
int
>::
type
*
=
PrivateDetailTypeErasedT
value
,
nullptr
)
noexcept
typename
std
::
enable_if
<!
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
:
value_
(
std
::
move
(
value
))
int
>::
type
*
=
nullptr
)
noexcept
:
private_detail_te_value
(
std
::
move
(
value
))
{
{
}
}
std
::
shared_ptr
<
handle_base_type
_
>
clone
()
const
override
std
::
shared_ptr
<
private_detail_te_
handle_base_type
>
clone
()
const
override
{
{
return
std
::
make_shared
<
handle_type
_
>
(
value
_
);
return
std
::
make_shared
<
private_detail_te_
handle_type
>
(
private_detail_te_
value
);
}
}
std
::
string
name
()
const
override
{
return
value
_
.
name
();
}
std
::
string
name
()
const
override
{
return
private_detail_te_
value
.
name
();
}
shape
compute_shape
(
std
::
vector
<
shape
>
input
)
const
override
shape
compute_shape
(
std
::
vector
<
shape
>
input
)
const
override
{
{
return
value
_
.
compute_shape
(
std
::
move
(
input
));
return
private_detail_te_
value
.
compute_shape
(
std
::
move
(
input
));
}
}
argument
compute
(
std
::
vector
<
argument
>
input
)
const
override
argument
compute
(
std
::
vector
<
argument
>
input
)
const
override
{
{
return
value
_
.
compute
(
std
::
move
(
input
));
return
private_detail_te_
value
.
compute
(
std
::
move
(
input
));
}
}
TypeErased_T_
value_
;
PrivateDetailTypeErasedT
private_detail_te_value
;
};
};
template
<
typename
TypeErased_T_
>
template
<
typename
PrivateDetailTypeErasedT
>
struct
handle_type_
<
std
::
reference_wrapper
<
TypeErased_T_
>>
:
handle_type_
<
TypeErased_T_
&>
struct
private_detail_te_handle_type
<
std
::
reference_wrapper
<
PrivateDetailTypeErasedT
>>
:
private_detail_te_handle_type
<
PrivateDetailTypeErasedT
&>
{
{
handle_type
_
(
std
::
reference_wrapper
<
TypeErased
_T_
>
ref
)
private_detail_te_
handle_type
(
std
::
reference_wrapper
<
PrivateDetail
TypeErased
T
>
ref
)
:
handle_type_
<
TypeErased
_T_
&>
(
ref
.
get
())
:
private_detail_te_handle_type
<
PrivateDetail
TypeErased
T
&>
(
ref
.
get
())
{
{
}
}
};
};
const
handle_base_type_
&
get_handle_
()
const
{
return
*
handle_mem_var_
;
}
const
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
const
{
return
*
private_detail_te_handle_mem_var
;
}
handle_base_type
_
&
get_handle
_
()
private_detail_te_
handle_base_type
&
private_detail_te_
get_handle
()
{
{
if
(
!
handle_mem_var
_
.
unique
())
if
(
!
private_detail_te_
handle_mem_var
.
unique
())
handle_mem_var
_
=
handle_mem_var
_
->
clone
();
private_detail_te_
handle_mem_var
=
private_detail_te_
handle_mem_var
->
clone
();
return
*
handle_mem_var
_
;
return
*
private_detail_te_
handle_mem_var
;
}
}
std
::
shared_ptr
<
handle_base_type
_
>
handle_mem_var
_
;
std
::
shared_ptr
<
private_detail_te_
handle_base_type
>
private_detail_te_
handle_mem_var
;
};
};
}
// namespace rtg
}
// namespace rtg
...
...
include/rtg/shape.hpp
View file @
83aa9ac3
...
@@ -110,7 +110,7 @@ struct shape
...
@@ -110,7 +110,7 @@ struct shape
template
<
class
Visitor
>
template
<
class
Visitor
>
void
visit_type
(
Visitor
v
)
const
void
visit_type
(
Visitor
v
)
const
{
{
switch
(
this
->
type
_
)
switch
(
this
->
m_
type
)
{
{
#define RTG_SHAPE_VISITOR_CASE(x, t) \
#define RTG_SHAPE_VISITOR_CASE(x, t) \
case x: v(as<t>()); return;
case x: v(as<t>()); return;
...
@@ -121,10 +121,10 @@ struct shape
...
@@ -121,10 +121,10 @@ struct shape
}
}
private:
private:
type_t
type
_
;
type_t
m_
type
;
std
::
vector
<
std
::
size_t
>
lens
_
;
std
::
vector
<
std
::
size_t
>
m_
lens
;
std
::
vector
<
std
::
size_t
>
strides
_
;
std
::
vector
<
std
::
size_t
>
m_
strides
;
bool
packed
_
;
bool
m_
packed
;
void
calculate_strides
();
void
calculate_strides
();
std
::
size_t
element_space
()
const
;
std
::
size_t
element_space
()
const
;
...
...
include/rtg/tensor_view.hpp
View file @
83aa9ac3
...
@@ -11,103 +11,103 @@ namespace rtg {
...
@@ -11,103 +11,103 @@ namespace rtg {
template
<
class
T
>
template
<
class
T
>
struct
tensor_view
struct
tensor_view
{
{
tensor_view
()
:
data
_
(
nullptr
)
{}
tensor_view
()
:
m_
data
(
nullptr
)
{}
tensor_view
(
shape
s
,
T
*
d
)
:
data
_
(
d
),
shape
_
(
s
)
{}
tensor_view
(
shape
s
,
T
*
d
)
:
m_
data
(
d
),
m_
shape
(
s
)
{}
const
shape
&
get_shape
()
const
{
return
this
->
shape
_
;
}
const
shape
&
get_shape
()
const
{
return
this
->
m_
shape
;
}
bool
empty
()
const
{
return
data
_
==
nullptr
||
shape
_
.
lens
().
empty
();
}
bool
empty
()
const
{
return
m_
data
==
nullptr
||
m_
shape
.
lens
().
empty
();
}
std
::
size_t
size
()
const
{
return
shape
_
.
elements
();
}
std
::
size_t
size
()
const
{
return
m_
shape
.
elements
();
}
T
*
data
()
{
return
this
->
data
_
;
}
T
*
data
()
{
return
this
->
m_
data
;
}
const
T
*
data
()
const
{
return
this
->
data
_
;
}
const
T
*
data
()
const
{
return
this
->
m_
data
;
}
template
<
class
...
Ts
>
template
<
class
...
Ts
>
const
T
&
operator
()(
Ts
...
xs
)
const
const
T
&
operator
()(
Ts
...
xs
)
const
{
{
return
data
_
[
shape
_
.
index
({
xs
...})];
return
m_
data
[
m_
shape
.
index
({
xs
...})];
}
}
template
<
class
...
Ts
>
template
<
class
...
Ts
>
T
&
operator
()(
Ts
...
xs
)
T
&
operator
()(
Ts
...
xs
)
{
{
return
data
_
[
shape
_
.
index
({
xs
...})];
return
m_
data
[
m_
shape
.
index
({
xs
...})];
}
}
T
&
operator
[](
std
::
size_t
i
)
T
&
operator
[](
std
::
size_t
i
)
{
{
assert
(
!
this
->
empty
()
&&
i
<
this
->
size
());
assert
(
!
this
->
empty
()
&&
i
<
this
->
size
());
return
data
_
[
shape
_
.
index
(
i
)];
return
m_
data
[
m_
shape
.
index
(
i
)];
}
}
const
T
&
operator
[](
std
::
size_t
i
)
const
const
T
&
operator
[](
std
::
size_t
i
)
const
{
{
assert
(
!
this
->
empty
()
&&
i
<
this
->
size
());
assert
(
!
this
->
empty
()
&&
i
<
this
->
size
());
return
data
_
[
shape
_
.
index
(
i
)];
return
m_
data
[
m_
shape
.
index
(
i
)];
}
}
T
&
front
()
T
&
front
()
{
{
assert
(
!
this
->
empty
());
assert
(
!
this
->
empty
());
return
data
_
[
0
];
return
m_
data
[
0
];
}
}
const
T
&
front
()
const
const
T
&
front
()
const
{
{
assert
(
!
this
->
empty
());
assert
(
!
this
->
empty
());
return
data
_
[
0
];
return
m_
data
[
0
];
}
}
T
&
back
()
T
&
back
()
{
{
assert
(
!
this
->
empty
());
assert
(
!
this
->
empty
());
return
data
_
[
shape
_
.
index
(
this
->
size
()
-
1
)];
return
m_
data
[
m_
shape
.
index
(
this
->
size
()
-
1
)];
}
}
const
T
&
back
()
const
const
T
&
back
()
const
{
{
assert
(
!
this
->
empty
());
assert
(
!
this
->
empty
());
return
data
_
[
shape
_
.
index
(
this
->
size
()
-
1
)];
return
m_
data
[
m_
shape
.
index
(
this
->
size
()
-
1
)];
}
}
// TODO: Add iterators so it can handle nonpacked tensors
// TODO: Add iterators so it can handle nonpacked tensors
T
*
begin
()
T
*
begin
()
{
{
assert
(
this
->
shape
_
.
packed
());
assert
(
this
->
m_
shape
.
packed
());
return
data
_
;
return
m_
data
;
}
}
T
*
end
()
T
*
end
()
{
{
assert
(
this
->
shape
_
.
packed
());
assert
(
this
->
m_
shape
.
packed
());
if
(
this
->
empty
())
if
(
this
->
empty
())
return
data
_
;
return
m_
data
;
else
else
return
data
_
+
this
->
size
();
return
m_
data
+
this
->
size
();
}
}
const
T
*
begin
()
const
const
T
*
begin
()
const
{
{
assert
(
this
->
shape
_
.
packed
());
assert
(
this
->
m_
shape
.
packed
());
return
data
_
;
return
m_
data
;
}
}
const
T
*
end
()
const
const
T
*
end
()
const
{
{
assert
(
this
->
shape
_
.
packed
());
assert
(
this
->
m_
shape
.
packed
());
if
(
this
->
empty
())
if
(
this
->
empty
())
return
data
_
;
return
m_
data
;
else
else
return
data
_
+
this
->
size
();
return
m_
data
+
this
->
size
();
}
}
friend
bool
operator
==
(
const
tensor_view
<
T
>&
x
,
const
tensor_view
<
T
>&
y
)
friend
bool
operator
==
(
const
tensor_view
<
T
>&
x
,
const
tensor_view
<
T
>&
y
)
{
{
if
(
x
.
shape
_
==
y
.
shape
_
)
if
(
x
.
m_
shape
==
y
.
m_
shape
)
{
{
for
(
std
::
size_t
i
=
0
;
i
<
x
.
shape
_
.
elements
();
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
x
.
m_
shape
.
elements
();
i
++
)
{
{
if
(
!
float_equal
(
x
[
i
],
y
[
i
]))
if
(
!
float_equal
(
x
[
i
],
y
[
i
]))
return
false
;
return
false
;
...
@@ -124,17 +124,17 @@ struct tensor_view
...
@@ -124,17 +124,17 @@ struct tensor_view
if
(
!
x
.
empty
())
if
(
!
x
.
empty
())
{
{
os
<<
x
.
front
();
os
<<
x
.
front
();
for
(
std
::
size_t
i
=
1
;
i
<
x
.
shape
_
.
elements
();
i
++
)
for
(
std
::
size_t
i
=
1
;
i
<
x
.
m_
shape
.
elements
();
i
++
)
{
{
os
<<
", "
<<
x
.
data
_
[
x
.
shape
_
.
index
(
i
)];
os
<<
", "
<<
x
.
m_
data
[
x
.
m_
shape
.
index
(
i
)];
}
}
}
}
return
os
;
return
os
;
}
}
private:
private:
T
*
data
_
;
T
*
m_
data
;
shape
shape
_
;
shape
m_
shape
;
};
};
template
<
class
T
>
template
<
class
T
>
...
...
src/shape.cpp
View file @
83aa9ac3
...
@@ -7,35 +7,35 @@
...
@@ -7,35 +7,35 @@
namespace
rtg
{
namespace
rtg
{
shape
::
shape
()
:
type
_
(
float_type
),
packed
_
(
false
)
{}
shape
::
shape
()
:
m_
type
(
float_type
),
m_
packed
(
false
)
{}
shape
::
shape
(
type_t
t
)
:
type
_
(
t
),
lens
_
({
1
}),
strides
_
({
1
}),
packed
_
(
true
)
{}
shape
::
shape
(
type_t
t
)
:
m_
type
(
t
),
m_
lens
({
1
}),
m_
strides
({
1
}),
m_
packed
(
true
)
{}
shape
::
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
)
:
type
_
(
t
),
lens
_
(
std
::
move
(
l
)),
packed
_
(
true
)
shape
::
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
)
:
m_
type
(
t
),
m_
lens
(
std
::
move
(
l
)),
m_
packed
(
true
)
{
{
this
->
calculate_strides
();
this
->
calculate_strides
();
assert
(
lens
_
.
size
()
==
strides
_
.
size
());
assert
(
m_
lens
.
size
()
==
m_
strides
.
size
());
}
}
shape
::
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
,
std
::
vector
<
std
::
size_t
>
s
)
shape
::
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
,
std
::
vector
<
std
::
size_t
>
s
)
:
type
_
(
t
),
lens
_
(
std
::
move
(
l
)),
strides
_
(
std
::
move
(
s
))
:
m_
type
(
t
),
m_
lens
(
std
::
move
(
l
)),
m_
strides
(
std
::
move
(
s
))
{
{
assert
(
lens
_
.
size
()
==
strides
_
.
size
());
assert
(
m_
lens
.
size
()
==
m_
strides
.
size
());
packed
_
=
this
->
elements
()
==
this
->
element_space
();
m_
packed
=
this
->
elements
()
==
this
->
element_space
();
}
}
void
shape
::
calculate_strides
()
void
shape
::
calculate_strides
()
{
{
strides
_
.
clear
();
m_
strides
.
clear
();
strides
_
.
resize
(
lens
_
.
size
(),
0
);
m_
strides
.
resize
(
m_
lens
.
size
(),
0
);
if
(
strides
_
.
empty
())
if
(
m_
strides
.
empty
())
return
;
return
;
strides
_
.
back
()
=
1
;
m_
strides
.
back
()
=
1
;
std
::
partial_sum
(
std
::
partial_sum
(
lens
_
.
rbegin
(),
lens
_
.
rend
()
-
1
,
strides
_
.
rbegin
()
+
1
,
std
::
multiplies
<
std
::
size_t
>
());
m_
lens
.
rbegin
(),
m_
lens
.
rend
()
-
1
,
m_
strides
.
rbegin
()
+
1
,
std
::
multiplies
<
std
::
size_t
>
());
}
}
shape
::
type_t
shape
::
type
()
const
{
return
this
->
type
_
;
}
shape
::
type_t
shape
::
type
()
const
{
return
this
->
m_
type
;
}
const
std
::
vector
<
std
::
size_t
>&
shape
::
lens
()
const
{
return
this
->
lens
_
;
}
const
std
::
vector
<
std
::
size_t
>&
shape
::
lens
()
const
{
return
this
->
m_
lens
;
}
const
std
::
vector
<
std
::
size_t
>&
shape
::
strides
()
const
{
return
this
->
strides
_
;
}
const
std
::
vector
<
std
::
size_t
>&
shape
::
strides
()
const
{
return
this
->
m_
strides
;
}
std
::
size_t
shape
::
elements
()
const
std
::
size_t
shape
::
elements
()
const
{
{
assert
(
this
->
lens
().
size
()
==
this
->
strides
().
size
());
assert
(
this
->
lens
().
size
()
==
this
->
strides
().
size
());
...
@@ -71,7 +71,7 @@ std::size_t shape::index(std::size_t i) const
...
@@ -71,7 +71,7 @@ std::size_t shape::index(std::size_t i) const
std
::
plus
<
std
::
size_t
>
{},
std
::
plus
<
std
::
size_t
>
{},
[
&
](
std
::
size_t
len
,
std
::
size_t
stride
)
{
return
((
i
/
stride
)
%
len
)
*
stride
;
});
[
&
](
std
::
size_t
len
,
std
::
size_t
stride
)
{
return
((
i
/
stride
)
%
len
)
*
stride
;
});
}
}
bool
shape
::
packed
()
const
{
return
this
->
packed
_
;
}
bool
shape
::
packed
()
const
{
return
this
->
m_
packed
;
}
std
::
size_t
shape
::
element_space
()
const
std
::
size_t
shape
::
element_space
()
const
{
{
// TODO: Get rid of intermediate vector
// TODO: Get rid of intermediate vector
...
@@ -89,7 +89,7 @@ std::size_t shape::element_space() const
...
@@ -89,7 +89,7 @@ std::size_t shape::element_space() const
std
::
string
shape
::
type_string
()
const
std
::
string
shape
::
type_string
()
const
{
{
switch
(
this
->
type
_
)
switch
(
this
->
m_
type
)
{
{
#define RTG_SHAPE_TYPE_STRING_CASE(x, t) \
#define RTG_SHAPE_TYPE_STRING_CASE(x, t) \
case x: return #x;
case x: return #x;
...
...
test/test.hpp
View file @
83aa9ac3
...
@@ -4,8 +4,8 @@
...
@@ -4,8 +4,8 @@
#include <cstdlib>
#include <cstdlib>
#include <iostream>
#include <iostream>
#ifndef GUARD_TEST_TEST_HPP
_
#ifndef GUARD_TEST_TEST_HPP
#define GUARD_TEST_TEST_HPP
_
#define GUARD_TEST_TEST_HPP
inline
void
failed
(
const
char
*
msg
,
const
char
*
file
,
int
line
)
inline
void
failed
(
const
char
*
msg
,
const
char
*
file
,
int
line
)
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment