Unverified Commit 50188092 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

[DP] fix: engine crash when decode batch is padded (#8995)

parent 326a901d
......@@ -408,9 +408,9 @@ class CommunicateWithAllReduceAndLayerNormFn:
):
if residual_input_mode == ScatterMode.SCATTERED and context.attn_tp_size > 1:
residual, local_residual = (
forward_batch.gathered_buffer[
: forward_batch.input_ids.shape[0]
].clone(),
torch.empty_like(
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]]
),
residual,
)
attn_tp_all_gather_into_tensor(residual, local_residual)
......@@ -420,13 +420,11 @@ class CommunicateWithAllReduceAndLayerNormFn:
# Perform layernorm on smaller data before comm. Only valid when attn_tp_size is 1 (tp_size == dp_size)
use_layer_norm_before_gather = context.attn_tp_size == 1
if use_layer_norm_before_gather:
residual.copy_(hidden_states)
if hidden_states.shape[0] != 0:
hidden_states = layernorm(hidden_states)
if use_layer_norm_before_gather and hidden_states.shape[0] != 0:
residual = hidden_states
hidden_states = layernorm(hidden_states)
hidden_states, local_hidden_states = (
forward_batch.gathered_buffer,
torch.empty_like(forward_batch.gathered_buffer),
hidden_states,
)
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
......@@ -552,10 +550,6 @@ class CommunicateSummableTensorPairFn:
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states,
)
if hidden_states.data_ptr() is global_hidden_states.data_ptr():
hidden_states = torch.empty_like(hidden_states)
if allow_reduce_scatter and forward_batch.dp_padding_mode.is_max_len():
# When using padding, all_reduce is skipped after MLP and MOE and reduce scatter is used here instead.
dp_reduce_scatter_tensor(hidden_states, global_hidden_states)
......
......@@ -653,12 +653,30 @@ class ForwardBatch:
else:
num_tokens = global_num_tokens[0]
if self.forward_mode.is_decode():
setattr(self, "raw_bs", self.batch_size)
self.batch_size = num_tokens
bs = self.batch_size
if self.forward_mode.is_decode():
if self.is_extend_in_batch and dp_padding_mode.is_max_len():
setattr(self, "_original_forward_mode", self.forward_mode)
self.forward_mode = ForwardMode.EXTEND
self.extend_num_tokens = bs
self.extend_seq_lens = torch.full_like(self.seq_lens, 1)
self.extend_prefix_lens = self.seq_lens - 1
self.extend_start_loc = torch.arange(
bs, dtype=torch.int32, device=self.seq_lens.device
)
self.extend_prefix_lens_cpu = self.extend_prefix_lens.cpu()
self.extend_seq_lens_cpu = self.extend_seq_lens.cpu()
self.extend_logprob_start_lens_cpu = self.extend_prefix_lens_cpu
else:
setattr(self, "_original_batch_size", self.batch_size)
if self.spec_info is not None:
bs = self.batch_size = (
num_tokens // self.spec_info.num_tokens_per_batch
)
else:
bs = self.batch_size = num_tokens
# padding
self.input_ids = self._pad_tensor_to_size(self.input_ids, num_tokens)
self.req_pool_indices = self._pad_tensor_to_size(self.req_pool_indices, bs)
......@@ -689,6 +707,7 @@ class ForwardBatch:
if self.mrope_positions is not None:
self.mrope_positions = self._pad_tensor_to_size(self.mrope_positions, bs)
# TODO: check if we need to pad other tensors
if self.extend_seq_lens is not None:
self.extend_seq_lens = self._pad_tensor_to_size(self.extend_seq_lens, bs)
......@@ -712,7 +731,9 @@ class ForwardBatch:
def post_forward_mlp_sync_batch(self, logits_output: LogitsProcessorOutput):
bs = getattr(self, "raw_bs", self.batch_size)
self.forward_mode = getattr(self, "_original_forward_mode", self.forward_mode)
self.batch_size = getattr(self, "_original_batch_size", self.batch_size)
bs = self.batch_size
if self.spec_info is not None:
if self.forward_mode.is_decode(): # draft
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment