"docs/vscode:/vscode.git/clone" did not exist on "d5cb0be2cd16e6c5eefd4d266a38357fde83a660"
Unverified Commit 40782f05 authored by Qiaolin Yu's avatar Qiaolin Yu Committed by GitHub
Browse files

Refactor: Move return_hidden_states to the generate input (#3985)


Co-authored-by: default avatarBeichen-Ma <mabeichen12@gmail.com>
parent 18bb216c
...@@ -57,7 +57,7 @@ ...@@ -57,7 +57,7 @@
"metadata": {}, "metadata": {},
"source": [ "source": [
"## Generate (text generation model)\n", "## Generate (text generation model)\n",
"Generate completions. This is similar to the `/v1/completions` in OpenAI API. Detailed parameters can be found in the [sampling parameters](../references/sampling_params.md)." "Generate completions. This is similar to the `/v1/completions` in OpenAI API. Detailed parameters can be found in the [sampling parameters](https://docs.sglang.ai/backend/sampling_params.html)."
] ]
}, },
{ {
......
...@@ -17,6 +17,7 @@ The `/generate` endpoint accepts the following parameters in JSON format. For in ...@@ -17,6 +17,7 @@ The `/generate` endpoint accepts the following parameters in JSON format. For in
* `stream`: Whether to stream the output. `bool = False` * `stream`: Whether to stream the output. `bool = False`
* `lora_path`: Path to LoRA weights. `Optional[Union[List[Optional[str]], Optional[str]]] = None` * `lora_path`: Path to LoRA weights. `Optional[Union[List[Optional[str]], Optional[str]]] = None`
* `custom_logit_processor`: Custom logit processor for advanced sampling control. For usage see below. `Optional[Union[List[Optional[str]], str]] = None` * `custom_logit_processor`: Custom logit processor for advanced sampling control. For usage see below. `Optional[Union[List[Optional[str]], str]] = None`
* `return_hidden_states`: Whether to return hidden states of the model. Note that each time it changes, the cuda graph will be recaptured, which might lead to a performance hit. See the [examples](https://github.com/sgl-project/sglang/blob/main/examples/runtime/engine/hidden_states.py) for more information. `bool = False`
## Sampling params ## Sampling params
...@@ -55,8 +56,6 @@ Please refer to our dedicated guide on [constrained decoding](https://docs.sglan ...@@ -55,8 +56,6 @@ Please refer to our dedicated guide on [constrained decoding](https://docs.sglan
* `ignore_eos`: Don't stop generation when EOS token is sampled. `bool = False` * `ignore_eos`: Don't stop generation when EOS token is sampled. `bool = False`
* `skip_special_tokens`: Remove special tokens during decoding. `bool = True` * `skip_special_tokens`: Remove special tokens during decoding. `bool = True`
* `custom_params`: Used when employing `CustomLogitProcessor`. For usage see below. `Optional[List[Optional[Dict[str, Any]]]] = None` * `custom_params`: Used when employing `CustomLogitProcessor`. For usage see below. `Optional[List[Optional[Dict[str, Any]]]] = None`
* `return_hidden_states`: Whether to return hidden states of the model. Note that each time it changes, the cuda graph will be recaptured, which might lead to a performance hit. See the [examples](https://github.com/sgl-project/sglang/blob/main/examples/runtime/engine/hidden_states.py) for more information. `bool = False`
### Custom Logit Processor ### Custom Logit Processor
......
...@@ -26,10 +26,11 @@ def main(): ...@@ -26,10 +26,11 @@ def main():
"temperature": 0.8, "temperature": 0.8,
"top_p": 0.95, "top_p": 0.95,
"max_new_tokens": 10, "max_new_tokens": 10,
"return_hidden_states": True,
} }
outputs = llm.generate(prompts, sampling_params=sampling_params) outputs = llm.generate(
prompts, sampling_params=sampling_params, return_hidden_states=True
)
for prompt, output in zip(prompts, outputs): for prompt, output in zip(prompts, outputs):
print("===============================") print("===============================")
print( print(
......
...@@ -123,6 +123,7 @@ class Engine: ...@@ -123,6 +123,7 @@ class Engine:
top_logprobs_num: Optional[Union[List[int], int]] = None, top_logprobs_num: Optional[Union[List[int], int]] = None,
lora_path: Optional[List[Optional[str]]] = None, lora_path: Optional[List[Optional[str]]] = None,
custom_logit_processor: Optional[Union[List[str], str]] = None, custom_logit_processor: Optional[Union[List[str], str]] = None,
return_hidden_states: bool = False,
stream: bool = False, stream: bool = False,
) -> Union[Dict, Iterator[Dict]]: ) -> Union[Dict, Iterator[Dict]]:
""" """
...@@ -144,6 +145,7 @@ class Engine: ...@@ -144,6 +145,7 @@ class Engine:
lora_path=lora_path, lora_path=lora_path,
modalities=modalities_list, modalities=modalities_list,
custom_logit_processor=custom_logit_processor, custom_logit_processor=custom_logit_processor,
return_hidden_states=return_hidden_states,
stream=stream, stream=stream,
) )
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
......
...@@ -69,11 +69,15 @@ class GenerateReqInput: ...@@ -69,11 +69,15 @@ class GenerateReqInput:
# Session info for continual prompting # Session info for continual prompting
session_params: Optional[Union[List[Dict], Dict]] = None session_params: Optional[Union[List[Dict], Dict]] = None
# Custom logit processor for advanced sampling control. Must be a serialized instance # Custom logit processor for advanced sampling control. Must be a serialized instance
# of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py # of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
# Use the processor's `to_str()` method to generate the serialized string. # Use the processor's `to_str()` method to generate the serialized string.
custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None
# Whether to return hidden states
return_hidden_states: bool = False
def normalize_batch_and_arguments(self): def normalize_batch_and_arguments(self):
if ( if (
self.text is None and self.input_ids is None and self.input_embeds is None self.text is None and self.input_ids is None and self.input_embeds is None
...@@ -218,6 +222,7 @@ class GenerateReqInput: ...@@ -218,6 +222,7 @@ class GenerateReqInput:
if self.custom_logit_processor is not None if self.custom_logit_processor is not None
else None else None
), ),
return_hidden_states=self.return_hidden_states,
) )
...@@ -255,6 +260,9 @@ class TokenizedGenerateReqInput: ...@@ -255,6 +260,9 @@ class TokenizedGenerateReqInput:
# Use the processor's `to_str()` method to generate the serialized string. # Use the processor's `to_str()` method to generate the serialized string.
custom_logit_processor: Optional[str] = None custom_logit_processor: Optional[str] = None
# Whether to return hidden states
return_hidden_states: bool = False
@dataclass @dataclass
class EmbeddingReqInput: class EmbeddingReqInput:
......
...@@ -236,6 +236,7 @@ class Req: ...@@ -236,6 +236,7 @@ class Req:
input_embeds: Optional[List[List[float]]] = None, input_embeds: Optional[List[List[float]]] = None,
session_id: Optional[str] = None, session_id: Optional[str] = None,
custom_logit_processor: Optional[str] = None, custom_logit_processor: Optional[str] = None,
return_hidden_states: bool = False,
eos_token_ids: Optional[Set[int]] = None, eos_token_ids: Optional[Set[int]] = None,
): ):
# Input and output info # Input and output info
...@@ -256,7 +257,9 @@ class Req: ...@@ -256,7 +257,9 @@ class Req:
# Sampling info # Sampling info
self.sampling_params = sampling_params self.sampling_params = sampling_params
self.custom_logit_processor = custom_logit_processor self.custom_logit_processor = custom_logit_processor
self.return_hidden_states = return_hidden_states
# Memory pool info # Memory pool info
self.req_pool_idx = None self.req_pool_idx = None
...@@ -608,6 +611,9 @@ class ScheduleBatch: ...@@ -608,6 +611,9 @@ class ScheduleBatch:
# Enable custom logit processor # Enable custom logit processor
enable_custom_logit_processor: bool = False enable_custom_logit_processor: bool = False
# Whether to return hidden states
return_hidden_states: bool = False
@classmethod @classmethod
def init_new( def init_new(
cls, cls,
...@@ -619,6 +625,7 @@ class ScheduleBatch: ...@@ -619,6 +625,7 @@ class ScheduleBatch:
enable_overlap: bool, enable_overlap: bool,
spec_algorithm: SpeculativeAlgorithm, spec_algorithm: SpeculativeAlgorithm,
enable_custom_logit_processor: bool, enable_custom_logit_processor: bool,
return_hidden_states: bool = False,
): ):
return cls( return cls(
reqs=reqs, reqs=reqs,
...@@ -633,6 +640,7 @@ class ScheduleBatch: ...@@ -633,6 +640,7 @@ class ScheduleBatch:
device=req_to_token_pool.device, device=req_to_token_pool.device,
spec_algorithm=spec_algorithm, spec_algorithm=spec_algorithm,
enable_custom_logit_processor=enable_custom_logit_processor, enable_custom_logit_processor=enable_custom_logit_processor,
return_hidden_states=return_hidden_states,
) )
def batch_size(self): def batch_size(self):
...@@ -1153,6 +1161,7 @@ class ScheduleBatch: ...@@ -1153,6 +1161,7 @@ class ScheduleBatch:
self.return_logprob |= other.return_logprob self.return_logprob |= other.return_logprob
self.has_stream |= other.has_stream self.has_stream |= other.has_stream
self.has_grammar |= other.has_grammar self.has_grammar |= other.has_grammar
self.return_hidden_states |= other.return_hidden_states
if self.spec_info: if self.spec_info:
self.spec_info.merge_batch(other.spec_info) self.spec_info.merge_batch(other.spec_info)
...@@ -1201,7 +1210,7 @@ class ScheduleBatch: ...@@ -1201,7 +1210,7 @@ class ScheduleBatch:
spec_info=self.spec_info, spec_info=self.spec_info,
capture_hidden_mode=( capture_hidden_mode=(
CaptureHiddenMode.FULL CaptureHiddenMode.FULL
if self.sampling_info.return_hidden_states if self.return_hidden_states
else ( else (
getattr( getattr(
self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
......
...@@ -631,6 +631,7 @@ class Scheduler: ...@@ -631,6 +631,7 @@ class Scheduler:
lora_path=recv_req.lora_path, lora_path=recv_req.lora_path,
input_embeds=recv_req.input_embeds, input_embeds=recv_req.input_embeds,
custom_logit_processor=custom_logit_processor, custom_logit_processor=custom_logit_processor,
return_hidden_states=recv_req.return_hidden_states,
eos_token_ids=self.model_config.hf_eos_token_id, eos_token_ids=self.model_config.hf_eos_token_id,
) )
req.tokenizer = self.tokenizer req.tokenizer = self.tokenizer
...@@ -947,9 +948,11 @@ class Scheduler: ...@@ -947,9 +948,11 @@ class Scheduler:
if self.running_batch is not None if self.running_batch is not None
else set([]) else set([])
) )
return_hidden_states = False
# Get requests from the waiting queue to a new prefill batch # Get requests from the waiting queue to a new prefill batch
for req in self.waiting_queue: for req in self.waiting_queue:
if req.return_hidden_states:
return_hidden_states = True
if ( if (
self.lora_paths self.lora_paths
and len( and len(
...@@ -1035,6 +1038,7 @@ class Scheduler: ...@@ -1035,6 +1038,7 @@ class Scheduler:
self.enable_overlap, self.enable_overlap,
self.spec_algorithm, self.spec_algorithm,
self.server_args.enable_custom_logit_processor, self.server_args.enable_custom_logit_processor,
return_hidden_states,
) )
new_batch.prepare_for_extend() new_batch.prepare_for_extend()
...@@ -1226,7 +1230,7 @@ class Scheduler: ...@@ -1226,7 +1230,7 @@ class Scheduler:
i, req, logprob_pt, next_token_ids, logits_output i, req, logprob_pt, next_token_ids, logits_output
) )
if ( if (
req.sampling_params.return_hidden_states req.return_hidden_states
and logits_output.hidden_states is not None and logits_output.hidden_states is not None
): ):
req.hidden_states.append( req.hidden_states.append(
...@@ -1333,10 +1337,7 @@ class Scheduler: ...@@ -1333,10 +1337,7 @@ class Scheduler:
logits_output.next_token_top_logprobs_idx[i] logits_output.next_token_top_logprobs_idx[i]
) )
if ( if req.return_hidden_states and logits_output.hidden_states is not None:
req.sampling_params.return_hidden_states
and logits_output.hidden_states is not None
):
req.hidden_states.append(logits_output.hidden_states[i].cpu().clone()) req.hidden_states.append(logits_output.hidden_states[i].cpu().clone())
if req.grammar is not None: if req.grammar is not None:
...@@ -1462,10 +1463,7 @@ class Scheduler: ...@@ -1462,10 +1463,7 @@ class Scheduler:
completion_tokens = [] completion_tokens = []
cached_tokens = [] cached_tokens = []
spec_verify_ct = [] spec_verify_ct = []
return_hidden_states = any( output_hidden_states = None
req.sampling_params.return_hidden_states for req in reqs
)
output_hidden_states = [] if return_hidden_states else None
if return_logprob: if return_logprob:
input_token_logprobs_val = [] input_token_logprobs_val = []
...@@ -1532,7 +1530,9 @@ class Scheduler: ...@@ -1532,7 +1530,9 @@ class Scheduler:
output_top_logprobs_val.append(req.output_top_logprobs_val) output_top_logprobs_val.append(req.output_top_logprobs_val)
output_top_logprobs_idx.append(req.output_top_logprobs_idx) output_top_logprobs_idx.append(req.output_top_logprobs_idx)
if req.sampling_params.return_hidden_states: if req.return_hidden_states:
if output_hidden_states is None:
output_hidden_states = []
output_hidden_states.append(req.hidden_states) output_hidden_states.append(req.hidden_states)
# Send to detokenizer # Send to detokenizer
......
...@@ -383,6 +383,7 @@ class TokenizerManager: ...@@ -383,6 +383,7 @@ class TokenizerManager:
input_embeds=input_embeds, input_embeds=input_embeds,
session_params=session_params, session_params=session_params,
custom_logit_processor=obj.custom_logit_processor, custom_logit_processor=obj.custom_logit_processor,
return_hidden_states=obj.return_hidden_states,
) )
elif isinstance(obj, EmbeddingReqInput): elif isinstance(obj, EmbeddingReqInput):
tokenized_obj = TokenizedEmbeddingReqInput( tokenized_obj = TokenizedEmbeddingReqInput(
......
...@@ -408,13 +408,13 @@ class CudaGraphRunner: ...@@ -408,13 +408,13 @@ class CudaGraphRunner:
) )
# If the capture_hidden_mode changes, we need to recapture the graph # If the capture_hidden_mode changes, we need to recapture the graph
if ( if (
forward_batch.sampling_info.return_hidden_states forward_batch.capture_hidden_mode == CaptureHiddenMode.FULL
and self.capture_hidden_mode != CaptureHiddenMode.FULL and self.capture_hidden_mode != CaptureHiddenMode.FULL
): ):
self.capture_hidden_mode = CaptureHiddenMode.FULL self.capture_hidden_mode = CaptureHiddenMode.FULL
self.capture() self.capture()
elif ( elif (
not forward_batch.sampling_info.return_hidden_states forward_batch.capture_hidden_mode != CaptureHiddenMode.FULL
and self.capture_hidden_mode != hidden_mode_from_spec_info and self.capture_hidden_mode != hidden_mode_from_spec_info
): ):
self.capture_hidden_mode = hidden_mode_from_spec_info self.capture_hidden_mode = hidden_mode_from_spec_info
......
...@@ -37,9 +37,6 @@ class SamplingBatchInfo: ...@@ -37,9 +37,6 @@ class SamplingBatchInfo:
# Whether any request has custom logit processor # Whether any request has custom logit processor
has_custom_logit_processor: bool has_custom_logit_processor: bool
# Whether any request needs to return hidden states
return_hidden_states: bool
# Bias Tensors # Bias Tensors
vocab_size: int vocab_size: int
grammars: Optional[List] = None grammars: Optional[List] = None
...@@ -94,9 +91,6 @@ class SamplingBatchInfo: ...@@ -94,9 +91,6 @@ class SamplingBatchInfo:
and any(r.custom_logit_processor for r in reqs) # then check the requests. and any(r.custom_logit_processor for r in reqs) # then check the requests.
) )
# Check if any request needs to return hidden states
return_hidden_states = any(r.sampling_params.return_hidden_states for r in reqs)
if has_custom_logit_processor: if has_custom_logit_processor:
# Merge the same type of custom logit processors together # Merge the same type of custom logit processors together
processor_dict = {} processor_dict = {}
...@@ -136,7 +130,6 @@ class SamplingBatchInfo: ...@@ -136,7 +130,6 @@ class SamplingBatchInfo:
device=device, device=device,
custom_params=custom_params, custom_params=custom_params,
custom_logit_processor=merged_custom_logit_processor, custom_logit_processor=merged_custom_logit_processor,
return_hidden_states=return_hidden_states,
) )
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge. # TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
...@@ -344,9 +337,6 @@ class SamplingBatchInfo: ...@@ -344,9 +337,6 @@ class SamplingBatchInfo:
self.logit_bias, other.logit_bias, len(self), len(other), self.device self.logit_bias, other.logit_bias, len(self), len(other), self.device
) )
# Merge the return hidden states flag
self.return_hidden_states |= other.return_hidden_states
# Merge the custom logit processors and custom params lists # Merge the custom logit processors and custom params lists
if self.has_custom_logit_processor or other.has_custom_logit_processor: if self.has_custom_logit_processor or other.has_custom_logit_processor:
# Merge the custom logit processors # Merge the custom logit processors
......
...@@ -49,7 +49,6 @@ class SamplingParams: ...@@ -49,7 +49,6 @@ class SamplingParams:
no_stop_trim: bool = False, no_stop_trim: bool = False,
ignore_eos: bool = False, ignore_eos: bool = False,
skip_special_tokens: bool = True, skip_special_tokens: bool = True,
return_hidden_states: bool = False,
custom_params: Optional[Dict[str, Any]] = None, custom_params: Optional[Dict[str, Any]] = None,
) -> None: ) -> None:
self.temperature = temperature self.temperature = temperature
...@@ -75,7 +74,6 @@ class SamplingParams: ...@@ -75,7 +74,6 @@ class SamplingParams:
self.ebnf = ebnf self.ebnf = ebnf
self.structural_tag = structural_tag self.structural_tag = structural_tag
self.no_stop_trim = no_stop_trim self.no_stop_trim = no_stop_trim
self.return_hidden_states = return_hidden_states
self.custom_params = custom_params self.custom_params = custom_params
# Process some special cases # Process some special cases
......
...@@ -17,7 +17,6 @@ class TestHiddenState(unittest.TestCase): ...@@ -17,7 +17,6 @@ class TestHiddenState(unittest.TestCase):
sampling_params = { sampling_params = {
"temperature": 0, "temperature": 0,
"max_new_tokens": 8, "max_new_tokens": 8,
"return_hidden_states": True,
} }
engine = sgl.Engine( engine = sgl.Engine(
...@@ -25,7 +24,11 @@ class TestHiddenState(unittest.TestCase): ...@@ -25,7 +24,11 @@ class TestHiddenState(unittest.TestCase):
random_seed=42, random_seed=42,
skip_tokenizer_init=True, skip_tokenizer_init=True,
) )
outputs = engine.generate(input_ids=input_ids, sampling_params=sampling_params) outputs = engine.generate(
input_ids=input_ids,
sampling_params=sampling_params,
return_hidden_states=True,
)
engine.shutdown() engine.shutdown()
for output in outputs: for output in outputs:
...@@ -81,16 +84,9 @@ class TestHiddenState(unittest.TestCase): ...@@ -81,16 +84,9 @@ class TestHiddenState(unittest.TestCase):
tokenizer = AutoTokenizer.from_pretrained(model_path) tokenizer = AutoTokenizer.from_pretrained(model_path)
input_ids = tokenizer(prompts).input_ids input_ids = tokenizer(prompts).input_ids
sample_completion = { sampling_params = {
"temperature": 0,
"max_new_tokens": 8,
"return_hidden_states": True,
}
sample_hidden_state = {
"temperature": 0, "temperature": 0,
"max_new_tokens": 8, "max_new_tokens": 8,
"return_hidden_states": False,
} }
engine = sgl.Engine( engine = sgl.Engine(
...@@ -99,14 +95,20 @@ class TestHiddenState(unittest.TestCase): ...@@ -99,14 +95,20 @@ class TestHiddenState(unittest.TestCase):
skip_tokenizer_init=True, skip_tokenizer_init=True,
) )
outputs_completion_first_round = engine.generate( outputs_completion_first_round = engine.generate(
input_ids=input_ids, sampling_params=sample_completion input_ids=input_ids,
sampling_params=sampling_params,
return_hidden_states=True,
) )
outputs_hidden_state = engine.generate( outputs_hidden_state = engine.generate(
input_ids=input_ids, sampling_params=sample_hidden_state input_ids=input_ids,
sampling_params=sampling_params,
return_hidden_states=False,
) )
outputs_completion_last_round = engine.generate( outputs_completion_last_round = engine.generate(
input_ids=input_ids, sampling_params=sample_completion input_ids=input_ids,
sampling_params=sampling_params,
return_hidden_states=True,
) )
engine.shutdown() engine.shutdown()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment