Unverified Commit b9fd178f authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Fix retraction + overlap (#1860)


Co-authored-by: default avatarLianmin Zheng <lianminzheng@gmail.com>
parent d8e9d61f
...@@ -50,7 +50,7 @@ jobs: ...@@ -50,7 +50,7 @@ jobs:
timeout-minutes: 20 timeout-minutes: 20
run: | run: |
cd test/srt cd test/srt
python3 run_suite.py --suite minimal --range-begin 0 --range-end 4 python3 run_suite.py --suite minimal --range-begin 0 --range-end 6
unit-test-backend-part-2: unit-test-backend-part-2:
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
...@@ -67,7 +67,7 @@ jobs: ...@@ -67,7 +67,7 @@ jobs:
timeout-minutes: 20 timeout-minutes: 20
run: | run: |
cd test/srt cd test/srt
python3 run_suite.py --suite minimal --range-begin 4 --range-end 14 python3 run_suite.py --suite minimal --range-begin 6 --range-end 14
unit-test-backend-part-3: unit-test-backend-part-3:
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
......
...@@ -211,9 +211,6 @@ class Req: ...@@ -211,9 +211,6 @@ class Req:
# this does not include the jump forward tokens. # this does not include the jump forward tokens.
self.completion_tokens_wo_jump_forward = 0 self.completion_tokens_wo_jump_forward = 0
# The number of cached tokens, that were already cached in the KV store
self.cached_tokens = 0
# For vision inputs # For vision inputs
self.image_inputs: Optional[ImageInputs] = None self.image_inputs: Optional[ImageInputs] = None
...@@ -223,6 +220,9 @@ class Req: ...@@ -223,6 +220,9 @@ class Req:
self.last_node = None self.last_node = None
self.is_being_chunked = 0 self.is_being_chunked = 0
# For retraction
self.is_retracted = False
# Logprobs (arguments) # Logprobs (arguments)
self.return_logprob = False self.return_logprob = False
self.logprob_start_len = 0 self.logprob_start_len = 0
...@@ -242,12 +242,15 @@ class Req: ...@@ -242,12 +242,15 @@ class Req:
# The relative logprob_start_len in an extend batch # The relative logprob_start_len in an extend batch
self.extend_logprob_start_len = 0 self.extend_logprob_start_len = 0
# Embedding # Embedding (return values)
self.embedding = None self.embedding = None
# Constrained decoding # Constrained decoding
self.grammar: Optional[Grammar] = None self.grammar: Optional[Grammar] = None
# The number of cached tokens, that were already cached in the KV cache
self.cached_tokens = 0
# For Qwen2-VL # For Qwen2-VL
self.mrope_position_delta = [] # use mutable object self.mrope_position_delta = [] # use mutable object
...@@ -561,7 +564,7 @@ class ScheduleBatch: ...@@ -561,7 +564,7 @@ class ScheduleBatch:
seq_lens[i] -= encoder_len seq_lens[i] -= encoder_len
if len(req.prefix_indices) < encoder_len: if len(req.prefix_indices) < encoder_len:
# NOTE: the encoder part should considered as a whole # NOTE: the encoder part should be considered as a whole
assert len(req.prefix_indices) == 0 assert len(req.prefix_indices) == 0
input_ids[i] = input_ids[i][encoder_len:] input_ids[i] = input_ids[i][encoder_len:]
encoder_out_cache_loc.append(self.out_cache_loc[pt : pt + encoder_len]) encoder_out_cache_loc.append(self.out_cache_loc[pt : pt + encoder_len])
...@@ -648,6 +651,7 @@ class ScheduleBatch: ...@@ -648,6 +651,7 @@ class ScheduleBatch:
req.extend_logprob_start_len = extend_logprob_start_len req.extend_logprob_start_len = extend_logprob_start_len
pt += req.extend_input_len pt += req.extend_input_len
req.is_retracted = False
# Set fields # Set fields
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to( self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
...@@ -780,6 +784,7 @@ class ScheduleBatch: ...@@ -780,6 +784,7 @@ class ScheduleBatch:
req.prefix_indices = [] req.prefix_indices = []
req.last_node = None req.last_node = None
req.extend_input_len = 0 req.extend_input_len = 0
req.is_retracted = True
# For incremental logprobs # For incremental logprobs
req.last_update_decode_tokens = 0 req.last_update_decode_tokens = 0
......
...@@ -79,6 +79,7 @@ from sglang.utils import get_exception_traceback ...@@ -79,6 +79,7 @@ from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Crash on warning if we are running CI tests # Crash on warning if we are running CI tests
crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true" crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true"
...@@ -831,9 +832,10 @@ class Scheduler: ...@@ -831,9 +832,10 @@ class Scheduler:
# Check finish conditions # Check finish conditions
logprob_pt = 0 logprob_pt = 0
for i, req in enumerate(batch.reqs): for i, req in enumerate(batch.reqs):
if req.is_being_chunked > 0: if req.is_retracted:
req.is_being_chunked -= 1 continue
else:
if req.is_being_chunked <= 0:
# Inflight reqs' prefill is not finished # Inflight reqs' prefill is not finished
req.completion_tokens_wo_jump_forward += 1 req.completion_tokens_wo_jump_forward += 1
req.output_ids.append(next_token_ids[i]) req.output_ids.append(next_token_ids[i])
...@@ -851,12 +853,18 @@ class Scheduler: ...@@ -851,12 +853,18 @@ class Scheduler:
logprob_pt += self.add_logprob_return_values( logprob_pt += self.add_logprob_return_values(
i, req, logprob_pt, next_token_ids, logits_output i, req, logprob_pt, next_token_ids, logits_output
) )
else:
req.is_being_chunked -= 1
else: # embedding or reward model else: # embedding or reward model
embeddings, bid = result embeddings, bid = result
embeddings = embeddings.tolist() embeddings = embeddings.tolist()
# Check finish conditions # Check finish conditions
for i, req in enumerate(batch.reqs): for i, req in enumerate(batch.reqs):
if req.is_retracted:
continue
req.embedding = embeddings[i] req.embedding = embeddings[i]
if req.is_being_chunked > 0: if req.is_being_chunked > 0:
req.is_being_chunked -= 1 req.is_being_chunked -= 1
...@@ -893,7 +901,12 @@ class Scheduler: ...@@ -893,7 +901,12 @@ class Scheduler:
# Check finish condition # Check finish condition
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)): for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
if self.server_args.enable_overlap_schedule and req.finished(): if req.is_retracted:
continue
if self.server_args.enable_overlap_schedule and (
req.finished()
):
self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1]) self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
continue continue
...@@ -1015,6 +1028,7 @@ class Scheduler: ...@@ -1015,6 +1028,7 @@ class Scheduler:
is_stream_iter = self.forward_ct_decode % self.stream_interval == 0 is_stream_iter = self.forward_ct_decode % self.stream_interval == 0
for req in reqs: for req in reqs:
# TODO(lianmin): revisit this for overlap + retract + stream
if req.finished() or ( if req.finished() or (
req.stream and (is_stream_iter or len(req.output_ids) == 1) req.stream and (is_stream_iter or len(req.output_ids) == 1)
): ):
......
...@@ -107,6 +107,27 @@ class TestRadixCacheLPM(TestRadixCacheFCFS): ...@@ -107,6 +107,27 @@ class TestRadixCacheLPM(TestRadixCacheFCFS):
) )
class TestRadixCacheOverlapLPM(TestRadixCacheFCFS):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--enable-overlap-schedule",
"--chunked-prefill-size",
"128",
"--max-total-tokens",
"20000",
"--schedule-policy",
"lpm",
],
)
if __name__ == "__main__": if __name__ == "__main__":
os.environ["SGLANG_TEST_RETRACT"] = "true" os.environ["SGLANG_TEST_RETRACT"] = "true"
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment