Unverified Commit e903f695 authored by Lifu Huang's avatar Lifu Huang Committed by GitHub
Browse files

Fix potential flakiness in test_lora_qwen3 (#10250)

parent 27760fc1
...@@ -24,6 +24,7 @@ from utils import ( ...@@ -24,6 +24,7 @@ from utils import (
CI_MULTI_LORA_MODELS, CI_MULTI_LORA_MODELS,
TORCH_DTYPES, TORCH_DTYPES,
LoRAModelCase, LoRAModelCase,
ensure_reproducibility,
) )
from sglang.test.runners import HFRunner, SRTRunner from sglang.test.runners import HFRunner, SRTRunner
...@@ -76,13 +77,6 @@ class TestLoRA(CustomTestCase): ...@@ -76,13 +77,6 @@ class TestLoRA(CustomTestCase):
return batches return batches
def ensure_reproducibility(self):
seed = 42
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.use_deterministic_algorithms(True)
def _run_lora_multiple_batch_on_model_cases(self, model_cases: List[LoRAModelCase]): def _run_lora_multiple_batch_on_model_cases(self, model_cases: List[LoRAModelCase]):
for model_case in model_cases: for model_case in model_cases:
for torch_dtype in TORCH_DTYPES: for torch_dtype in TORCH_DTYPES:
...@@ -121,14 +115,14 @@ class TestLoRA(CustomTestCase): ...@@ -121,14 +115,14 @@ class TestLoRA(CustomTestCase):
f"\n--- Running Batch {i} --- prompts: {prompts}, lora_paths: {lora_paths}" f"\n--- Running Batch {i} --- prompts: {prompts}, lora_paths: {lora_paths}"
) )
self.ensure_reproducibility() ensure_reproducibility()
srt_outputs = srt_runner.batch_forward( srt_outputs = srt_runner.batch_forward(
prompts, prompts,
max_new_tokens=max_new_tokens, max_new_tokens=max_new_tokens,
lora_paths=lora_paths, lora_paths=lora_paths,
) )
self.ensure_reproducibility() ensure_reproducibility()
hf_outputs = hf_runner.forward( hf_outputs = hf_runner.forward(
prompts, prompts,
max_new_tokens=max_new_tokens, max_new_tokens=max_new_tokens,
......
...@@ -18,7 +18,7 @@ import random ...@@ -18,7 +18,7 @@ import random
import unittest import unittest
from typing import List from typing import List
from utils import TORCH_DTYPES, LoRAAdaptor, LoRAModelCase from utils import TORCH_DTYPES, LoRAAdaptor, LoRAModelCase, ensure_reproducibility
from sglang.test.runners import HFRunner, SRTRunner from sglang.test.runners import HFRunner, SRTRunner
from sglang.test.test_utils import CustomTestCase, calculate_rouge_l, is_in_ci from sglang.test.test_utils import CustomTestCase, calculate_rouge_l, is_in_ci
...@@ -59,19 +59,18 @@ TEST_MULTIPLE_BATCH_PROMPTS = [ ...@@ -59,19 +59,18 @@ TEST_MULTIPLE_BATCH_PROMPTS = [
The Transformers are large language models, The Transformers are large language models,
They're used to make predictions on text. They're used to make predictions on text.
""", """,
# "AI is a field of computer science focused on", TODO: Add it back after fixing its bug "AI is a field of computer science focused on",
"Computer science is the study of", "Computer science is the study of",
"Write a short story.", "Write a short story.",
"What are the main components of a computer?", "What are the main components of a computer?",
] ]
class TestLoRA(CustomTestCase): class TestLoRAQwen3(CustomTestCase):
def _run_lora_multiple_batch_on_model_cases(self, model_cases: List[LoRAModelCase]): def _run_lora_multiple_batch_on_model_cases(self, model_cases: List[LoRAModelCase]):
for model_case in model_cases: for model_case in model_cases:
for torch_dtype in TORCH_DTYPES: for torch_dtype in TORCH_DTYPES:
max_new_tokens = 10 max_new_tokens = 32
backend = "triton" backend = "triton"
base_path = model_case.base base_path = model_case.base
lora_adapter_paths = [a.name for a in model_case.adaptors] lora_adapter_paths = [a.name for a in model_case.adaptors]
...@@ -133,6 +132,7 @@ class TestLoRA(CustomTestCase): ...@@ -133,6 +132,7 @@ class TestLoRA(CustomTestCase):
) )
# Initialize runners # Initialize runners
ensure_reproducibility()
srt_runner = SRTRunner( srt_runner = SRTRunner(
base_path, base_path,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
...@@ -140,7 +140,11 @@ class TestLoRA(CustomTestCase): ...@@ -140,7 +140,11 @@ class TestLoRA(CustomTestCase):
lora_paths=[lora_adapter_paths[0], lora_adapter_paths[1]], lora_paths=[lora_adapter_paths[0], lora_adapter_paths[1]],
max_loras_per_batch=len(lora_adapter_paths) + 1, max_loras_per_batch=len(lora_adapter_paths) + 1,
lora_backend=backend, lora_backend=backend,
sleep_on_idle=True, # Eliminate non-determinism by forcing all requests to be processed in one batch.
attention_backend="torch_native",
) )
ensure_reproducibility()
hf_runner = HFRunner( hf_runner = HFRunner(
base_path, base_path,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# ============================================================================== # ==============================================================================
import dataclasses import dataclasses
import random
from typing import List from typing import List
import torch import torch
...@@ -386,3 +387,11 @@ def run_lora_test_by_batch( ...@@ -386,3 +387,11 @@ def run_lora_test_by_batch(
srt_no_lora_outputs.output_strs[i].strip(" "), srt_no_lora_outputs.output_strs[i].strip(" "),
hf_no_lora_outputs.output_strs[i].strip(" "), hf_no_lora_outputs.output_strs[i].strip(" "),
) )
def ensure_reproducibility():
seed = 42
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.use_deterministic_algorithms(True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment