"docs/vscode:/vscode.git/clone" did not exist on "84fccf579324661107c4bf407b94ace92a278339"
Unverified Commit 32f61443 authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

fix: Fix returned prefill logits and add output str test (#1046)

parent fb1f28cb
...@@ -208,6 +208,11 @@ class LogitsProcessor(nn.Module): ...@@ -208,6 +208,11 @@ class LogitsProcessor(nn.Module):
all_logits = tensor_model_parallel_all_gather(all_logits) all_logits = tensor_model_parallel_all_gather(all_logits)
all_logits = all_logits[:, : self.config.vocab_size].float() all_logits = all_logits[:, : self.config.vocab_size].float()
if hasattr(self.config, "final_logit_softcapping"):
all_logits /= self.config.final_logit_softcapping
all_logits = torch.tanh(all_logits)
all_logits *= self.config.final_logit_softcapping
all_logprobs = all_logits all_logprobs = all_logits
del all_logits, hidden_states del all_logits, hidden_states
all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1) all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
......
...@@ -26,9 +26,11 @@ from sglang.srt.server import Runtime ...@@ -26,9 +26,11 @@ from sglang.srt.server import Runtime
from sglang.srt.utils import is_generation_model from sglang.srt.utils import is_generation_model
DEFAULT_PROMPTS = [ DEFAULT_PROMPTS = [
"The capital of France is", # the output of gemma-2-2b from SRT is unstable on the commented prompt
# "The capital of France is",
"The capital of the United Kindom is", "The capital of the United Kindom is",
"Today is a sunny day and I like", "Today is a sunny day and I like",
"AI is a field of computer science focused on",
] ]
NUM_TOP_LOGPROBS = 5 NUM_TOP_LOGPROBS = 5
...@@ -43,10 +45,11 @@ def get_dtype_str(torch_dtype): ...@@ -43,10 +45,11 @@ def get_dtype_str(torch_dtype):
@dataclass @dataclass
class ModelOutput: class ModelOutput:
output_strs: str = None output_strs: List[str] = None
top_input_logprobs: torch.Tensor = None output_ids: List[int] = None
top_output_logprobs: torch.Tensor = None top_input_logprobs: List[torch.Tensor] = None
embed_logits: torch.Tensor = None top_output_logprobs: List[torch.Tensor] = None
embed_logits: List[torch.Tensor] = None
class HFRunner: class HFRunner:
...@@ -117,7 +120,9 @@ class HFRunner: ...@@ -117,7 +120,9 @@ class HFRunner:
output_ids = self.model.generate( output_ids = self.model.generate(
input_ids, do_sample=False, max_new_tokens=max_new_tokens input_ids, do_sample=False, max_new_tokens=max_new_tokens
) )
output_strs.append(self.tokenizer.decode(output_ids[0])) output_strs.append(
self.tokenizer.decode(output_ids[0][len(input_ids[0]) :])
)
logits = self.model.forward(input_ids).logits[0] logits = self.model.forward(input_ids).logits[0]
logprobs = F.log_softmax( logprobs = F.log_softmax(
...@@ -145,7 +150,7 @@ class HFRunner: ...@@ -145,7 +150,7 @@ class HFRunner:
def forward( def forward(
self, self,
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS, prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
max_new_tokens=64, max_new_tokens=8,
): ):
self.in_queue.put((prompts, max_new_tokens)) self.in_queue.put((prompts, max_new_tokens))
return self.out_queue.get() return self.out_queue.get()
...@@ -184,7 +189,7 @@ class SRTRunner: ...@@ -184,7 +189,7 @@ class SRTRunner:
def forward( def forward(
self, self,
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS, prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
max_new_tokens=64, max_new_tokens=8,
): ):
if self.is_generation_model: if self.is_generation_model:
# the return value contains logprobs from prefill # the return value contains logprobs from prefill
......
...@@ -21,23 +21,25 @@ from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner ...@@ -21,23 +21,25 @@ from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner
MODELS = [ MODELS = [
("meta-llama/Meta-Llama-3.1-8B-Instruct", 1), ("meta-llama/Meta-Llama-3.1-8B-Instruct", 1),
("google/gemma-2-2b", 1),
] ]
TORCH_DTYPES = [torch.float16] TORCH_DTYPES = [torch.float16]
class TestCausalModels(unittest.TestCase): class TestGenerationModels(unittest.TestCase):
def assert_close_prefill_logits( def assert_close_prefill_logits_and_output_strs(
self, self,
prompts, prompts,
model_path, model_path,
tp_size, tp_size,
torch_dtype, torch_dtype,
max_new_tokens,
) -> None: ) -> None:
with HFRunner( with HFRunner(
model_path, torch_dtype=torch_dtype, is_generation_model=True model_path, torch_dtype=torch_dtype, is_generation_model=True
) as hf_runner: ) as hf_runner:
hf_outputs = hf_runner.forward(prompts) hf_outputs = hf_runner.forward(prompts, max_new_tokens=max_new_tokens)
with SRTRunner( with SRTRunner(
model_path, model_path,
...@@ -45,7 +47,7 @@ class TestCausalModels(unittest.TestCase): ...@@ -45,7 +47,7 @@ class TestCausalModels(unittest.TestCase):
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
is_generation_model=True, is_generation_model=True,
) as srt_runner: ) as srt_runner:
srt_outputs = srt_runner.forward(prompts) srt_outputs = srt_runner.forward(prompts, max_new_tokens=max_new_tokens)
for i in range(len(prompts)): for i in range(len(prompts)):
hf_logprobs = torch.Tensor(hf_outputs.top_input_logprobs[i]) hf_logprobs = torch.Tensor(hf_outputs.top_input_logprobs[i])
...@@ -56,11 +58,18 @@ class TestCausalModels(unittest.TestCase): ...@@ -56,11 +58,18 @@ class TestCausalModels(unittest.TestCase):
abs(hf_logprobs - srt_logprobs) < tolerance abs(hf_logprobs - srt_logprobs) < tolerance
), f"prefill logprobs not all close" ), f"prefill logprobs not all close"
assert hf_outputs.output_strs == srt_outputs.output_strs
def test_prefill_logits(self): def test_prefill_logits(self):
for model, tp_size in MODELS: for model, tp_size in MODELS:
for torch_dtype in TORCH_DTYPES: for torch_dtype in TORCH_DTYPES:
self.assert_close_prefill_logits( max_new_tokens = 8
DEFAULT_PROMPTS, model, tp_size, torch_dtype self.assert_close_prefill_logits_and_output_strs(
DEFAULT_PROMPTS,
model,
tp_size,
torch_dtype,
max_new_tokens,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment