Unverified Commit aba9eae4 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix the correctness test in bench_latency.py when tp > 1 and test_generation_models.py (#1631)

parent bbd72bfc
...@@ -220,6 +220,7 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len): ...@@ -220,6 +220,7 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
return reqs return reqs
@torch.inference_mode()
def extend(reqs, model_runner): def extend(reqs, model_runner):
batch = ScheduleBatch.init_new( batch = ScheduleBatch.init_new(
reqs=reqs, reqs=reqs,
...@@ -235,6 +236,7 @@ def extend(reqs, model_runner): ...@@ -235,6 +236,7 @@ def extend(reqs, model_runner):
return next_token_ids, logits_output.next_token_logits, batch return next_token_ids, logits_output.next_token_logits, batch
@torch.inference_mode()
def decode(input_token_ids, batch, model_runner): def decode(input_token_ids, batch, model_runner):
batch.prepare_for_decode(input_token_ids) batch.prepare_for_decode(input_token_ids)
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
...@@ -244,7 +246,6 @@ def decode(input_token_ids, batch, model_runner): ...@@ -244,7 +246,6 @@ def decode(input_token_ids, batch, model_runner):
return next_token_ids, logits_output.next_token_logits return next_token_ids, logits_output.next_token_logits
@torch.inference_mode()
def correctness_test( def correctness_test(
server_args, server_args,
port_args, port_args,
...@@ -287,7 +288,6 @@ def correctness_test( ...@@ -287,7 +288,6 @@ def correctness_test(
rank_print(tokenizer.decode(output_ids[i]), "\n") rank_print(tokenizer.decode(output_ids[i]), "\n")
@torch.inference_mode()
def latency_test_run_once( def latency_test_run_once(
run_name, model_runner, rank_print, reqs, batch_size, input_len, output_len run_name, model_runner, rank_print, reqs, batch_size, input_len, output_len
): ):
......
...@@ -42,13 +42,13 @@ class ModelCase: ...@@ -42,13 +42,13 @@ class ModelCase:
rouge_l_tolerance: float = 1 rouge_l_tolerance: float = 1
# Popular models that run on CI # Popular models that run on the CI
CI_MODELS = [ CI_MODELS = [
ModelCase("meta-llama/Llama-3.1-8B-Instruct"), ModelCase("meta-llama/Llama-3.1-8B-Instruct"),
ModelCase("google/gemma-2-2b"), ModelCase("google/gemma-2-2b"),
] ]
# All other models # All other models that do not run on the CI
ALL_OTHER_MODELS = [ ALL_OTHER_MODELS = [
ModelCase("Qwen/Qwen2-1.5B"), ModelCase("Qwen/Qwen2-1.5B"),
ModelCase("Qwen/Qwen2.5-14B-Instruct"), ModelCase("Qwen/Qwen2.5-14B-Instruct"),
...@@ -59,6 +59,10 @@ TORCH_DTYPES = [torch.float16] ...@@ -59,6 +59,10 @@ TORCH_DTYPES = [torch.float16]
class TestGenerationModels(unittest.TestCase): class TestGenerationModels(unittest.TestCase):
@classmethod
def setUpClass(cls):
mp.set_start_method("spawn")
def assert_close_logits_and_output_strs( def assert_close_logits_and_output_strs(
self, self,
prompts: List[str], prompts: List[str],
...@@ -140,16 +144,21 @@ class TestGenerationModels(unittest.TestCase): ...@@ -140,16 +144,21 @@ class TestGenerationModels(unittest.TestCase):
return return
for model_case in ALL_OTHER_MODELS: for model_case in ALL_OTHER_MODELS:
# Only run a specified model
if ( if (
"ONLY_RUN" in os.environ "ONLY_RUN" in os.environ
and os.environ["ONLY_RUN"] != model_case.model_path and os.environ["ONLY_RUN"] != model_case.model_path
): ):
continue continue
self.assert_close_logits_and_output_strs(
DEFAULT_PROMPTS, model_case, torch.float16 # Skip long prompts for models that does not have a long context
) prompts = DEFAULT_PROMPTS
if model_case.model_path in ["HuggingFaceTB/SmolLM-135M-Instruct"]:
prompts = [p for p in DEFAULT_PROMPTS if len(p) < 1000]
# Assert the logits and output strs are close
self.assert_close_logits_and_output_strs(prompts, model_case, torch.float16)
if __name__ == "__main__": if __name__ == "__main__":
mp.set_start_method("spawn")
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment