Commit d80ef3e2 authored by Benjamin Thomas Graham's avatar Benjamin Thomas Graham
Browse files

utils

parent 6f80e92d
...@@ -58,4 +58,7 @@ class CheckpointedSequential(Sequential): ...@@ -58,4 +58,7 @@ class CheckpointedSequential(Sequential):
def forward(self, x): def forward(self, x):
def run(x): def run(x):
return Sequential.forward(self,x) return Sequential.forward(self,x)
if hasattr(x,'metadata'):
return checkpoint101(run, x)
else:
return torch.utils.checkpoint.checkpoint(run, x) return torch.utils.checkpoint.checkpoint(run, x)
...@@ -212,7 +212,7 @@ class checkpointFunction(torch.autograd.Function): ...@@ -212,7 +212,7 @@ class checkpointFunction(torch.autograd.Function):
ctx.x_metadata=x_metadata ctx.x_metadata=x_metadata
with torch.no_grad(): with torch.no_grad():
y = run_function( y = run_function(
scn.SparseConvNetTensor SparseConvNetTensor
(x_features, x_metadata, x_spatial_size)) (x_features, x_metadata, x_spatial_size))
return y.features return y.features
@staticmethod @staticmethod
...@@ -222,7 +222,7 @@ class checkpointFunction(torch.autograd.Function): ...@@ -222,7 +222,7 @@ class checkpointFunction(torch.autograd.Function):
x_features.requires_grad = True x_features.requires_grad = True
with torch.enable_grad(): with torch.enable_grad():
y = ctx.run_function( y = ctx.run_function(
scn.SparseConvNetTensor SparseConvNetTensor
(x_features, ctx.x_metadata, x_spatial_size)) (x_features, ctx.x_metadata, x_spatial_size))
torch.autograd.backward(y.features, grad_y_features,retain_graph=False) torch.autograd.backward(y.features, grad_y_features,retain_graph=False)
return None, x_features.grad, None, None return None, x_features.grad, None, None
...@@ -230,7 +230,7 @@ class checkpointFunction(torch.autograd.Function): ...@@ -230,7 +230,7 @@ class checkpointFunction(torch.autograd.Function):
def checkpoint101(run_function, x, down=1): def checkpoint101(run_function, x, down=1):
f=checkpointFunction.apply(run_function, x.features, x.metadata, x.spatial_size) f=checkpointFunction.apply(run_function, x.features, x.metadata, x.spatial_size)
s=x.spatial_size//down s=x.spatial_size//down
return scn.SparseConvNetTensor(f, x.metadata, s) return SparseConvNetTensor(f, x.metadata, s)
def matplotlib_cubes(ax, positions,colors): def matplotlib_cubes(ax, positions,colors):
from mpl_toolkits.mplot3d import Axes3D from mpl_toolkits.mplot3d import Axes3D
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment