Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
SparseConvNet
Commits
fdee4456
Commit
fdee4456
authored
Sep 12, 2017
by
Ed Ng
Committed by
GitHub
Sep 12, 2017
Browse files
Merge branch 'master' into get_locations
parents
94a39536
fc961107
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
116 additions
and
84 deletions
+116
-84
PyTorch/sparseconvnet/SCN/generic/Geometry/Metadata.cpp
PyTorch/sparseconvnet/SCN/generic/Geometry/Metadata.cpp
+57
-41
PyTorch/sparseconvnet/SCN/generic/SparseConvNet.h
PyTorch/sparseconvnet/SCN/generic/SparseConvNet.h
+1
-1
PyTorch/sparseconvnet/legacy/inputBatch.py
PyTorch/sparseconvnet/legacy/inputBatch.py
+2
-2
Torch/InputBatch.lua
Torch/InputBatch.lua
+7
-6
examples/hello-world.lua
examples/hello-world.lua
+36
-28
examples/hello-world.py
examples/hello-world.py
+13
-6
No files found.
PyTorch/sparseconvnet/SCN/generic/Geometry/Metadata.cpp
View file @
fdee4456
...
@@ -23,65 +23,81 @@ extern "C" void scn_D_(batchAddSample)(void **m) {
...
@@ -23,65 +23,81 @@ extern "C" void scn_D_(batchAddSample)(void **m) {
_m
.
inputSGs
->
resize
(
_m
.
inputSGs
->
size
()
+
1
);
_m
.
inputSGs
->
resize
(
_m
.
inputSGs
->
size
()
+
1
);
_m
.
inputSG
=
&
_m
.
inputSGs
->
back
();
_m
.
inputSG
=
&
_m
.
inputSGs
->
back
();
}
}
extern
"C"
void
scn_D_
(
setInputSpatialLocation
)(
void
**
m
,
void
scn_D_
(
addPointToSparseGridMapAndFeatures
)(
SparseGridMap
<
Dimension
>
&
mp
,
Point
<
Dimension
>
p
,
uInt
&
nActive
,
long
nPlanes
,
THFloatTensor
*
features
,
THFloatTensor
*
features
,
THLongTensor
*
location
,
float
*
vec
,
bool
overwrite
)
{
THFloatTensor
*
vec
,
bool
overwrite
)
{
SCN_INITIALIZE_AND_REFERENCE
(
Metadata
<
Dimension
>
,
m
)
auto
p
=
LongTensorToPoint
<
Dimension
>
(
location
);
auto
&
mp
=
_m
.
inputSG
->
mp
;
auto
&
nActive
=
*
_m
.
inputNActive
;
auto
iter
=
mp
.
find
(
p
);
auto
iter
=
mp
.
find
(
p
);
auto
nPlanes
=
vec
->
size
[
0
];
if
(
iter
==
mp
.
end
())
{
if
(
iter
==
mp
.
end
())
{
iter
=
mp
.
insert
(
std
::
make_pair
(
p
,
nActive
++
)).
first
;
iter
=
mp
.
insert
(
std
::
make_pair
(
p
,
nActive
++
)).
first
;
THFloatTensor_resize2d
(
features
,
nActive
,
nPlanes
);
THFloatTensor_resize2d
(
features
,
nActive
,
nPlanes
);
std
::
memcpy
(
THFloatTensor_data
(
features
)
+
(
nActive
-
1
)
*
nPlanes
,
std
::
memcpy
(
THFloatTensor_data
(
features
)
+
(
nActive
-
1
)
*
nPlanes
,
vec
,
THFloatTensor_data
(
vec
),
sizeof
(
float
)
*
nPlanes
);
sizeof
(
float
)
*
nPlanes
);
}
else
if
(
overwrite
)
{
}
else
if
(
overwrite
)
{
std
::
memcpy
(
THFloatTensor_data
(
features
)
+
iter
->
second
*
nPlanes
,
std
::
memcpy
(
THFloatTensor_data
(
features
)
+
iter
->
second
*
nPlanes
,
vec
,
THFloatTensor_data
(
vec
),
sizeof
(
float
)
*
nPlanes
);
sizeof
(
float
)
*
nPlanes
);
}
}
}
}
extern
"C"
void
scn_D_
(
setInputSpatialLocation
s
)(
void
**
m
,
extern
"C"
void
scn_D_
(
setInputSpatialLocation
)(
void
**
m
,
THFloatTensor
*
features
,
THFloatTensor
*
features
,
THLongTensor
*
location
s
,
THLongTensor
*
location
,
THFloatTensor
*
vec
s
,
THFloatTensor
*
vec
,
bool
overwrite
)
{
bool
overwrite
)
{
assert
(
locations
->
size
[
0
]
==
vecs
->
size
[
0
]
&&
"Location and vec length must be identical!"
);
SCN_INITIALIZE_AND_REFERENCE
(
Metadata
<
Dimension
>
,
m
)
SCN_INITIALIZE_AND_REFERENCE
(
Metadata
<
Dimension
>
,
m
)
auto
p
=
LongTensorToPoint
<
Dimension
>
(
location
);
auto
&
mp
=
_m
.
inputSG
->
mp
;
auto
&
mp
=
_m
.
inputSG
->
mp
;
auto
&
nActive
=
*
_m
.
inputNActive
;
auto
&
nActive
=
*
_m
.
inputNActive
;
auto
nSamples
=
locations
->
size
[
0
];
auto
nPlanes
=
vec
->
size
[
0
];
auto
isMpEmpty
=
mp
.
empty
();
scn_D_
(
addPointToSparseGridMapAndFeatures
)(
mp
,
p
,
nActive
,
nPlanes
,
features
,
THFloatTensor_data
(
vec
),
overwrite
);
}
extern
"C"
void
scn_D_
(
setInputSpatialLocations
)(
void
**
m
,
THFloatTensor
*
features
,
THLongTensor
*
locations
,
THFloatTensor
*
vecs
,
bool
overwrite
)
{
assert
(
locations
->
size
[
0
]
==
vecs
->
size
[
0
]
and
"Location.size(0) and vecs.size(0) must be equal!"
);
assert
((
locations
->
size
[
1
]
==
Dimension
or
locations
->
size
[
1
]
==
1
+
Dimension
)
and
"locations.size(0) must be either Dimension or Dimension+1"
);
if
(
isMpEmpty
)
{
SCN_INITIALIZE_AND_REFERENCE
(
Metadata
<
Dimension
>
,
m
)
auto
nPlanes
=
vecs
->
size
[
1
];
THFloatTensor_resize2d
(
features
,
nSamples
,
nPlanes
);
Point
<
Dimension
>
p
;
std
::
memcpy
(
THFloatTensor_data
(
features
),
auto
&
nActive
=
*
_m
.
inputNActive
;
THFloatTensor_data
(
vecs
),
sizeof
(
float
)
*
nSamples
*
nPlanes
);
auto
nPlanes
=
vecs
->
size
[
1
];
auto
l
=
THLongTensor_data
(
locations
);
auto
v
=
THFloatTensor_data
(
vecs
);
mp
.
resize
(
nSamples
);
if
(
locations
->
size
[
1
]
==
Dimension
)
{
assert
(
_m
.
inputSG
);
// add points to current sample
auto
&
mp
=
_m
.
inputSG
->
mp
;
for
(
uInt
idx
=
0
;
idx
<
locations
->
size
[
0
];
++
idx
)
{
for
(
int
d
=
0
;
d
<
Dimension
;
++
d
)
p
[
d
]
=
*
l
++
;
scn_D_
(
addPointToSparseGridMapAndFeatures
)(
mp
,
p
,
nActive
,
nPlanes
,
features
,
v
,
overwrite
);
v
+=
nPlanes
;
}
}
}
if
(
locations
->
size
[
1
]
==
Dimension
+
1
)
{
for
(
unsigned
int
i
=
0
;
i
<
nSamples
;
++
i
)
{
// add new samples to batch as necessary
THLongTensor
*
location
=
THLongTensor_newSelect
(
locations
,
0
,
i
);
auto
&
SGs
=
*
_m
.
inputSGs
;
THFloatTensor
*
vec
=
THFloatTensor_newSelect
(
vecs
,
0
,
i
);
for
(
uInt
idx
=
0
;
idx
<
locations
->
size
[
0
];
++
idx
)
{
for
(
int
d
=
0
;
d
<
Dimension
;
++
d
)
if
(
isMpEmpty
)
{
p
[
d
]
=
*
l
++
;
auto
p
=
LongTensorToPoint
<
Dimension
>
(
location
);
auto
batch
=
*
l
++
;
mp
.
insert
(
std
::
make_pair
(
p
,
nActive
++
));
if
(
batch
>=
SGs
.
size
())
{
}
else
{
SGs
.
resize
(
batch
+
1
);
scn_D_
(
setInputSpatialLocation
)(
m
,
features
,
location
,
vec
,
overwrite
);
}
auto
&
mp
=
SGs
[
batch
].
mp
;
scn_D_
(
addPointToSparseGridMapAndFeatures
)(
mp
,
p
,
nActive
,
nPlanes
,
features
,
v
,
overwrite
);
v
+=
nPlanes
;
}
}
THLongTensor_free
(
location
);
THFloatTensor_free
(
vec
);
}
}
}
}
extern
"C"
void
scn_D_
(
getSpatialLocations
)(
void
**
m
,
extern
"C"
void
scn_D_
(
getSpatialLocations
)(
void
**
m
,
...
...
PyTorch/sparseconvnet/SCN/generic/SparseConvNet.h
View file @
fdee4456
...
@@ -32,7 +32,7 @@
...
@@ -32,7 +32,7 @@
template
<
uInt
dimension
>
template
<
uInt
dimension
>
using
SparseGridMap
=
using
SparseGridMap
=
google
::
dense_hash_map
<
Point
<
dimension
>
,
i
nt
,
IntArrayHash
<
dimension
>
,
google
::
dense_hash_map
<
Point
<
dimension
>
,
uI
nt
,
IntArrayHash
<
dimension
>
,
std
::
equal_to
<
Point
<
dimension
>>>
;
std
::
equal_to
<
Point
<
dimension
>>>
;
template
<
uInt
dimension
>
class
SparseGrid
{
template
<
uInt
dimension
>
class
SparseGrid
{
...
...
PyTorch/sparseconvnet/legacy/inputBatch.py
View file @
fdee4456
...
@@ -33,8 +33,8 @@ class InputBatch(SparseConvNetTensor):
...
@@ -33,8 +33,8 @@ class InputBatch(SparseConvNetTensor):
self
.
metadata
.
ffi
,
self
.
features
,
location
,
vector
,
overwrite
)
self
.
metadata
.
ffi
,
self
.
features
,
location
,
vector
,
overwrite
)
def
setLocations
(
self
,
locations
,
vectors
,
overwrite
=
False
):
def
setLocations
(
self
,
locations
,
vectors
,
overwrite
=
False
):
assert
locations
.
min
()
>=
0
and
(
self
.
spatial_size
.
expand_as
(
locations
)
-
locations
).
min
()
>
0
l
=
locations
.
narrow
(
1
,
0
,
self
.
dimension
)
assert
l
.
min
()
>=
0
and
(
self
.
spatial_size
.
expand_as
(
l
)
-
l
).
min
()
>
0
dim_fn
(
self
.
dimension
,
'setInputSpatialLocations'
)(
dim_fn
(
self
.
dimension
,
'setInputSpatialLocations'
)(
self
.
metadata
.
ffi
,
self
.
features
,
locations
,
vectors
,
overwrite
)
self
.
metadata
.
ffi
,
self
.
features
,
locations
,
vectors
,
overwrite
)
...
...
Torch/InputBatch.lua
View file @
fdee4456
...
@@ -15,7 +15,7 @@ return function(sparseconvnet)
...
@@ -15,7 +15,7 @@ return function(sparseconvnet)
self
.
spatialSize
=
type
(
spatialSize
)
==
'number'
and
torch
.
LongTensor
(
self
.
spatialSize
=
type
(
spatialSize
)
==
'number'
and
torch
.
LongTensor
(
dimension
):
fill
(
spatialSize
)
or
spatialSize
dimension
):
fill
(
spatialSize
)
or
spatialSize
C
.
dimensionFn
(
self
.
dimension
,
'setInputSpatialSize'
)(
self
.
metadata
.
ffi
,
C
.
dimensionFn
(
self
.
dimension
,
'setInputSpatialSize'
)(
self
.
metadata
.
ffi
,
self
.
spatialSize
:
cdata
())
self
.
spatialSize
:
cdata
())
end
end
function
InputBatch
:
addSample
()
function
InputBatch
:
addSample
()
C
.
dimensionFn
(
self
.
dimension
,
'batchAddSample'
)(
self
.
metadata
.
ffi
)
C
.
dimensionFn
(
self
.
dimension
,
'batchAddSample'
)(
self
.
metadata
.
ffi
)
...
@@ -28,7 +28,7 @@ return function(sparseconvnet)
...
@@ -28,7 +28,7 @@ return function(sparseconvnet)
end
end
function
InputBatch
:
setLocation
(
location
,
vector
,
overwrite
)
function
InputBatch
:
setLocation
(
location
,
vector
,
overwrite
)
--[[location is a self.dimensional length set of coordinates:
--[[location is a self.dimensional length set of coordinates:
torch.LongStorage or a table]]
torch.LongStorage or a table]]
if
type
(
location
)
==
'table'
then
if
type
(
location
)
==
'table'
then
local
l
=
torch
.
LongStorage
(
self
.
dimension
)
local
l
=
torch
.
LongStorage
(
self
.
dimension
)
for
i
,
x
in
ipairs
(
location
)
do
for
i
,
x
in
ipairs
(
location
)
do
...
@@ -38,19 +38,20 @@ return function(sparseconvnet)
...
@@ -38,19 +38,20 @@ return function(sparseconvnet)
end
end
assert
(
location
:
min
()
>=
0
and
(
self
.
spatialSize
-
location
):
min
()
>
0
)
assert
(
location
:
min
()
>=
0
and
(
self
.
spatialSize
-
location
):
min
()
>
0
)
C
.
dimensionFn
(
self
.
dimension
,
'setInputSpatialLocation'
)(
self
.
metadata
.
ffi
,
C
.
dimensionFn
(
self
.
dimension
,
'setInputSpatialLocation'
)(
self
.
metadata
.
ffi
,
self
.
features
:
cdata
(),
location
:
cdata
(),
vector
:
cdata
(),
overwrite
)
self
.
features
:
cdata
(),
location
:
cdata
(),
vector
:
cdata
(),
overwrite
)
end
end
function
InputBatch
:
setLocations
(
locations
,
vectors
,
overwrite
)
function
InputBatch
:
setLocations
(
locations
,
vectors
,
overwrite
)
--[[locations is a n_locations x self.dimensional length set of coordinates:
--[[locations is a n_locations x self.dimensional length set of coordinates:
torch.LongStorage or a 2-D table]]
torch.LongStorage or a 2-D table]]
if
type
(
locations
)
==
'table'
then
if
type
(
locations
)
==
'table'
then
locations
=
torch
.
LongStorage
(
locations
)
locations
=
torch
.
LongStorage
(
locations
)
end
end
assert
(
locations
:
min
()
>=
0
and
(
self
.
spatialSize
:
view
(
1
,
self
.
dimension
):
expandAs
(
locations
)
-
locations
):
min
()
>
0
)
local
l
=
locations
:
narrow
(
2
,
1
,
self
.
dimension
)
assert
(
l
:
min
()
>=
0
and
(
self
.
spatialSize
:
view
(
1
,
self
.
dimension
):
expandAs
(
l
)
-
l
):
min
()
>
0
)
C
.
dimensionFn
(
self
.
dimension
,
'setInputSpatialLocations'
)(
self
.
metadata
.
ffi
,
C
.
dimensionFn
(
self
.
dimension
,
'setInputSpatialLocations'
)(
self
.
metadata
.
ffi
,
self
.
features
:
cdata
(),
locations
:
cdata
(),
vectors
:
cdata
(),
overwrite
)
self
.
features
:
cdata
(),
locations
:
cdata
(),
vectors
:
cdata
(),
overwrite
)
end
end
function
InputBatch
:
precomputeMetadata
(
stride
)
function
InputBatch
:
precomputeMetadata
(
stride
)
if
stride
==
2
then
if
stride
==
2
then
...
...
examples/hello-world.lua
View file @
fdee4456
...
@@ -10,26 +10,26 @@ tensorType = scn.cutorch and 'torch.CudaTensor' or 'torch.FloatTensor'
...
@@ -10,26 +10,26 @@ tensorType = scn.cutorch and 'torch.CudaTensor' or 'torch.FloatTensor'
model
=
scn
.
Sequential
()
model
=
scn
.
Sequential
()
:
add
(
scn
.
SparseVggNet
(
2
,
1
,{
--dimension 2, 1 input plane
:
add
(
scn
.
SparseVggNet
(
2
,
1
,{
--dimension 2, 1 input plane
{
'C'
,
8
},
-- 3x3 VSC convolution, 8 output planes, batchnorm, ReLU
{
'C'
,
8
},
-- 3x3 VSC convolution, 8 output planes, batchnorm, ReLU
{
'C'
,
8
},
-- and another
{
'C'
,
8
},
-- and another
{
'MP'
,
3
,
2
},
--max pooling, size 3, stride 2
{
'MP'
,
3
,
2
},
--max pooling, size 3, stride 2
{
'C'
,
16
},
-- etc
{
'C'
,
16
},
-- etc
{
'C'
,
16
},
{
'C'
,
16
},
{
'MP'
,
3
,
2
},
{
'MP'
,
3
,
2
},
{
'C'
,
24
},
{
'C'
,
24
},
{
'C'
,
24
},
{
'C'
,
24
},
{
'MP'
,
3
,
2
}}))
{
'MP'
,
3
,
2
}}))
:
add
(
scn
.
Convolution
(
2
,
24
,
32
,
3
,
1
,
false
))
--an SC convolution on top
:
add
(
scn
.
Convolution
(
2
,
24
,
32
,
3
,
1
,
false
))
--an SC convolution on top
:
add
(
scn
.
BatchNormReLU
(
32
))
:
add
(
scn
.
BatchNormReLU
(
32
))
:
add
(
scn
.
SparseToDense
(
2
))
:
add
(
scn
.
SparseToDense
(
2
))
:
type
(
tensorType
)
:
type
(
tensorType
)
--[[
--[[
To use the network we must create an scn.InputBatch with right dimensionality.
To use the network we must create an scn.InputBatch with right dimensionality.
If we want the output to have spatial size 10x10, we can find the appropriate
If we want the output to have spatial size 10x10, we can find the appropriate
input size, give that we uses three layers of MP3/2 max-pooling, and finish
input size, give that we uses three layers of MP3/2 max-pooling, and finish
with a SC convoluton
with a SC convoluton
]]
]]
inputSpatialSize
=
model
:
suggestInputSize
(
torch
.
LongTensor
{
10
,
10
})
--103x103
inputSpatialSize
=
model
:
suggestInputSize
(
torch
.
LongTensor
{
10
,
10
})
--103x103
...
@@ -43,12 +43,21 @@ msg={
...
@@ -43,12 +43,21 @@ msg={
" O O O O O O O O O O O O O O O O O O "
,
" O O O O O O O O O O O O O O O O O O "
,
" O O OOO OOO OOO OO O O OO O O OOO OOO "
,
" O O OOO OOO OOO OO O O OO O O OOO OOO "
,
}
}
input
:
addSample
()
input
:
addSample
()
for
y
,
line
in
ipairs
(
msg
)
do
for
x
=
1
,
string.len
(
line
)
do
if
string.sub
(
line
,
x
,
x
)
==
'O'
then
local
location
=
torch
.
LongTensor
{
x
,
y
}
local
featureVector
=
torch
.
FloatTensor
{
1
}
input
:
setLocation
(
location
,
featureVector
,
0
)
end
end
end
--We can also use setLocations
input
:
addSample
()
local
locations
=
{}
local
locations
=
{}
local
featureVectors
=
{}
local
featureVectors
=
{}
for
y
,
line
in
ipairs
(
msg
)
do
for
y
,
line
in
ipairs
(
msg
)
do
for
x
=
1
,
string.len
(
line
)
do
for
x
=
1
,
string.len
(
line
)
do
if
string.sub
(
line
,
x
,
x
)
==
'O'
then
if
string.sub
(
line
,
x
,
x
)
==
'O'
then
...
@@ -57,19 +66,18 @@ for y,line in ipairs(msg) do
...
@@ -57,19 +66,18 @@ for y,line in ipairs(msg) do
end
end
end
end
end
end
input
:
setLocations
(
input
:
setLocations
(
torch
.
LongTensor
(
locations
),
torch
.
LongTensor
(
locations
),
torch
.
FloatTensor
(
featureVectors
),
torch
.
FloatTensor
(
featureVectors
),
0
)
0
)
--[[
--[[
Optional: allow metadata preprocessing to be done in batch preparation threads
Optional: allow metadata preprocessing to be done in batch preparation threads
to improve GPU utilization.
to improve GPU utilization.
Parameter:
Parameter:
3 if using MP3/2 or size-3 stride-2 convolutions for downsizeing,
3 if using MP3/2 or size-3 stride-2 convolutions for downsizeing,
2 if using MP2
2 if using MP2
]]
]]
input
:
precomputeMetadata
(
3
)
input
:
precomputeMetadata
(
3
)
...
@@ -78,7 +86,7 @@ input:type(tensorType)
...
@@ -78,7 +86,7 @@ input:type(tensorType)
output
=
model
:
forward
(
input
)
output
=
model
:
forward
(
input
)
--[[
--[[
Output is
1
x32x10x10: our minibatch has
1
sample, the network has 32 output
Output is
2
x32x10x10: our minibatch has
2
sample
s
, the network has 32 output
feature planes, and 10x10 is the spatial size of the output.
feature planes, and 10x10 is the spatial size of the output.
]]
]]
print
(
output
:
size
(),
output
:
type
())
print
(
output
:
size
(),
output
:
type
())
examples/hello-world.py
View file @
fdee4456
...
@@ -13,9 +13,9 @@ dtype = 'torch.cuda.FloatTensor' if torch.cuda.is_available() else 'torch.FloatT
...
@@ -13,9 +13,9 @@ dtype = 'torch.cuda.FloatTensor' if torch.cuda.is_available() else 'torch.FloatT
model
=
scn
.
Sequential
().
add
(
model
=
scn
.
Sequential
().
add
(
scn
.
SparseVggNet
(
2
,
1
,
scn
.
SparseVggNet
(
2
,
1
,
[[
'C'
,
8
],
[
'C'
,
8
],
[
'MP'
,
3
,
2
],
[[
'C'
,
8
],
[
'C'
,
8
],
[
'MP'
,
3
,
2
],
[
'C'
,
16
],
[
'C'
,
16
],
[
'MP'
,
3
,
2
],
[
'C'
,
16
],
[
'C'
,
16
],
[
'MP'
,
3
,
2
],
[
'C'
,
24
],
[
'C'
,
24
],
[
'MP'
,
3
,
2
]])
[
'C'
,
24
],
[
'C'
,
24
],
[
'MP'
,
3
,
2
]])
).
add
(
).
add
(
scn
.
ValidConvolution
(
2
,
24
,
32
,
3
,
False
)
scn
.
ValidConvolution
(
2
,
24
,
32
,
3
,
False
)
).
add
(
).
add
(
...
@@ -34,20 +34,27 @@ msg = [
...
@@ -34,20 +34,27 @@ msg = [
" XXXXX XX X X X X X X X X X XXX X X X "
,
" XXXXX XX X X X X X X X X X XXX X X X "
,
" X X X X X X X X X X X X X X X X X X "
,
" X X X X X X X X X X X X X X X X X X "
,
" X X XXX XXX XXX XX X X XX X X XXX XXX "
]
" X X XXX XXX XXX XX X X XX X X XXX XXX "
]
#Add a sample using setLocation
input
.
addSample
()
input
.
addSample
()
for
y
,
line
in
enumerate
(
msg
):
for
x
,
c
in
enumerate
(
line
):
if
c
==
'X'
:
location
=
torch
.
LongTensor
([
x
,
y
])
featureVector
=
torch
.
FloatTensor
([
1
])
input
.
setLocation
(
location
,
featureVector
,
0
)
#Add a sample using setLocations
input
.
addSample
()
locations
=
[]
locations
=
[]
features
=
[]
features
=
[]
for
y
,
line
in
enumerate
(
msg
):
for
y
,
line
in
enumerate
(
msg
):
for
x
,
c
in
enumerate
(
line
):
for
x
,
c
in
enumerate
(
line
):
if
c
==
'X'
:
if
c
==
'X'
:
locations
.
append
([
x
,
y
])
locations
.
append
([
x
,
y
])
features
.
append
([
1
])
features
.
append
([
1
])
locations
=
torch
.
LongTensor
(
locations
)
locations
=
torch
.
LongTensor
(
locations
)
features
=
torch
.
FloatTensor
(
features
)
features
=
torch
.
FloatTensor
(
features
)
input
.
setLocations
(
locations
,
features
,
0
)
input
.
setLocations
(
locations
,
features
,
0
)
# Optional: allow metadata preprocessing to be done in batch preparation threads
# Optional: allow metadata preprocessing to be done in batch preparation threads
...
@@ -62,6 +69,6 @@ model.evaluate()
...
@@ -62,6 +69,6 @@ model.evaluate()
input
.
type
(
dtype
)
input
.
type
(
dtype
)
output
=
model
.
forward
(
input
)
output
=
model
.
forward
(
input
)
# Output is
1
x32x10x10: our minibatch has
1
sample, the network has 32 output
# Output is
2
x32x10x10: our minibatch has
2
sample
s
, the network has 32 output
# feature planes, and 10x10 is the spatial size of the output.
# feature planes, and 10x10 is the spatial size of the output.
print
(
output
.
size
(),
output
.
type
())
print
(
output
.
size
(),
output
.
type
())
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment