Unverified Commit 646c11b0 authored by Laurens van der Maaten's avatar Laurens van der Maaten Committed by GitHub
Browse files

Expose option to perform async CUDA copies

See title. I tested the `async` flag locally, and it appears to work as expected.
parent cd5f3fc1
...@@ -9,6 +9,7 @@ import torch ...@@ -9,6 +9,7 @@ import torch
from .utils import dim_fn from .utils import dim_fn
from torch.autograd import Variable from torch.autograd import Variable
class SparseConvNetTensor(object): class SparseConvNetTensor(object):
def __init__(self, features=None, metadata=None, spatial_size=None): def __init__(self, features=None, metadata=None, spatial_size=None):
self.features = features self.features = features
...@@ -30,8 +31,8 @@ class SparseConvNetTensor(object): ...@@ -30,8 +31,8 @@ class SparseConvNetTensor(object):
return self return self
return self.features.type() return self.features.type()
def cuda(self): def cuda(self, async=False):
self.features = self.features.cuda() self.features = self.features.cuda(async=async)
return self return self
def cpu(self): def cpu(self):
...@@ -47,7 +48,7 @@ class SparseConvNetTensor(object): ...@@ -47,7 +48,7 @@ class SparseConvNetTensor(object):
return 'SparseConvNetTensor<<' + \ return 'SparseConvNetTensor<<' + \
repr(self.features) + repr(self.metadata) + repr(self.spatial_size) + '>>' repr(self.features) + repr(self.metadata) + repr(self.spatial_size) + '>>'
def to_variable(self, requires_grad = False, volatile=False): def to_variable(self, requires_grad=False, volatile=False):
"Convert self.features to a variable for use with modern PyTorch interface." "Convert self.features to a variable for use with modern PyTorch interface."
self.features=Variable(self.features, requires_grad=requires_grad, volatile=volatile) self.features = Variable(self.features, requires_grad=requires_grad, volatile=volatile)
return self return self
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment