Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
5bc9f889
Commit
5bc9f889
authored
Jan 28, 2022
by
Lawrence McAfee
Browse files
narrowed issue to pipeline rank 0, virtual pipeline rank >= 1
parent
f17a3933
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
52 additions
and
8 deletions
+52
-8
megatron/model/transformer.py
megatron/model/transformer.py
+12
-0
megatron/p2p_communication.py
megatron/p2p_communication.py
+19
-6
megatron/schedules.py
megatron/schedules.py
+9
-0
megatron/training.py
megatron/training.py
+12
-2
No files found.
megatron/model/transformer.py
View file @
5bc9f889
...
...
@@ -698,6 +698,18 @@ class ParallelTransformer(MegatronModule):
# See set_input_tensor()
hidden_states
=
self
.
input_tensor
# >>>
# if not self.pre_process and self.num_layers == 0:
# # raise Exception("tp %d, pp %d, vp %d ... hidden states %s, input tensor %s." % (
# # mpu.get_tensor_model_parallel_rank(),
# # mpu.get_pipeline_model_parallel_rank(),
# # mpu.get_virtual_pipeline_model_parallel_rank(),
# # "--" if hidden_states is None else str(hidden_states.shape),
# # "--" if self.input_tensor is None else str(self.input_tensor.shape),
# # ))
# hidden_states = hidden_states.clone()
# <<<
# Viewless tensor.
# - We only need to create a viewless tensor in the case of micro batch
# size (mbs) == 1, since in this case, 'hidden_states.transpose()'
...
...
megatron/p2p_communication.py
View file @
5bc9f889
...
...
@@ -136,22 +136,35 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
# To protect against race condition when using batch_isend_irecv().
torch
.
cuda
.
synchronize
()
# >>>
def
make_viewless_tensor
(
t
):
return
mpu
.
make_viewless_tensor
(
t
,
requires_grad
=
True
,
keep_graph
=
False
)
# <<<
# If using scatter-gather optimization, gather smaller chunks.
if
not
override_scatter_gather_tensors_in_pipeline
and
\
args
.
scatter_gather_tensors_in_pipeline
:
if
recv_prev
:
tensor_recv_prev
=
mpu
.
gather_split_1d_tensor
(
tensor_recv_prev
).
view
(
tensor_shape
).
requires_grad_
()
tensor_recv_prev
=
mpu
.
make_viewless_tensor
(
tensor_recv_prev
,
requires_grad
=
True
,
keep_graph
=
False
)
# >>>
# tensor_recv_prev = mpu.make_viewless_tensor(tensor_recv_prev,
# requires_grad = True,
# keep_graph = False)
# +++
tensor_recv_prev
=
make_viewless_tensor
(
tensor_recv_prev
)
# <<<
if
recv_next
:
tensor_recv_next
=
mpu
.
gather_split_1d_tensor
(
tensor_recv_next
).
view
(
tensor_shape
).
requires_grad_
()
tensor_recv_next
=
mpu
.
make_viewless_tensor
(
tensor_recv_next
,
requires_grad
=
True
,
keep_graph
=
False
)
# >>>
# tensor_recv_next = mpu.make_viewless_tensor(tensor_recv_next,
# requires_grad = True,
# keep_graph = False)
# +++
tensor_recv_next
=
make_viewless_tensor
(
tensor_recv_next
)
# <<<
return
tensor_recv_prev
,
tensor_recv_next
...
...
megatron/schedules.py
View file @
5bc9f889
...
...
@@ -334,6 +334,15 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
input_tensor
,
losses_reduced
)
output_tensors
[
model_chunk_id
].
append
(
output_tensor
)
# >>>
if
id
(
input_tensor
)
==
id
(
output_tensor
):
raise
Exception
(
"tp %d, pp %d, vp %d."
%
(
mpu
.
get_tensor_model_parallel_rank
(),
mpu
.
get_pipeline_model_parallel_rank
(),
mpu
.
get_virtual_pipeline_model_parallel_rank
(),
))
# <<<
# if forward-only, no need to save tensors for a backward pass
if
forward_only
:
input_tensors
[
model_chunk_id
].
pop
()
...
...
megatron/training.py
View file @
5bc9f889
...
...
@@ -369,8 +369,18 @@ def setup_model_and_optimizer(model_provider_func, model_type):
model
=
get_model
(
model_provider_func
,
model_type
)
# >>>
# from lutil import pax
# pax({"model": model})
# if mpu.get_tensor_model_parallel_rank() == 0:
# from lutil import pax
# pax({
# # "model" : model,
# "model" : [
# sum(t.nelement() for t in m.parameters())
# for m in model
# ],
# })
# else:
# torch.distributed.barrier()
# exit(0)
# <<<
unwrapped_model
=
unwrap_model
(
model
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment