Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
5422d23a
Commit
5422d23a
authored
Jan 07, 2022
by
Lawrence McAfee
Browse files
debugging make_standalone_tensor(), safely_set_tensor_data_attr()
parent
9a8b89ac
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
31 additions
and
2 deletions
+31
-2
megatron/model/transformer.py
megatron/model/transformer.py
+15
-0
megatron/mpu/random.py
megatron/mpu/random.py
+16
-2
No files found.
megatron/model/transformer.py
View file @
5422d23a
...
...
@@ -696,6 +696,21 @@ class ParallelTransformer(MegatronModule):
# See set_input_tensor()
hidden_states
=
self
.
input_tensor
# >>>
def
make_standalone_tensor
(
a
):
assert
a
.
_base
is
not
None
b
=
torch
.
empty
((
1
,),
dtype
=
a
.
dtype
,
device
=
a
.
device
)
b
.
data
=
a
.
data
return
b
# <<<
# hidden_states = make_standalone_tensor(hidden_states)
hidden_states
=
hidden_states
.
clone
()
# >>>
# from lutil import pax
# pax({"hidden_states": hidden_states})
# <<<
if
encoder_output
is
not
None
:
encoder_output
=
encoder_output
.
transpose
(
0
,
1
).
contiguous
()
...
...
megatron/mpu/random.py
View file @
5422d23a
...
...
@@ -98,6 +98,13 @@ def gather_split_1d_tensor(tensor):
group
=
get_tensor_model_parallel_group
())
return
gathered
def
safely_set_tensor_data_attr
(
tensor
,
new_data_tensor
):
assert
tensor
.
_base
is
None
,
(
"Ensure tensor._base is None before setting tensor.data. Otherwise, "
"a memory leak will occur (and likely accumulate over iterations). "
"FYI, tensor._base has shape %s, and new_data_tensor has shape %s."
)
%
(
tensor
.
_base
.
shape
,
new_data_tensor
.
shape
)
tensor
.
data
=
new_data_tensor
class
CudaRNGStatesTracker
:
"""Tracker for the cuda RNG states.
...
...
@@ -241,9 +248,16 @@ class CheckpointFunction(torch.autograd.Function):
# Divide hidden states across model parallel group and only keep
# the chunk corresponding to the current rank.
if
distribute_checkpointed_activations
:
# >>>
# from lutil import data_leak_ctx
# with data_leak_ctx(args[0]):
# <<<
ctx
.
input_0_shape
=
args
[
0
].
data
.
shape
args
[
0
].
data
=
split_tensor_into_1d_equal_chunks
(
args
[
0
].
data
,
new_buffer
=
True
)
# args[0].data = split_tensor_into_1d_equal_chunks(args[0].data,
# new_buffer=True)
safely_set_tensor_data_attr
(
args
[
0
],
split_tensor_into_1d_equal_chunks
(
args
[
0
].
data
,
new_buffer
=
True
))
# Store everything.
ctx
.
save_for_backward
(
*
args
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment