"tools/cfgs/vscode:/vscode.git/clone" did not exist on "70857b83ca8689732af9c1e1efc84c99041ac467"
Unverified Commit 86d22581 authored by Bin Jia's avatar Bin Jia Committed by GitHub
Browse files

[shardformer] Add overlap optional for HybridParallelPlugin (#4615)

* add optional overlap for plugin

* remove fixed todo
parent a39a5c66
...@@ -280,6 +280,7 @@ class HybridParallelPlugin(PipelinePluginBase): ...@@ -280,6 +280,7 @@ class HybridParallelPlugin(PipelinePluginBase):
enable_flash_attention: bool = False, enable_flash_attention: bool = False,
enable_jit_fused: bool = False, enable_jit_fused: bool = False,
enable_sequence_parallelism: bool = False, enable_sequence_parallelism: bool = False,
enable_sequence_overlap: bool = False,
num_microbatches: Optional[int] = None, num_microbatches: Optional[int] = None,
microbatch_size: Optional[int] = None, microbatch_size: Optional[int] = None,
initial_scale: float = 2**16, initial_scale: float = 2**16,
...@@ -341,7 +342,8 @@ class HybridParallelPlugin(PipelinePluginBase): ...@@ -341,7 +342,8 @@ class HybridParallelPlugin(PipelinePluginBase):
enable_fused_normalization=self.enable_fused_normalization, enable_fused_normalization=self.enable_fused_normalization,
enable_flash_attention=self.enable_flash_attention, enable_flash_attention=self.enable_flash_attention,
enable_jit_fused=self.enable_jit_fused, enable_jit_fused=self.enable_jit_fused,
enable_sequence_parallelism=enable_sequence_parallelism) enable_sequence_parallelism=enable_sequence_parallelism,
enable_sequence_overlap=enable_sequence_overlap)
self.amp_config = dict( self.amp_config = dict(
initial_scale=initial_scale, initial_scale=initial_scale,
growth_factor=growth_factor, growth_factor=growth_factor,
......
...@@ -180,7 +180,6 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function): ...@@ -180,7 +180,6 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
overlap = ctx.overlap overlap = ctx.overlap
if not overlap: if not overlap:
# TODO: overlap SP input with gradient computation
input_parallel = _gather(input_, dim, process_group) input_parallel = _gather(input_, dim, process_group)
total_input = input_parallel total_input = input_parallel
...@@ -191,7 +190,6 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function): ...@@ -191,7 +190,6 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
grad_output = grad_output.view(-1, grad_output.shape[-1]) grad_output = grad_output.view(-1, grad_output.shape[-1])
total_input = total_input.view(-1, total_input.shape[-1]) total_input = total_input.view(-1, total_input.shape[-1])
# TODO: overlap SP input with gradient computation
if ctx.async_grad_reduce_scatter: if ctx.async_grad_reduce_scatter:
# Asynchronous reduce-scatter # Asynchronous reduce-scatter
input_list = [ input_list = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment