Commit 5bc9f889 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

narrowed issue to pipeline rank 0, virtual pipeline rank >= 1

parent f17a3933
......@@ -698,6 +698,18 @@ class ParallelTransformer(MegatronModule):
# See set_input_tensor()
hidden_states = self.input_tensor
# >>>
# if not self.pre_process and self.num_layers == 0:
# # raise Exception("tp %d, pp %d, vp %d ... hidden states %s, input tensor %s." % (
# # mpu.get_tensor_model_parallel_rank(),
# # mpu.get_pipeline_model_parallel_rank(),
# # mpu.get_virtual_pipeline_model_parallel_rank(),
# # "--" if hidden_states is None else str(hidden_states.shape),
# # "--" if self.input_tensor is None else str(self.input_tensor.shape),
# # ))
# hidden_states = hidden_states.clone()
# <<<
# Viewless tensor.
# - We only need to create a viewless tensor in the case of micro batch
# size (mbs) == 1, since in this case, 'hidden_states.transpose()'
......
......@@ -136,22 +136,35 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
# To protect against race condition when using batch_isend_irecv().
torch.cuda.synchronize()
# >>>
def make_viewless_tensor(t):
return mpu.make_viewless_tensor(t, requires_grad=True, keep_graph=False)
# <<<
# If using scatter-gather optimization, gather smaller chunks.
if not override_scatter_gather_tensors_in_pipeline and \
args.scatter_gather_tensors_in_pipeline:
if recv_prev:
tensor_recv_prev = mpu.gather_split_1d_tensor(
tensor_recv_prev).view(tensor_shape).requires_grad_()
tensor_recv_prev = mpu.make_viewless_tensor(tensor_recv_prev,
requires_grad = True,
keep_graph = False)
# >>>
# tensor_recv_prev = mpu.make_viewless_tensor(tensor_recv_prev,
# requires_grad = True,
# keep_graph = False)
# +++
tensor_recv_prev = make_viewless_tensor(tensor_recv_prev)
# <<<
if recv_next:
tensor_recv_next = mpu.gather_split_1d_tensor(
tensor_recv_next).view(tensor_shape).requires_grad_()
tensor_recv_next = mpu.make_viewless_tensor(tensor_recv_next,
requires_grad = True,
keep_graph = False)
# >>>
# tensor_recv_next = mpu.make_viewless_tensor(tensor_recv_next,
# requires_grad = True,
# keep_graph = False)
# +++
tensor_recv_next = make_viewless_tensor(tensor_recv_next)
# <<<
return tensor_recv_prev, tensor_recv_next
......
......@@ -334,6 +334,15 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
input_tensor, losses_reduced)
output_tensors[model_chunk_id].append(output_tensor)
# >>>
if id(input_tensor) == id(output_tensor):
raise Exception("tp %d, pp %d, vp %d." % (
mpu.get_tensor_model_parallel_rank(),
mpu.get_pipeline_model_parallel_rank(),
mpu.get_virtual_pipeline_model_parallel_rank(),
))
# <<<
# if forward-only, no need to save tensors for a backward pass
if forward_only:
input_tensors[model_chunk_id].pop()
......
......@@ -369,8 +369,18 @@ def setup_model_and_optimizer(model_provider_func, model_type):
model = get_model(model_provider_func, model_type)
# >>>
# from lutil import pax
# pax({"model": model})
# if mpu.get_tensor_model_parallel_rank() == 0:
# from lutil import pax
# pax({
# # "model" : model,
# "model" : [
# sum(t.nelement() for t in m.parameters())
# for m in model
# ],
# })
# else:
# torch.distributed.barrier()
# exit(0)
# <<<
unwrapped_model = unwrap_model(model,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment