Unverified Commit cdc8d607 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Improve the code style: more comments and remove useless packages (#1139)

parent 9208591f
...@@ -17,7 +17,6 @@ limitations under the License. ...@@ -17,7 +17,6 @@ limitations under the License.
import asyncio import asyncio
import dataclasses import dataclasses
import inspect
from typing import List from typing import List
import uvloop import uvloop
...@@ -126,8 +125,6 @@ class DetokenizerManager: ...@@ -126,8 +125,6 @@ class DetokenizerManager:
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0], spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
) )
# Trim stop str
# TODO(lmzheng): handle the case where multiple stop strs are hit
output_strs = [] output_strs = []
for i in range(bs): for i in range(bs):
s = self.decode_status[recv_obj.rids[i]] s = self.decode_status[recv_obj.rids[i]]
...@@ -144,6 +141,7 @@ class DetokenizerManager: ...@@ -144,6 +141,7 @@ class DetokenizerManager:
output_strs.append(s.decoded_text + new_text) output_strs.append(s.decoded_text + new_text)
# Trim stop str. TODO(lmzheng): handle the case where multiple stop strs are hit
if isinstance(recv_obj.finished_reason[i], FINISH_MATCHED_STR): if isinstance(recv_obj.finished_reason[i], FINISH_MATCHED_STR):
pos = output_strs[i].find(recv_obj.finished_reason[i].matched) pos = output_strs[i].find(recv_obj.finished_reason[i].matched)
if pos != -1: if pos != -1:
......
...@@ -22,8 +22,6 @@ import uuid ...@@ -22,8 +22,6 @@ import uuid
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
import torch
from sglang.srt.managers.schedule_batch import BaseFinishReason from sglang.srt.managers.schedule_batch import BaseFinishReason
from sglang.srt.sampling_params import SamplingParams from sglang.srt.sampling_params import SamplingParams
...@@ -43,9 +41,9 @@ class GenerateReqInput: ...@@ -43,9 +41,9 @@ class GenerateReqInput:
rid: Optional[Union[List[str], str]] = None rid: Optional[Union[List[str], str]] = None
# Whether to return logprobs. # Whether to return logprobs.
return_logprob: Optional[Union[List[bool], bool]] = None return_logprob: Optional[Union[List[bool], bool]] = None
# The start location of the prompt for return_logprob. # If return logprobs, the start location in the prompt for returning logprobs.
logprob_start_len: Optional[Union[List[int], int]] = None logprob_start_len: Optional[Union[List[int], int]] = None
# The number of top logprobs to return. # If return logprobs, the number of top logprobs to return at each position.
top_logprobs_num: Optional[Union[List[int], int]] = None top_logprobs_num: Optional[Union[List[int], int]] = None
# Whether to detokenize tokens in text in the returned logprobs. # Whether to detokenize tokens in text in the returned logprobs.
return_text_in_logprobs: bool = False return_text_in_logprobs: bool = False
...@@ -155,16 +153,27 @@ class GenerateReqInput: ...@@ -155,16 +153,27 @@ class GenerateReqInput:
@dataclass @dataclass
class TokenizedGenerateReqInput: class TokenizedGenerateReqInput:
# The request id
rid: str rid: str
# The input text
input_text: str input_text: str
# The input token ids
input_ids: List[int] input_ids: List[int]
# The pixel values for input images
pixel_values: List[float] pixel_values: List[float]
# The hash of input images
image_hash: int image_hash: int
# The image size
image_size: List[int] image_size: List[int]
# The sampling parameters
sampling_params: SamplingParams sampling_params: SamplingParams
# Whether to return the logprobs
return_logprob: bool return_logprob: bool
# If return logprobs, the start location in the prompt for returning logprobs.
logprob_start_len: int logprob_start_len: int
# If return logprobs, the number of top logprobs to return at each position.
top_logprobs_num: int top_logprobs_num: int
# Whether to stream output
stream: bool stream: bool
...@@ -215,15 +224,21 @@ class EmbeddingReqInput: ...@@ -215,15 +224,21 @@ class EmbeddingReqInput:
@dataclass @dataclass
class TokenizedEmbeddingReqInput: class TokenizedEmbeddingReqInput:
# The request id
rid: str rid: str
# The input text
input_text: str input_text: str
# The input token ids
input_ids: List[int] input_ids: List[int]
# Dummy sampling params for compatibility
sampling_params: SamplingParams sampling_params: SamplingParams
@dataclass @dataclass
class BatchTokenIDOut: class BatchTokenIDOut:
# The request id
rids: List[str] rids: List[str]
# The version id to sync decode status with in detokenizer_manager
vids: List[int] vids: List[int]
decoded_texts: List[str] decoded_texts: List[str]
decode_ids: List[int] decode_ids: List[int]
...@@ -236,17 +251,25 @@ class BatchTokenIDOut: ...@@ -236,17 +251,25 @@ class BatchTokenIDOut:
@dataclass @dataclass
class BatchStrOut: class BatchStrOut:
# The request id
rids: List[str] rids: List[str]
# The output decoded strings
output_strs: List[str] output_strs: List[str]
# The meta info
meta_info: List[Dict] meta_info: List[Dict]
# The finish reason
finished_reason: List[BaseFinishReason] finished_reason: List[BaseFinishReason]
@dataclass @dataclass
class BatchEmbeddingOut: class BatchEmbeddingOut:
# The request id
rids: List[str] rids: List[str]
# The output embedding
embeddings: List[List[float]] embeddings: List[List[float]]
# The meta info
meta_info: List[Dict] meta_info: List[Dict]
# The finish reason
finished_reason: List[BaseFinishReason] finished_reason: List[BaseFinishReason]
...@@ -257,9 +280,5 @@ class FlushCacheReq: ...@@ -257,9 +280,5 @@ class FlushCacheReq:
@dataclass @dataclass
class AbortReq: class AbortReq:
# The request id
rid: str rid: str
@dataclass
class DetokenizeReqInput:
input_ids: List[int]
...@@ -34,7 +34,6 @@ from typing import Dict, List, Optional, Union ...@@ -34,7 +34,6 @@ from typing import Dict, List, Optional, Union
setattr(threading, "_register_atexit", lambda *args, **kwargs: None) setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
import aiohttp import aiohttp
import psutil
import requests import requests
import uvicorn import uvicorn
import uvloop import uvloop
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment