Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
522fecc1
Unverified
Commit
522fecc1
authored
Apr 27, 2023
by
Kirthi Shankar Sivamani
Committed by
GitHub
Apr 27, 2023
Browse files
Re-add support for PyTorch version 1.x (#180)
Signed-off-by:
Kirthi Shankar Sivamani
<
ksivamani@nvidia.com
>
parent
1a868ff3
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
5 deletions
+5
-5
transformer_engine/pytorch/transformer.py
transformer_engine/pytorch/transformer.py
+5
-5
No files found.
transformer_engine/pytorch/transformer.py
View file @
522fecc1
...
@@ -94,13 +94,13 @@ class _SplitLastDim(torch.autograd.Function):
...
@@ -94,13 +94,13 @@ class _SplitLastDim(torch.autograd.Function):
noop_ok
=
True
noop_ok
=
True
strides
=
grad_outputs
[
0
].
stride
()
strides
=
grad_outputs
[
0
].
stride
()
data_ptr
=
grad_outputs
[
0
].
untyped_
storage
().
data_ptr
()
data_ptr
=
grad_outputs
[
0
].
storage
().
data_ptr
()
shape
=
grad_outputs
[
0
].
shape
shape
=
grad_outputs
[
0
].
shape
last_dim_size
=
grad_outputs
[
0
].
shape
[
-
1
]
last_dim_size
=
grad_outputs
[
0
].
shape
[
-
1
]
for
i
,
tensor
in
enumerate
(
grad_outputs
):
for
i
,
tensor
in
enumerate
(
grad_outputs
):
if
(
tensor
.
stride
()
!=
strides
or
if
(
tensor
.
stride
()
!=
strides
or
tensor
.
shape
!=
shape
or
tensor
.
shape
!=
shape
or
tensor
.
untyped_
storage
().
data_ptr
()
!=
data_ptr
or
tensor
.
storage
().
data_ptr
()
!=
data_ptr
or
tensor
.
storage_offset
()
!=
i
*
last_dim_size
):
tensor
.
storage_offset
()
!=
i
*
last_dim_size
):
noop_ok
=
False
noop_ok
=
False
break
break
...
@@ -111,7 +111,7 @@ class _SplitLastDim(torch.autograd.Function):
...
@@ -111,7 +111,7 @@ class _SplitLastDim(torch.autograd.Function):
dtype
=
grad_outputs
[
0
].
dtype
)
dtype
=
grad_outputs
[
0
].
dtype
)
new_shape
=
list
(
shape
)
new_shape
=
list
(
shape
)
new_shape
[
-
1
]
=
new_shape
[
-
1
]
*
len
(
grad_outputs
)
new_shape
[
-
1
]
=
new_shape
[
-
1
]
*
len
(
grad_outputs
)
ret
.
set_
(
grad_outputs
[
0
].
untyped_
storage
(),
ret
.
set_
(
grad_outputs
[
0
].
storage
(),
grad_outputs
[
0
].
storage_offset
(),
grad_outputs
[
0
].
storage_offset
(),
new_shape
,
new_shape
,
grad_outputs
[
0
].
stride
()
grad_outputs
[
0
].
stride
()
...
@@ -277,8 +277,8 @@ class _PrepareQKVForFA(torch.autograd.Function):
...
@@ -277,8 +277,8 @@ class _PrepareQKVForFA(torch.autograd.Function):
return
dq
,
dk
,
dv
return
dq
,
dk
,
dv
def
_check_if_interleaved
(
q
,
k
,
v
):
def
_check_if_interleaved
(
q
,
k
,
v
):
data_ptr
=
q
.
untyped_
storage
().
data_ptr
()
data_ptr
=
q
.
storage
().
data_ptr
()
check_ptrs
=
all
(
x
.
untyped_
storage
().
data_ptr
()
==
data_ptr
for
x
in
[
q
,
k
,
v
])
check_ptrs
=
all
(
x
.
storage
().
data_ptr
()
==
data_ptr
for
x
in
[
q
,
k
,
v
])
if
not
check_ptrs
:
if
not
check_ptrs
:
return
False
return
False
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment