Commit 2f6072ed authored by Benjamin Thomas Graham's avatar Benjamin Thomas Graham
Browse files

contiguous check

parent de3743f6
...@@ -25,7 +25,7 @@ class InputBatch(SparseConvNetTensor): ...@@ -25,7 +25,7 @@ class InputBatch(SparseConvNetTensor):
def set_location(self, location, vector, overwrite=False): def set_location(self, location, vector, overwrite=False):
assert location.min() >= 0 and (self.spatial_size - location).min() > 0 assert location.min() >= 0 and (self.spatial_size - location).min() > 0
self.metadata.setInputSpatialLocation( self.metadata.setInputSpatialLocation(
self.features, location, vector, overwrite) self.features, location.contiguous(), vector.contiguous(), overwrite)
def set_location_(self, location, vector, overwrite=False): def set_location_(self, location, vector, overwrite=False):
self.metadata.setInputSpatialLocation( self.metadata.setInputSpatialLocation(
...@@ -53,7 +53,7 @@ class InputBatch(SparseConvNetTensor): ...@@ -53,7 +53,7 @@ class InputBatch(SparseConvNetTensor):
l = locations[:, :self.dimension] l = locations[:, :self.dimension]
assert l.min() >= 0 and (self.spatial_size.expand_as(l) - l).min() > 0 assert l.min() >= 0 and (self.spatial_size.expand_as(l) - l).min() > 0
self.metadata.setInputSpatialLocations( self.metadata.setInputSpatialLocations(
self.features, locations, vectors, overwrite) self.features, locations.contiguous(), vectors.contiguous(), overwrite)
def set_locations_(self, locations, vector, overwrite=False): def set_locations_(self, locations, vector, overwrite=False):
self.metadata.setInputSpatialLocations( self.metadata.setInputSpatialLocations(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment