Unverified Commit b0d20cde authored by Lifu Huang's avatar Lifu Huang Committed by GitHub
Browse files

Set csgmv as default lora backend. (#11488)

parent cbac4997
...@@ -53,7 +53,7 @@ if __name__ == "__main__": ...@@ -53,7 +53,7 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--lora-backend", "--lora-backend",
type=str, type=str,
default="triton", default="csgmv",
) )
parser.add_argument( parser.add_argument(
"--tp-size", "--tp-size",
......
...@@ -309,8 +309,8 @@ class ServerArgs: ...@@ -309,8 +309,8 @@ class ServerArgs:
] = None ] = None
max_loaded_loras: Optional[int] = None max_loaded_loras: Optional[int] = None
max_loras_per_batch: int = 8 max_loras_per_batch: int = 8
lora_backend: str = "csgmv"
lora_eviction_policy: str = DEFAULT_LORA_EVICTION_POLICY lora_eviction_policy: str = DEFAULT_LORA_EVICTION_POLICY
lora_backend: str = "triton"
max_lora_chunk_size: Optional[int] = 16 max_lora_chunk_size: Optional[int] = 16
# Kernel backend # Kernel backend
......
...@@ -496,7 +496,7 @@ class SRTRunner: ...@@ -496,7 +496,7 @@ class SRTRunner:
attention_backend: Optional[str] = None, attention_backend: Optional[str] = None,
prefill_attention_backend: Optional[str] = None, prefill_attention_backend: Optional[str] = None,
decode_attention_backend: Optional[str] = None, decode_attention_backend: Optional[str] = None,
lora_backend: str = "triton", lora_backend: str = "csgmv",
disable_cuda_graph: bool = False, disable_cuda_graph: bool = False,
disable_radix_cache: bool = False, disable_radix_cache: bool = False,
chunked_prefill_size: Optional[int] = None, chunked_prefill_size: Optional[int] = None,
......
...@@ -81,13 +81,12 @@ class TestLoRA(CustomTestCase): ...@@ -81,13 +81,12 @@ class TestLoRA(CustomTestCase):
for model_case in model_cases: for model_case in model_cases:
for torch_dtype in TORCH_DTYPES: for torch_dtype in TORCH_DTYPES:
max_new_tokens = 32 max_new_tokens = 32
backend = "triton"
base_path = model_case.base base_path = model_case.base
lora_adapter_paths = [a.name for a in model_case.adaptors] lora_adapter_paths = [a.name for a in model_case.adaptors]
assert len(lora_adapter_paths) >= 2 assert len(lora_adapter_paths) >= 2
print( print(
f"\n========== Testing multiple batches on base '{base_path}' with backend={backend}, dtype={torch_dtype} ---" f"\n========== Testing multiple batches on base '{base_path}', dtype={torch_dtype} ---"
) )
# Initialize runners # Initialize runners
...@@ -97,7 +96,6 @@ class TestLoRA(CustomTestCase): ...@@ -97,7 +96,6 @@ class TestLoRA(CustomTestCase):
model_type="generation", model_type="generation",
lora_paths=[lora_adapter_paths[0], lora_adapter_paths[1]], lora_paths=[lora_adapter_paths[0], lora_adapter_paths[1]],
max_loras_per_batch=len(lora_adapter_paths) + 1, max_loras_per_batch=len(lora_adapter_paths) + 1,
lora_backend=backend,
sleep_on_idle=True, # Eliminate non-determinism by forcing all requests to be processed in one batch. sleep_on_idle=True, # Eliminate non-determinism by forcing all requests to be processed in one batch.
attention_backend="torch_native", attention_backend="torch_native",
) )
...@@ -142,7 +140,7 @@ class TestLoRA(CustomTestCase): ...@@ -142,7 +140,7 @@ class TestLoRA(CustomTestCase):
if rouge_score < rouge_tol: if rouge_score < rouge_tol:
raise AssertionError( raise AssertionError(
f"ROUGE-L score {rouge_score} below tolerance {rouge_tol} " f"ROUGE-L score {rouge_score} below tolerance {rouge_tol} "
f"for base '{base_path}', adaptor '{lora_paths}', backend '{backend}', prompt: '{prompts}...'" f"for base '{base_path}', adaptor '{lora_paths}', prompt: '{prompts}...'"
) )
print(f"--- Batch {i} Comparison Passed --- ") print(f"--- Batch {i} Comparison Passed --- ")
......
...@@ -62,7 +62,6 @@ class TestLoRACudaGraph(CustomTestCase): ...@@ -62,7 +62,6 @@ class TestLoRACudaGraph(CustomTestCase):
model_case, model_case,
torch_dtype, torch_dtype,
max_new_tokens=32, max_new_tokens=32,
backend="triton",
disable_cuda_graph=True, disable_cuda_graph=True,
test_tag="without_cuda_graph", test_tag="without_cuda_graph",
) )
...@@ -77,7 +76,6 @@ class TestLoRACudaGraph(CustomTestCase): ...@@ -77,7 +76,6 @@ class TestLoRACudaGraph(CustomTestCase):
model_case, model_case,
torch_dtype, torch_dtype,
max_new_tokens=32, max_new_tokens=32,
backend="triton",
disable_cuda_graph=False, disable_cuda_graph=False,
test_tag="cuda_graph_padding", test_tag="cuda_graph_padding",
) )
......
...@@ -83,7 +83,6 @@ class TestLoRAEviction(CustomTestCase): ...@@ -83,7 +83,6 @@ class TestLoRAEviction(CustomTestCase):
): ):
REUSED_LORA_NAME = "lora" REUSED_LORA_NAME = "lora"
max_new_tokens = 256 max_new_tokens = 256
backend = "triton"
torch_dtype = torch.float16 torch_dtype = torch.float16
base_path = BASE_MODEL base_path = BASE_MODEL
assert len(lora_paths) >= 2 assert len(lora_paths) >= 2
...@@ -96,7 +95,6 @@ class TestLoRAEviction(CustomTestCase): ...@@ -96,7 +95,6 @@ class TestLoRAEviction(CustomTestCase):
model_type="generation", model_type="generation",
lora_paths=initial_lora_paths, lora_paths=initial_lora_paths,
max_loras_per_batch=1, max_loras_per_batch=1,
lora_backend=backend,
enable_lora=True, enable_lora=True,
max_lora_rank=256, max_lora_rank=256,
lora_target_modules=["all"], lora_target_modules=["all"],
......
...@@ -71,7 +71,6 @@ class TestLoRAQwen3(CustomTestCase): ...@@ -71,7 +71,6 @@ class TestLoRAQwen3(CustomTestCase):
for model_case in model_cases: for model_case in model_cases:
for torch_dtype in TORCH_DTYPES: for torch_dtype in TORCH_DTYPES:
max_new_tokens = 32 max_new_tokens = 32
backend = "triton"
base_path = model_case.base base_path = model_case.base
lora_adapter_paths = [a.name for a in model_case.adaptors] lora_adapter_paths = [a.name for a in model_case.adaptors]
assert len(lora_adapter_paths) >= 2 assert len(lora_adapter_paths) >= 2
...@@ -128,7 +127,7 @@ class TestLoRAQwen3(CustomTestCase): ...@@ -128,7 +127,7 @@ class TestLoRAQwen3(CustomTestCase):
] ]
print( print(
f"\n========== Testing multiple batches on base '{base_path}' with backend={backend}, dtype={torch_dtype} ---" f"\n========== Testing multiple batches on base '{base_path}', dtype={torch_dtype} ---"
) )
# Initialize runners # Initialize runners
...@@ -139,7 +138,6 @@ class TestLoRAQwen3(CustomTestCase): ...@@ -139,7 +138,6 @@ class TestLoRAQwen3(CustomTestCase):
model_type="generation", model_type="generation",
lora_paths=[lora_adapter_paths[0], lora_adapter_paths[1]], lora_paths=[lora_adapter_paths[0], lora_adapter_paths[1]],
max_loras_per_batch=len(lora_adapter_paths) + 1, max_loras_per_batch=len(lora_adapter_paths) + 1,
lora_backend=backend,
sleep_on_idle=True, # Eliminate non-determinism by forcing all requests to be processed in one batch. sleep_on_idle=True, # Eliminate non-determinism by forcing all requests to be processed in one batch.
attention_backend="torch_native", attention_backend="torch_native",
) )
...@@ -183,7 +181,7 @@ class TestLoRAQwen3(CustomTestCase): ...@@ -183,7 +181,7 @@ class TestLoRAQwen3(CustomTestCase):
if rouge_score < rouge_tol: if rouge_score < rouge_tol:
raise AssertionError( raise AssertionError(
f"ROUGE-L score {rouge_score} below tolerance {rouge_tol} " f"ROUGE-L score {rouge_score} below tolerance {rouge_tol} "
f"for base '{base_path}', adaptor '{lora_paths}', backend '{backend}', prompt: '{prompts}...'" f"for base '{base_path}', adaptor '{lora_paths}', prompt: '{prompts}...'"
) )
print(f"--- Batch {i+1} Comparison Passed --- ") print(f"--- Batch {i+1} Comparison Passed --- ")
......
...@@ -44,7 +44,6 @@ class TestLoRARadixCache(CustomTestCase): ...@@ -44,7 +44,6 @@ class TestLoRARadixCache(CustomTestCase):
torch_dtype = torch.float16 torch_dtype = torch.float16
max_new_tokens = 32 max_new_tokens = 32
backend = "triton"
batch_prompts = ( batch_prompts = (
PROMPTS PROMPTS
if not model_case.skip_long_prompt if not model_case.skip_long_prompt
...@@ -57,7 +56,6 @@ class TestLoRARadixCache(CustomTestCase): ...@@ -57,7 +56,6 @@ class TestLoRARadixCache(CustomTestCase):
model_case, model_case,
torch_dtype, torch_dtype,
max_new_tokens=max_new_tokens, max_new_tokens=max_new_tokens,
backend=backend,
disable_radix_cache=False, disable_radix_cache=False,
test_tag="lora-with-radix-cache", test_tag="lora-with-radix-cache",
) )
...@@ -68,7 +66,6 @@ class TestLoRARadixCache(CustomTestCase): ...@@ -68,7 +66,6 @@ class TestLoRARadixCache(CustomTestCase):
model_case, model_case,
torch_dtype, torch_dtype,
max_new_tokens=max_new_tokens, max_new_tokens=max_new_tokens,
backend=backend,
disable_radix_cache=True, disable_radix_cache=True,
test_tag="lora-without-radix-cache", test_tag="lora-without-radix-cache",
) )
......
...@@ -48,7 +48,6 @@ class TestLoRATP(CustomTestCase): ...@@ -48,7 +48,6 @@ class TestLoRATP(CustomTestCase):
model_case, model_case,
torch_dtype, torch_dtype,
max_new_tokens=32, max_new_tokens=32,
backend="triton",
test_tag=f"tp={tp_size}", test_tag=f"tp={tp_size}",
) )
......
...@@ -763,7 +763,7 @@ class LoRAUpdateTestSessionBase: ...@@ -763,7 +763,7 @@ class LoRAUpdateTestSessionBase:
max_lora_rank: Optional[int], max_lora_rank: Optional[int],
enable_lora: Optional[bool] = None, enable_lora: Optional[bool] = None,
lora_target_modules: Optional[List[str]] = None, lora_target_modules: Optional[List[str]] = None,
lora_backend: str = "triton", lora_backend: str = "csgmv",
disable_cuda_graph: bool = False, disable_cuda_graph: bool = False,
cuda_graph_max_bs: int = 4, cuda_graph_max_bs: int = 4,
): ):
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import dataclasses import dataclasses
import random import random
from typing import List from typing import List, Optional
import torch import torch
...@@ -50,7 +50,7 @@ class LoRAModelCase: ...@@ -50,7 +50,7 @@ class LoRAModelCase:
TORCH_DTYPES = [torch.float16] TORCH_DTYPES = [torch.float16]
BACKENDS = ["triton"] BACKENDS = ["triton", "csgmv"]
DEFAULT_PROMPTS = [ DEFAULT_PROMPTS = [
"AI is a field of computer science focused on", "AI is a field of computer science focused on",
""" """
...@@ -135,7 +135,7 @@ def run_lora_test_one_by_one( ...@@ -135,7 +135,7 @@ def run_lora_test_one_by_one(
model_case: LoRAModelCase, model_case: LoRAModelCase,
torch_dtype: torch.dtype, torch_dtype: torch.dtype,
max_new_tokens: int, max_new_tokens: int,
backend: str, backend: str = "csgmv",
disable_cuda_graph: bool = False, disable_cuda_graph: bool = False,
disable_radix_cache: bool = False, disable_radix_cache: bool = False,
mem_fraction_static: float = 0.88, mem_fraction_static: float = 0.88,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment