Commit 6dcd0fb8 authored by dongcl's avatar dongcl
Browse files

modify schedule_chunk_1f1b to fix the nccl communication bug

parent 124accba
...@@ -342,8 +342,6 @@ class MoeMlPNode(TransformerLayerNode): ...@@ -342,8 +342,6 @@ class MoeMlPNode(TransformerLayerNode):
assert mlp_bias is None assert mlp_bias is None
# pre_mlp_layernorm_output used # pre_mlp_layernorm_output used
# cur_stream = torch.cuda.current_stream()
# self.common_state.pre_mlp_layernorm_output.record_stream(cur_stream)
self.common_state.pre_mlp_layernorm_output = None self.common_state.pre_mlp_layernorm_output = None
return expert_output, shared_expert_output return expert_output, shared_expert_output
...@@ -542,6 +540,8 @@ class ModelChunkSchedulePlan(AbstractSchedulePlan): ...@@ -542,6 +540,8 @@ class ModelChunkSchedulePlan(AbstractSchedulePlan):
def state(self): def state(self):
return self._model_chunk_state return self._model_chunk_state
# F_DISPATCH_B_MLP_SYNC_EVENT = torch.cuda.Event()
F_DISPATCH_B_MLP_SYNC_EVENT = None
def schedule_layer_1f1b( def schedule_layer_1f1b(
f_layer, f_layer,
...@@ -581,13 +581,17 @@ def schedule_layer_1f1b( ...@@ -581,13 +581,17 @@ def schedule_layer_1f1b(
with f_context: with f_context:
f_input = f_layer.attn.forward(f_input) f_input = f_layer.attn.forward(f_input)
f_dispatch_b_mlp_sync_event = None
if f_layer is not None and b_layer is not None:
f_dispatch_b_mlp_sync_event = F_DISPATCH_B_MLP_SYNC_EVENT
if f_layer is not None: if f_layer is not None:
with f_context: with f_context:
f_input = f_layer.dispatch.forward(f_input) f_input = f_layer.dispatch.forward(f_input, stream_record_event=f_dispatch_b_mlp_sync_event)
if b_layer is not None: if b_layer is not None:
with b_context: with b_context:
b_grad = b_layer.mlp.backward(b_grad) b_grad = b_layer.mlp.backward(b_grad, stream_wait_event=f_dispatch_b_mlp_sync_event)
b_grad = b_layer.dispatch.backward(b_grad) b_grad = b_layer.dispatch.backward(b_grad)
b_layer.mlp.dw() b_layer.mlp.dw()
...@@ -690,32 +694,28 @@ def schedule_chunk_1f1b( ...@@ -690,32 +694,28 @@ def schedule_chunk_1f1b(
) )
torch.cuda.nvtx.range_pop() torch.cuda.nvtx.range_pop()
# tail forward
f_input = layer_pre_forward()
del layer_pre_forward
# tail backward # tail backward
grad = layer_pre_backward() grad = layer_pre_backward()
del layer_pre_backward del layer_pre_backward
with b_context: with b_context:
for i in range(overlaped_layers, b_num_layers): for i in range(overlaped_layers, b_num_layers):
b_layer = b_schedule_plan.get_layer(b_num_layers - 1 - i) b_layer = b_schedule_plan.get_layer(b_num_layers - 1 - i)
torch.cuda.nvtx.range_push(f"layer_{b_num_layers - 1 - i}b") torch.cuda.nvtx.range_push(f"layer_{b_num_layers - 1 - i}b")
tmp, grad, _ = schedule_layer_1f1b(None, b_layer, b_grad=grad) _, grad, _ = schedule_layer_1f1b(None, b_layer, b_grad=grad)
torch.cuda.nvtx.range_pop() torch.cuda.nvtx.range_pop()
if b_schedule_plan is not None:
b_schedule_plan.pre_process.backward(grad)
# tail forward
f_input = layer_pre_forward()
del layer_pre_forward
with f_context: with f_context:
for i in range(overlaped_layers, f_num_layers): for i in range(overlaped_layers, f_num_layers):
f_layer = f_schedule_plan.get_layer(i) f_layer = f_schedule_plan.get_layer(i)
torch.cuda.nvtx.range_push(f"layer_{i}f") torch.cuda.nvtx.range_push(f"layer_{i}f")
f_input, tmp, _ = schedule_layer_1f1b(f_layer, None, f_input=f_input) f_input, _, _ = schedule_layer_1f1b(f_layer, None, f_input=f_input)
torch.cuda.nvtx.range_pop() torch.cuda.nvtx.range_pop()
if f_schedule_plan is not None and f_schedule_plan.post_process is not None:
f_input = f_schedule_plan.post_process.forward(f_input)
# output pp send receive, overlapped with attn backward # output pp send receive, overlapped with attn backward
if f_schedule_plan is not None and post_forward is not None: if f_schedule_plan is not None and post_forward is not None:
with f_context: with f_context:
...@@ -732,6 +732,13 @@ def schedule_chunk_1f1b( ...@@ -732,6 +732,13 @@ def schedule_chunk_1f1b(
layer_pre_backward_dw() layer_pre_backward_dw()
del layer_pre_backward_dw del layer_pre_backward_dw
with f_context:
if f_schedule_plan is not None and f_schedule_plan.post_process is not None:
f_input = f_schedule_plan.post_process.forward(f_input)
with b_context:
if b_schedule_plan is not None:
b_schedule_plan.pre_process.backward(grad)
if f_schedule_plan: if f_schedule_plan:
f_schedule_plan.wait_current_stream() f_schedule_plan.wait_current_stream()
if b_schedule_plan: if b_schedule_plan:
......
...@@ -73,17 +73,20 @@ class ScheduleNode: ...@@ -73,17 +73,20 @@ class ScheduleNode:
) )
return output_grad return output_grad
def forward(self, inputs=()): def forward(self, inputs=(), stream_wait_event=None, stream_record_event=None):
"""schedule node forward""" """schedule node forward"""
if not isinstance(inputs, tuple): if not isinstance(inputs, tuple):
inputs = (inputs,) inputs = (inputs,)
return self._forward(*inputs) return self._forward(*inputs, stream_wait_event=stream_wait_event, stream_record_event=stream_record_event)
def _forward(self, *inputs): def _forward(self, *inputs, stream_wait_event=None, stream_record_event=None):
with stream_acquire_context(self.stream, self.event): with stream_acquire_context(self.stream, self.event):
torch.cuda.nvtx.range_push(f"{self.name} forward") torch.cuda.nvtx.range_push(f"{self.name} forward")
with torch.cuda.stream(self.stream): with torch.cuda.stream(self.stream):
if stream_wait_event is not None:
stream_wait_event.wait(self.stream)
self.inputs = [make_viewless(e).detach() if e is not None else None for e in inputs] self.inputs = [make_viewless(e).detach() if e is not None else None for e in inputs]
for i, input in enumerate(self.inputs): for i, input in enumerate(self.inputs):
if input is not None: if input is not None:
...@@ -98,6 +101,10 @@ class ScheduleNode: ...@@ -98,6 +101,10 @@ class ScheduleNode:
data = tuple([make_viewless(e) if isinstance(e, Tensor) else e for e in data]) data = tuple([make_viewless(e) if isinstance(e, Tensor) else e for e in data])
self.output = data self.output = data
if stream_record_event is not None:
stream_record_event.record(self.stream)
torch.cuda.nvtx.range_pop() torch.cuda.nvtx.range_pop()
if self.free_inputs: if self.free_inputs:
...@@ -111,16 +118,19 @@ class ScheduleNode: ...@@ -111,16 +118,19 @@ class ScheduleNode:
"""get the forward output""" """get the forward output"""
return self.output return self.output
def backward(self, output_grad): def backward(self, output_grad, stream_wait_event=None, stream_record_event=None):
"""schedule node backward""" """schedule node backward"""
if not isinstance(output_grad, tuple): if not isinstance(output_grad, tuple):
output_grad = (output_grad,) output_grad = (output_grad,)
return self._backward(*output_grad) return self._backward(*output_grad, stream_wait_event=stream_wait_event, stream_record_event=stream_record_event)
def _backward(self, *output_grad): def _backward(self, *output_grad, stream_wait_event=None, stream_record_event=None):
with stream_acquire_context(self.stream, self.event): with stream_acquire_context(self.stream, self.event):
torch.cuda.nvtx.range_push(f"{self.name} backward") torch.cuda.nvtx.range_push(f"{self.name} backward")
with torch.cuda.stream(self.stream): with torch.cuda.stream(self.stream):
if stream_wait_event is not None:
stream_wait_event.wait(self.stream)
outputs = self.output outputs = self.output
if not isinstance(outputs, tuple): if not isinstance(outputs, tuple):
outputs = (outputs,) outputs = (outputs,)
...@@ -131,6 +141,10 @@ class ScheduleNode: ...@@ -131,6 +141,10 @@ class ScheduleNode:
output_grad = self.backward_func(outputs, output_grad) output_grad = self.backward_func(outputs, output_grad)
else: else:
output_grad = self.default_backward_func(outputs, output_grad) output_grad = self.default_backward_func(outputs, output_grad)
if stream_record_event is not None:
stream_record_event.record(self.stream)
torch.cuda.nvtx.range_pop() torch.cuda.nvtx.range_pop()
# output_grad maybe from another stream # output_grad maybe from another stream
...@@ -198,17 +212,6 @@ def schedule_chunk_1f1b( ...@@ -198,17 +212,6 @@ def schedule_chunk_1f1b(
) )
def schedule_chunk_forward(schedule_plan):
"""model level fine-grained forward schedule"""
f_input = schedule_chunk_1f1b(schedule_plan, None, None)
return f_input
def schedule_chunk_backward(schedule_plan, grad):
"""model level fine-grained backward schedule"""
tmp = schedule_chunk_1f1b(None, schedule_plan, grad)
_COMP_STREAM = None _COMP_STREAM = None
_COM_STREAM = None _COM_STREAM = None
...@@ -221,7 +224,7 @@ def set_streams(comp_stream=None, com_stream=None): ...@@ -221,7 +224,7 @@ def set_streams(comp_stream=None, com_stream=None):
return return
if comp_stream is None: if comp_stream is None:
comp_stream = torch.cuda.Stream(device="cuda") comp_stream = torch.cuda.current_stream()
if com_stream is None: if com_stream is None:
com_stream = torch.cuda.Stream(device="cuda") com_stream = torch.cuda.Stream(device="cuda")
......
...@@ -448,6 +448,8 @@ def forward_backward_pipelining_with_interleaving( ...@@ -448,6 +448,8 @@ def forward_backward_pipelining_with_interleaving(
"""Helper method to run backward step with model split into chunks """Helper method to run backward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling (run set_virtual_pipeline_model_parallel_rank() before calling
backward_step()).""" backward_step())."""
nonlocal output_tensor_grads
model_chunk_id = get_model_chunk_id(virtual_microbatch_id, forward=False) model_chunk_id = get_model_chunk_id(virtual_microbatch_id, forward=False)
parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id) parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
......
...@@ -40,6 +40,8 @@ class ExtraTransformerConfig: ...@@ -40,6 +40,8 @@ class ExtraTransformerConfig:
combined_1f1b_recipe: str = 'ep_a2a' combined_1f1b_recipe: str = 'ep_a2a'
"""Recipe to use for combined 1F1B. Currently only 'ep_a2a' and 'golden' are supported.""" """Recipe to use for combined 1F1B. Currently only 'ep_a2a' and 'golden' are supported."""
# split_bw: bool = False
@dataclass @dataclass
class TransformerConfigPatch(TransformerConfig, ExtraTransformerConfig): class TransformerConfigPatch(TransformerConfig, ExtraTransformerConfig):
......
...@@ -196,6 +196,7 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler, ...@@ -196,6 +196,7 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
repeat=1), repeat=1),
on_trace_ready=trace_handler, on_trace_ready=trace_handler,
record_shapes=True, record_shapes=True,
with_stack=True,
) )
prof.start() prof.start()
elif args.profile and torch.distributed.get_rank() in args.profile_ranks and args.use_hip_profiler: elif args.profile and torch.distributed.get_rank() in args.profile_ranks and args.use_hip_profiler:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment