Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
ed0c8714
Commit
ed0c8714
authored
Jan 07, 2022
by
Lawrence McAfee
Browse files
more iterating on 'viewless tensor' methods
parent
5422d23a
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
106 additions
and
16 deletions
+106
-16
megatron/model/transformer.py
megatron/model/transformer.py
+8
-10
megatron/mpu/random.py
megatron/mpu/random.py
+57
-6
megatron/p2p_communication.py
megatron/p2p_communication.py
+5
-0
megatron/schedules.py
megatron/schedules.py
+36
-0
No files found.
megatron/model/transformer.py
View file @
ed0c8714
...
...
@@ -27,6 +27,9 @@ from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from
megatron.model.fused_bias_gelu
import
bias_gelu_impl
from
megatron.model.utils
import
attention_mask_func
,
openai_gelu
,
erf_gelu
# >>>
from
megatron.mpu.random
import
make_viewless_tensor
# <<<
""" We use the following notation throughout this file:
h: hidden size
...
...
@@ -696,19 +699,14 @@ class ParallelTransformer(MegatronModule):
# See set_input_tensor()
hidden_states
=
self
.
input_tensor
# >>>
def
make_standalone_tensor
(
a
):
assert
a
.
_base
is
not
None
b
=
torch
.
empty
((
1
,),
dtype
=
a
.
dtype
,
device
=
a
.
device
)
b
.
data
=
a
.
data
return
b
# <<<
# hidden_states = make_standalone_tensor(hidden_states)
hidden_states
=
hidden_states
.
clone
()
# hidden_states = MakeStandaloneTensor.apply(hidden_states)
# hidden_states = MakeViewlessTensor.apply(hidden_states)
hidden_states
=
make_viewless_tensor
(
hidden_states
)
# hidden_states = hidden_states.clone()
# >>>
# from lutil import pax
# pax({"hidden_states": hidden_states})
# pax(
0,
{"hidden_states": hidden_states})
# <<<
if
encoder_output
is
not
None
:
...
...
megatron/mpu/random.py
View file @
ed0c8714
...
...
@@ -98,13 +98,54 @@ def gather_split_1d_tensor(tensor):
group
=
get_tensor_model_parallel_group
())
return
gathered
def
safely_set_tensor_data_attr
(
tensor
,
new_data_tensor
):
# >>>
# from lutil import pax
# def make_standalone_tensor(a):
# assert a._base is not None
# b = torch.empty((1,), dtype = a.dtype, device = a.device)
# b.data = a.data
# return b
# class MakeStandaloneTensor(torch.autograd.Function):
class
MakeViewlessTensor_
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
inp
):
assert
inp
.
_base
is
not
None
out
=
torch
.
empty
((
1
,),
dtype
=
inp
.
dtype
,
device
=
inp
.
device
)
out
.
data
=
inp
.
data
# pax(0, {"inp": inp, "out": out})
return
out
@
staticmethod
def
backward
(
ctx
,
grad_output
):
# pax(0, {"grad_output": grad_output})
return
grad_output
def
make_viewless_tensor
(
tensor
):
if
tensor
.
_base
is
None
:
return
tensor
else
:
return
MakeViewlessTensor_
.
apply
(
tensor
)
def
assert_viewless_tensor
(
tensor
):
if
isinstance
(
tensor
,
list
):
[
assert_viewless_tensor
(
t
)
for
t
in
tensor
]
return
# assert isinstance(tensor, torch.Tensor), \
# "expected Tensor; found %s." % type(tensor).__name__
if
not
isinstance
(
tensor
,
torch
.
Tensor
):
return
assert
tensor
.
_base
is
None
,
(
"Ensure tensor._base is None before setting tensor.data. Otherwise, "
"a memory leak will occur (and likely accumulate over iterations). "
"FYI, tensor._base has shape %s, and new_data_tensor has shape %s."
"Ensure tensor._base is None before setting tensor.data or storing "
"tensor to memory buffer. Otherwise, a memory leak will occur (and "
"likely accumulate over iterations). FYI, tensor._base has shape "
"%s, and new_data_tensor has shape %s."
)
%
(
tensor
.
_base
.
shape
,
new_data_tensor
.
shape
)
# def set_viewless_tensor_data_attr(tensor, new_data_tensor):
def
safely_set_tensor_data_attr
(
tensor
,
new_data_tensor
):
assert_viewless_tensor
(
tensor
)
tensor
.
data
=
new_data_tensor
# <<<
class
CudaRNGStatesTracker
:
"""Tracker for the cuda RNG states.
...
...
@@ -253,11 +294,13 @@ class CheckpointFunction(torch.autograd.Function):
# with data_leak_ctx(args[0]):
# <<<
ctx
.
input_0_shape
=
args
[
0
].
data
.
shape
# >>>
# args[0].data = split_tensor_into_1d_equal_chunks(args[0].data,
# new_buffer=True)
safely_set_tensor_data_attr
(
args
[
0
],
split_tensor_into_1d_equal_chunks
(
args
[
0
].
data
,
new_buffer
=
True
))
# <<<
# Store everything.
ctx
.
save_for_backward
(
*
args
)
...
...
@@ -271,8 +314,16 @@ class CheckpointFunction(torch.autograd.Function):
"please use .backward() if possible"
)
inputs
=
ctx
.
saved_tensors
if
ctx
.
distribute_checkpointed_activations
:
inputs
[
0
].
data
=
gather_split_1d_tensor
(
inputs
[
0
].
data
)
inputs
[
0
].
data
=
inputs
[
0
].
data
.
view
(
ctx
.
input_0_shape
)
# >>>
# inputs[0].data = gather_split_1d_tensor(inputs[0].data)
# inputs[0].data = inputs[0].data.view(ctx.input_0_shape)
safely_set_tensor_data_attr
(
inputs
[
0
],
gather_split_1d_tensor
(
inputs
[
0
].
data
))
safely_set_tensor_data_attr
(
inputs
[
0
],
inputs
[
0
].
data
.
view
(
ctx
.
input_0_shape
))
# <<<
# Store the current states.
bwd_cpu_rng_state
=
torch
.
get_rng_state
()
...
...
megatron/p2p_communication.py
View file @
ed0c8714
...
...
@@ -20,6 +20,9 @@ import torch
from
megatron
import
get_args
from
megatron
import
mpu
# >>>
from
megatron.mpu.random
import
make_viewless_tensor
# <<<
def
_communicate
(
tensor_send_next
,
tensor_send_prev
,
recv_prev
,
recv_next
,
tensor_shape
,
...
...
@@ -142,10 +145,12 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
if
recv_prev
:
tensor_recv_prev
=
mpu
.
gather_split_1d_tensor
(
tensor_recv_prev
).
view
(
tensor_shape
).
requires_grad_
()
tensor_recv_prev
=
make_viewless_tensor
(
tensor_recv_prev
)
if
recv_next
:
tensor_recv_next
=
mpu
.
gather_split_1d_tensor
(
tensor_recv_next
).
view
(
tensor_shape
).
requires_grad_
()
tensor_recv_next
=
make_viewless_tensor
(
tensor_recv_next
)
return
tensor_recv_prev
,
tensor_recv_next
...
...
megatron/schedules.py
View file @
ed0c8714
...
...
@@ -28,6 +28,10 @@ from megatron.model import DistributedDataParallel as LocalDDP
from
megatron.model
import
Float16Module
from
megatron.model
import
ModelType
# >>>
from
megatron.mpu.random
import
assert_viewless_tensor
# <<<
def
get_forward_backward_func
():
args
=
get_args
()
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
...
...
@@ -306,6 +310,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
model
[
model_chunk_id
],
input_tensor
,
losses_reduced
)
output_tensors
[
model_chunk_id
].
append
(
output_tensor
)
assert_viewless_tensor
(
output_tensor
)
# if forward-only, no need to save tensors for a backward pass
if
forward_only
:
...
...
@@ -339,6 +344,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
mpu
.
set_virtual_pipeline_model_parallel_rank
(
0
)
input_tensors
[
0
].
append
(
p2p_communication
.
recv_forward
(
tensor_shape
,
timers
=
timers
))
assert_viewless_tensor
(
input_tensors
[
0
][
-
1
])
for
k
in
range
(
num_warmup_microbatches
):
output_tensor
=
forward_step_helper
(
k
)
...
...
@@ -370,6 +376,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
tensor_shape
=
tensor_shape
,
timers
=
timers
)
output_tensor_grads
[
num_model_chunks
-
1
].
append
(
output_tensor_grad
)
assert_viewless_tensor
(
output_tensor_grad
)
else
:
input_tensor
=
\
p2p_communication
.
send_forward_recv_forward
(
...
...
@@ -378,6 +385,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
timers
=
timers
)
free_output_tensor
(
output_tensor
,
args
.
deallocate_pipeline_outputs
)
input_tensors
[
next_forward_model_chunk_id
].
append
(
input_tensor
)
assert_viewless_tensor
(
input_tensor
)
# Run 1F1B in steady state.
for
k
in
range
(
num_microbatches_remaining
):
...
...
@@ -447,15 +455,18 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
# right location.
if
recv_prev
:
input_tensors
[
next_forward_model_chunk_id
].
append
(
input_tensor
)
assert_viewless_tensor
(
input_tensor
)
if
recv_next
:
output_tensor_grads
[
next_backward_model_chunk_id
].
append
(
output_tensor_grad
)
assert_viewless_tensor
(
output_tensor_grad
)
# Run cooldown backward passes (flush out pipeline).
if
not
forward_only
:
if
all_warmup_microbatches
:
output_tensor_grads
[
num_model_chunks
-
1
].
append
(
p2p_communication
.
recv_backward
(
tensor_shape
,
timers
=
timers
))
assert_viewless_tensor
(
output_tensor_grads
[
num_model_chunks
-
1
][
-
1
])
for
k
in
range
(
num_microbatches_remaining
,
num_microbatches
):
input_tensor_grad
=
backward_step_helper
(
k
)
next_backward_model_chunk_id
=
get_model_chunk_id
(
k
+
1
,
forward
=
False
)
...
...
@@ -470,6 +481,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
input_tensor_grad
,
recv_next
=
recv_next
,
tensor_shape
=
tensor_shape
,
timers
=
timers
))
assert_viewless_tensor
(
output_tensor_grads
[
next_backward_model_chunk_id
][
-
1
])
return
losses_reduced
...
...
@@ -508,6 +520,7 @@ def recv_forward(tensor_shapes, timers):
else
:
input_tensors
.
append
(
p2p_communication
.
recv_forward
(
tensor_shape
,
timers
=
timers
))
assert_viewless_tensor
(
input_tensors
[
-
1
])
return
input_tensors
...
...
@@ -519,6 +532,7 @@ def recv_backward(tensor_shapes, timers):
else
:
output_tensor_grads
.
append
(
p2p_communication
.
recv_backward
(
tensor_shape
,
timers
=
timers
))
assert_viewless_tensor
(
output_tensor_grads
[
-
1
])
return
output_tensor_grads
...
...
@@ -551,6 +565,7 @@ def send_forward_recv_backward(output_tensors, tensor_shapes, timers):
output_tensor_grad
=
p2p_communication
.
send_forward_recv_backward
(
output_tensor
,
tensor_shape
,
timers
=
timers
)
output_tensor_grads
.
append
(
output_tensor_grad
)
assert_viewless_tensor
(
output_tensor_grad
)
return
output_tensor_grads
...
...
@@ -565,6 +580,7 @@ def send_backward_recv_forward(input_tensor_grads, tensor_shapes, timers):
input_tensor
=
p2p_communication
.
send_backward_recv_forward
(
input_tensor_grad
,
tensor_shape
,
timers
=
timers
)
input_tensors
.
append
(
input_tensor
)
assert_viewless_tensor
(
input_tensor
)
return
input_tensors
...
...
@@ -615,6 +631,15 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
send_forward
(
output_tensor
,
send_tensor_shapes
,
timers
=
timers
)
if
not
forward_only
:
# >>>
if
input_tensor
[
0
]
is
not
None
:
from
lutil
import
pax
pax
({
"input_tensor"
:
input_tensor
,
})
# <<<
assert_viewless_tensor
(
input_tensor
)
assert_viewless_tensor
(
output_tensor
)
input_tensors
.
append
(
input_tensor
)
output_tensors
.
append
(
output_tensor
)
free_output_tensor
(
output_tensor
,
args
.
deallocate_pipeline_outputs
)
...
...
@@ -644,6 +669,17 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
timers
=
timers
)
# Add input_tensor and output_tensor to end of list.
# >>>
# assert input_tensor[0]._base is None, \
# "rank %s; uh oh." % torch.distributed.get_rank()
# if input_tensor[0] is not None:
# from lutil import pax
# pax(4, {
# "input_tensor[0]" : input_tensor[0],
# })
# <<<
assert_viewless_tensor
(
input_tensor
)
assert_viewless_tensor
(
output_tensor
)
input_tensors
.
append
(
input_tensor
)
output_tensors
.
append
(
output_tensor
)
free_output_tensor
(
output_tensor
,
args
.
deallocate_pipeline_outputs
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment