"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "9d945b2b90fb10f3d4299f1402ca8bf78fe7f7b8"
Unverified Commit 85d2365d authored by Qiaolin Yu's avatar Qiaolin Yu Committed by GitHub
Browse files

Fix the output of hidden states after HTTP requests (#4269)

parent 5fe79605
...@@ -7,6 +7,8 @@ the cuda graph will be recaptured, which might lead to a performance hit. ...@@ -7,6 +7,8 @@ the cuda graph will be recaptured, which might lead to a performance hit.
So avoid getting hidden states and completions alternately. So avoid getting hidden states and completions alternately.
""" """
import torch
import sglang as sgl import sglang as sgl
...@@ -31,11 +33,29 @@ def main(): ...@@ -31,11 +33,29 @@ def main():
outputs = llm.generate( outputs = llm.generate(
prompts, sampling_params=sampling_params, return_hidden_states=True prompts, sampling_params=sampling_params, return_hidden_states=True
) )
llm.shutdown()
for prompt, output in zip(prompts, outputs): for prompt, output in zip(prompts, outputs):
for i in range(len(output["meta_info"]["hidden_states"])):
output["meta_info"]["hidden_states"][i] = torch.tensor(
output["meta_info"]["hidden_states"][i], dtype=torch.bfloat16
)
print("===============================") print("===============================")
print( print(
f"Prompt: {prompt}\nGenerated text: {output['text']}\nPrompt_Tokens: {output['meta_info']['prompt_tokens']}\tCompletion_tokens: {output['meta_info']['completion_tokens']}\nHidden states: {[i.shape for i in output['meta_info']['hidden_states']]}" f"Prompt: {prompt}\n"
f"Generated text: {output['text']}\n"
f"Prompt_Tokens: {output['meta_info']['prompt_tokens']}\t"
f"Completion_tokens: {output['meta_info']['completion_tokens']}"
)
print("Hidden states: ")
hidden_states = torch.cat(
[
i.unsqueeze(0) if len(i.shape) == 1 else i
for i in output["meta_info"]["hidden_states"]
]
) )
print(hidden_states)
print() print()
......
...@@ -9,6 +9,7 @@ So avoid getting hidden states and completions alternately. ...@@ -9,6 +9,7 @@ So avoid getting hidden states and completions alternately.
""" """
import requests import requests
import torch
from sglang.test.test_utils import is_in_ci from sglang.test.test_utils import is_in_ci
from sglang.utils import print_highlight, terminate_process, wait_for_server from sglang.utils import print_highlight, terminate_process, wait_for_server
...@@ -50,20 +51,31 @@ def main(): ...@@ -50,20 +51,31 @@ def main():
json=json_data, json=json_data,
) )
terminate_process(server_process)
outputs = response.json() outputs = response.json()
for prompt, output in zip(prompts, outputs): for prompt, output in zip(prompts, outputs):
for i in range(len(output["meta_info"]["hidden_states"])):
output["meta_info"]["hidden_states"][i] = torch.tensor(
output["meta_info"]["hidden_states"][i], dtype=torch.bfloat16
)
print("===============================") print("===============================")
print( print(
f"Prompt: {prompt}\n" f"Prompt: {prompt}\n"
f"Generated text: {output['text']}\n" f"Generated text: {output['text']}\n"
f"Prompt_Tokens: {output['meta_info']['prompt_tokens']}\t" f"Prompt_Tokens: {output['meta_info']['prompt_tokens']}\t"
f"Completion_tokens: {output['meta_info']['completion_tokens']}\n" f"Completion_tokens: {output['meta_info']['completion_tokens']}"
f"Hidden states: {output['meta_info']['hidden_states']}" )
print("Hidden states: ")
hidden_states = torch.cat(
[
i.unsqueeze(0) if len(i.shape) == 1 else i
for i in output["meta_info"]["hidden_states"]
]
) )
print(hidden_states)
print() print()
terminate_process(server_process)
if __name__ == "__main__": if __name__ == "__main__":
main() main()
...@@ -361,7 +361,7 @@ class Req: ...@@ -361,7 +361,7 @@ class Req:
) = self.output_top_logprobs_idx = self.output_token_ids_logprobs_val = ( ) = self.output_top_logprobs_idx = self.output_token_ids_logprobs_val = (
self.output_token_ids_logprobs_idx self.output_token_ids_logprobs_idx
) = None ) = None
self.hidden_states = [] self.hidden_states: List[List[float]] = []
# Embedding (return values) # Embedding (return values)
self.embedding = None self.embedding = None
......
...@@ -111,6 +111,7 @@ class SchedulerOutputProcessorMixin: ...@@ -111,6 +111,7 @@ class SchedulerOutputProcessorMixin:
] ]
.cpu() .cpu()
.clone() .clone()
.tolist()
) )
if req.grammar is not None: if req.grammar is not None:
...@@ -245,7 +246,9 @@ class SchedulerOutputProcessorMixin: ...@@ -245,7 +246,9 @@ class SchedulerOutputProcessorMixin:
) )
if req.return_hidden_states and logits_output.hidden_states is not None: if req.return_hidden_states and logits_output.hidden_states is not None:
req.hidden_states.append(logits_output.hidden_states[i].cpu().clone()) req.hidden_states.append(
logits_output.hidden_states[i].cpu().clone().tolist()
)
if req.grammar is not None and batch.spec_algorithm.is_none(): if req.grammar is not None and batch.spec_algorithm.is_none():
req.grammar.accept_token(next_token_id) req.grammar.accept_token(next_token_id)
......
...@@ -33,8 +33,11 @@ class TestHiddenState(unittest.TestCase): ...@@ -33,8 +33,11 @@ class TestHiddenState(unittest.TestCase):
for output in outputs: for output in outputs:
self.assertEqual(len(output["meta_info"]["hidden_states"]), 8) self.assertEqual(len(output["meta_info"]["hidden_states"]), 8)
for hidden_state in output["meta_info"]["hidden_states"]: for i in range(len(output["meta_info"]["hidden_states"])):
self.assertIsInstance(hidden_state, torch.Tensor) assert isinstance(output["meta_info"]["hidden_states"][i], list)
output["meta_info"]["hidden_states"][i] = torch.tensor(
output["meta_info"]["hidden_states"][i], dtype=torch.bfloat16
)
# Checks that splicing of the batch was done correctly # Checks that splicing of the batch was done correctly
self.assertGreater( self.assertGreater(
outputs[1]["meta_info"]["hidden_states"][0].shape[0], outputs[1]["meta_info"]["hidden_states"][0].shape[0],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment