"vscode:/vscode.git/clone" did not exist on "abb89da4de3b9196933d9a885db822d60da18cac"
Unverified Commit aba9eae4 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix the correctness test in bench_latency.py when tp > 1 and test_generation_models.py (#1631)

parent bbd72bfc
......@@ -220,6 +220,7 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
return reqs
@torch.inference_mode()
def extend(reqs, model_runner):
batch = ScheduleBatch.init_new(
reqs=reqs,
......@@ -235,6 +236,7 @@ def extend(reqs, model_runner):
return next_token_ids, logits_output.next_token_logits, batch
@torch.inference_mode()
def decode(input_token_ids, batch, model_runner):
batch.prepare_for_decode(input_token_ids)
model_worker_batch = batch.get_model_worker_batch()
......@@ -244,7 +246,6 @@ def decode(input_token_ids, batch, model_runner):
return next_token_ids, logits_output.next_token_logits
@torch.inference_mode()
def correctness_test(
server_args,
port_args,
......@@ -287,7 +288,6 @@ def correctness_test(
rank_print(tokenizer.decode(output_ids[i]), "\n")
@torch.inference_mode()
def latency_test_run_once(
run_name, model_runner, rank_print, reqs, batch_size, input_len, output_len
):
......
......@@ -42,13 +42,13 @@ class ModelCase:
rouge_l_tolerance: float = 1
# Popular models that run on CI
# Popular models that run on the CI
CI_MODELS = [
ModelCase("meta-llama/Llama-3.1-8B-Instruct"),
ModelCase("google/gemma-2-2b"),
]
# All other models
# All other models that do not run on the CI
ALL_OTHER_MODELS = [
ModelCase("Qwen/Qwen2-1.5B"),
ModelCase("Qwen/Qwen2.5-14B-Instruct"),
......@@ -59,6 +59,10 @@ TORCH_DTYPES = [torch.float16]
class TestGenerationModels(unittest.TestCase):
@classmethod
def setUpClass(cls):
mp.set_start_method("spawn")
def assert_close_logits_and_output_strs(
self,
prompts: List[str],
......@@ -140,16 +144,21 @@ class TestGenerationModels(unittest.TestCase):
return
for model_case in ALL_OTHER_MODELS:
# Only run a specified model
if (
"ONLY_RUN" in os.environ
and os.environ["ONLY_RUN"] != model_case.model_path
):
continue
self.assert_close_logits_and_output_strs(
DEFAULT_PROMPTS, model_case, torch.float16
)
# Skip long prompts for models that does not have a long context
prompts = DEFAULT_PROMPTS
if model_case.model_path in ["HuggingFaceTB/SmolLM-135M-Instruct"]:
prompts = [p for p in DEFAULT_PROMPTS if len(p) < 1000]
# Assert the logits and output strs are close
self.assert_close_logits_and_output_strs(prompts, model_case, torch.float16)
if __name__ == "__main__":
mp.set_start_method("spawn")
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment