Unverified Commit d6898dd2 authored by Qiaolin Yu's avatar Qiaolin Yu Committed by GitHub
Browse files

Add return hidden state in the native API (#3897)


Co-authored-by: default avatarBeichen-Ma <mabeichen12@gmail.com>
Co-authored-by: default avatarChayenne <zhaochen20@outlook.com>
parent 71ed0183
...@@ -55,6 +55,7 @@ Please refer to our dedicated guide on [constrained decoding](https://docs.sglan ...@@ -55,6 +55,7 @@ Please refer to our dedicated guide on [constrained decoding](https://docs.sglan
* `ignore_eos`: Don't stop generation when EOS token is sampled. * `ignore_eos`: Don't stop generation when EOS token is sampled.
* `skip_special_tokens`: Remove special tokens during decoding. * `skip_special_tokens`: Remove special tokens during decoding.
* `custom_params`: Used when employing `CustomLogitProcessor`. For usage see below. * `custom_params`: Used when employing `CustomLogitProcessor`. For usage see below.
* `return_hidden_states`: Whether to return hidden states of the model. Note that each time it changes, the cuda graph will be recaptured, which might lead to a performance hit. See the [examples](https://github.com/sgl-project/sglang/blob/main/examples/runtime/engine/hidden_states.py) for more information.
### Custom Logit Processor ### Custom Logit Processor
......
...@@ -2,7 +2,9 @@ ...@@ -2,7 +2,9 @@
Usage: Usage:
python hidden_states.py python hidden_states.py
Note that we are actively working on moving return_hidden_states to the sampling_params. Note that each time you change the `return_hidden_states` parameter,
the cuda graph will be recaptured, which might lead to a performance hit.
So avoid getting hidden states and completions alternately.
""" """
import sglang as sgl import sglang as sgl
...@@ -18,10 +20,14 @@ def main(): ...@@ -18,10 +20,14 @@ def main():
# Create an LLM. # Create an LLM.
llm = sgl.Engine( llm = sgl.Engine(
model_path="Alibaba-NLP/gte-Qwen2-1.5B-instruct", model_path="Alibaba-NLP/gte-Qwen2-1.5B-instruct",
return_hidden_states=True,
) )
sampling_params = {"temperature": 0.8, "top_p": 0.95, "max_new_tokens": 10} sampling_params = {
"temperature": 0.8,
"top_p": 0.95,
"max_new_tokens": 10,
"return_hidden_states": True,
}
outputs = llm.generate(prompts, sampling_params=sampling_params) outputs = llm.generate(prompts, sampling_params=sampling_params)
for prompt, output in zip(prompts, outputs): for prompt, output in zip(prompts, outputs):
......
...@@ -607,9 +607,6 @@ class ScheduleBatch: ...@@ -607,9 +607,6 @@ class ScheduleBatch:
# Enable custom logit processor # Enable custom logit processor
enable_custom_logit_processor: bool = False enable_custom_logit_processor: bool = False
# Return hidden states
return_hidden_states: bool = False
@classmethod @classmethod
def init_new( def init_new(
cls, cls,
...@@ -621,7 +618,6 @@ class ScheduleBatch: ...@@ -621,7 +618,6 @@ class ScheduleBatch:
enable_overlap: bool, enable_overlap: bool,
spec_algorithm: SpeculativeAlgorithm, spec_algorithm: SpeculativeAlgorithm,
enable_custom_logit_processor: bool, enable_custom_logit_processor: bool,
return_hidden_states: bool = False,
): ):
return cls( return cls(
reqs=reqs, reqs=reqs,
...@@ -636,7 +632,6 @@ class ScheduleBatch: ...@@ -636,7 +632,6 @@ class ScheduleBatch:
device=req_to_token_pool.device, device=req_to_token_pool.device,
spec_algorithm=spec_algorithm, spec_algorithm=spec_algorithm,
enable_custom_logit_processor=enable_custom_logit_processor, enable_custom_logit_processor=enable_custom_logit_processor,
return_hidden_states=return_hidden_states,
) )
def batch_size(self): def batch_size(self):
...@@ -1205,7 +1200,7 @@ class ScheduleBatch: ...@@ -1205,7 +1200,7 @@ class ScheduleBatch:
spec_info=self.spec_info, spec_info=self.spec_info,
capture_hidden_mode=( capture_hidden_mode=(
CaptureHiddenMode.FULL CaptureHiddenMode.FULL
if self.return_hidden_states if self.sampling_info.return_hidden_states
else ( else (
getattr( getattr(
self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
......
...@@ -1030,7 +1030,6 @@ class Scheduler: ...@@ -1030,7 +1030,6 @@ class Scheduler:
self.enable_overlap, self.enable_overlap,
self.spec_algorithm, self.spec_algorithm,
self.server_args.enable_custom_logit_processor, self.server_args.enable_custom_logit_processor,
self.server_args.return_hidden_states,
) )
new_batch.prepare_for_extend() new_batch.prepare_for_extend()
...@@ -1221,9 +1220,8 @@ class Scheduler: ...@@ -1221,9 +1220,8 @@ class Scheduler:
logprob_pt += self.add_logprob_return_values( logprob_pt += self.add_logprob_return_values(
i, req, logprob_pt, next_token_ids, logits_output i, req, logprob_pt, next_token_ids, logits_output
) )
if ( if (
self.server_args.return_hidden_states req.sampling_params.return_hidden_states
and logits_output.hidden_states is not None and logits_output.hidden_states is not None
): ):
req.hidden_states.append( req.hidden_states.append(
...@@ -1331,7 +1329,7 @@ class Scheduler: ...@@ -1331,7 +1329,7 @@ class Scheduler:
) )
if ( if (
self.server_args.return_hidden_states req.sampling_params.return_hidden_states
and logits_output.hidden_states is not None and logits_output.hidden_states is not None
): ):
req.hidden_states.append(logits_output.hidden_states[i].cpu().clone()) req.hidden_states.append(logits_output.hidden_states[i].cpu().clone())
...@@ -1459,7 +1457,10 @@ class Scheduler: ...@@ -1459,7 +1457,10 @@ class Scheduler:
completion_tokens = [] completion_tokens = []
cached_tokens = [] cached_tokens = []
spec_verify_ct = [] spec_verify_ct = []
output_hidden_states = [] if self.server_args.return_hidden_states else None return_hidden_states = any(
req.sampling_params.return_hidden_states for req in reqs
)
output_hidden_states = [] if return_hidden_states else None
if return_logprob: if return_logprob:
input_token_logprobs_val = [] input_token_logprobs_val = []
...@@ -1526,7 +1527,7 @@ class Scheduler: ...@@ -1526,7 +1527,7 @@ class Scheduler:
output_top_logprobs_val.append(req.output_top_logprobs_val) output_top_logprobs_val.append(req.output_top_logprobs_val)
output_top_logprobs_idx.append(req.output_top_logprobs_idx) output_top_logprobs_idx.append(req.output_top_logprobs_idx)
if self.server_args.return_hidden_states: if req.sampling_params.return_hidden_states:
output_hidden_states.append(req.hidden_states) output_hidden_states.append(req.hidden_states)
# Send to detokenizer # Send to detokenizer
...@@ -1619,7 +1620,6 @@ class Scheduler: ...@@ -1619,7 +1620,6 @@ class Scheduler:
self.enable_overlap, self.enable_overlap,
self.spec_algorithm, self.spec_algorithm,
self.server_args.enable_custom_logit_processor, self.server_args.enable_custom_logit_processor,
self.server_args.return_hidden_states,
) )
idle_batch.prepare_for_idle() idle_batch.prepare_for_idle()
return idle_batch return idle_batch
......
...@@ -120,7 +120,7 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner): ...@@ -120,7 +120,7 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
if max(capture_bs) > model_runner.req_to_token_pool.size: if max(capture_bs) > model_runner.req_to_token_pool.size:
# In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests # In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests
# is very samll. We add more values here to make sure we capture the maximum bs. # is very small. We add more values here to make sure we capture the maximum bs.
capture_bs = list( capture_bs = list(
sorted( sorted(
set( set(
...@@ -175,6 +175,7 @@ class CudaGraphRunner: ...@@ -175,6 +175,7 @@ class CudaGraphRunner:
# Batch sizes to capture # Batch sizes to capture
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner) self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
self.capture_forward_mode = ForwardMode.DECODE self.capture_forward_mode = ForwardMode.DECODE
self.capture_hidden_mode = CaptureHiddenMode.NULL
self.num_tokens_per_bs = 1 self.num_tokens_per_bs = 1
if model_runner.spec_algorithm.is_eagle(): if model_runner.spec_algorithm.is_eagle():
if self.model_runner.is_draft_worker: if self.model_runner.is_draft_worker:
...@@ -335,6 +336,10 @@ class CudaGraphRunner: ...@@ -335,6 +336,10 @@ class CudaGraphRunner:
gathered_buffer = None gathered_buffer = None
spec_info = self.get_spec_info(num_tokens) spec_info = self.get_spec_info(num_tokens)
if self.capture_hidden_mode != CaptureHiddenMode.FULL:
self.capture_hidden_mode = (
spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL
)
forward_batch = ForwardBatch( forward_batch = ForwardBatch(
forward_mode=self.capture_forward_mode, forward_mode=self.capture_forward_mode,
...@@ -355,15 +360,7 @@ class CudaGraphRunner: ...@@ -355,15 +360,7 @@ class CudaGraphRunner:
mrope_positions=mrope_positions, mrope_positions=mrope_positions,
spec_algorithm=self.model_runner.spec_algorithm, spec_algorithm=self.model_runner.spec_algorithm,
spec_info=spec_info, spec_info=spec_info,
capture_hidden_mode=( capture_hidden_mode=self.capture_hidden_mode,
CaptureHiddenMode.FULL
if self.model_runner.server_args.return_hidden_states
else (
spec_info.capture_hidden_mode
if spec_info
else CaptureHiddenMode.NULL
)
),
) )
# Attention backend # Attention backend
...@@ -406,6 +403,23 @@ class CudaGraphRunner: ...@@ -406,6 +403,23 @@ class CudaGraphRunner:
def replay(self, forward_batch: ForwardBatch): def replay(self, forward_batch: ForwardBatch):
assert forward_batch.out_cache_loc is not None assert forward_batch.out_cache_loc is not None
hidden_mode_from_spec_info = getattr(
forward_batch.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
)
# If the capture_hidden_mode changes, we need to recapture the graph
if (
forward_batch.sampling_info.return_hidden_states
and self.capture_hidden_mode != CaptureHiddenMode.FULL
):
self.capture_hidden_mode = CaptureHiddenMode.FULL
self.capture()
elif (
not forward_batch.sampling_info.return_hidden_states
and self.capture_hidden_mode != hidden_mode_from_spec_info
):
self.capture_hidden_mode = hidden_mode_from_spec_info
self.capture()
raw_bs = forward_batch.batch_size raw_bs = forward_batch.batch_size
raw_num_token = raw_bs * self.num_tokens_per_bs raw_num_token = raw_bs * self.num_tokens_per_bs
......
...@@ -37,6 +37,9 @@ class SamplingBatchInfo: ...@@ -37,6 +37,9 @@ class SamplingBatchInfo:
# Whether any request has custom logit processor # Whether any request has custom logit processor
has_custom_logit_processor: bool has_custom_logit_processor: bool
# Whether any request needs to return hidden states
return_hidden_states: bool
# Bias Tensors # Bias Tensors
vocab_size: int vocab_size: int
grammars: Optional[List] = None grammars: Optional[List] = None
...@@ -91,6 +94,9 @@ class SamplingBatchInfo: ...@@ -91,6 +94,9 @@ class SamplingBatchInfo:
and any(r.custom_logit_processor for r in reqs) # then check the requests. and any(r.custom_logit_processor for r in reqs) # then check the requests.
) )
# Check if any request needs to return hidden states
return_hidden_states = any(r.sampling_params.return_hidden_states for r in reqs)
if has_custom_logit_processor: if has_custom_logit_processor:
# Merge the same type of custom logit processors together # Merge the same type of custom logit processors together
processor_dict = {} processor_dict = {}
...@@ -130,6 +136,7 @@ class SamplingBatchInfo: ...@@ -130,6 +136,7 @@ class SamplingBatchInfo:
device=device, device=device,
custom_params=custom_params, custom_params=custom_params,
custom_logit_processor=merged_custom_logit_processor, custom_logit_processor=merged_custom_logit_processor,
return_hidden_states=return_hidden_states,
) )
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge. # TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
...@@ -336,6 +343,10 @@ class SamplingBatchInfo: ...@@ -336,6 +343,10 @@ class SamplingBatchInfo:
self.logit_bias = SamplingBatchInfo.merge_bias_tensor( self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
self.logit_bias, other.logit_bias, len(self), len(other), self.device self.logit_bias, other.logit_bias, len(self), len(other), self.device
) )
# Merge the return hidden states flag
self.return_hidden_states |= other.return_hidden_states
# Merge the custom logit processors and custom params lists # Merge the custom logit processors and custom params lists
if self.has_custom_logit_processor or other.has_custom_logit_processor: if self.has_custom_logit_processor or other.has_custom_logit_processor:
# Merge the custom logit processors # Merge the custom logit processors
......
...@@ -48,6 +48,7 @@ class SamplingParams: ...@@ -48,6 +48,7 @@ class SamplingParams:
no_stop_trim: bool = False, no_stop_trim: bool = False,
ignore_eos: bool = False, ignore_eos: bool = False,
skip_special_tokens: bool = True, skip_special_tokens: bool = True,
return_hidden_states: bool = False,
custom_params: Optional[Dict[str, Any]] = None, custom_params: Optional[Dict[str, Any]] = None,
) -> None: ) -> None:
self.temperature = temperature self.temperature = temperature
...@@ -72,6 +73,7 @@ class SamplingParams: ...@@ -72,6 +73,7 @@ class SamplingParams:
self.json_schema = json_schema self.json_schema = json_schema
self.ebnf = ebnf self.ebnf = ebnf
self.no_stop_trim = no_stop_trim self.no_stop_trim = no_stop_trim
self.return_hidden_states = return_hidden_states
self.custom_params = custom_params self.custom_params = custom_params
# Process some special cases # Process some special cases
......
...@@ -162,7 +162,6 @@ class ServerArgs: ...@@ -162,7 +162,6 @@ class ServerArgs:
delete_ckpt_after_loading: bool = False delete_ckpt_after_loading: bool = False
enable_memory_saver: bool = False enable_memory_saver: bool = False
allow_auto_truncate: bool = False allow_auto_truncate: bool = False
return_hidden_states: bool = False
enable_custom_logit_processor: bool = False enable_custom_logit_processor: bool = False
tool_call_parser: str = None tool_call_parser: str = None
enable_hierarchical_cache: bool = False enable_hierarchical_cache: bool = False
...@@ -917,11 +916,6 @@ class ServerArgs: ...@@ -917,11 +916,6 @@ class ServerArgs:
action="store_true", action="store_true",
help="Enable users to pass custom logit processors to the server (disabled by default for security)", help="Enable users to pass custom logit processors to the server (disabled by default for security)",
) )
parser.add_argument(
"--return-hidden-states",
action="store_true",
help="Return hidden states in the response.",
)
parser.add_argument( parser.add_argument(
"--tool-call-parser", "--tool-call-parser",
type=str, type=str,
......
...@@ -14,12 +14,15 @@ class TestHiddenState(unittest.TestCase): ...@@ -14,12 +14,15 @@ class TestHiddenState(unittest.TestCase):
tokenizer = AutoTokenizer.from_pretrained(model_path) tokenizer = AutoTokenizer.from_pretrained(model_path)
input_ids = tokenizer(prompts).input_ids input_ids = tokenizer(prompts).input_ids
sampling_params = {"temperature": 0, "max_new_tokens": 8} sampling_params = {
"temperature": 0,
"max_new_tokens": 8,
"return_hidden_states": True,
}
engine = sgl.Engine( engine = sgl.Engine(
model_path=model_path, model_path=model_path,
random_seed=42, random_seed=42,
return_hidden_states=True,
skip_tokenizer_init=True, skip_tokenizer_init=True,
) )
outputs = engine.generate(input_ids=input_ids, sampling_params=sampling_params) outputs = engine.generate(input_ids=input_ids, sampling_params=sampling_params)
...@@ -72,6 +75,58 @@ class TestHiddenState(unittest.TestCase): ...@@ -72,6 +75,58 @@ class TestHiddenState(unittest.TestCase):
) )
) )
def test_repeatedly_changes_hidden_states(self):
prompts = ["Today is", "Today is a sunny day and I like"]
model_path = "meta-llama/Meta-Llama-3.1-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_path)
input_ids = tokenizer(prompts).input_ids
sample_completion = {
"temperature": 0,
"max_new_tokens": 8,
"return_hidden_states": True,
}
sample_hidden_state = {
"temperature": 0,
"max_new_tokens": 8,
"return_hidden_states": False,
}
engine = sgl.Engine(
model_path=model_path,
random_seed=42,
skip_tokenizer_init=True,
)
outputs_completion_first_round = engine.generate(
input_ids=input_ids, sampling_params=sample_completion
)
outputs_hidden_state = engine.generate(
input_ids=input_ids, sampling_params=sample_hidden_state
)
outputs_completion_last_round = engine.generate(
input_ids=input_ids, sampling_params=sample_completion
)
engine.shutdown()
for (
output_completion_first_round,
output_hidden_state,
output_completion_last_round,
) in zip(
outputs_completion_first_round,
outputs_hidden_state,
outputs_completion_last_round,
):
self.assertEqual(
len(output_completion_first_round["meta_info"]["hidden_states"]), 8
)
self.assertNotIn("hidden_states", output_hidden_state["meta_info"])
self.assertEqual(
len(output_completion_last_round["meta_info"]["hidden_states"]), 8
)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment