Unverified Commit 0e0eef00 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

[DP] fix the compatibility issue between DP attention and `--attention-backend triton` (#8723)

parent cb099d20
......@@ -646,12 +646,17 @@ class ForwardBatch:
device=model_runner.device,
)
bs = self.batch_size
if len(global_num_tokens) > 1:
num_tokens = global_num_tokens[get_attention_dp_rank()]
else:
num_tokens = global_num_tokens[0]
if self.forward_mode.is_decode():
setattr(self, "raw_bs", self.batch_size)
self.batch_size = num_tokens
bs = self.batch_size
# padding
self.input_ids = self._pad_tensor_to_size(self.input_ids, num_tokens)
self.req_pool_indices = self._pad_tensor_to_size(self.req_pool_indices, bs)
......@@ -659,6 +664,9 @@ class ForwardBatch:
seq_len_fill_value = (
model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
)
self.seq_lens_sum = self.seq_lens_sum + seq_len_fill_value * (
bs - self.seq_lens.shape[0]
)
self.seq_lens = self._pad_tensor_to_size(
self.seq_lens, bs, value=seq_len_fill_value
)
......@@ -702,7 +710,7 @@ class ForwardBatch:
def post_forward_mlp_sync_batch(self, logits_output: LogitsProcessorOutput):
bs = self.batch_size
bs = getattr(self, "raw_bs", self.batch_size)
if self.spec_info is not None:
if self.forward_mode.is_decode(): # draft
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment