Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
SparseConvNet
Commits
db6454cd
Commit
db6454cd
authored
Aug 04, 2017
by
Ed Ng
Browse files
AddLocations API
Introduce a setInputLocations (batch) API for sparse map construction. 0.1.0 => 0.1.1
parent
fcc28e95
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
128 additions
and
12 deletions
+128
-12
.gitignore
.gitignore
+5
-1
PyTorch/setup.py
PyTorch/setup.py
+1
-1
PyTorch/sparseconvnet/SCN/generic/Geometry/Metadata.cpp
PyTorch/sparseconvnet/SCN/generic/Geometry/Metadata.cpp
+42
-0
PyTorch/sparseconvnet/SCN/header_cpu.h
PyTorch/sparseconvnet/SCN/header_cpu.h
+31
-1
PyTorch/sparseconvnet/legacy/inputBatch.py
PyTorch/sparseconvnet/legacy/inputBatch.py
+10
-1
README.md
README.md
+1
-1
Torch/C.lua
Torch/C.lua
+3
-0
Torch/InputBatch.lua
Torch/InputBatch.lua
+12
-0
Torch/sparseconvnet-0.1-1.rockspec
Torch/sparseconvnet-0.1-1.rockspec
+1
-1
examples/hello-world.lua
examples/hello-world.lua
+11
-3
examples/hello-world.py
examples/hello-world.py
+11
-3
No files found.
.gitignore
View file @
db6454cd
...
@@ -6,4 +6,8 @@ t7/
...
@@ -6,4 +6,8 @@ t7/
*.so
*.so
build
build
__pycache__
__pycache__
pickle
pickle
\ No newline at end of file
*.pyc
PyTorch/sparseconvnet.egg-info/
PyTorch/sparseconvnet/SCN/__init__.py
PyTorch/setup.py
View file @
db6454cd
...
@@ -63,7 +63,7 @@ ffi.build()
...
@@ -63,7 +63,7 @@ ffi.build()
from
setuptools
import
setup
,
find_packages
from
setuptools
import
setup
,
find_packages
setup
(
setup
(
name
=
'sparseconvnet'
,
name
=
'sparseconvnet'
,
version
=
'0.1'
,
version
=
'0.1
.1
'
,
description
=
'Submanifold (Spatially) Sparse Convolutional Networks https://arxiv.org/abs/1706.01307'
,
description
=
'Submanifold (Spatially) Sparse Convolutional Networks https://arxiv.org/abs/1706.01307'
,
author
=
'Facebook AI Research'
,
author
=
'Facebook AI Research'
,
author_email
=
'benjamingraham@fb.com'
,
author_email
=
'benjamingraham@fb.com'
,
...
...
PyTorch/sparseconvnet/SCN/generic/Geometry/Metadata.cpp
View file @
db6454cd
...
@@ -44,6 +44,48 @@ extern "C" void scn_D_(setInputSpatialLocation)(void **m,
...
@@ -44,6 +44,48 @@ extern "C" void scn_D_(setInputSpatialLocation)(void **m,
THFloatTensor_data
(
vec
),
sizeof
(
float
)
*
nPlanes
);
THFloatTensor_data
(
vec
),
sizeof
(
float
)
*
nPlanes
);
}
}
}
}
extern
"C"
void
scn_D_
(
setInputSpatialLocations
)(
void
**
m
,
THFloatTensor
*
features
,
THLongTensor
*
locations
,
THFloatTensor
*
vecs
,
bool
overwrite
)
{
assert
(
locations
->
size
[
0
]
==
vecs
->
size
[
0
]
&&
"Location and vec length must be identical!"
);
SCN_INITIALIZE_AND_REFERENCE
(
Metadata
<
Dimension
>
,
m
)
auto
&
mp
=
_m
.
inputSG
->
mp
;
auto
&
nActive
=
*
_m
.
inputNActive
;
auto
nSamples
=
locations
->
size
[
0
];
auto
isMpEmpty
=
mp
.
empty
();
if
(
isMpEmpty
)
{
auto
nPlanes
=
vecs
->
size
[
1
];
THFloatTensor_resize2d
(
features
,
nSamples
,
nPlanes
);
std
::
memcpy
(
THFloatTensor_data
(
features
),
THFloatTensor_data
(
vecs
),
sizeof
(
float
)
*
nSamples
*
nPlanes
);
mp
.
resize
(
nSamples
);
}
for
(
unsigned
int
i
=
0
;
i
<
nSamples
;
++
i
)
{
THLongTensor
*
location
=
THLongTensor_newSelect
(
locations
,
0
,
i
);
THFloatTensor
*
vec
=
THFloatTensor_newSelect
(
vecs
,
0
,
i
);
if
(
isMpEmpty
)
{
auto
p
=
LongTensorToPoint
<
Dimension
>
(
location
);
mp
.
insert
(
std
::
make_pair
(
p
,
nActive
++
));
}
else
{
scn_D_
(
setInputSpatialLocation
)(
m
,
features
,
location
,
vec
,
overwrite
);
}
THLongTensor_free
(
location
);
THFloatTensor_free
(
vec
);
}
}
extern
"C"
void
extern
"C"
void
scn_D_
(
createMetadataForDenseToSparse
)(
void
**
m
,
THLongTensor
*
spatialSize_
,
scn_D_
(
createMetadataForDenseToSparse
)(
void
**
m
,
THLongTensor
*
spatialSize_
,
THLongTensor
*
pad_
,
THLongTensor
*
pad_
,
...
...
PyTorch/sparseconvnet/SCN/header_cpu.h
View file @
db6454cd
...
@@ -25,6 +25,9 @@ void scn_1_setInputSpatialSize(void **m, THLongTensor *spatialSize);
...
@@ -25,6 +25,9 @@ void scn_1_setInputSpatialSize(void **m, THLongTensor *spatialSize);
void
scn_1_setInputSpatialLocation
(
void
**
m
,
THFloatTensor
*
features
,
void
scn_1_setInputSpatialLocation
(
void
**
m
,
THFloatTensor
*
features
,
THLongTensor
*
location
,
THFloatTensor
*
vec
,
THLongTensor
*
location
,
THFloatTensor
*
vec
,
_Bool
overwrite
);
_Bool
overwrite
);
void
scn_1_setInputSpatialLocations
(
void
**
m
,
THFloatTensor
*
features
,
THLongTensor
*
locations
,
THFloatTensor
*
vecs
,
_Bool
overwrite
);
double
scn_2_addSampleFromThresholdedTensor
(
void
**
m
,
THFloatTensor
*
features_
,
double
scn_2_addSampleFromThresholdedTensor
(
void
**
m
,
THFloatTensor
*
features_
,
THFloatTensor
*
tensor_
,
THFloatTensor
*
tensor_
,
THLongTensor
*
offset_
,
THLongTensor
*
offset_
,
...
@@ -39,7 +42,10 @@ void scn_2_generateRuleBooks3s2(void **m);
...
@@ -39,7 +42,10 @@ void scn_2_generateRuleBooks3s2(void **m);
void
scn_2_generateRuleBooks2s2
(
void
**
m
);
void
scn_2_generateRuleBooks2s2
(
void
**
m
);
void
scn_2_setInputSpatialSize
(
void
**
m
,
THLongTensor
*
spatialSize
);
void
scn_2_setInputSpatialSize
(
void
**
m
,
THLongTensor
*
spatialSize
);
void
scn_2_setInputSpatialLocation
(
void
**
m
,
THFloatTensor
*
features
,
void
scn_2_setInputSpatialLocation
(
void
**
m
,
THFloatTensor
*
features
,
THLongTensor
*
location
,
THFloatTensor
*
vec
,
THLongTensor
*
location
,
THFloatTensor
*
vec
,
_Bool
overwrite
);
void
scn_2_setInputSpatialLocations
(
void
**
m
,
THFloatTensor
*
features
,
THLongTensor
*
locations
,
THFloatTensor
*
vecs
,
_Bool
overwrite
);
_Bool
overwrite
);
double
scn_3_addSampleFromThresholdedTensor
(
void
**
m
,
THFloatTensor
*
features_
,
double
scn_3_addSampleFromThresholdedTensor
(
void
**
m
,
THFloatTensor
*
features_
,
THFloatTensor
*
tensor_
,
THFloatTensor
*
tensor_
,
...
@@ -57,6 +63,9 @@ void scn_3_setInputSpatialSize(void **m, THLongTensor *spatialSize);
...
@@ -57,6 +63,9 @@ void scn_3_setInputSpatialSize(void **m, THLongTensor *spatialSize);
void
scn_3_setInputSpatialLocation
(
void
**
m
,
THFloatTensor
*
features
,
void
scn_3_setInputSpatialLocation
(
void
**
m
,
THFloatTensor
*
features
,
THLongTensor
*
location
,
THFloatTensor
*
vec
,
THLongTensor
*
location
,
THFloatTensor
*
vec
,
_Bool
overwrite
);
_Bool
overwrite
);
void
scn_3_setInputSpatialLocations
(
void
**
m
,
THFloatTensor
*
features
,
THLongTensor
*
locations
,
THFloatTensor
*
vecs
,
_Bool
overwrite
);
double
scn_4_addSampleFromThresholdedTensor
(
void
**
m
,
THFloatTensor
*
features_
,
double
scn_4_addSampleFromThresholdedTensor
(
void
**
m
,
THFloatTensor
*
features_
,
THFloatTensor
*
tensor_
,
THFloatTensor
*
tensor_
,
THLongTensor
*
offset_
,
THLongTensor
*
offset_
,
...
@@ -73,6 +82,9 @@ void scn_4_setInputSpatialSize(void **m, THLongTensor *spatialSize);
...
@@ -73,6 +82,9 @@ void scn_4_setInputSpatialSize(void **m, THLongTensor *spatialSize);
void
scn_4_setInputSpatialLocation
(
void
**
m
,
THFloatTensor
*
features
,
void
scn_4_setInputSpatialLocation
(
void
**
m
,
THFloatTensor
*
features
,
THLongTensor
*
location
,
THFloatTensor
*
vec
,
THLongTensor
*
location
,
THFloatTensor
*
vec
,
_Bool
overwrite
);
_Bool
overwrite
);
void
scn_4_setInputSpatialLocations
(
void
**
m
,
THFloatTensor
*
features
,
THLongTensor
*
locations
,
THFloatTensor
*
vecs
,
_Bool
overwrite
);
double
scn_5_addSampleFromThresholdedTensor
(
void
**
m
,
THFloatTensor
*
features_
,
double
scn_5_addSampleFromThresholdedTensor
(
void
**
m
,
THFloatTensor
*
features_
,
THFloatTensor
*
tensor_
,
THFloatTensor
*
tensor_
,
THLongTensor
*
offset_
,
THLongTensor
*
offset_
,
...
@@ -89,6 +101,9 @@ void scn_5_setInputSpatialSize(void **m, THLongTensor *spatialSize);
...
@@ -89,6 +101,9 @@ void scn_5_setInputSpatialSize(void **m, THLongTensor *spatialSize);
void
scn_5_setInputSpatialLocation
(
void
**
m
,
THFloatTensor
*
features
,
void
scn_5_setInputSpatialLocation
(
void
**
m
,
THFloatTensor
*
features
,
THLongTensor
*
location
,
THFloatTensor
*
vec
,
THLongTensor
*
location
,
THFloatTensor
*
vec
,
_Bool
overwrite
);
_Bool
overwrite
);
void
scn_5_setInputSpatialLocations
(
void
**
m
,
THFloatTensor
*
features
,
THLongTensor
*
locations
,
THFloatTensor
*
vecs
,
_Bool
overwrite
);
double
scn_6_addSampleFromThresholdedTensor
(
void
**
m
,
THFloatTensor
*
features_
,
double
scn_6_addSampleFromThresholdedTensor
(
void
**
m
,
THFloatTensor
*
features_
,
THFloatTensor
*
tensor_
,
THFloatTensor
*
tensor_
,
THLongTensor
*
offset_
,
THLongTensor
*
offset_
,
...
@@ -105,6 +120,9 @@ void scn_6_setInputSpatialSize(void **m, THLongTensor *spatialSize);
...
@@ -105,6 +120,9 @@ void scn_6_setInputSpatialSize(void **m, THLongTensor *spatialSize);
void
scn_6_setInputSpatialLocation
(
void
**
m
,
THFloatTensor
*
features
,
void
scn_6_setInputSpatialLocation
(
void
**
m
,
THFloatTensor
*
features
,
THLongTensor
*
location
,
THFloatTensor
*
vec
,
THLongTensor
*
location
,
THFloatTensor
*
vec
,
_Bool
overwrite
);
_Bool
overwrite
);
void
scn_6_setInputSpatialLocations
(
void
**
m
,
THFloatTensor
*
features
,
THLongTensor
*
locations
,
THFloatTensor
*
vecs
,
_Bool
overwrite
);
double
scn_7_addSampleFromThresholdedTensor
(
void
**
m
,
THFloatTensor
*
features_
,
double
scn_7_addSampleFromThresholdedTensor
(
void
**
m
,
THFloatTensor
*
features_
,
THFloatTensor
*
tensor_
,
THFloatTensor
*
tensor_
,
THLongTensor
*
offset_
,
THLongTensor
*
offset_
,
...
@@ -121,6 +139,9 @@ void scn_7_setInputSpatialSize(void **m, THLongTensor *spatialSize);
...
@@ -121,6 +139,9 @@ void scn_7_setInputSpatialSize(void **m, THLongTensor *spatialSize);
void
scn_7_setInputSpatialLocation
(
void
**
m
,
THFloatTensor
*
features
,
void
scn_7_setInputSpatialLocation
(
void
**
m
,
THFloatTensor
*
features
,
THLongTensor
*
location
,
THFloatTensor
*
vec
,
THLongTensor
*
location
,
THFloatTensor
*
vec
,
_Bool
overwrite
);
_Bool
overwrite
);
void
scn_7_setInputSpatialLocations
(
void
**
m
,
THFloatTensor
*
features
,
THLongTensor
*
locations
,
THFloatTensor
*
vecs
,
_Bool
overwrite
);
double
scn_8_addSampleFromThresholdedTensor
(
void
**
m
,
THFloatTensor
*
features_
,
double
scn_8_addSampleFromThresholdedTensor
(
void
**
m
,
THFloatTensor
*
features_
,
THFloatTensor
*
tensor_
,
THFloatTensor
*
tensor_
,
THLongTensor
*
offset_
,
THLongTensor
*
offset_
,
...
@@ -136,6 +157,9 @@ void scn_8_generateRuleBooks2s2(void **m);
...
@@ -136,6 +157,9 @@ void scn_8_generateRuleBooks2s2(void **m);
void
scn_8_setInputSpatialSize
(
void
**
m
,
THLongTensor
*
spatialSize
);
void
scn_8_setInputSpatialSize
(
void
**
m
,
THLongTensor
*
spatialSize
);
void
scn_8_setInputSpatialLocation
(
void
**
m
,
THFloatTensor
*
features
,
void
scn_8_setInputSpatialLocation
(
void
**
m
,
THFloatTensor
*
features
,
THLongTensor
*
location
,
THFloatTensor
*
vec
,
THLongTensor
*
location
,
THFloatTensor
*
vec
,
_Bool
overwrite
);
void
scn_8_setInputSpatialLocations
(
void
**
m
,
THFloatTensor
*
features
,
THLongTensor
*
locations
,
THFloatTensor
*
vecs
,
_Bool
overwrite
);
_Bool
overwrite
);
double
scn_9_addSampleFromThresholdedTensor
(
void
**
m
,
THFloatTensor
*
features_
,
double
scn_9_addSampleFromThresholdedTensor
(
void
**
m
,
THFloatTensor
*
features_
,
THFloatTensor
*
tensor_
,
THFloatTensor
*
tensor_
,
...
@@ -153,6 +177,9 @@ void scn_9_setInputSpatialSize(void **m, THLongTensor *spatialSize);
...
@@ -153,6 +177,9 @@ void scn_9_setInputSpatialSize(void **m, THLongTensor *spatialSize);
void
scn_9_setInputSpatialLocation
(
void
**
m
,
THFloatTensor
*
features
,
void
scn_9_setInputSpatialLocation
(
void
**
m
,
THFloatTensor
*
features
,
THLongTensor
*
location
,
THFloatTensor
*
vec
,
THLongTensor
*
location
,
THFloatTensor
*
vec
,
_Bool
overwrite
);
_Bool
overwrite
);
void
scn_9_setInputSpatialLocations
(
void
**
m
,
THFloatTensor
*
features
,
THLongTensor
*
locations
,
THFloatTensor
*
vecs
,
_Bool
overwrite
);
double
scn_10_addSampleFromThresholdedTensor
(
void
**
m
,
THFloatTensor
*
features_
,
double
scn_10_addSampleFromThresholdedTensor
(
void
**
m
,
THFloatTensor
*
features_
,
THFloatTensor
*
tensor_
,
THFloatTensor
*
tensor_
,
THLongTensor
*
offset_
,
THLongTensor
*
offset_
,
...
@@ -169,6 +196,9 @@ void scn_10_setInputSpatialSize(void **m, THLongTensor *spatialSize);
...
@@ -169,6 +196,9 @@ void scn_10_setInputSpatialSize(void **m, THLongTensor *spatialSize);
void
scn_10_setInputSpatialLocation
(
void
**
m
,
THFloatTensor
*
features
,
void
scn_10_setInputSpatialLocation
(
void
**
m
,
THFloatTensor
*
features
,
THLongTensor
*
location
,
THFloatTensor
*
vec
,
THLongTensor
*
location
,
THFloatTensor
*
vec
,
_Bool
overwrite
);
_Bool
overwrite
);
void
scn_10_setInputSpatialLocations
(
void
**
m
,
THFloatTensor
*
features
,
THLongTensor
*
locations
,
THFloatTensor
*
vecs
,
_Bool
overwrite
);
void
scn_cpu_float_AffineReluTrivialConvolution_updateOutput
(
void
scn_cpu_float_AffineReluTrivialConvolution_updateOutput
(
THFloatTensor
*
input_features
,
THFloatTensor
*
output_features
,
THFloatTensor
*
input_features
,
THFloatTensor
*
output_features
,
THFloatTensor
*
affineWeight
,
THFloatTensor
*
affineBias
,
THFloatTensor
*
affineWeight
,
THFloatTensor
*
affineBias
,
...
...
PyTorch/sparseconvnet/legacy/inputBatch.py
View file @
db6454cd
...
@@ -9,7 +9,6 @@ from .metadata import Metadata
...
@@ -9,7 +9,6 @@ from .metadata import Metadata
from
..utils
import
toLongTensor
,
dim_fn
from
..utils
import
toLongTensor
,
dim_fn
from
.sparseConvNetTensor
import
SparseConvNetTensor
from
.sparseConvNetTensor
import
SparseConvNetTensor
class
InputBatch
(
SparseConvNetTensor
):
class
InputBatch
(
SparseConvNetTensor
):
def
__init__
(
self
,
dimension
,
spatial_size
):
def
__init__
(
self
,
dimension
,
spatial_size
):
self
.
dimension
=
dimension
self
.
dimension
=
dimension
...
@@ -33,6 +32,16 @@ class InputBatch(SparseConvNetTensor):
...
@@ -33,6 +32,16 @@ class InputBatch(SparseConvNetTensor):
dim_fn
(
self
.
dimension
,
'setInputSpatialLocation'
)(
dim_fn
(
self
.
dimension
,
'setInputSpatialLocation'
)(
self
.
metadata
.
ffi
,
self
.
features
,
location
,
vector
,
overwrite
)
self
.
metadata
.
ffi
,
self
.
features
,
location
,
vector
,
overwrite
)
def
setLocations
(
self
,
locations
,
vectors
,
overwrite
=
False
):
assert
locations
.
min
()
>=
0
and
(
self
.
spatial_size
.
expand_as
(
locations
)
-
locations
).
min
()
>
0
dim_fn
(
self
.
dimension
,
'setInputSpatialLocations'
)(
self
.
metadata
.
ffi
,
self
.
features
,
locations
,
vectors
,
overwrite
)
def
setLocations_
(
self
,
locations
,
vector
,
overwrite
=
False
):
dim_fn
(
self
.
dimension
,
'setInputSpatialLocations'
)(
self
.
metadata
.
ffi
,
self
.
features
,
locations
,
vectors
,
overwrite
)
def
addSampleFromTensor
(
self
,
tensor
,
offset
,
threshold
=
0
):
def
addSampleFromTensor
(
self
,
tensor
,
offset
,
threshold
=
0
):
self
.
nActive
=
dim_fn
(
self
.
nActive
=
dim_fn
(
self
.
dimension
,
self
.
dimension
,
...
...
README.md
View file @
db6454cd
...
@@ -203,7 +203,7 @@ git clone git@github.com:facebookresearch/SparseConvNet.git
...
@@ -203,7 +203,7 @@ git clone git@github.com:facebookresearch/SparseConvNet.git
then
then
cd SparseConvNet/Torch/
cd SparseConvNet/Torch/
luarocks make sparseconvnet-0.1-
0
.rockspec
luarocks make sparseconvnet-0.1-
1
.rockspec
and/or
and/or
...
...
Torch/C.lua
View file @
db6454cd
...
@@ -49,6 +49,9 @@ return function (sparseconvnet)
...
@@ -49,6 +49,9 @@ return function (sparseconvnet)
void scn_DIMENSION_setInputSpatialLocation(void **m,
void scn_DIMENSION_setInputSpatialLocation(void **m,
THFloatTensor *features, THLongTensor *location, THFloatTensor *vec,
THFloatTensor *features, THLongTensor *location, THFloatTensor *vec,
bool overwrite);
bool overwrite);
void scn_DIMENSION_setInputSpatialLocations(void **m,
THFloatTensor *features, THLongTensor *locations, THFloatTensor *vecs,
bool overwrite);
]]
]]
for
DIMENSION
=
1
,
10
do
for
DIMENSION
=
1
,
10
do
...
...
Torch/InputBatch.lua
View file @
db6454cd
...
@@ -40,6 +40,18 @@ return function(sparseconvnet)
...
@@ -40,6 +40,18 @@ return function(sparseconvnet)
C
.
dimensionFn
(
self
.
dimension
,
'setInputSpatialLocation'
)(
self
.
metadata
.
ffi
,
C
.
dimensionFn
(
self
.
dimension
,
'setInputSpatialLocation'
)(
self
.
metadata
.
ffi
,
self
.
features
:
cdata
(),
location
:
cdata
(),
vector
:
cdata
(),
overwrite
)
self
.
features
:
cdata
(),
location
:
cdata
(),
vector
:
cdata
(),
overwrite
)
end
end
function
InputBatch
:
setLocations
(
locations
,
vectors
,
overwrite
)
--[[locations is a n_locations x self.dimensional length set of coordinates:
torch.LongStorage or a 2-D table]]
if
type
(
locations
)
==
'table'
then
locations
=
torch
.
LongStorage
(
locations
)
end
assert
(
locations
:
min
()
>=
0
and
(
self
.
spatialSize
:
view
(
1
,
self
.
dimension
):
expandAs
(
locations
)
-
locations
):
min
()
>
0
)
C
.
dimensionFn
(
self
.
dimension
,
'setInputSpatialLocations'
)(
self
.
metadata
.
ffi
,
self
.
features
:
cdata
(),
locations
:
cdata
(),
vectors
:
cdata
(),
overwrite
)
end
function
InputBatch
:
precomputeMetadata
(
stride
)
function
InputBatch
:
precomputeMetadata
(
stride
)
if
stride
==
2
then
if
stride
==
2
then
C
.
dimensionFn
(
self
.
dimension
,
'generateRuleBooks2s2'
)(
self
.
metadata
.
ffi
)
C
.
dimensionFn
(
self
.
dimension
,
'generateRuleBooks2s2'
)(
self
.
metadata
.
ffi
)
...
...
Torch/sparseconvnet-0.1-
0
.rockspec
→
Torch/sparseconvnet-0.1-
1
.rockspec
View file @
db6454cd
...
@@ -5,7 +5,7 @@
...
@@ -5,7 +5,7 @@
--
LICENSE
file
in
the
root
directory
of
this
source
tree
.
--
LICENSE
file
in
the
root
directory
of
this
source
tree
.
package
=
"sparseconvnet"
package
=
"sparseconvnet"
version
=
"0.1-
0
"
version
=
"0.1-
1
"
source
=
{
source
=
{
url
=
""
,
url
=
""
,
...
...
examples/hello-world.lua
View file @
db6454cd
...
@@ -45,16 +45,24 @@ msg={
...
@@ -45,16 +45,24 @@ msg={
}
}
input
:
addSample
()
input
:
addSample
()
local
locations
=
{}
local
featureVectors
=
{}
for
y
,
line
in
ipairs
(
msg
)
do
for
y
,
line
in
ipairs
(
msg
)
do
for
x
=
1
,
string.len
(
line
)
do
for
x
=
1
,
string.len
(
line
)
do
if
string.sub
(
line
,
x
,
x
)
==
'O'
then
if
string.sub
(
line
,
x
,
x
)
==
'O'
then
local
location
=
torch
.
LongTensor
{
x
,
y
}
table.insert
(
locations
,
{
x
,
y
})
local
featureVector
=
torch
.
FloatTensor
{
1
}
table.insert
(
featureVectors
,
{
1
})
input
:
setLocation
(
location
,
featureVector
,
0
)
end
end
end
end
end
end
input
:
setLocations
(
torch
.
LongTensor
(
locations
),
torch
.
FloatTensor
(
featureVectors
),
0
)
--[[
--[[
Optional: allow metadata preprocessing to be done in batch preparation threads
Optional: allow metadata preprocessing to be done in batch preparation threads
to improve GPU utilization.
to improve GPU utilization.
...
...
examples/hello-world.py
View file @
db6454cd
...
@@ -35,12 +35,20 @@ msg = [
...
@@ -35,12 +35,20 @@ msg = [
" X X X X X X X X X X X X X X X X X X "
,
" X X X X X X X X X X X X X X X X X X "
,
" X X XXX XXX XXX XX X X XX X X XXX XXX "
]
" X X XXX XXX XXX XX X X XX X X XXX XXX "
]
input
.
addSample
()
input
.
addSample
()
locations
=
[]
features
=
[]
for
y
,
line
in
enumerate
(
msg
):
for
y
,
line
in
enumerate
(
msg
):
for
x
,
c
in
enumerate
(
line
):
for
x
,
c
in
enumerate
(
line
):
if
c
==
'X'
:
if
c
==
'X'
:
location
=
torch
.
LongTensor
([
x
,
y
])
locations
.
append
([
x
,
y
])
featureVector
=
torch
.
FloatTensor
([
1
])
features
.
append
([
1
])
input
.
setLocation
(
location
,
featureVector
,
0
)
locations
=
torch
.
LongTensor
(
locations
)
features
=
torch
.
FloatTensor
(
features
)
input
.
setLocations
(
locations
,
features
,
0
)
# Optional: allow metadata preprocessing to be done in batch preparation threads
# Optional: allow metadata preprocessing to be done in batch preparation threads
# to improve GPU utilization.
# to improve GPU utilization.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment