Unverified Commit 2373faa3 authored by Lifu Huang's avatar Lifu Huang Committed by GitHub
Browse files

Fix flakiness in LoRA batch test. (#7552)

parent 9efb2993
...@@ -503,6 +503,7 @@ class SRTRunner: ...@@ -503,6 +503,7 @@ class SRTRunner:
disable_overlap_schedule: bool = False, disable_overlap_schedule: bool = False,
disable_custom_all_reduce: bool = False, disable_custom_all_reduce: bool = False,
torchao_config: Optional[str] = None, torchao_config: Optional[str] = None,
sleep_on_idle=False,
): ):
self.model_type = model_type self.model_type = model_type
self.is_generation = model_type == "generation" self.is_generation = model_type == "generation"
...@@ -540,6 +541,7 @@ class SRTRunner: ...@@ -540,6 +541,7 @@ class SRTRunner:
disable_overlap_schedule=disable_overlap_schedule, disable_overlap_schedule=disable_overlap_schedule,
cuda_graph_max_bs=4, cuda_graph_max_bs=4,
disable_custom_all_reduce=disable_custom_all_reduce, disable_custom_all_reduce=disable_custom_all_reduce,
sleep_on_idle=sleep_on_idle,
**spec_kwargs, **spec_kwargs,
) )
......
...@@ -18,6 +18,7 @@ import random ...@@ -18,6 +18,7 @@ import random
import unittest import unittest
from typing import List from typing import List
import torch
from utils import ( from utils import (
ALL_OTHER_MULTI_LORA_MODELS, ALL_OTHER_MULTI_LORA_MODELS,
CI_MULTI_LORA_MODELS, CI_MULTI_LORA_MODELS,
...@@ -46,7 +47,7 @@ TEST_MULTIPLE_BATCH_PROMPTS = [ ...@@ -46,7 +47,7 @@ TEST_MULTIPLE_BATCH_PROMPTS = [
The Transformers are large language models, The Transformers are large language models,
They're used to make predictions on text. They're used to make predictions on text.
""", """,
# "AI is a field of computer science focused on", TODO: Add it back after fixing its bug "AI is a field of computer science focused on",
"Computer science is the study of", "Computer science is the study of",
"Write a short story.", "Write a short story.",
"What are the main components of a computer?", "What are the main components of a computer?",
...@@ -54,8 +55,36 @@ TEST_MULTIPLE_BATCH_PROMPTS = [ ...@@ -54,8 +55,36 @@ TEST_MULTIPLE_BATCH_PROMPTS = [
class TestLoRA(CustomTestCase): class TestLoRA(CustomTestCase):
def _create_test_samples(
self, lora_adapter_paths: List[str], repeated_trials: int = 3
):
random.seed(42) # Ensure reproducibility
patterns = [
[None, lora_adapter_paths[0], lora_adapter_paths[1]],
[lora_adapter_paths[0], None, lora_adapter_paths[1]],
[lora_adapter_paths[0], lora_adapter_paths[1], None],
[None, lora_adapter_paths[1], None],
[None, None, None],
]
batches = [
[random.choice(pattern) for _ in range(3)]
for pattern in patterns
for _ in range(repeated_trials)
]
return batches
def ensure_reproducibility(self):
seed = 42
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.use_deterministic_algorithms(True)
def _run_lora_multiple_batch_on_model_cases(self, model_cases: List[LoRAModelCase]): def _run_lora_multiple_batch_on_model_cases(self, model_cases: List[LoRAModelCase]):
for model_case in model_cases: for model_case in model_cases:
for torch_dtype in TORCH_DTYPES: for torch_dtype in TORCH_DTYPES:
max_new_tokens = 32 max_new_tokens = 32
...@@ -64,57 +93,6 @@ class TestLoRA(CustomTestCase): ...@@ -64,57 +93,6 @@ class TestLoRA(CustomTestCase):
lora_adapter_paths = [a.name for a in model_case.adaptors] lora_adapter_paths = [a.name for a in model_case.adaptors]
assert len(lora_adapter_paths) >= 2 assert len(lora_adapter_paths) >= 2
batches = [
(
[
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
],
[
None,
lora_adapter_paths[0],
lora_adapter_paths[1],
],
),
(
[
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
],
[
lora_adapter_paths[0],
None,
lora_adapter_paths[1],
],
),
(
[
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
],
[lora_adapter_paths[0], lora_adapter_paths[1], None],
),
(
[
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
],
[None, lora_adapter_paths[1], None],
),
(
[
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
],
[None, None, None],
),
]
print( print(
f"\n========== Testing multiple batches on base '{base_path}' with backend={backend}, dtype={torch_dtype} ---" f"\n========== Testing multiple batches on base '{base_path}' with backend={backend}, dtype={torch_dtype} ---"
) )
...@@ -128,23 +106,31 @@ class TestLoRA(CustomTestCase): ...@@ -128,23 +106,31 @@ class TestLoRA(CustomTestCase):
max_loras_per_batch=len(lora_adapter_paths) + 1, max_loras_per_batch=len(lora_adapter_paths) + 1,
lora_backend=backend, lora_backend=backend,
disable_radix_cache=True, disable_radix_cache=True,
sleep_on_idle=True, # Eliminate non-determinism by forcing all requests to be processed in one batch.
attention_backend="torch_native",
) )
hf_runner = HFRunner( hf_runner = HFRunner(
base_path, torch_dtype=torch_dtype, model_type="generation" base_path, torch_dtype=torch_dtype, model_type="generation"
) )
batches = self._create_test_samples(lora_adapter_paths)
with srt_runner, hf_runner: with srt_runner, hf_runner:
for i, (prompts, lora_paths) in enumerate(batches): for i, lora_paths in enumerate(batches, start=1):
prompts = [
random.choice(TEST_MULTIPLE_BATCH_PROMPTS) for _ in range(3)
]
print( print(
f"\n--- Running Batch {i+1} --- prompts: {prompts}, lora_paths: {lora_paths}" f"\n--- Running Batch {i} --- prompts: {prompts}, lora_paths: {lora_paths}"
) )
self.ensure_reproducibility()
srt_outputs = srt_runner.batch_forward( srt_outputs = srt_runner.batch_forward(
prompts, prompts,
max_new_tokens=max_new_tokens, max_new_tokens=max_new_tokens,
lora_paths=lora_paths, lora_paths=lora_paths,
) )
self.ensure_reproducibility()
hf_outputs = hf_runner.forward( hf_outputs = hf_runner.forward(
prompts, prompts,
max_new_tokens=max_new_tokens, max_new_tokens=max_new_tokens,
...@@ -167,7 +153,7 @@ class TestLoRA(CustomTestCase): ...@@ -167,7 +153,7 @@ class TestLoRA(CustomTestCase):
f"for base '{base_path}', adaptor '{lora_paths}', backend '{backend}', prompt: '{prompts}...'" f"for base '{base_path}', adaptor '{lora_paths}', backend '{backend}', prompt: '{prompts}...'"
) )
print(f"--- Batch {i+1} Comparison Passed --- ") print(f"--- Batch {i} Comparison Passed --- ")
def test_ci_lora_models(self): def test_ci_lora_models(self):
self._run_lora_multiple_batch_on_model_cases(CI_MULTI_LORA_MODELS) self._run_lora_multiple_batch_on_model_cases(CI_MULTI_LORA_MODELS)
......
...@@ -13,7 +13,7 @@ class TestFile: ...@@ -13,7 +13,7 @@ class TestFile:
suites = { suites = {
"per-commit": [ "per-commit": [
TestFile("models/lora/test_lora.py", 76), TestFile("models/lora/test_lora.py", 200),
TestFile("models/lora/test_lora_backend.py", 99), TestFile("models/lora/test_lora_backend.py", 99),
TestFile("models/lora/test_multi_lora_backend.py", 60), TestFile("models/lora/test_multi_lora_backend.py", 60),
TestFile("models/lora/test_lora_cuda_graph.py", 250), TestFile("models/lora/test_lora_cuda_graph.py", 250),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment