Unverified Commit 689ff588 authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

[CI] Return output logprobs in unit test (#1361)

parent a7c47e0f
...@@ -50,6 +50,12 @@ def get_dtype_str(torch_dtype): ...@@ -50,6 +50,12 @@ def get_dtype_str(torch_dtype):
raise NotImplementedError() raise NotImplementedError()
def get_top_logprobs(logits, k):
logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
logprobs, top_indices = torch.topk(logprobs, k=k, dim=-1)
return logprobs
@dataclass @dataclass
class ModelOutput: class ModelOutput:
output_strs: List[str] = None output_strs: List[str] = None
...@@ -108,7 +114,8 @@ class HFRunner: ...@@ -108,7 +114,8 @@ class HFRunner:
if prompts is not None: if prompts is not None:
if self.is_generation: if self.is_generation:
output_strs = [] output_strs = []
prefill_logprobs = [] top_input_logprobs = []
top_output_logprobs = []
for p in prompts: for p in prompts:
if isinstance(p, str): if isinstance(p, str):
input_ids = self.tokenizer.encode( input_ids = self.tokenizer.encode(
...@@ -117,32 +124,43 @@ class HFRunner: ...@@ -117,32 +124,43 @@ class HFRunner:
else: else:
input_ids = torch.tensor([p], device="cuda") input_ids = torch.tensor([p], device="cuda")
output_ids = self.model.generate( outputs = self.model.generate(
input_ids, do_sample=False, max_new_tokens=max_new_tokens input_ids,
do_sample=False,
temperature=None,
top_p=None,
max_new_tokens=max_new_tokens,
return_dict_in_generate=True,
output_scores=True,
) )
output_strs.append( output_strs.append(
self.tokenizer.decode(output_ids[0][len(input_ids[0]) :]) self.tokenizer.decode(outputs[0][0][len(input_ids[0]) :])
) )
# outputs.scores: (num_token, 1, vocab_size)
top_output_logprobs.append(
[
get_top_logprobs(logits[0], NUM_TOP_LOGPROBS).tolist()
for logits in outputs.scores
]
)
del outputs
logits = self.model.forward(input_ids).logits[0] input_logits = self.model.forward(input_ids).logits[0]
logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32) top_input_logprobs.append(
logprobs, top_indices = torch.topk( get_top_logprobs(input_logits, NUM_TOP_LOGPROBS).tolist()
logprobs, k=NUM_TOP_LOGPROBS, dim=-1
) )
# print("index", top_indices) del input_logits
prefill_logprobs.append(logprobs.tolist())
del logits
del logprobs
out_queue.put( out_queue.put(
ModelOutput( ModelOutput(
output_strs=output_strs, top_input_logprobs=prefill_logprobs output_strs=output_strs,
top_input_logprobs=top_input_logprobs,
top_output_logprobs=top_output_logprobs,
) )
) )
else: else:
logits = self.model.encode(prompts).tolist() logits = self.model.encode(prompts).tolist()
out_queue.put(ModelOutput(embed_logits=logits)) out_queue.put(ModelOutput(embed_logits=logits))
def forward( def forward(
...@@ -194,6 +212,7 @@ class SRTRunner: ...@@ -194,6 +212,7 @@ class SRTRunner:
# the return value contains logprobs from prefill # the return value contains logprobs from prefill
output_strs = [] output_strs = []
top_input_logprobs = [] top_input_logprobs = []
top_output_logprobs = []
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0} sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
for prompt in prompts: for prompt in prompts:
response = self.runtime.generate( response = self.runtime.generate(
...@@ -219,9 +238,17 @@ class SRTRunner: ...@@ -219,9 +238,17 @@ class SRTRunner:
] ]
] ]
) )
top_output_logprobs.append(
[
[tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
for x in response["meta_info"]["output_top_logprobs"]
]
)
return ModelOutput( return ModelOutput(
output_strs=output_strs, top_input_logprobs=top_input_logprobs output_strs=output_strs,
top_input_logprobs=top_input_logprobs,
top_output_logprobs=top_output_logprobs,
) )
else: else:
response = self.runtime.encode(prompts) response = self.runtime.encode(prompts)
......
...@@ -21,9 +21,9 @@ import torch ...@@ -21,9 +21,9 @@ import torch
from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner
MODELS = [ MODELS = [
("meta-llama/Meta-Llama-3.1-8B-Instruct", 1, 1.1, 3e-2, 1), ("meta-llama/Meta-Llama-3.1-8B-Instruct", 1, 1.1, 3e-2, 4e-2, 1),
("google/gemma-2-2b", 1, 3, 3e-2, 1), ("google/gemma-2-2b", 1, 3, 3e-2, 5e-2, 1),
("Alibaba-NLP/gte-Qwen2-1.5B-instruct", 1, None, 6e-2, 1), ("Alibaba-NLP/gte-Qwen2-1.5B-instruct", 1, None, 6e-2, 4e-2, 1),
] ]
TORCH_DTYPES = [torch.float16] TORCH_DTYPES = [torch.float16]
...@@ -70,6 +70,7 @@ class TestGenerationModels(unittest.TestCase): ...@@ -70,6 +70,7 @@ class TestGenerationModels(unittest.TestCase):
torch_dtype, torch_dtype,
max_new_tokens, max_new_tokens,
prefill_tolerance, prefill_tolerance,
output_tolerance,
rouge_threshold, rouge_threshold,
long_context_tolerance, long_context_tolerance,
) -> None: ) -> None:
...@@ -89,15 +90,37 @@ class TestGenerationModels(unittest.TestCase): ...@@ -89,15 +90,37 @@ class TestGenerationModels(unittest.TestCase):
srt_outputs = srt_runner.forward(prompts, max_new_tokens=max_new_tokens) srt_outputs = srt_runner.forward(prompts, max_new_tokens=max_new_tokens)
for i in range(len(prompts)): for i in range(len(prompts)):
# input logprobs comparison
hf_logprobs = torch.Tensor(hf_outputs.top_input_logprobs[i]) hf_logprobs = torch.Tensor(hf_outputs.top_input_logprobs[i])
srt_logprobs = torch.Tensor(srt_outputs.top_input_logprobs[i]) srt_logprobs = torch.Tensor(srt_outputs.top_input_logprobs[i])
input_len = hf_logprobs.shape[0]
print("max_diff", torch.max(abs(hf_logprobs - srt_logprobs))) print(
if hf_logprobs.shape[0] <= 100: "prefill logprobs max_diff", torch.max(abs(hf_logprobs - srt_logprobs))
)
if input_len <= 100:
assert torch.all( assert torch.all(
abs(hf_logprobs - srt_logprobs) < prefill_tolerance abs(hf_logprobs - srt_logprobs) < prefill_tolerance
), f"prefill logprobs are not all close with model_path={model_path} prompts={prompts} prefill_tolerance={prefill_tolerance}" ), f"prefill logprobs are not all close with model_path={model_path} prompts={prompts} prefill_tolerance={prefill_tolerance}"
# output logprobs comparison
hf_logprobs = torch.Tensor(hf_outputs.top_output_logprobs[i])
srt_logprobs = torch.Tensor(srt_outputs.top_output_logprobs[i])
# print(
# "output logprobs diff",
# [
# float(torch.max(abs(hf_logprobs[j] - srt_logprobs[j])))
# for j in range(max_new_tokens)
# ],
# )
print(
"output logprobs max_diff", torch.max(abs(hf_logprobs - srt_logprobs))
)
if input_len <= 100:
assert torch.all(
abs(hf_logprobs - srt_logprobs) < output_tolerance
), f"output logprobs are not all close with model_path={model_path} prompts={prompts}... output_tolerance={output_tolerance}"
# output strings comparison
print(f"hf_outputs.output_strs={hf_outputs.output_strs}") print(f"hf_outputs.output_strs={hf_outputs.output_strs}")
print(f"srt_outputs.output_strs={srt_outputs.output_strs}") print(f"srt_outputs.output_strs={srt_outputs.output_strs}")
rouge_l_scores = calculate_rouge_l( rouge_l_scores = calculate_rouge_l(
...@@ -114,6 +137,7 @@ class TestGenerationModels(unittest.TestCase): ...@@ -114,6 +137,7 @@ class TestGenerationModels(unittest.TestCase):
tp_size, tp_size,
long_context_tolerance, long_context_tolerance,
prefill_tolerance, prefill_tolerance,
output_tolerance,
rouge_threshold, rouge_threshold,
) in MODELS: ) in MODELS:
for torch_dtype in TORCH_DTYPES: for torch_dtype in TORCH_DTYPES:
...@@ -125,6 +149,7 @@ class TestGenerationModels(unittest.TestCase): ...@@ -125,6 +149,7 @@ class TestGenerationModels(unittest.TestCase):
torch_dtype, torch_dtype,
max_new_tokens, max_new_tokens,
prefill_tolerance=prefill_tolerance, prefill_tolerance=prefill_tolerance,
output_tolerance=output_tolerance,
rouge_threshold=rouge_threshold, rouge_threshold=rouge_threshold,
long_context_tolerance=long_context_tolerance, long_context_tolerance=long_context_tolerance,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment