Unverified Commit 3c19f106 authored by Eric Harper's avatar Eric Harper Committed by GitHub
Browse files

Sequence parallel perf updates (#1437)



* use _all_gather_base
Signed-off-by: default avatarericharper <complex451@gmail.com>

* use _reduce_scatter_base
Signed-off-by: default avatarericharper <complex451@gmail.com>

* remove torch empty in backward
Signed-off-by: default avatarericharper <complex451@gmail.com>

* check self.attn_mask_type
Signed-off-by: default avatarericharper <complex451@gmail.com>

* remove extra arg
Signed-off-by: default avatarericharper <complex451@gmail.com>

* update get_tensor_shapes logic
Signed-off-by: default avatarericharper <complex451@gmail.com>
parent 2e025ab5
...@@ -162,7 +162,10 @@ class FusedScaleMaskSoftmax(torch.nn.Module): ...@@ -162,7 +162,10 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
if ( if (
self.scaled_masked_softmax_fusion # user want to fuse self.scaled_masked_softmax_fusion # user want to fuse
and self.input_in_float16 # input must be fp16 and self.input_in_float16 # input must be fp16
and mask is not None # mask tensor must not be None and (
self.attn_mask_type == AttnMaskType.causal
or (self.attn_mask_type == AttnMaskType.padding and mask is not None)
)
and 16 < sk <= 2048 # sk must be 16 ~ 2048 and 16 < sk <= 2048 # sk must be 16 ~ 2048
and sq % 4 == 0 # sq must be divisor of 4 and sq % 4 == 0 # sq must be divisor of 4
and sk % 4 == 0 # sk must be divisor of 4 and sk % 4 == 0 # sk must be divisor of 4
......
...@@ -55,28 +55,31 @@ def get_tensor_shapes( ...@@ -55,28 +55,31 @@ def get_tensor_shapes(
assert ( assert (
len(tensor_shape) == 3 len(tensor_shape) == 3
), f"`tensor_shape` should be [sequence_length, micro_batch_size, hidden_size] but {tensor_shape}" ), f"`tensor_shape` should be [sequence_length, micro_batch_size, hidden_size] but {tensor_shape}"
sequence_length, micro_batch_size, hidden_size = tensor_shape sequence_length, micro_batch_size, hidden_size = tensor_shape
seq_len = sequence_length
if sequence_parallel_enabled:
seq_len = sequence_length // parallel_state.get_tensor_model_parallel_world_size()
tensor_shapes = [] tensor_shapes = []
if sequence_parallel_enabled:
seq_length = sequence_length // parallel_state.get_tensor_model_parallel_world_size()
else:
seq_length = sequence_length
if model_type == ModelType.encoder_and_decoder: if model_type == ModelType.encoder_and_decoder:
if decoder_sequence_length is None:
raise ValueError("`decoder_sequence_length` is required for `ModelType.encoder_and_decoder`")
dec_seq_len = decoder_sequence_length
if sequence_parallel_enabled: if sequence_parallel_enabled:
dec_seq_len = decoder_sequence_length // parallel_state.get_tensor_model_parallel_world_size() dec_seq_length = decoder_sequence_length // parallel_state.get_tensor_model_parallel_world_size()
else:
dec_seq_length = decoder_sequence_length
if parallel_state.is_pipeline_stage_before_split(rank): if parallel_state.is_pipeline_stage_before_split(rank):
# If next rank is after split, then need transpose for encoder_hidden_state. tensor_shapes.append((seq_length, micro_batch_size, hidden_size))
if parallel_state.is_pipeline_stage_before_split(rank + 1):
tensor_shapes.append((seq_len, micro_batch_size, hidden_size))
else:
tensor_shapes.append((dec_seq_len, micro_batch_size, hidden_size))
else: else:
tensor_shapes.append((dec_seq_len, micro_batch_size, hidden_size)) tensor_shapes.append((dec_seq_length, micro_batch_size, hidden_size))
tensor_shapes.append((seq_len, micro_batch_size, hidden_size)) tensor_shapes.append((seq_length, micro_batch_size, hidden_size))
else: else:
tensor_shapes.append((seq_len, micro_batch_size, hidden_size)) tensor_shapes.append((seq_length, micro_batch_size, hidden_size))
return tensor_shapes return tensor_shapes
......
...@@ -302,11 +302,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): ...@@ -302,11 +302,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
requires_grad=False, requires_grad=False,
) )
torch.distributed.all_gather( torch.distributed._all_gather_base(all_gather_buffer, input, group=get_tensor_model_parallel_group())
list(all_gather_buffer.chunk(world_size)),
input,
group=get_tensor_model_parallel_group(),
)
total_input = all_gather_buffer total_input = all_gather_buffer
else: else:
total_input = input total_input = input
...@@ -331,15 +327,12 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): ...@@ -331,15 +327,12 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
requires_grad=False, requires_grad=False,
) )
handle = torch.distributed.all_gather( handle = torch.distributed._all_gather_base(
list(all_gather_buffer.chunk(get_tensor_model_parallel_world_size())), all_gather_buffer,
input, input,
group=get_tensor_model_parallel_group(), group=get_tensor_model_parallel_group(),
async_op=True, async_op=True,
) )
# Delay the start of input gradient computation shortly (3us) to have gather scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1
total_input = all_gather_buffer total_input = all_gather_buffer
else: else:
total_input = input total_input = input
...@@ -358,21 +351,16 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): ...@@ -358,21 +351,16 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
handle = torch.distributed.all_reduce( handle = torch.distributed.all_reduce(
grad_input, group=get_tensor_model_parallel_group(), async_op=True grad_input, group=get_tensor_model_parallel_group(), async_op=True
) )
# Delay the start of weight gradient computation shortly (3us) to have
# all-reduce scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1
if ctx.sequence_parallel_enabled: if ctx.sequence_parallel_enabled:
assert not ctx.async_grad_allreduce assert not ctx.async_grad_allreduce
sub_grad_input = torch.empty(input.shape, dtype=input.dtype, device=torch.cuda.current_device(), requires_grad=False) sub_grad_input = torch.empty(input.shape, dtype=input.dtype, device=torch.cuda.current_device(), requires_grad=False)
handle = torch.distributed.reduce_scatter( handle = torch.distributed._reduce_scatter_base(
sub_grad_input, sub_grad_input,
list(grad_input.chunk(get_tensor_model_parallel_world_size())), grad_input,
group=get_tensor_model_parallel_group(), group=get_tensor_model_parallel_group(),
async_op=True, async_op=True
) )
# Delay the start of weight gradient computation shortly (3us) to have reduce scatter scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1
if ctx.gradient_accumulation_fusion: if ctx.gradient_accumulation_fusion:
if not ctx.use_16bit_in_wgrad_accum_fusion: if not ctx.use_16bit_in_wgrad_accum_fusion:
......
...@@ -103,15 +103,11 @@ def _gather_along_first_dim(input_: torch.Tensor) -> torch.Tensor: ...@@ -103,15 +103,11 @@ def _gather_along_first_dim(input_: torch.Tensor) -> torch.Tensor:
shape[0] *= world_size shape[0] *= world_size
output = torch.empty(shape, dtype=input_.dtype, device=torch.cuda.current_device()) output = torch.empty(shape, dtype=input_.dtype, device=torch.cuda.current_device())
# Original implementation uses `_all_gather_base` as follows. torch.distributed._all_gather_base(
# Deliberately keep the comment-out for reference because output,
# I'd love to switch to this API once this gets public/stable.
# torch.distributed._all_gather_base(output, input_.contiguous(), group=get_tensor_model_parallel_group())
torch.distributed.all_gather(
list(output.chunk(world_size)),
input_.contiguous(), input_.contiguous(),
group=get_tensor_model_parallel_group(), group=get_tensor_model_parallel_group()
) )
return output return output
...@@ -126,15 +122,11 @@ def _reduce_scatter_along_first_dim(input_: torch.Tensor) -> torch.Tensor: ...@@ -126,15 +122,11 @@ def _reduce_scatter_along_first_dim(input_: torch.Tensor) -> torch.Tensor:
assert shape[0] % world_size == 0 assert shape[0] % world_size == 0
shape[0] //= world_size shape[0] //= world_size
output = torch.empty(shape, dtype=input_.dtype, device=torch.cuda.current_device()) output = torch.empty(shape, dtype=input_.dtype, device=torch.cuda.current_device())
# Original implementation uses `_reduce_scatter_base` as follows. torch.distributed._reduce_scatter_base(
# Deliberately keep the comment-out for reference because
# I'd love to switch to this API once this gets public/stable.
# torch.distributed._reduce_scatter_base(output, input_.contiguous(), group=get_tensor_model_parallel_group())
torch.distributed.reduce_scatter(
output, output,
list(input_.contiguous().chunk(world_size)), input_.contiguous(),
group=get_tensor_model_parallel_group(), group=get_tensor_model_parallel_group()
) )
return output return output
......
...@@ -40,8 +40,9 @@ def gather_split_1d_tensor(tensor): ...@@ -40,8 +40,9 @@ def gather_split_1d_tensor(tensor):
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
requires_grad=False, requires_grad=False,
) )
chunks = [gathered[i * numel : (i + 1) * numel] for i in range(world_size)] torch.distributed._all_gather_base(
torch.distributed.all_gather( gathered,
chunks, tensor, group=parallel_state.get_tensor_model_parallel_group() tensor,
) group=parallel_state.get_tensor_model_parallel_group()
)
return gathered return gathered
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment