"src/array/vscode:/vscode.git/clone" did not exist on "56962858ccb8a965e827da602ae21183be19edab"
Unverified Commit f4488e9d authored by Minglei Zhu's avatar Minglei Zhu Committed by GitHub
Browse files

set default attention backend for deterministic inference (#11801)

parent e68a2b5b
...@@ -44,6 +44,7 @@ from sglang.srt.utils import ( ...@@ -44,6 +44,7 @@ from sglang.srt.utils import (
is_remote_url, is_remote_url,
is_sm90_supported, is_sm90_supported,
is_sm100_supported, is_sm100_supported,
is_sm120_supported,
is_triton_kernels_available, is_triton_kernels_available,
is_valid_ipv6_address, is_valid_ipv6_address,
json_list_type, json_list_type,
...@@ -1411,9 +1412,23 @@ class ServerArgs: ...@@ -1411,9 +1412,23 @@ class ServerArgs:
) )
# Check attention backend # Check attention backend
if self.attention_backend not in DETERMINISTIC_ATTENTION_BACKEND_CHOICES: if self.attention_backend is None:
# User didn't specify attention backend, fallback based on GPU architecture
if is_sm100_supported() or is_sm120_supported():
# Blackwell and newer architectures
self.attention_backend = "flashinfer"
else:
# Hopper (SM90) and older architectures
self.attention_backend = "fa3"
logger.warning(
f"Attention backend not specified. Falling back to '{self.attention_backend}' for deterministic inference. "
f"You can explicitly set --attention-backend to one of {DETERMINISTIC_ATTENTION_BACKEND_CHOICES}."
)
elif self.attention_backend not in DETERMINISTIC_ATTENTION_BACKEND_CHOICES:
# User explicitly specified an incompatible attention backend
raise ValueError( raise ValueError(
f"Currently only {DETERMINISTIC_ATTENTION_BACKEND_CHOICES} attention backends are supported for deterministic inference." f"Currently only {DETERMINISTIC_ATTENTION_BACKEND_CHOICES} attention backends are supported for deterministic inference, "
f"but you explicitly specified '{self.attention_backend}'."
) )
# Currently, only FA3 supports radix cache. Support for other backends is in progress # Currently, only FA3 supports radix cache. Support for other backends is in progress
......
...@@ -174,6 +174,15 @@ def is_blackwell(): ...@@ -174,6 +174,15 @@ def is_blackwell():
return torch.cuda.get_device_capability()[0] == 10 return torch.cuda.get_device_capability()[0] == 10
@lru_cache(maxsize=1)
def is_sm120_supported(device=None) -> bool:
if not is_cuda_alike():
return False
return (torch.cuda.get_device_capability(device)[0] == 12) and (
torch.version.cuda >= "12.8"
)
@lru_cache(maxsize=1) @lru_cache(maxsize=1)
def is_sm100_supported(device=None) -> bool: def is_sm100_supported(device=None) -> bool:
if not is_cuda_alike(): if not is_cuda_alike():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment