Unverified Commit 0e0eef00 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

[DP] fix the compatibility issue between DP attention and `--attention-backend triton` (#8723)

parent cb099d20
...@@ -646,12 +646,17 @@ class ForwardBatch: ...@@ -646,12 +646,17 @@ class ForwardBatch:
device=model_runner.device, device=model_runner.device,
) )
bs = self.batch_size
if len(global_num_tokens) > 1: if len(global_num_tokens) > 1:
num_tokens = global_num_tokens[get_attention_dp_rank()] num_tokens = global_num_tokens[get_attention_dp_rank()]
else: else:
num_tokens = global_num_tokens[0] num_tokens = global_num_tokens[0]
if self.forward_mode.is_decode():
setattr(self, "raw_bs", self.batch_size)
self.batch_size = num_tokens
bs = self.batch_size
# padding # padding
self.input_ids = self._pad_tensor_to_size(self.input_ids, num_tokens) self.input_ids = self._pad_tensor_to_size(self.input_ids, num_tokens)
self.req_pool_indices = self._pad_tensor_to_size(self.req_pool_indices, bs) self.req_pool_indices = self._pad_tensor_to_size(self.req_pool_indices, bs)
...@@ -659,6 +664,9 @@ class ForwardBatch: ...@@ -659,6 +664,9 @@ class ForwardBatch:
seq_len_fill_value = ( seq_len_fill_value = (
model_runner.attn_backend.get_cuda_graph_seq_len_fill_value() model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
) )
self.seq_lens_sum = self.seq_lens_sum + seq_len_fill_value * (
bs - self.seq_lens.shape[0]
)
self.seq_lens = self._pad_tensor_to_size( self.seq_lens = self._pad_tensor_to_size(
self.seq_lens, bs, value=seq_len_fill_value self.seq_lens, bs, value=seq_len_fill_value
) )
...@@ -702,7 +710,7 @@ class ForwardBatch: ...@@ -702,7 +710,7 @@ class ForwardBatch:
def post_forward_mlp_sync_batch(self, logits_output: LogitsProcessorOutput): def post_forward_mlp_sync_batch(self, logits_output: LogitsProcessorOutput):
bs = self.batch_size bs = getattr(self, "raw_bs", self.batch_size)
if self.spec_info is not None: if self.spec_info is not None:
if self.forward_mode.is_decode(): # draft if self.forward_mode.is_decode(): # draft
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment