"docs/zh_cn/vscode:/vscode.git/clone" did not exist on "4c87e777d855636b9eda7ec87bcbbf12b62caed3"
Unverified Commit ce832d70 authored by Lifu Huang's avatar Lifu Huang Committed by GitHub
Browse files

Add env var to control custom Triton kernel cache and set CSGMV as default backend. (#12176)

parent 88596739
...@@ -53,7 +53,7 @@ if __name__ == "__main__": ...@@ -53,7 +53,7 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--lora-backend", "--lora-backend",
type=str, type=str,
default="triton", default="csgmv",
) )
parser.add_argument( parser.add_argument(
"--tp-size", "--tp-size",
......
...@@ -180,6 +180,7 @@ class Envs: ...@@ -180,6 +180,7 @@ class Envs:
# Triton # Triton
SGLANG_TRITON_DECODE_ATTN_STATIC_KV_SPLITS = EnvBool(False) SGLANG_TRITON_DECODE_ATTN_STATIC_KV_SPLITS = EnvBool(False)
SGLANG_USE_CUSTOM_TRITON_KERNEL_CACHE = EnvBool(False)
# Torch Compile # Torch Compile
SGLANG_ENABLE_TORCH_COMPILE = EnvBool(False) SGLANG_ENABLE_TORCH_COMPILE = EnvBool(False)
......
...@@ -331,7 +331,7 @@ class ServerArgs: ...@@ -331,7 +331,7 @@ class ServerArgs:
max_loaded_loras: Optional[int] = None max_loaded_loras: Optional[int] = None
max_loras_per_batch: int = 8 max_loras_per_batch: int = 8
lora_eviction_policy: str = "lru" lora_eviction_policy: str = "lru"
lora_backend: str = "triton" lora_backend: str = "csgmv"
max_lora_chunk_size: Optional[int] = 16 max_lora_chunk_size: Optional[int] = 16
# Kernel backend # Kernel backend
......
...@@ -3571,7 +3571,17 @@ def cached_triton_kernel(key_fn=None): ...@@ -3571,7 +3571,17 @@ def cached_triton_kernel(key_fn=None):
""" """
def decorator(fn): def decorator(fn):
return CachedKernel(fn, key_fn) if envs.SGLANG_USE_CUSTOM_TRITON_KERNEL_CACHE.get():
logger.debug(
f"{envs.SGLANG_USE_CUSTOM_TRITON_KERNEL_CACHE.name} = True. Using custom triton kernel cache."
)
return CachedKernel(fn, key_fn)
else:
# Fallback to the native triton cache.
logger.debug(
f"{envs.SGLANG_USE_CUSTOM_TRITON_KERNEL_CACHE.name} = False. Using native triton kernel cache."
)
return fn
return decorator return decorator
......
...@@ -511,7 +511,7 @@ class SRTRunner: ...@@ -511,7 +511,7 @@ class SRTRunner:
attention_backend: Optional[str] = None, attention_backend: Optional[str] = None,
prefill_attention_backend: Optional[str] = None, prefill_attention_backend: Optional[str] = None,
decode_attention_backend: Optional[str] = None, decode_attention_backend: Optional[str] = None,
lora_backend: str = "triton", lora_backend: str = "csgmv",
disable_cuda_graph: bool = False, disable_cuda_graph: bool = False,
disable_radix_cache: bool = False, disable_radix_cache: bool = False,
chunked_prefill_size: Optional[int] = None, chunked_prefill_size: Optional[int] = None,
......
...@@ -81,13 +81,12 @@ class TestLoRA(CustomTestCase): ...@@ -81,13 +81,12 @@ class TestLoRA(CustomTestCase):
for model_case in model_cases: for model_case in model_cases:
for torch_dtype in TORCH_DTYPES: for torch_dtype in TORCH_DTYPES:
max_new_tokens = 32 max_new_tokens = 32
backend = "triton"
base_path = model_case.base base_path = model_case.base
lora_adapter_paths = [a.name for a in model_case.adaptors] lora_adapter_paths = [a.name for a in model_case.adaptors]
assert len(lora_adapter_paths) >= 2 assert len(lora_adapter_paths) >= 2
print( print(
f"\n========== Testing multiple batches on base '{base_path}' with backend={backend}, dtype={torch_dtype} ---" f"\n========== Testing multiple batches on base '{base_path}', dtype={torch_dtype} ---"
) )
# Initialize runners # Initialize runners
...@@ -97,7 +96,6 @@ class TestLoRA(CustomTestCase): ...@@ -97,7 +96,6 @@ class TestLoRA(CustomTestCase):
model_type="generation", model_type="generation",
lora_paths=[lora_adapter_paths[0], lora_adapter_paths[1]], lora_paths=[lora_adapter_paths[0], lora_adapter_paths[1]],
max_loras_per_batch=len(lora_adapter_paths) + 1, max_loras_per_batch=len(lora_adapter_paths) + 1,
lora_backend=backend,
sleep_on_idle=True, # Eliminate non-determinism by forcing all requests to be processed in one batch. sleep_on_idle=True, # Eliminate non-determinism by forcing all requests to be processed in one batch.
attention_backend="torch_native", attention_backend="torch_native",
) )
...@@ -142,7 +140,7 @@ class TestLoRA(CustomTestCase): ...@@ -142,7 +140,7 @@ class TestLoRA(CustomTestCase):
if rouge_score < rouge_tol: if rouge_score < rouge_tol:
raise AssertionError( raise AssertionError(
f"ROUGE-L score {rouge_score} below tolerance {rouge_tol} " f"ROUGE-L score {rouge_score} below tolerance {rouge_tol} "
f"for base '{base_path}', adaptor '{lora_paths}', backend '{backend}', prompt: '{prompts}...'" f"for base '{base_path}', adaptor '{lora_paths}', prompt: '{prompts}...'"
) )
print(f"--- Batch {i} Comparison Passed --- ") print(f"--- Batch {i} Comparison Passed --- ")
......
...@@ -62,7 +62,6 @@ class TestLoRACudaGraph(CustomTestCase): ...@@ -62,7 +62,6 @@ class TestLoRACudaGraph(CustomTestCase):
model_case, model_case,
torch_dtype, torch_dtype,
max_new_tokens=32, max_new_tokens=32,
backend="triton",
disable_cuda_graph=True, disable_cuda_graph=True,
test_tag="without_cuda_graph", test_tag="without_cuda_graph",
) )
...@@ -77,7 +76,6 @@ class TestLoRACudaGraph(CustomTestCase): ...@@ -77,7 +76,6 @@ class TestLoRACudaGraph(CustomTestCase):
model_case, model_case,
torch_dtype, torch_dtype,
max_new_tokens=32, max_new_tokens=32,
backend="triton",
disable_cuda_graph=False, disable_cuda_graph=False,
test_tag="cuda_graph_padding", test_tag="cuda_graph_padding",
) )
......
...@@ -83,7 +83,6 @@ class TestLoRAEviction(CustomTestCase): ...@@ -83,7 +83,6 @@ class TestLoRAEviction(CustomTestCase):
): ):
REUSED_LORA_NAME = "lora" REUSED_LORA_NAME = "lora"
max_new_tokens = 256 max_new_tokens = 256
backend = "triton"
torch_dtype = torch.float16 torch_dtype = torch.float16
base_path = BASE_MODEL base_path = BASE_MODEL
assert len(lora_paths) >= 2 assert len(lora_paths) >= 2
...@@ -96,7 +95,6 @@ class TestLoRAEviction(CustomTestCase): ...@@ -96,7 +95,6 @@ class TestLoRAEviction(CustomTestCase):
model_type="generation", model_type="generation",
lora_paths=initial_lora_paths, lora_paths=initial_lora_paths,
max_loras_per_batch=1, max_loras_per_batch=1,
lora_backend=backend,
enable_lora=True, enable_lora=True,
max_lora_rank=256, max_lora_rank=256,
lora_target_modules=["all"], lora_target_modules=["all"],
......
...@@ -71,7 +71,6 @@ class TestLoRAQwen3(CustomTestCase): ...@@ -71,7 +71,6 @@ class TestLoRAQwen3(CustomTestCase):
for model_case in model_cases: for model_case in model_cases:
for torch_dtype in TORCH_DTYPES: for torch_dtype in TORCH_DTYPES:
max_new_tokens = 32 max_new_tokens = 32
backend = "triton"
base_path = model_case.base base_path = model_case.base
lora_adapter_paths = [a.name for a in model_case.adaptors] lora_adapter_paths = [a.name for a in model_case.adaptors]
assert len(lora_adapter_paths) >= 2 assert len(lora_adapter_paths) >= 2
...@@ -128,7 +127,7 @@ class TestLoRAQwen3(CustomTestCase): ...@@ -128,7 +127,7 @@ class TestLoRAQwen3(CustomTestCase):
] ]
print( print(
f"\n========== Testing multiple batches on base '{base_path}' with backend={backend}, dtype={torch_dtype} ---" f"\n========== Testing multiple batches on base '{base_path}', dtype={torch_dtype} ---"
) )
# Initialize runners # Initialize runners
...@@ -139,7 +138,6 @@ class TestLoRAQwen3(CustomTestCase): ...@@ -139,7 +138,6 @@ class TestLoRAQwen3(CustomTestCase):
model_type="generation", model_type="generation",
lora_paths=[lora_adapter_paths[0], lora_adapter_paths[1]], lora_paths=[lora_adapter_paths[0], lora_adapter_paths[1]],
max_loras_per_batch=len(lora_adapter_paths) + 1, max_loras_per_batch=len(lora_adapter_paths) + 1,
lora_backend=backend,
sleep_on_idle=True, # Eliminate non-determinism by forcing all requests to be processed in one batch. sleep_on_idle=True, # Eliminate non-determinism by forcing all requests to be processed in one batch.
attention_backend="torch_native", attention_backend="torch_native",
) )
...@@ -183,7 +181,7 @@ class TestLoRAQwen3(CustomTestCase): ...@@ -183,7 +181,7 @@ class TestLoRAQwen3(CustomTestCase):
if rouge_score < rouge_tol: if rouge_score < rouge_tol:
raise AssertionError( raise AssertionError(
f"ROUGE-L score {rouge_score} below tolerance {rouge_tol} " f"ROUGE-L score {rouge_score} below tolerance {rouge_tol} "
f"for base '{base_path}', adaptor '{lora_paths}', backend '{backend}', prompt: '{prompts}...'" f"for base '{base_path}', adaptor '{lora_paths}', prompt: '{prompts}...'"
) )
print(f"--- Batch {i+1} Comparison Passed --- ") print(f"--- Batch {i+1} Comparison Passed --- ")
......
...@@ -44,7 +44,6 @@ class TestLoRARadixCache(CustomTestCase): ...@@ -44,7 +44,6 @@ class TestLoRARadixCache(CustomTestCase):
torch_dtype = torch.float16 torch_dtype = torch.float16
max_new_tokens = 32 max_new_tokens = 32
backend = "triton"
batch_prompts = ( batch_prompts = (
PROMPTS PROMPTS
if not model_case.skip_long_prompt if not model_case.skip_long_prompt
...@@ -57,7 +56,6 @@ class TestLoRARadixCache(CustomTestCase): ...@@ -57,7 +56,6 @@ class TestLoRARadixCache(CustomTestCase):
model_case, model_case,
torch_dtype, torch_dtype,
max_new_tokens=max_new_tokens, max_new_tokens=max_new_tokens,
backend=backend,
disable_radix_cache=False, disable_radix_cache=False,
test_tag="lora-with-radix-cache", test_tag="lora-with-radix-cache",
) )
...@@ -68,7 +66,6 @@ class TestLoRARadixCache(CustomTestCase): ...@@ -68,7 +66,6 @@ class TestLoRARadixCache(CustomTestCase):
model_case, model_case,
torch_dtype, torch_dtype,
max_new_tokens=max_new_tokens, max_new_tokens=max_new_tokens,
backend=backend,
disable_radix_cache=True, disable_radix_cache=True,
test_tag="lora-without-radix-cache", test_tag="lora-without-radix-cache",
) )
......
...@@ -48,7 +48,6 @@ class TestLoRATP(CustomTestCase): ...@@ -48,7 +48,6 @@ class TestLoRATP(CustomTestCase):
model_case, model_case,
torch_dtype, torch_dtype,
max_new_tokens=32, max_new_tokens=32,
backend="triton",
test_tag=f"tp={tp_size}", test_tag=f"tp={tp_size}",
) )
......
...@@ -764,7 +764,7 @@ class LoRAUpdateTestSessionBase: ...@@ -764,7 +764,7 @@ class LoRAUpdateTestSessionBase:
max_lora_rank: Optional[int], max_lora_rank: Optional[int],
enable_lora: Optional[bool] = None, enable_lora: Optional[bool] = None,
lora_target_modules: Optional[List[str]] = None, lora_target_modules: Optional[List[str]] = None,
lora_backend: str = "triton", lora_backend: str = "csgmv",
disable_cuda_graph: bool = False, disable_cuda_graph: bool = False,
cuda_graph_max_bs: int = 4, cuda_graph_max_bs: int = 4,
): ):
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import dataclasses import dataclasses
import random import random
from typing import List from typing import List, Optional
import torch import torch
...@@ -50,7 +50,7 @@ class LoRAModelCase: ...@@ -50,7 +50,7 @@ class LoRAModelCase:
TORCH_DTYPES = [torch.float16] TORCH_DTYPES = [torch.float16]
BACKENDS = ["triton"] BACKENDS = ["triton", "csgmv"]
DEFAULT_PROMPTS = [ DEFAULT_PROMPTS = [
"AI is a field of computer science focused on", "AI is a field of computer science focused on",
""" """
...@@ -135,7 +135,7 @@ def run_lora_test_one_by_one( ...@@ -135,7 +135,7 @@ def run_lora_test_one_by_one(
model_case: LoRAModelCase, model_case: LoRAModelCase,
torch_dtype: torch.dtype, torch_dtype: torch.dtype,
max_new_tokens: int, max_new_tokens: int,
backend: str, backend: str = "csgmv",
disable_cuda_graph: bool = False, disable_cuda_graph: bool = False,
disable_radix_cache: bool = False, disable_radix_cache: bool = False,
mem_fraction_static: float = 0.88, mem_fraction_static: float = 0.88,
...@@ -283,7 +283,7 @@ def run_lora_test_by_batch( ...@@ -283,7 +283,7 @@ def run_lora_test_by_batch(
model_case: LoRAModelCase, model_case: LoRAModelCase,
torch_dtype: torch.dtype, torch_dtype: torch.dtype,
max_new_tokens: int, max_new_tokens: int,
backend: str, backend: str = "csgmv",
disable_cuda_graph: bool = False, disable_cuda_graph: bool = False,
disable_radix_cache: bool = False, disable_radix_cache: bool = False,
mem_fraction_static: float = 0.88, mem_fraction_static: float = 0.88,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment