Unverified Commit f424e76d authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Fix illegal tokens during sampling (#676)

parent 490a1f39
...@@ -7,7 +7,7 @@ The `/generate` endpoint accepts the following arguments in the JSON format. ...@@ -7,7 +7,7 @@ The `/generate` endpoint accepts the following arguments in the JSON format.
@dataclass @dataclass
class GenerateReqInput: class GenerateReqInput:
# The input prompt. It can be a single prompt or a batch of prompts. # The input prompt. It can be a single prompt or a batch of prompts.
text: Union[List[str], str] text: Optional[Union[List[str], str]] = None
# The token ids for text; one can either specify text or input_ids. # The token ids for text; one can either specify text or input_ids.
input_ids: Optional[Union[List[List[int]], List[int]]] = None input_ids: Optional[Union[List[List[int]], List[int]]] = None
# The image input. It can be a file name, a url, or base64 encoded string. # The image input. It can be a file name, a url, or base64 encoded string.
......
...@@ -665,16 +665,20 @@ class Batch: ...@@ -665,16 +665,20 @@ class Batch:
# TODO(lmzheng): apply penalty # TODO(lmzheng): apply penalty
probs = torch.softmax(logits, dim=-1) probs = torch.softmax(logits, dim=-1)
try:
max_top_k_round, batch_size = 32, probs.shape[0] max_top_k_round, batch_size = 32, probs.shape[0]
uniform_samples = torch.rand( uniform_samples = torch.rand((max_top_k_round, batch_size), device=probs.device)
(max_top_k_round, batch_size), device=probs.device batch_next_token_ids, _ = top_k_top_p_sampling_from_probs(
) probs, uniform_samples, self.top_ks, self.top_ps
batch_next_token_ids, _ = top_k_top_p_sampling_from_probs( )
probs, uniform_samples, self.top_ks, self.top_ps
) # FIXME: this is a temporary fix for the illegal token ids
except RuntimeError as e: illegal_mask = torch.logical_or(
warnings.warn(f"Ignore errors in sampling: {e}") batch_next_token_ids < 0, batch_next_token_ids >= probs.shape[-1]
)
if torch.any(illegal_mask):
warnings.warn("Illegal sampled token ids")
probs = probs.masked_fill(torch.isnan(probs), 0.0)
batch_next_token_ids = torch.argmax(probs, dim=-1) batch_next_token_ids = torch.argmax(probs, dim=-1)
if has_regex: if has_regex:
......
...@@ -246,12 +246,11 @@ class ModelRunner: ...@@ -246,12 +246,11 @@ class ModelRunner:
self.cuda_graph_runner = CudaGraphRunner( self.cuda_graph_runner = CudaGraphRunner(
self, max_batch_size_to_capture=max(batch_size_list) self, max_batch_size_to_capture=max(batch_size_list)
) )
logger.info(f"Capture for batch sizes {batch_size_list}")
try: try:
self.cuda_graph_runner.capture(batch_size_list) self.cuda_graph_runner.capture(batch_size_list)
except: except RuntimeError as e:
raise Exception( raise Exception(
f"Capture cuda graph failed. Possible solutions:\n" f"Capture cuda graph failed {e}. Possible solutions:\n"
f"1. disable cuda graph by --disable-cuda-graph\n" f"1. disable cuda graph by --disable-cuda-graph\n"
f"2. set --mem-fraction-static to a smaller value\n" f"2. set --mem-fraction-static to a smaller value\n"
f"Open an issue on GitHub with reproducible scripts if you need help.\n" f"Open an issue on GitHub with reproducible scripts if you need help.\n"
......
...@@ -14,7 +14,7 @@ from sglang.srt.sampling_params import SamplingParams ...@@ -14,7 +14,7 @@ from sglang.srt.sampling_params import SamplingParams
@dataclass @dataclass
class GenerateReqInput: class GenerateReqInput:
# The input prompt. It can be a single prompt or a batch of prompts. # The input prompt. It can be a single prompt or a batch of prompts.
text: Union[List[str], str] text: Optional[Union[List[str], str]] = None
# The token ids for text; one can either specify text or input_ids. # The token ids for text; one can either specify text or input_ids.
input_ids: Optional[Union[List[List[int]], List[int]]] = None input_ids: Optional[Union[List[List[int]], List[int]]] = None
# The image input. It can be a file name, a url, or base64 encoded string. # The image input. It can be a file name, a url, or base64 encoded string.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment