Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
SparseConvNet
Commits
d796a754
Commit
d796a754
authored
Aug 24, 2017
by
Benjamin Thomas Graham
Browse files
Allow setLocations to include additional sampleIdx column in locations
parent
95b46a86
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
107 additions
and
77 deletions
+107
-77
PyTorch/sparseconvnet/SCN/generic/Geometry/Metadata.cpp
PyTorch/sparseconvnet/SCN/generic/Geometry/Metadata.cpp
+57
-41
PyTorch/sparseconvnet/SCN/generic/SparseConvNet.h
PyTorch/sparseconvnet/SCN/generic/SparseConvNet.h
+1
-1
PyTorch/sparseconvnet/legacy/inputBatch.py
PyTorch/sparseconvnet/legacy/inputBatch.py
+0
-1
examples/hello-world.lua
examples/hello-world.lua
+36
-28
examples/hello-world.py
examples/hello-world.py
+13
-6
No files found.
PyTorch/sparseconvnet/SCN/generic/Geometry/Metadata.cpp
View file @
d796a754
...
...
@@ -23,6 +23,24 @@ extern "C" void scn_D_(batchAddSample)(void **m) {
_m
.
inputSGs
->
resize
(
_m
.
inputSGs
->
size
()
+
1
);
_m
.
inputSG
=
&
_m
.
inputSGs
->
back
();
}
void
scn_D_
(
addPointToSparseGridMapAndFeatures
)(
SparseGridMap
<
Dimension
>
&
mp
,
Point
<
Dimension
>
p
,
uInt
&
nActive
,
long
nPlanes
,
THFloatTensor
*
features
,
float
*
vec
,
bool
overwrite
)
{
auto
iter
=
mp
.
find
(
p
);
if
(
iter
==
mp
.
end
())
{
iter
=
mp
.
insert
(
std
::
make_pair
(
p
,
nActive
++
)).
first
;
THFloatTensor_resize2d
(
features
,
nActive
,
nPlanes
);
std
::
memcpy
(
THFloatTensor_data
(
features
)
+
(
nActive
-
1
)
*
nPlanes
,
vec
,
sizeof
(
float
)
*
nPlanes
);
}
else
if
(
overwrite
)
{
std
::
memcpy
(
THFloatTensor_data
(
features
)
+
iter
->
second
*
nPlanes
,
vec
,
sizeof
(
float
)
*
nPlanes
);
}
}
extern
"C"
void
scn_D_
(
setInputSpatialLocation
)(
void
**
m
,
THFloatTensor
*
features
,
THLongTensor
*
location
,
...
...
@@ -32,57 +50,55 @@ extern "C" void scn_D_(setInputSpatialLocation)(void **m,
auto
p
=
LongTensorToPoint
<
Dimension
>
(
location
);
auto
&
mp
=
_m
.
inputSG
->
mp
;
auto
&
nActive
=
*
_m
.
inputNActive
;
auto
iter
=
mp
.
find
(
p
);
auto
nPlanes
=
vec
->
size
[
0
];
if
(
iter
==
mp
.
end
())
{
iter
=
mp
.
insert
(
std
::
make_pair
(
p
,
nActive
++
)).
first
;
THFloatTensor_resize2d
(
features
,
nActive
,
nPlanes
);
std
::
memcpy
(
THFloatTensor_data
(
features
)
+
(
nActive
-
1
)
*
nPlanes
,
THFloatTensor_data
(
vec
),
sizeof
(
float
)
*
nPlanes
);
}
else
if
(
overwrite
)
{
std
::
memcpy
(
THFloatTensor_data
(
features
)
+
iter
->
second
*
nPlanes
,
THFloatTensor_data
(
vec
),
sizeof
(
float
)
*
nPlanes
);
}
scn_D_
(
addPointToSparseGridMapAndFeatures
)(
mp
,
p
,
nActive
,
nPlanes
,
features
,
THFloatTensor_data
(
vec
),
overwrite
);
}
extern
"C"
void
scn_D_
(
setInputSpatialLocations
)(
void
**
m
,
THFloatTensor
*
features
,
THLongTensor
*
locations
,
THFloatTensor
*
vecs
,
bool
overwrite
)
{
assert
(
locations
->
size
[
0
]
==
vecs
->
size
[
0
]
&&
"Location and vec length must be identical!"
);
assert
(
locations
->
size
[
0
]
==
vecs
->
size
[
0
]
and
"Location.size(0) and vecs.size(0) must be equal!"
);
assert
((
locations
->
size
[
1
]
==
Dimension
or
locations
->
size
[
1
]
==
1
+
Dimension
)
and
"locations.size(0) must be either Dimension or Dimension+1"
);
SCN_INITIALIZE_AND_REFERENCE
(
Metadata
<
Dimension
>
,
m
)
auto
&
mp
=
_m
.
inputSG
->
mp
;
auto
&
nActive
=
*
_m
.
inputNActive
;
auto
nSamples
=
locations
->
size
[
0
];
auto
isMpEmpty
=
mp
.
empty
();
if
(
isMpEmpty
)
{
Point
<
Dimension
>
p
;
auto
&
nActive
=
*
_m
.
inputNActive
;
auto
nPlanes
=
vecs
->
size
[
1
];
auto
l
=
THLongTensor_data
(
locations
);
auto
v
=
THFloatTensor_data
(
vecs
);
THFloatTensor_resize2d
(
features
,
nSamples
,
nPlanes
);
std
::
memcpy
(
THFloatTensor_data
(
features
),
THFloatTensor_data
(
vecs
),
sizeof
(
float
)
*
nSamples
*
nPlanes
);
mp
.
resize
(
nSamples
);
if
(
locations
->
size
[
1
]
==
Dimension
)
{
assert
(
_m
.
inputSG
);
// add points to current sample
auto
&
mp
=
_m
.
inputSG
->
mp
;
for
(
uInt
idx
=
0
;
idx
<
locations
->
size
[
0
];
++
idx
)
{
for
(
int
d
=
0
;
d
<
Dimension
;
++
d
)
p
[
d
]
=
*
l
++
;
scn_D_
(
addPointToSparseGridMapAndFeatures
)(
mp
,
p
,
nActive
,
nPlanes
,
features
,
v
,
overwrite
);
v
+=
nPlanes
;
}
for
(
unsigned
int
i
=
0
;
i
<
nSamples
;
++
i
)
{
THLongTensor
*
location
=
THLongTensor_newSelect
(
locations
,
0
,
i
);
THFloatTensor
*
vec
=
THFloatTensor_newSelect
(
vecs
,
0
,
i
);
if
(
isMpEmpty
)
{
auto
p
=
LongTensorToPoint
<
Dimension
>
(
location
);
mp
.
insert
(
std
::
make_pair
(
p
,
nActive
++
));
}
else
{
scn_D_
(
setInputSpatialLocation
)(
m
,
features
,
location
,
vec
,
overwrite
);
}
THLongTensor_free
(
location
);
THFloatTensor_free
(
vec
);
if
(
locations
->
size
[
1
]
==
Dimension
+
1
)
{
// add new samples to batch as necessary
auto
SGs
=
*
_m
.
inputSGs
;
for
(
uInt
idx
=
0
;
idx
<
locations
->
size
[
0
];
++
idx
)
{
for
(
int
d
=
0
;
d
<
Dimension
;
++
d
)
p
[
d
]
=
*
l
++
;
auto
batch
=
*
l
++
;
if
(
batch
>=
SGs
.
size
())
{
SGs
.
resize
(
batch
+
1
);
}
auto
&
mp
=
SGs
[
batch
].
mp
;
scn_D_
(
addPointToSparseGridMapAndFeatures
)(
mp
,
p
,
nActive
,
nPlanes
,
features
,
v
,
overwrite
);
v
+=
nPlanes
;
}
}
}
...
...
PyTorch/sparseconvnet/SCN/generic/SparseConvNet.h
View file @
d796a754
...
...
@@ -32,7 +32,7 @@
template
<
uInt
dimension
>
using
SparseGridMap
=
google
::
dense_hash_map
<
Point
<
dimension
>
,
i
nt
,
IntArrayHash
<
dimension
>
,
google
::
dense_hash_map
<
Point
<
dimension
>
,
uI
nt
,
IntArrayHash
<
dimension
>
,
std
::
equal_to
<
Point
<
dimension
>>>
;
template
<
uInt
dimension
>
class
SparseGrid
{
...
...
PyTorch/sparseconvnet/legacy/inputBatch.py
View file @
d796a754
...
...
@@ -34,7 +34,6 @@ class InputBatch(SparseConvNetTensor):
def
setLocations
(
self
,
locations
,
vectors
,
overwrite
=
False
):
assert
locations
.
min
()
>=
0
and
(
self
.
spatial_size
.
expand_as
(
locations
)
-
locations
).
min
()
>
0
dim_fn
(
self
.
dimension
,
'setInputSpatialLocations'
)(
self
.
metadata
.
ffi
,
self
.
features
,
locations
,
vectors
,
overwrite
)
...
...
examples/hello-world.lua
View file @
d796a754
...
...
@@ -10,7 +10,7 @@ tensorType = scn.cutorch and 'torch.CudaTensor' or 'torch.FloatTensor'
model
=
scn
.
Sequential
()
:
add
(
scn
.
SparseVggNet
(
2
,
1
,{
--dimension 2, 1 input plane
:
add
(
scn
.
SparseVggNet
(
2
,
1
,{
--dimension 2, 1 input plane
{
'C'
,
8
},
-- 3x3 VSC convolution, 8 output planes, batchnorm, ReLU
{
'C'
,
8
},
-- and another
{
'MP'
,
3
,
2
},
--max pooling, size 3, stride 2
...
...
@@ -20,16 +20,16 @@ model = scn.Sequential()
{
'C'
,
24
},
{
'C'
,
24
},
{
'MP'
,
3
,
2
}}))
:
add
(
scn
.
Convolution
(
2
,
24
,
32
,
3
,
1
,
false
))
--an SC convolution on top
:
add
(
scn
.
BatchNormReLU
(
32
))
:
add
(
scn
.
SparseToDense
(
2
))
:
type
(
tensorType
)
:
add
(
scn
.
Convolution
(
2
,
24
,
32
,
3
,
1
,
false
))
--an SC convolution on top
:
add
(
scn
.
BatchNormReLU
(
32
))
:
add
(
scn
.
SparseToDense
(
2
))
:
type
(
tensorType
)
--[[
To use the network we must create an scn.InputBatch with right dimensionality.
If we want the output to have spatial size 10x10, we can find the appropriate
input size, give that we uses three layers of MP3/2 max-pooling, and finish
with a SC convoluton
To use the network we must create an scn.InputBatch with right dimensionality.
If we want the output to have spatial size 10x10, we can find the appropriate
input size, give that we uses three layers of MP3/2 max-pooling, and finish
with a SC convoluton
]]
inputSpatialSize
=
model
:
suggestInputSize
(
torch
.
LongTensor
{
10
,
10
})
--103x103
...
...
@@ -43,12 +43,21 @@ msg={
" O O O O O O O O O O O O O O O O O O "
,
" O O OOO OOO OOO OO O O OO O O OOO OOO "
,
}
input
:
addSample
()
for
y
,
line
in
ipairs
(
msg
)
do
for
x
=
1
,
string.len
(
line
)
do
if
string.sub
(
line
,
x
,
x
)
==
'O'
then
local
location
=
torch
.
LongTensor
{
x
,
y
}
local
featureVector
=
torch
.
FloatTensor
{
1
}
input
:
setLocation
(
location
,
featureVector
,
0
)
end
end
end
--We can also use setLocations
input
:
addSample
()
local
locations
=
{}
local
featureVectors
=
{}
for
y
,
line
in
ipairs
(
msg
)
do
for
x
=
1
,
string.len
(
line
)
do
if
string.sub
(
line
,
x
,
x
)
==
'O'
then
...
...
@@ -57,19 +66,18 @@ for y,line in ipairs(msg) do
end
end
end
input
:
setLocations
(
torch
.
LongTensor
(
locations
),
torch
.
FloatTensor
(
featureVectors
),
0
)
--[[
Optional: allow metadata preprocessing to be done in batch preparation threads
to improve GPU utilization.
Optional: allow metadata preprocessing to be done in batch preparation threads
to improve GPU utilization.
Parameter:
3 if using MP3/2 or size-3 stride-2 convolutions for downsizeing,
2 if using MP2
Parameter:
3 if using MP3/2 or size-3 stride-2 convolutions for downsizeing,
2 if using MP2
]]
input
:
precomputeMetadata
(
3
)
...
...
@@ -78,7 +86,7 @@ input:type(tensorType)
output
=
model
:
forward
(
input
)
--[[
Output is
1
x32x10x10: our minibatch has
1
sample, the network has 32 output
feature planes, and 10x10 is the spatial size of the output.
Output is
2
x32x10x10: our minibatch has
2
sample
s
, the network has 32 output
feature planes, and 10x10 is the spatial size of the output.
]]
print
(
output
:
size
(),
output
:
type
())
examples/hello-world.py
View file @
d796a754
...
...
@@ -34,20 +34,27 @@ msg = [
" XXXXX XX X X X X X X X X X XXX X X X "
,
" X X X X X X X X X X X X X X X X X X "
,
" X X XXX XXX XXX XX X X XX X X XXX XXX "
]
#Add a sample using setLocation
input
.
addSample
()
for
y
,
line
in
enumerate
(
msg
):
for
x
,
c
in
enumerate
(
line
):
if
c
==
'X'
:
location
=
torch
.
LongTensor
([
x
,
y
])
featureVector
=
torch
.
FloatTensor
([
1
])
input
.
setLocation
(
location
,
featureVector
,
0
)
#Add a sample using setLocations
input
.
addSample
()
locations
=
[]
features
=
[]
for
y
,
line
in
enumerate
(
msg
):
for
x
,
c
in
enumerate
(
line
):
if
c
==
'X'
:
locations
.
append
([
x
,
y
])
features
.
append
([
1
])
locations
=
torch
.
LongTensor
(
locations
)
features
=
torch
.
FloatTensor
(
features
)
input
.
setLocations
(
locations
,
features
,
0
)
# Optional: allow metadata preprocessing to be done in batch preparation threads
...
...
@@ -62,6 +69,6 @@ model.evaluate()
input
.
type
(
dtype
)
output
=
model
.
forward
(
input
)
# Output is
1
x32x10x10: our minibatch has
1
sample, the network has 32 output
# Output is
2
x32x10x10: our minibatch has
2
sample
s
, the network has 32 output
# feature planes, and 10x10 is the spatial size of the output.
print
(
output
.
size
(),
output
.
type
())
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment