Unverified Commit 10bca45b authored by Juwan Yoo's avatar Juwan Yoo Committed by GitHub
Browse files

bugfix: penalizers to be merged before reqs (#1001)

parent b91a4cb1
......@@ -679,6 +679,11 @@ class ScheduleBatch:
setattr(self, item, self_val[new_indices])
def merge(self, other: "ScheduleBatch"):
# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
# orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
# needs to be called with pre-merged Batch.reqs.
self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
self.reqs.extend(other.reqs)
self.req_pool_indices = torch.concat(
......@@ -692,8 +697,6 @@ class ScheduleBatch:
self.top_logprobs_nums.extend(other.top_logprobs_nums)
self.return_logprob = any(req.return_logprob for req in self.reqs)
self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
for item in [
"temperatures",
"top_ps",
......
......@@ -133,6 +133,10 @@ class BatchedPenalizerOrchestrator:
"""
Merge the penalizers of another orchestrator into this one.
Note that this function **must** be called _before_ self.batch.reqs is updated (filtered).
Each unprepared penalizers would have to be prepared (creating tensors, etc.) first before merging.
This step requires the original batch.reqs, before it gets merged with other batch.reqs.
Args:
their (BatchedPenalizerOrchestrator): The orchestrator to merge into this one.
"""
......
import json
import unittest
from multiprocessing import Process
import requests
......@@ -58,6 +59,40 @@ class TestBatchPenalizerE2E(unittest.TestCase):
def test_default_values(self):
self.run_decode()
def test_mixed(self):
"""
Sends two requests with one with penalizers disabled, and the other with penalizers enabled.
This will cause two different {ScheduleBatch} to be initialized and eventually gets merged.
Merging batch with penalizers enabled with enabled, or disabled is trivial. However disabled + enabled is not.
This is because the penalizer will not be prepared if it is not required, then it will be prepared during the merge.
This test triggers the merge of disabled + enabled.
"""
processes = []
p = Process(
target=self.run_decode,
)
processes.append(p)
p.start()
p = Process(
target=self.run_decode,
kwargs={
"frequency_penalty": 2,
"min_new_tokens": 16,
"presence_penalty": 2,
"repetition_penalty": 2,
},
)
processes.append(p)
p.start()
for p in processes:
p.join()
def test_frequency_penalty(self):
self.run_decode(frequency_penalty=2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment