Unverified Commit 10771026 authored by IAN's avatar IAN Committed by GitHub
Browse files

[BugFix] Fix crash when receive a req with structed output in DP attention mode. (#3841)

parent 4606e2a3
...@@ -46,6 +46,7 @@ def json_decode(s, document): ...@@ -46,6 +46,7 @@ def json_decode(s, document):
def main(args): def main(args):
lines = read_jsonl(args.data_path) lines = read_jsonl(args.data_path)
lines = list(lines)
arguments = [] arguments = []
for i in range(len(lines[: args.num_questions])): for i in range(len(lines[: args.num_questions])):
arguments.append( arguments.append(
......
...@@ -1154,6 +1154,10 @@ class Scheduler: ...@@ -1154,6 +1154,10 @@ class Scheduler:
elif batch.forward_mode.is_idle(): elif batch.forward_mode.is_idle():
if self.enable_overlap: if self.enable_overlap:
self.tp_worker.resolve_batch_result(result.bid) self.tp_worker.resolve_batch_result(result.bid)
if batch.next_batch_sampling_info:
batch.next_batch_sampling_info.update_regex_vocab_mask()
self.current_stream.synchronize()
batch.next_batch_sampling_info.sampling_info_done.set()
elif batch.forward_mode.is_dummy_first(): elif batch.forward_mode.is_dummy_first():
batch.next_batch_sampling_info.update_regex_vocab_mask() batch.next_batch_sampling_info.update_regex_vocab_mask()
self.current_stream.synchronize() self.current_stream.synchronize()
...@@ -1630,16 +1634,34 @@ class Scheduler: ...@@ -1630,16 +1634,34 @@ class Scheduler:
except futures._base.TimeoutError: except futures._base.TimeoutError:
break break
if self.tp_size > 1: if self.server_args.enable_dp_attention:
# Sync across TP ranks to make sure they have the same number of ready requests if self.attn_tp_size > 1:
tensor = torch.tensor(num_ready_reqs, dtype=torch.int32) # Sync across attn TP ranks to make sure they have the same number of ready requests
torch.distributed.all_reduce( tensor = torch.tensor(num_ready_reqs, dtype=torch.int32)
tensor, op=torch.distributed.ReduceOp.MAX, group=self.tp_cpu_group torch.distributed.all_reduce(
) tensor,
num_ready_reqs_max = tensor.item() op=torch.distributed.ReduceOp.MAX,
for i in range(num_ready_reqs, num_ready_reqs_max): group=self.attn_tp_cpu_group,
self.grammar_queue[i].grammar = self.grammar_queue[i].grammar.result() )
num_ready_reqs = num_ready_reqs_max num_ready_reqs_max = tensor.item()
for i in range(num_ready_reqs, num_ready_reqs_max):
self.grammar_queue[i].grammar = self.grammar_queue[
i
].grammar.result()
num_ready_reqs = num_ready_reqs_max
else:
if self.tp_size > 1:
# Sync across TP ranks to make sure they have the same number of ready requests
tensor = torch.tensor(num_ready_reqs, dtype=torch.int32)
torch.distributed.all_reduce(
tensor, op=torch.distributed.ReduceOp.MAX, group=self.tp_cpu_group
)
num_ready_reqs_max = tensor.item()
for i in range(num_ready_reqs, num_ready_reqs_max):
self.grammar_queue[i].grammar = self.grammar_queue[
i
].grammar.result()
num_ready_reqs = num_ready_reqs_max
self.waiting_queue.extend(self.grammar_queue[:num_ready_reqs]) self.waiting_queue.extend(self.grammar_queue[:num_ready_reqs])
self.grammar_queue = self.grammar_queue[num_ready_reqs:] self.grammar_queue = self.grammar_queue[num_ready_reqs:]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment