Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
270d6412
Commit
270d6412
authored
Jan 10, 2022
by
Lawrence McAfee
Browse files
loss matches; memory savings for multi-node (tested n3, n16)
parent
b6d4369b
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
70 additions
and
46 deletions
+70
-46
megatron/model/transformer.py
megatron/model/transformer.py
+6
-9
megatron/mpu/random.py
megatron/mpu/random.py
+58
-19
megatron/p2p_communication.py
megatron/p2p_communication.py
+6
-2
megatron/schedules.py
megatron/schedules.py
+0
-16
No files found.
megatron/model/transformer.py
View file @
270d6412
...
...
@@ -699,15 +699,12 @@ class ParallelTransformer(MegatronModule):
# See set_input_tensor()
hidden_states
=
self
.
input_tensor
# hidden_states = make_standalone_tensor(hidden_states)
# hidden_states = MakeStandaloneTensor.apply(hidden_states)
# hidden_states = MakeViewlessTensor.apply(hidden_states)
hidden_states
=
make_viewless_tensor
(
hidden_states
)
# hidden_states = hidden_states.clone()
# >>>
# from lutil import pax
# pax(0, {"hidden_states": hidden_states})
# <<<
# Viewless tensor
hidden_states
=
make_viewless_tensor
(
hidden_states
,
requires_grad
=
True
,
keep_graph
=
True
,
)
if
encoder_output
is
not
None
:
encoder_output
=
encoder_output
.
transpose
(
0
,
1
).
contiguous
()
...
...
megatron/mpu/random.py
View file @
270d6412
...
...
@@ -99,7 +99,7 @@ def gather_split_1d_tensor(tensor):
return
gathered
# >>>
#
from lutil import pax
from
lutil
import
pax
# ****************
# def make_standalone_tensor(a):
# assert a._base is not None
...
...
@@ -107,26 +107,66 @@ def gather_split_1d_tensor(tensor):
# b.data = a.data
# return b
# class MakeStandaloneTensor(torch.autograd.Function):
class
MakeViewlessTensor_
(
torch
.
autograd
.
Function
):
# class MakeViewlessTensor_(torch.autograd.Function):
class
MakeViewlessTensor
(
torch
.
autograd
.
Function
):
# @staticmethod
# def forward(ctx, inp):
# assert inp._base is not None
# out = torch.empty((1,), dtype = inp.dtype, device = inp.device)
# out.data = inp.data
# # pax(0, {"inp": inp, "out": out})
# return out
@
staticmethod
def
forward
(
ctx
,
inp
):
assert
inp
.
_base
is
not
None
out
=
torch
.
empty
((
1
,),
dtype
=
inp
.
dtype
,
device
=
inp
.
device
)
out
.
data
=
inp
.
data
# pax(0, {"inp": inp, "out": out})
return
out
def
forward
(
ctx
,
inp
,
requires_grad
):
return
_kernel_make_viewless_tensor
(
inp
,
requires_grad
)
# @staticmethod
# def forward(ctx, args):
# return [_kernel_make_viewless_tensor(*args)]
@
staticmethod
def
backward
(
ctx
,
grad_output
):
# pax(0, {"grad_output": grad_output})
return
grad_output
# return grad_output
return
grad_output
,
None
def
_kernel_make_viewless_tensor
(
inp
,
requires_grad
):
out
=
torch
.
empty
(
(
1
,),
dtype
=
inp
.
dtype
,
device
=
inp
.
device
,
requires_grad
=
requires_grad
,
)
out
.
data
=
inp
.
data
# >>>
# pax(0, {"inp": inp, "out": out})
# assert out.requires_grad
# <<<
return
out
def
make_viewless_tensor
(
tensor
):
if
tensor
.
_base
is
None
:
return
tensor
# def make_viewless_tensor(tensor):
# if tensor._base is None:
# return tensor
# else:
# return MakeViewlessTensor_.apply(tensor)
def
make_viewless_tensor
(
inp
,
requires_grad
,
keep_graph
):
# return tensor as-is, if not a 'view'
if
inp
.
_base
is
None
:
return
inp
# create viewless tensor
if
keep_graph
:
# return MakeViewlessTensor.apply((inp, requires_grad))[0]
return
MakeViewlessTensor
.
apply
(
inp
,
requires_grad
)
else
:
return
MakeViewlessTensor_
.
apply
(
tensor
)
return
_kernel_make_viewless_tensor
(
inp
,
requires_grad
)
# return MakeViewlessTensor.apply((inp, requires_grad))[0]
# return MakeViewlessTensor.apply(inp, requires_grad)
# return MakeViewlessTensor.apply(inp)
# return MakeViewlessTensor.apply(inp, 7)
# return MakeViewlessTensor.apply(inp, 7)[0]
def
assert_viewless_tensor
(
tensor
):
def
assert_viewless_tensor
(
tensor
,
extra_msg
=
None
):
if
isinstance
(
tensor
,
list
):
[
assert_viewless_tensor
(
t
)
for
t
in
tensor
]
return
...
...
@@ -137,13 +177,12 @@ def assert_viewless_tensor(tensor):
assert
tensor
.
_base
is
None
,
(
"Ensure tensor._base is None before setting tensor.data or storing "
"tensor to memory buffer. Otherwise, a memory leak will occur (and "
"likely accumulate over iterations). FYI, tensor._base has shape "
"%s, and new_data_tensor has shape %s."
)
%
(
tensor
.
_base
.
shape
,
new_data_tensor
.
shape
)
"likely accumulate over iterations). %s"
)
%
extra_msg
# def set_viewless_tensor_data_attr(tensor, new_data_tensor):
def
safely_set_tensor_data_attr
(
tensor
,
new_data_tensor
):
assert_viewless_tensor
(
tensor
)
assert_viewless_tensor
(
tensor
,
extra_msg
=
"FYI, tensor._base has shape %s, and new_data_tensor has shape %s."
%
(
"--"
if
tensor
.
_base
is
None
else
tensor
.
_base
.
shape
,
new_data_tensor
.
shape
)
)
tensor
.
data
=
new_data_tensor
# <<<
...
...
megatron/p2p_communication.py
View file @
270d6412
...
...
@@ -145,12 +145,16 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
if
recv_prev
:
tensor_recv_prev
=
mpu
.
gather_split_1d_tensor
(
tensor_recv_prev
).
view
(
tensor_shape
).
requires_grad_
()
tensor_recv_prev
=
make_viewless_tensor
(
tensor_recv_prev
)
tensor_recv_prev
=
make_viewless_tensor
(
tensor_recv_prev
,
requires_grad
=
True
,
keep_graph
=
False
)
if
recv_next
:
tensor_recv_next
=
mpu
.
gather_split_1d_tensor
(
tensor_recv_next
).
view
(
tensor_shape
).
requires_grad_
()
tensor_recv_next
=
make_viewless_tensor
(
tensor_recv_next
)
tensor_recv_next
=
make_viewless_tensor
(
tensor_recv_next
,
requires_grad
=
True
,
keep_graph
=
False
)
return
tensor_recv_prev
,
tensor_recv_next
...
...
megatron/schedules.py
View file @
270d6412
...
...
@@ -631,13 +631,6 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
send_forward
(
output_tensor
,
send_tensor_shapes
,
timers
=
timers
)
if
not
forward_only
:
# >>>
# if input_tensor[0] is not None:
# from lutil import pax
# pax({
# "input_tensor" : input_tensor,
# })
# <<<
assert_viewless_tensor
(
input_tensor
)
assert_viewless_tensor
(
output_tensor
)
input_tensors
.
append
(
input_tensor
)
...
...
@@ -669,15 +662,6 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
timers
=
timers
)
# Add input_tensor and output_tensor to end of list.
# >>>
# assert input_tensor[0]._base is None, \
# "rank %s; uh oh." % torch.distributed.get_rank()
# if input_tensor[0] is not None:
# from lutil import pax
# pax(4, {
# "input_tensor[0]" : input_tensor[0],
# })
# <<<
assert_viewless_tensor
(
input_tensor
)
assert_viewless_tensor
(
output_tensor
)
input_tensors
.
append
(
input_tensor
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment