Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
8724a471
Commit
8724a471
authored
Aug 30, 2018
by
Paul
Browse files
Fix indexing in the shape class
parent
17c6d683
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
86 additions
and
14 deletions
+86
-14
src/shape.cpp
src/shape.cpp
+14
-9
test/shape_test.cpp
test/shape_test.cpp
+72
-5
No files found.
src/shape.cpp
View file @
8724a471
...
...
@@ -116,15 +116,20 @@ std::size_t shape::index(std::size_t i) const
if
(
this
->
standard
())
return
i
;
else
return
std
::
inner_product
(
this
->
lens
().
begin
(),
this
->
lens
().
end
(),
this
->
strides
().
begin
(),
std
::
size_t
{
0
},
std
::
plus
<
std
::
size_t
>
{},
[
&
](
std
::
size_t
len
,
std
::
size_t
stride
)
{
assert
(
stride
>
0
and
len
>
0
);
return
((
i
/
stride
)
%
len
)
*
stride
;
});
{
std
::
size_t
s
=
1
;
std
::
size_t
result
=
0
;
for
(
std
::
size_t
j
=
0
;
j
<
this
->
lens
().
size
();
j
++
)
{
const
std
::
size_t
k
=
this
->
lens
().
size
()
-
j
-
1
;
const
std
::
size_t
stride
=
this
->
strides
()[
k
];
const
std
::
size_t
len
=
this
->
lens
()[
k
];
const
std
::
size_t
idx
=
(
i
%
(
s
*
len
))
/
s
;
result
+=
stride
*
idx
;
s
*=
len
;
}
return
result
;
}
}
bool
shape
::
packed
()
const
{
return
this
->
elements
()
==
this
->
element_space
();
}
...
...
test/shape_test.cpp
View file @
8724a471
...
...
@@ -97,6 +97,72 @@ void test_shape4()
EXPECT
(
s
.
index
(
s
.
elements
()
-
1
)
==
s
.
elements
()
-
1
);
}
void
test_shape42
()
{
migraph
::
shape
s
{
migraph
::
shape
::
float_type
,
{
100
,
32
,
8
,
8
},
{
2048
,
64
,
8
,
1
}};
EXPECT
(
s
.
standard
());
EXPECT
(
s
.
packed
());
EXPECT
(
not
s
.
transposed
());
EXPECT
(
not
s
.
broadcasted
());
EXPECT
(
s
.
type
()
==
migraph
::
shape
::
float_type
);
EXPECT
(
s
.
lens
()[
0
]
==
100
);
EXPECT
(
s
.
lens
()[
1
]
==
32
);
EXPECT
(
s
.
lens
()[
2
]
==
8
);
EXPECT
(
s
.
lens
()[
3
]
==
8
);
EXPECT
(
s
.
strides
()[
0
]
==
s
.
lens
()[
1
]
*
s
.
strides
()[
1
]);
EXPECT
(
s
.
strides
()[
1
]
==
s
.
lens
()[
2
]
*
s
.
strides
()[
2
]);
EXPECT
(
s
.
strides
()[
2
]
==
s
.
lens
()[
3
]
*
s
.
strides
()[
3
]);
EXPECT
(
s
.
strides
()[
3
]
==
1
);
EXPECT
(
s
.
elements
()
==
100
*
32
*
8
*
8
);
EXPECT
(
s
.
bytes
()
==
100
*
32
*
8
*
8
*
sizeof
(
float
));
EXPECT
(
s
.
index
({
0
,
0
,
0
,
0
})
==
0
);
EXPECT
(
s
.
index
({
0
,
0
,
0
,
1
})
==
1
);
EXPECT
(
s
.
index
({
0
,
0
,
0
,
0
})
==
s
.
index
(
0
));
EXPECT
(
s
.
index
({
0
,
0
,
0
,
1
})
==
s
.
index
(
1
));
EXPECT
(
s
.
index
({
0
,
0
,
1
,
0
})
==
s
.
index
(
8
));
EXPECT
(
s
.
index
({
0
,
1
,
0
,
0
})
==
s
.
index
(
8
*
8
));
EXPECT
(
s
.
index
({
1
,
0
,
0
,
0
})
==
s
.
index
(
8
*
8
*
32
));
EXPECT
(
s
.
index
(
0
)
==
0
);
EXPECT
(
s
.
index
(
1
)
==
1
);
EXPECT
(
s
.
index
(
8
)
==
8
);
EXPECT
(
s
.
index
(
8
*
8
)
==
8
*
8
);
EXPECT
(
s
.
index
(
8
*
8
*
32
)
==
8
*
8
*
32
);
EXPECT
(
s
.
index
(
s
.
elements
()
-
1
)
==
s
.
elements
()
-
1
);
}
void
test_shape4_transposed
()
{
migraph
::
shape
s
{
migraph
::
shape
::
float_type
,
{
32
,
100
,
8
,
8
},
{
64
,
2048
,
8
,
1
}};
EXPECT
(
s
.
transposed
());
EXPECT
(
s
.
packed
());
EXPECT
(
not
s
.
standard
());
EXPECT
(
not
s
.
broadcasted
());
EXPECT
(
s
.
type
()
==
migraph
::
shape
::
float_type
);
EXPECT
(
s
.
lens
()[
0
]
==
32
);
EXPECT
(
s
.
lens
()[
1
]
==
100
);
EXPECT
(
s
.
lens
()[
2
]
==
8
);
EXPECT
(
s
.
lens
()[
3
]
==
8
);
EXPECT
(
s
.
strides
()[
0
]
==
64
);
EXPECT
(
s
.
strides
()[
1
]
==
2048
);
EXPECT
(
s
.
strides
()[
2
]
==
8
);
EXPECT
(
s
.
strides
()[
3
]
==
1
);
EXPECT
(
s
.
elements
()
==
100
*
32
*
8
*
8
);
EXPECT
(
s
.
bytes
()
==
100
*
32
*
8
*
8
*
sizeof
(
float
));
EXPECT
(
s
.
index
({
0
,
0
,
0
,
0
})
==
0
);
EXPECT
(
s
.
index
({
0
,
0
,
0
,
1
})
==
1
);
EXPECT
(
s
.
index
({
0
,
0
,
0
,
0
})
==
s
.
index
(
0
));
EXPECT
(
s
.
index
({
0
,
0
,
0
,
1
})
==
s
.
index
(
1
));
EXPECT
(
s
.
index
({
0
,
0
,
1
,
0
})
==
s
.
index
(
8
));
EXPECT
(
s
.
index
({
0
,
1
,
0
,
0
})
==
s
.
index
(
8
*
8
));
EXPECT
(
s
.
index
({
1
,
0
,
0
,
0
})
==
s
.
index
(
8
*
8
*
100
));
EXPECT
(
s
.
index
(
0
)
==
0
);
EXPECT
(
s
.
index
(
1
)
==
1
);
EXPECT
(
s
.
index
(
8
)
==
8
);
EXPECT
(
s
.
index
(
8
*
8
)
==
2048
);
EXPECT
(
s
.
index
(
8
*
8
*
100
)
==
64
);
EXPECT
(
s
.
index
(
s
.
elements
()
-
1
)
==
s
.
elements
()
-
1
);
}
void
test_shape4_nonpacked
()
{
std
::
vector
<
std
::
size_t
>
lens
=
{
100
,
32
,
8
,
8
};
...
...
@@ -134,11 +200,10 @@ void test_shape4_nonpacked()
EXPECT
(
s
.
index
(
1
)
==
1
);
EXPECT
(
s
.
index
({
0
,
0
,
0
,
0
})
==
0
);
EXPECT
(
s
.
index
({
0
,
0
,
0
,
1
})
==
s
.
index
(
1
));
// TODO: Fix these tests
// EXPECT(s.index({0, 0, 1, 0}) == s.index(8));
// EXPECT(s.index({0, 1, 0, 0}) == s.index(8 * 8));
// EXPECT(s.index({1, 0, 0, 0}) == s.index(8 * 8 * 32));
// EXPECT(s.index(s.elements() - 1) == 469273);
EXPECT
(
s
.
index
({
0
,
0
,
1
,
0
})
==
s
.
index
(
8
));
EXPECT
(
s
.
index
({
0
,
1
,
0
,
0
})
==
s
.
index
(
8
*
8
));
EXPECT
(
s
.
index
({
1
,
0
,
0
,
0
})
==
s
.
index
(
8
*
8
*
32
));
EXPECT
(
s
.
index
(
s
.
elements
()
-
1
)
==
469273
);
}
int
main
()
...
...
@@ -151,5 +216,7 @@ int main()
test_shape_broadcasted
();
test_shape_default_copy
();
test_shape4
();
test_shape42
();
test_shape4_transposed
();
test_shape4_nonpacked
();
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment