Unverified Commit 86d22581 authored by Bin Jia's avatar Bin Jia Committed by GitHub
Browse files

[shardformer] Add overlap optional for HybridParallelPlugin (#4615)

* add optional overlap for plugin

* remove fixed todo
parent a39a5c66
......@@ -280,6 +280,7 @@ class HybridParallelPlugin(PipelinePluginBase):
enable_flash_attention: bool = False,
enable_jit_fused: bool = False,
enable_sequence_parallelism: bool = False,
enable_sequence_overlap: bool = False,
num_microbatches: Optional[int] = None,
microbatch_size: Optional[int] = None,
initial_scale: float = 2**16,
......@@ -341,7 +342,8 @@ class HybridParallelPlugin(PipelinePluginBase):
enable_fused_normalization=self.enable_fused_normalization,
enable_flash_attention=self.enable_flash_attention,
enable_jit_fused=self.enable_jit_fused,
enable_sequence_parallelism=enable_sequence_parallelism)
enable_sequence_parallelism=enable_sequence_parallelism,
enable_sequence_overlap=enable_sequence_overlap)
self.amp_config = dict(
initial_scale=initial_scale,
growth_factor=growth_factor,
......
......@@ -180,7 +180,6 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
overlap = ctx.overlap
if not overlap:
# TODO: overlap SP input with gradient computation
input_parallel = _gather(input_, dim, process_group)
total_input = input_parallel
......@@ -191,7 +190,6 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
grad_output = grad_output.view(-1, grad_output.shape[-1])
total_input = total_input.view(-1, total_input.shape[-1])
# TODO: overlap SP input with gradient computation
if ctx.async_grad_reduce_scatter:
# Asynchronous reduce-scatter
input_list = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment