Commit 062f0665 authored by Benjamin Thomas Graham's avatar Benjamin Thomas Graham
Browse files

at:: to torch::

parent 9aea4e05
......@@ -94,14 +94,6 @@ locations = torch.LongTensor(locations)
features = torch.FloatTensor(features)
input.set_locations(locations, features, 0)
# Optional: allow metadata preprocessing to be done in batch preparation threads
# to improve GPU utilization.
#
# Parameter:
# 3 if using MP3/2 pooling or C3/2 convolutions for downsizing,
# 2 if using MP2 pooling for downsizing.
input.precomputeMetadata(3)
model.train()
if use_gpu:
input.cuda()
......
......@@ -58,14 +58,6 @@ locations = torch.LongTensor(locations)
features = torch.FloatTensor(features)
input.set_locations(locations, features, 0)
# Optional: allow metadata preprocessing to be done in batch preparation threads
# to improve GPU utilization.
#
# Parameter:
# 3 if using MP3/2 pooling or C3/2 convolutions for downsizing,
# 2 if using MP2 pooling for downsizing.
input.precomputeMetadata(3)
model.train()
if use_cuda:
input.cuda()
......
......@@ -145,15 +145,13 @@ void Metadata<dimension>::setInputSpatialLocations(
}
template <Int dimension>
void Metadata<dimension>::getSpatialLocations(/*long*/ at::Tensor spatialSize,
/*long*/ at::Tensor locations) {
at::Tensor
Metadata<dimension>::getSpatialLocations(/*long*/ at::Tensor spatialSize) {
Int nActive = getNActive(spatialSize);
auto &SGs = getSparseGrid(spatialSize);
Int batchSize = SGs.size();
locations.resize_({(int)nActive, dimension + 1});
locations.zero_();
auto locations = torch::zeros({(int)nActive, dimension + 1}, at::kLong);
auto lD = locations.data<long>();
for (Int i = 0; i < batchSize; i++) {
......@@ -166,6 +164,7 @@ void Metadata<dimension>::getSpatialLocations(/*long*/ at::Tensor spatialSize,
lD[(it->second + offset) * (dimension + 1) + dimension] = i;
}
}
return locations;
}
template <Int dimension>
void Metadata<dimension>::createMetadataForDenseToSparse(
......@@ -260,7 +259,7 @@ Metadata<dimension>::sparsifyCompare(Metadata<dimension> &mReference,
Metadata<dimension> &mSparsified,
/*long*/ at::Tensor spatialSize) {
auto p = LongTensorToPoint<dimension>(spatialSize);
at::Tensor delta = at::zeros({nActive[p]}, at::kFloat);
at::Tensor delta = torch::zeros({nActive[p]}, at::kFloat);
float *deltaPtr = delta.data<float>();
auto &sgsReference = mReference.grids[p];
auto &sgsFull = grids[p];
......@@ -588,13 +587,13 @@ Metadata<dimension>::compareSparseHelper(Metadata<dimension> &mR,
}
}
}
at::Tensor cL_ = at::empty({(long)cL.size()}, at::CPU(at::kLong));
at::Tensor cL_ = torch::empty({(long)cL.size()}, at::CPU(at::kLong));
std::memcpy(cL_.data<long>(), &cL[0], cL.size() * sizeof(long));
at::Tensor cR_ = at::empty({(long)cR.size()}, at::CPU(at::kLong));
at::Tensor cR_ = torch::empty({(long)cR.size()}, at::CPU(at::kLong));
std::memcpy(cR_.data<long>(), &cR[0], cR.size() * sizeof(long));
at::Tensor L_ = at::empty({(long)L.size()}, at::CPU(at::kLong));
at::Tensor L_ = torch::empty({(long)L.size()}, at::CPU(at::kLong));
std::memcpy(L_.data<long>(), &L[0], L.size() * sizeof(long));
at::Tensor R_ = at::empty({(long)R.size()}, at::CPU(at::kLong));
at::Tensor R_ = torch::empty({(long)R.size()}, at::CPU(at::kLong));
std::memcpy(R_.data<long>(), &R[0], R.size() * sizeof(long));
return {cL_, cR_, L_, R_};
}
......
......@@ -92,8 +92,7 @@ public:
/*long*/ at::Tensor locations,
/*float*/ at::Tensor vecs, bool overwrite);
void getSpatialLocations(/*long*/ at::Tensor spatialSize,
/*long*/ at::Tensor locations);
at::Tensor getSpatialLocations(/*long*/ at::Tensor spatialSize);
void createMetadataForDenseToSparse(/*long*/ at::Tensor spatialSize,
/*long*/ at::Tensor nz_, long batchSize);
......
......@@ -19,8 +19,7 @@ class SparseConvNetTensor(object):
"Coordinates and batch index for the active spatial locations"
if spatial_size is None:
spatial_size = self.spatial_size
t = torch.LongTensor()
self.metadata.getSpatialLocations(spatial_size, t)
t = self.metadata.getSpatialLocations(spatial_size)
return t
def type(self, t=None):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment