Commit 0b84ab19 authored by Myle Ott's avatar Myle Ott
Browse files

Fix training

parent 5eddda8b
...@@ -197,7 +197,10 @@ def make_variable(sample, volatile=False, cuda_device=None): ...@@ -197,7 +197,10 @@ def make_variable(sample, volatile=False, cuda_device=None):
if torch.is_tensor(maybe_tensor): if torch.is_tensor(maybe_tensor):
if cuda_device is not None and torch.cuda.is_available(): if cuda_device is not None and torch.cuda.is_available():
maybe_tensor = maybe_tensor.cuda(async=True, device=cuda_device) maybe_tensor = maybe_tensor.cuda(async=True, device=cuda_device)
return volatile_variable(maybe_tensor) if volatile:
return volatile_variable(maybe_tensor)
else:
return Variable(maybe_tensor)
elif isinstance(maybe_tensor, dict): elif isinstance(maybe_tensor, dict):
return { return {
key: _make_variable(value) key: _make_variable(value)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment