Unverified Commit 5dc55a5f authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Handle truncation errors (#436)

parent 4231a42f
...@@ -369,7 +369,7 @@ ...@@ -369,7 +369,7 @@
"\n", "\n",
"The evaluation scores appear in the bottom right of the logs (screenshot below). Note, that there is no score for `answer_matches_target_llm_grader` and `percent_target_supported_by_context` as these evals are automatically skipped if the target answer is not provided.\n", "The evaluation scores appear in the bottom right of the logs (screenshot below). Note, that there is no score for `answer_matches_target_llm_grader` and `percent_target_supported_by_context` as these evals are automatically skipped if the target answer is not provided.\n",
"\n", "\n",
"![Fixed Max. Tokens](./images/rag/max-tokens-fixed-rag-trace.png)" "![Fixed Max. Tokens](max-tokens-fixed-rag-trace.png)"
], ],
"metadata": { "metadata": {
"collapsed": false "collapsed": false
......
...@@ -16,7 +16,7 @@ class GlobalConfig: ...@@ -16,7 +16,7 @@ class GlobalConfig:
# Optimization configs # Optimization configs
self.eager_fill_image = False self.eager_fill_image = False
self.enable_prefix_sharing = True self.enable_precache_with_tracing = True
self.enable_parallel_encoding = True self.enable_parallel_encoding = True
self.enable_parallel_decoding = True self.enable_parallel_decoding = True
......
...@@ -86,9 +86,9 @@ def run_program_batch( ...@@ -86,9 +86,9 @@ def run_program_batch(
if hasattr(backend, "endpoint"): if hasattr(backend, "endpoint"):
backend = backend.endpoint backend = backend.endpoint
# Extract prefix by tracing and cache it # Pre-cache the common prefix for a batch. The prefix is extracted by tracing the program.
if len(batch_arguments) > 1: if global_config.enable_precache_with_tracing and len(batch_arguments) > 1:
pin_program(program, backend) cache_program(program, backend)
# Run all programs # Run all programs
if num_threads == "auto": if num_threads == "auto":
...@@ -154,21 +154,12 @@ def run_program_batch( ...@@ -154,21 +154,12 @@ def run_program_batch(
return rets return rets
def pin_program(program, backend): def cache_program(program, backend):
if global_config.enable_prefix_sharing and program.pin_prefix_rid is None: from sglang.lang.tracer import extract_prefix_by_tracing
# TODO: handle multiple backends
from sglang.lang.tracer import extract_prefix_by_tracing
prefix = extract_prefix_by_tracing(program, backend) prefix = extract_prefix_by_tracing(program, backend)
if prefix and len(prefix) > 64: if prefix and len(prefix) > 64:
prefix_rid = backend.cache_prefix(prefix) backend.cache_prefix(prefix)
program.pin_prefix_rid = prefix_rid
return prefix_rid
return None
def unpin_program(program, backend):
pass
class StreamExecutor: class StreamExecutor:
...@@ -322,7 +313,7 @@ class StreamExecutor: ...@@ -322,7 +313,7 @@ class StreamExecutor:
try: try:
self._execute(expr) self._execute(expr)
except Exception as e: except Exception as e:
print(f"Error in stream_executor: {get_exception_traceback()}") # print(f"Error in stream_executor: {get_exception_traceback()}")
error = e error = e
break break
self.queue.task_done() self.queue.task_done()
...@@ -702,9 +693,10 @@ class ProgramState: ...@@ -702,9 +693,10 @@ class ProgramState:
return self.stream_executor.messages() return self.stream_executor.messages()
def sync(self): def sync(self):
ret = self.stream_executor.sync() return self.stream_executor.sync()
self.error = self.stream_executor.error
return ret def error(self):
return self.stream_executor.error
def text_iter(self, var_name: Optional[str] = None): def text_iter(self, var_name: Optional[str] = None):
if self.stream_executor.stream: if self.stream_executor.stream:
......
...@@ -193,17 +193,11 @@ class SglFunction: ...@@ -193,17 +193,11 @@ class SglFunction:
backend = backend or global_config.default_backend backend = backend or global_config.default_backend
return trace_program(self, kwargs, backend) return trace_program(self, kwargs, backend)
def pin(self, backend=None): def cache(self, backend=None):
from sglang.lang.interpreter import pin_program from sglang.lang.interpreter import cache_program
backend = backend or global_config.default_backend backend = backend or global_config.default_backend
return pin_program(self, backend) return cache_program(self, backend)
def unpin(self, backend=None):
from sglang.lang.interpreter import unpin_program
backend = backend or global_config.default_backend
return unpin_program(self, backend)
def compile(self, *, backend=None): def compile(self, *, backend=None):
from sglang.lang.compiler import compile_func from sglang.lang.compiler import compile_func
......
...@@ -20,6 +20,16 @@ class FinishReason(IntEnum): ...@@ -20,6 +20,16 @@ class FinishReason(IntEnum):
LENGTH = auto() LENGTH = auto()
STOP_STR = auto() STOP_STR = auto()
def to_str(self):
if self == FinishReason.EOS_TOKEN:
return None
elif self == FinishReason.LENGTH:
return "length"
elif self == FinishReason.STOP_STR:
return "stop"
else:
raise ValueError(f"Invalid finish reason: {self}")
class Req: class Req:
def __init__(self, rid, input_text, input_ids): def __init__(self, rid, input_text, input_ids):
......
...@@ -612,7 +612,7 @@ class ModelRpcServer: ...@@ -612,7 +612,7 @@ class ModelRpcServer:
+ len(req.output_ids) + len(req.output_ids)
- req.prompt_tokens, - req.prompt_tokens,
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward, "completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
"finish_reason": str(req.finish_reason), # FIXME: convert to the correct string "finish_reason": req.finish_reason.to_str(),
"hit_stop_str": req.hit_stop_str, "hit_stop_str": req.hit_stop_str,
} }
if req.return_logprob: if req.return_logprob:
......
...@@ -98,7 +98,6 @@ class TokenizerManager: ...@@ -98,7 +98,6 @@ class TokenizerManager:
self.hf_config = get_config( self.hf_config = get_config(
self.model_path, trust_remote_code=server_args.trust_remote_code self.model_path, trust_remote_code=server_args.trust_remote_code
) )
self.context_len = get_context_length(self.hf_config) self.context_len = get_context_length(self.hf_config)
if is_multimodal_model(self.model_path): if is_multimodal_model(self.model_path):
...@@ -156,6 +155,12 @@ class TokenizerManager: ...@@ -156,6 +155,12 @@ class TokenizerManager:
else: else:
input_ids = obj.input_ids input_ids = obj.input_ids
if len(input_ids) >= self.context_len:
raise ValueError(
f"The input ({len(input_ids)} tokens) is longer than the "
f"model's context length ({self.context_len} tokens)"
)
sampling_params = SamplingParams(**obj.sampling_params) sampling_params = SamplingParams(**obj.sampling_params)
if sampling_params.max_new_tokens != 0: if sampling_params.max_new_tokens != 0:
sampling_params.normalize(self.tokenizer) sampling_params.normalize(self.tokenizer)
......
...@@ -20,7 +20,7 @@ import requests ...@@ -20,7 +20,7 @@ import requests
import uvicorn import uvicorn
import uvloop import uvloop
from fastapi import FastAPI, Request from fastapi import FastAPI, Request
from fastapi.responses import Response, StreamingResponse from fastapi.responses import JSONResponse, Response, StreamingResponse
from sglang.backend.runtime_endpoint import RuntimeEndpoint from sglang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.srt.constrained import disable_cache from sglang.srt.constrained import disable_cache
...@@ -90,8 +90,11 @@ async def generate_request(obj: GenerateReqInput): ...@@ -90,8 +90,11 @@ async def generate_request(obj: GenerateReqInput):
return StreamingResponse(stream_results(), media_type="text/event-stream") return StreamingResponse(stream_results(), media_type="text/event-stream")
ret = await tokenizer_manager.generate_request(obj).__anext__() try:
return ret ret = await tokenizer_manager.generate_request(obj).__anext__()
return ret
except ValueError as e:
return JSONResponse({"error": str(e)}, status_code=400)
@app.post("/v1/completions") @app.post("/v1/completions")
......
...@@ -29,7 +29,7 @@ class TestBind(unittest.TestCase): ...@@ -29,7 +29,7 @@ class TestBind(unittest.TestCase):
tracer = few_shot_qa_2.trace() tracer = few_shot_qa_2.trace()
print(tracer.last_node.print_graph_dfs() + "\n") print(tracer.last_node.print_graph_dfs() + "\n")
def test_pin(self): def test_cache(self):
@sgl.function @sgl.function
def few_shot_qa(s, prompt, question): def few_shot_qa(s, prompt, question):
s += prompt s += prompt
...@@ -41,8 +41,7 @@ class TestBind(unittest.TestCase): ...@@ -41,8 +41,7 @@ class TestBind(unittest.TestCase):
few_shot_qa_2 = few_shot_qa.bind( few_shot_qa_2 = few_shot_qa.bind(
prompt="Answer the following questions as if you were a 5-year-old kid.\n\n" prompt="Answer the following questions as if you were a 5-year-old kid.\n\n"
) )
few_shot_qa_2.pin(self.backend) few_shot_qa_2.cache(self.backend)
few_shot_qa_2.unpin(self.backend)
if __name__ == "__main__": if __name__ == "__main__":
...@@ -50,4 +49,4 @@ if __name__ == "__main__": ...@@ -50,4 +49,4 @@ if __name__ == "__main__":
# t = TestBind() # t = TestBind()
# t.setUp() # t.setUp()
# t.test_pin() # t.test_cache()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment