"test/vscode:/vscode.git/clone" did not exist on "aeaccf69e08376c88ebe684eaad8181cea732494"
Unverified Commit ce832d70 authored by Lifu Huang's avatar Lifu Huang Committed by GitHub
Browse files

Add env var to control custom Triton kernel cache and set CSGMV as default backend. (#12176)

parent 88596739
......@@ -53,7 +53,7 @@ if __name__ == "__main__":
parser.add_argument(
"--lora-backend",
type=str,
default="triton",
default="csgmv",
)
parser.add_argument(
"--tp-size",
......
......@@ -180,6 +180,7 @@ class Envs:
# Triton
SGLANG_TRITON_DECODE_ATTN_STATIC_KV_SPLITS = EnvBool(False)
SGLANG_USE_CUSTOM_TRITON_KERNEL_CACHE = EnvBool(False)
# Torch Compile
SGLANG_ENABLE_TORCH_COMPILE = EnvBool(False)
......
......@@ -331,7 +331,7 @@ class ServerArgs:
max_loaded_loras: Optional[int] = None
max_loras_per_batch: int = 8
lora_eviction_policy: str = "lru"
lora_backend: str = "triton"
lora_backend: str = "csgmv"
max_lora_chunk_size: Optional[int] = 16
# Kernel backend
......
......@@ -3571,7 +3571,17 @@ def cached_triton_kernel(key_fn=None):
"""
def decorator(fn):
return CachedKernel(fn, key_fn)
if envs.SGLANG_USE_CUSTOM_TRITON_KERNEL_CACHE.get():
logger.debug(
f"{envs.SGLANG_USE_CUSTOM_TRITON_KERNEL_CACHE.name} = True. Using custom triton kernel cache."
)
return CachedKernel(fn, key_fn)
else:
# Fallback to the native triton cache.
logger.debug(
f"{envs.SGLANG_USE_CUSTOM_TRITON_KERNEL_CACHE.name} = False. Using native triton kernel cache."
)
return fn
return decorator
......
......@@ -511,7 +511,7 @@ class SRTRunner:
attention_backend: Optional[str] = None,
prefill_attention_backend: Optional[str] = None,
decode_attention_backend: Optional[str] = None,
lora_backend: str = "triton",
lora_backend: str = "csgmv",
disable_cuda_graph: bool = False,
disable_radix_cache: bool = False,
chunked_prefill_size: Optional[int] = None,
......
......@@ -81,13 +81,12 @@ class TestLoRA(CustomTestCase):
for model_case in model_cases:
for torch_dtype in TORCH_DTYPES:
max_new_tokens = 32
backend = "triton"
base_path = model_case.base
lora_adapter_paths = [a.name for a in model_case.adaptors]
assert len(lora_adapter_paths) >= 2
print(
f"\n========== Testing multiple batches on base '{base_path}' with backend={backend}, dtype={torch_dtype} ---"
f"\n========== Testing multiple batches on base '{base_path}', dtype={torch_dtype} ---"
)
# Initialize runners
......@@ -97,7 +96,6 @@ class TestLoRA(CustomTestCase):
model_type="generation",
lora_paths=[lora_adapter_paths[0], lora_adapter_paths[1]],
max_loras_per_batch=len(lora_adapter_paths) + 1,
lora_backend=backend,
sleep_on_idle=True, # Eliminate non-determinism by forcing all requests to be processed in one batch.
attention_backend="torch_native",
)
......@@ -142,7 +140,7 @@ class TestLoRA(CustomTestCase):
if rouge_score < rouge_tol:
raise AssertionError(
f"ROUGE-L score {rouge_score} below tolerance {rouge_tol} "
f"for base '{base_path}', adaptor '{lora_paths}', backend '{backend}', prompt: '{prompts}...'"
f"for base '{base_path}', adaptor '{lora_paths}', prompt: '{prompts}...'"
)
print(f"--- Batch {i} Comparison Passed --- ")
......
......@@ -62,7 +62,6 @@ class TestLoRACudaGraph(CustomTestCase):
model_case,
torch_dtype,
max_new_tokens=32,
backend="triton",
disable_cuda_graph=True,
test_tag="without_cuda_graph",
)
......@@ -77,7 +76,6 @@ class TestLoRACudaGraph(CustomTestCase):
model_case,
torch_dtype,
max_new_tokens=32,
backend="triton",
disable_cuda_graph=False,
test_tag="cuda_graph_padding",
)
......
......@@ -83,7 +83,6 @@ class TestLoRAEviction(CustomTestCase):
):
REUSED_LORA_NAME = "lora"
max_new_tokens = 256
backend = "triton"
torch_dtype = torch.float16
base_path = BASE_MODEL
assert len(lora_paths) >= 2
......@@ -96,7 +95,6 @@ class TestLoRAEviction(CustomTestCase):
model_type="generation",
lora_paths=initial_lora_paths,
max_loras_per_batch=1,
lora_backend=backend,
enable_lora=True,
max_lora_rank=256,
lora_target_modules=["all"],
......
......@@ -71,7 +71,6 @@ class TestLoRAQwen3(CustomTestCase):
for model_case in model_cases:
for torch_dtype in TORCH_DTYPES:
max_new_tokens = 32
backend = "triton"
base_path = model_case.base
lora_adapter_paths = [a.name for a in model_case.adaptors]
assert len(lora_adapter_paths) >= 2
......@@ -128,7 +127,7 @@ class TestLoRAQwen3(CustomTestCase):
]
print(
f"\n========== Testing multiple batches on base '{base_path}' with backend={backend}, dtype={torch_dtype} ---"
f"\n========== Testing multiple batches on base '{base_path}', dtype={torch_dtype} ---"
)
# Initialize runners
......@@ -139,7 +138,6 @@ class TestLoRAQwen3(CustomTestCase):
model_type="generation",
lora_paths=[lora_adapter_paths[0], lora_adapter_paths[1]],
max_loras_per_batch=len(lora_adapter_paths) + 1,
lora_backend=backend,
sleep_on_idle=True, # Eliminate non-determinism by forcing all requests to be processed in one batch.
attention_backend="torch_native",
)
......@@ -183,7 +181,7 @@ class TestLoRAQwen3(CustomTestCase):
if rouge_score < rouge_tol:
raise AssertionError(
f"ROUGE-L score {rouge_score} below tolerance {rouge_tol} "
f"for base '{base_path}', adaptor '{lora_paths}', backend '{backend}', prompt: '{prompts}...'"
f"for base '{base_path}', adaptor '{lora_paths}', prompt: '{prompts}...'"
)
print(f"--- Batch {i+1} Comparison Passed --- ")
......
......@@ -44,7 +44,6 @@ class TestLoRARadixCache(CustomTestCase):
torch_dtype = torch.float16
max_new_tokens = 32
backend = "triton"
batch_prompts = (
PROMPTS
if not model_case.skip_long_prompt
......@@ -57,7 +56,6 @@ class TestLoRARadixCache(CustomTestCase):
model_case,
torch_dtype,
max_new_tokens=max_new_tokens,
backend=backend,
disable_radix_cache=False,
test_tag="lora-with-radix-cache",
)
......@@ -68,7 +66,6 @@ class TestLoRARadixCache(CustomTestCase):
model_case,
torch_dtype,
max_new_tokens=max_new_tokens,
backend=backend,
disable_radix_cache=True,
test_tag="lora-without-radix-cache",
)
......
......@@ -48,7 +48,6 @@ class TestLoRATP(CustomTestCase):
model_case,
torch_dtype,
max_new_tokens=32,
backend="triton",
test_tag=f"tp={tp_size}",
)
......
......@@ -764,7 +764,7 @@ class LoRAUpdateTestSessionBase:
max_lora_rank: Optional[int],
enable_lora: Optional[bool] = None,
lora_target_modules: Optional[List[str]] = None,
lora_backend: str = "triton",
lora_backend: str = "csgmv",
disable_cuda_graph: bool = False,
cuda_graph_max_bs: int = 4,
):
......
......@@ -14,7 +14,7 @@
import dataclasses
import random
from typing import List
from typing import List, Optional
import torch
......@@ -50,7 +50,7 @@ class LoRAModelCase:
TORCH_DTYPES = [torch.float16]
BACKENDS = ["triton"]
BACKENDS = ["triton", "csgmv"]
DEFAULT_PROMPTS = [
"AI is a field of computer science focused on",
"""
......@@ -135,7 +135,7 @@ def run_lora_test_one_by_one(
model_case: LoRAModelCase,
torch_dtype: torch.dtype,
max_new_tokens: int,
backend: str,
backend: str = "csgmv",
disable_cuda_graph: bool = False,
disable_radix_cache: bool = False,
mem_fraction_static: float = 0.88,
......@@ -283,7 +283,7 @@ def run_lora_test_by_batch(
model_case: LoRAModelCase,
torch_dtype: torch.dtype,
max_new_tokens: int,
backend: str,
backend: str = "csgmv",
disable_cuda_graph: bool = False,
disable_radix_cache: bool = False,
mem_fraction_static: float = 0.88,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment