Unverified Commit 522fecc1 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Re-add support for PyTorch version 1.x (#180)


Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 1a868ff3
...@@ -94,13 +94,13 @@ class _SplitLastDim(torch.autograd.Function): ...@@ -94,13 +94,13 @@ class _SplitLastDim(torch.autograd.Function):
noop_ok = True noop_ok = True
strides = grad_outputs[0].stride() strides = grad_outputs[0].stride()
data_ptr = grad_outputs[0].untyped_storage().data_ptr() data_ptr = grad_outputs[0].storage().data_ptr()
shape = grad_outputs[0].shape shape = grad_outputs[0].shape
last_dim_size = grad_outputs[0].shape[-1] last_dim_size = grad_outputs[0].shape[-1]
for i, tensor in enumerate(grad_outputs): for i, tensor in enumerate(grad_outputs):
if (tensor.stride() != strides or if (tensor.stride() != strides or
tensor.shape != shape or tensor.shape != shape or
tensor.untyped_storage().data_ptr() != data_ptr or tensor.storage().data_ptr() != data_ptr or
tensor.storage_offset() != i * last_dim_size): tensor.storage_offset() != i * last_dim_size):
noop_ok = False noop_ok = False
break break
...@@ -111,7 +111,7 @@ class _SplitLastDim(torch.autograd.Function): ...@@ -111,7 +111,7 @@ class _SplitLastDim(torch.autograd.Function):
dtype=grad_outputs[0].dtype) dtype=grad_outputs[0].dtype)
new_shape = list(shape) new_shape = list(shape)
new_shape[-1] = new_shape[-1] * len(grad_outputs) new_shape[-1] = new_shape[-1] * len(grad_outputs)
ret.set_(grad_outputs[0].untyped_storage(), ret.set_(grad_outputs[0].storage(),
grad_outputs[0].storage_offset(), grad_outputs[0].storage_offset(),
new_shape, new_shape,
grad_outputs[0].stride() grad_outputs[0].stride()
...@@ -277,8 +277,8 @@ class _PrepareQKVForFA(torch.autograd.Function): ...@@ -277,8 +277,8 @@ class _PrepareQKVForFA(torch.autograd.Function):
return dq, dk, dv return dq, dk, dv
def _check_if_interleaved(q, k, v): def _check_if_interleaved(q, k, v):
data_ptr = q.untyped_storage().data_ptr() data_ptr = q.storage().data_ptr()
check_ptrs = all(x.untyped_storage().data_ptr() == data_ptr for x in [q, k, v]) check_ptrs = all(x.storage().data_ptr() == data_ptr for x in [q, k, v])
if not check_ptrs: if not check_ptrs:
return False return False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment