Unverified Commit 18fd4a83 authored by Andy Lo's avatar Andy Lo Committed by GitHub
Browse files

[Bugfix] Multi-sequence broken (#11898)


Signed-off-by: default avatarAndy Lo <andy@mistral.ai>
parent 132a1321
...@@ -31,7 +31,7 @@ def test_random_sample_with_seed( ...@@ -31,7 +31,7 @@ def test_random_sample_with_seed(
sampling_params = SamplingParams( sampling_params = SamplingParams(
# Parameters to ensure sufficient randomness # Parameters to ensure sufficient randomness
temperature=2.0, temperature=3.0,
top_p=min(random.random() + 0.3, 1), top_p=min(random.random() + 0.3, 1),
top_k=random.randint(5, 20), top_k=random.randint(5, 20),
n=random.randint(1, 10), n=random.randint(1, 10),
...@@ -75,3 +75,8 @@ def test_random_sample_with_seed( ...@@ -75,3 +75,8 @@ def test_random_sample_with_seed(
# verify requests with the same seed match # verify requests with the same seed match
assert outputs[1] == outputs[4] assert outputs[1] == outputs[4]
assert outputs[2] == outputs[5] assert outputs[2] == outputs[5]
# verify generations within the same parallel sampling group differ
for output in outputs:
for sub_output_a, sub_output_b in combinations(output, 2):
assert sub_output_a != sub_output_b
...@@ -172,9 +172,9 @@ class RequestOutput: ...@@ -172,9 +172,9 @@ class RequestOutput:
if seq_group.request_id in seq_id_to_seq_group: if seq_group.request_id in seq_id_to_seq_group:
group: SequenceGroupBase = seq_id_to_seq_group[ group: SequenceGroupBase = seq_id_to_seq_group[
seq_group.request_id] seq_group.request_id]
assembled_seq_group = group.maybe_assemble_group(seq_group)
if finished: if finished:
group.finish_seq(seq_group) group.finish_seq(seq_group)
assembled_seq_group = group.maybe_assemble_group(seq_group)
if assembled_seq_group is None: if assembled_seq_group is None:
return None return None
return cls.from_seq_group(assembled_seq_group, use_cache, return cls.from_seq_group(assembled_seq_group, use_cache,
......
...@@ -815,7 +815,9 @@ class SequenceGroup: ...@@ -815,7 +815,9 @@ class SequenceGroup:
def get_max_num_running_seqs(self) -> int: def get_max_num_running_seqs(self) -> int:
"""The maximum number of sequences running in parallel in the remaining """The maximum number of sequences running in parallel in the remaining
lifetime of the request.""" lifetime of the request."""
if self.is_single_seq:
return 0 if self.first_seq.is_finished() else 1 return 0 if self.first_seq.is_finished() else 1
return self.num_seqs() - self.num_finished_seqs()
def get_seqs( def get_seqs(
self, self,
...@@ -824,8 +826,11 @@ class SequenceGroup: ...@@ -824,8 +826,11 @@ class SequenceGroup:
if status is None: if status is None:
return self.seqs return self.seqs
if self.is_single_seq:
return self.seqs if self.first_seq.status == status else [] return self.seqs if self.first_seq.status == status else []
return [seq for seq in self.seqs if seq.status == status]
def is_encoder_decoder(self) -> bool: def is_encoder_decoder(self) -> bool:
return self.encoder_seq is not None return self.encoder_seq is not None
...@@ -833,17 +838,20 @@ class SequenceGroup: ...@@ -833,17 +838,20 @@ class SequenceGroup:
return self.encoder_seq return self.encoder_seq
def get_finished_seqs(self) -> List[Sequence]: def get_finished_seqs(self) -> List[Sequence]:
if self.is_single_seq:
return self.seqs if self.first_seq.is_finished() else [] return self.seqs if self.first_seq.is_finished() else []
return [seq for seq in self.seqs if seq.is_finished()]
def update_num_computed_tokens(self, num_new_computed_tokens: int): def update_num_computed_tokens(self, num_new_computed_tokens: int):
"""Update number of tokens computed so far.""" """Update number of tokens computed so far."""
seq = self.first_seq for seq in self.seqs:
if not seq.is_finished(): if not seq.is_finished():
seq.data.update_num_computed_tokens(num_new_computed_tokens) seq.data.update_num_computed_tokens(num_new_computed_tokens)
def get_num_uncomputed_tokens(self) -> int: def get_num_uncomputed_tokens(self) -> int:
num_uncomputed_tokens = 0 num_uncomputed_tokens = 0
seq = self.first_seq for seq in self.seqs:
if not seq.is_finished(): if not seq.is_finished():
num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens() num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
return num_uncomputed_tokens return num_uncomputed_tokens
...@@ -860,10 +868,14 @@ class SequenceGroup: ...@@ -860,10 +868,14 @@ class SequenceGroup:
return len(self.get_seqs(status)) return len(self.get_seqs(status))
def num_finished_seqs(self) -> int: def num_finished_seqs(self) -> int:
return 1 if self.first_seq.is_finished() else 0 if self.is_single_seq:
return 1 if self.seqs[0].is_finished() else 0
return len(self.get_finished_seqs())
def is_finished(self) -> bool: def is_finished(self) -> bool:
if self.is_single_seq:
return self.first_seq.is_finished() return self.first_seq.is_finished()
return all(seq.is_finished() for seq in self.seqs)
def is_prefill(self) -> bool: def is_prefill(self) -> bool:
return self.first_seq.is_prefill() return self.first_seq.is_prefill()
...@@ -1391,13 +1403,15 @@ class ParallelSampleSequenceGroup(SequenceGroupBase): ...@@ -1391,13 +1403,15 @@ class ParallelSampleSequenceGroup(SequenceGroupBase):
@staticmethod @staticmethod
def add_request(request_id: str, engine, params, **kwargs): def add_request(request_id: str, engine, params, **kwargs):
original_params = params original_params = params
params = original_params.clone()
params.n = 1
group = ParallelSampleSequenceGroup(request_id) group = ParallelSampleSequenceGroup(request_id)
seqs = [] seqs = []
for i in range(original_params.n): for i in range(original_params.n):
request_id_i = f"{request_id}_parallel_sample_{i}" request_id_i = f"{request_id}_parallel_sample_{i}"
group.seq_id_to_index[request_id_i] = i group.seq_id_to_index[request_id_i] = i
params = copy.deepcopy(original_params)
params.n = 1
if params.seed is not None:
params.seed += i
seq_group = engine._add_processed_request( seq_group = engine._add_processed_request(
request_id_i, request_id_i,
params=params, params=params,
...@@ -1432,20 +1446,20 @@ class ParallelSampleSequenceGroup(SequenceGroupBase): ...@@ -1432,20 +1446,20 @@ class ParallelSampleSequenceGroup(SequenceGroupBase):
self, seq_group: SequenceGroup) -> Optional[SequenceGroup]: self, seq_group: SequenceGroup) -> Optional[SequenceGroup]:
# in the streaming mode, we will return the assembled sequence # in the streaming mode, we will return the assembled sequence
# for the first sequence, and then return None for the rest of # for the first remaining sequence, and then return None for the
# sequences # rest of sequences
if self.streaming: if self.streaming:
if self.seq_id_to_index[seq_group.request_id] == 0: first_remaining_id = next(iter(self.to_be_finished))
if seq_group.request_id == first_remaining_id:
return self.assembled_seq_group return self.assembled_seq_group
return None return None
# in the non-streaming mode, we will return the assembled sequence # in the non-streaming mode, we will return the assembled sequence
# once after all sequences finish, and then return None for the # when the last sequences finishes, and then return None for the
# rest of the time # rest of the time
if (len(self.to_be_finished) == 1
if len(self.to_be_finished) > 0: and seq_group.request_id in self.to_be_finished
return None and seq_group.is_finished()):
assert self.assembled_seq_group is not None assert self.assembled_seq_group is not None
params = self.assembled_seq_group.sampling_params params = self.assembled_seq_group.sampling_params
assert isinstance(params, SamplingParams) assert isinstance(params, SamplingParams)
...@@ -1462,3 +1476,4 @@ class ParallelSampleSequenceGroup(SequenceGroupBase): ...@@ -1462,3 +1476,4 @@ class ParallelSampleSequenceGroup(SequenceGroupBase):
return self.assembled_seq_group return self.assembled_seq_group
if self.output_produced: if self.output_produced:
return None return None
return None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment