Unverified Commit 118b6af3 authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

feat: add should_use_tensor_core (#2179)

parent 9449a954
...@@ -18,7 +18,11 @@ import triton.language as tl ...@@ -18,7 +18,11 @@ import triton.language as tl
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.srt.layers.attention import AttentionBackend from sglang.srt.layers.attention import AttentionBackend
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import get_bool_env_var, is_flashinfer_available from sglang.srt.utils import (
get_bool_env_var,
is_flashinfer_available,
should_use_tensor_core,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
...@@ -31,7 +35,6 @@ if is_flashinfer_available(): ...@@ -31,7 +35,6 @@ if is_flashinfer_available():
BatchPrefillWithRaggedKVCacheWrapper, BatchPrefillWithRaggedKVCacheWrapper,
) )
from flashinfer.cascade import merge_state from flashinfer.cascade import merge_state
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
class WrapperDispatch(Enum): class WrapperDispatch(Enum):
...@@ -45,19 +48,14 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -45,19 +48,14 @@ class FlashInferAttnBackend(AttentionBackend):
def __init__(self, model_runner: ModelRunner): def __init__(self, model_runner: ModelRunner):
super().__init__() super().__init__()
# Parse constants self.decode_use_tensor_cores = should_use_tensor_core(
if "SGLANG_FLASHINFER_USE_TENSOR_CORE" in os.environ: kv_cache_dtype=model_runner.kv_cache_dtype,
self.decode_use_tensor_cores = get_bool_env_var( num_attention_heads=model_runner.model_config.num_attention_heads
"SGLANG_FLASHINFER_USE_TENSOR_CORE" // model_runner.tp_size,
) num_kv_heads=model_runner.model_config.get_num_kv_heads(
else: model_runner.tp_size
if not _grouped_size_compiled_for_decode_kernels( ),
model_runner.model_config.num_attention_heads // model_runner.tp_size, )
model_runner.model_config.get_num_kv_heads(model_runner.tp_size),
):
self.decode_use_tensor_cores = True
else:
self.decode_use_tensor_cores = False
self.max_context_len = model_runner.model_config.context_len self.max_context_len = model_runner.model_config.context_len
......
...@@ -1108,3 +1108,51 @@ def cuda_device_count_stateless() -> int: ...@@ -1108,3 +1108,51 @@ def cuda_device_count_stateless() -> int:
# This can be removed and simply replaced with torch.cuda.get_device_count # This can be removed and simply replaced with torch.cuda.get_device_count
# after https://github.com/pytorch/pytorch/pull/122815 is released. # after https://github.com/pytorch/pytorch/pull/122815 is released.
return _cuda_device_count_stateless(os.environ.get("CUDA_VISIBLE_DEVICES", None)) return _cuda_device_count_stateless(os.environ.get("CUDA_VISIBLE_DEVICES", None))
def should_use_tensor_core(
kv_cache_dtype: torch.dtype,
num_attention_heads: int,
num_kv_heads: int,
) -> bool:
"""
Determine whether to use tensor cores for attention computation.
Args:
kv_cache_dtype: Data type of the KV cache
num_attention_heads: Number of attention heads
num_kv_heads: Number of key/value heads
Returns:
bool: Whether to use tensor cores
"""
# Try to use environment variable first
env_override = os.environ.get("SGLANG_FLASHINFER_USE_TENSOR_CORE")
if env_override is not None:
return env_override.lower() == "true"
# Try to use _grouped_size_compiled_for_decode_kernels if available
# This is for flashinfer <=0.1.6. Otherwise, there is an accuracy bug
try:
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
if not _grouped_size_compiled_for_decode_kernels(
num_attention_heads,
num_kv_heads,
):
return True
else:
return False
except (ImportError, AttributeError):
pass
# Calculate GQA group size
gqa_group_size = num_attention_heads // num_kv_heads
# Determine based on dtype and GQA group size
if kv_cache_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
return True
elif kv_cache_dtype in (torch.float16, torch.half, torch.bfloat16):
return gqa_group_size > 4
else:
return False
...@@ -11,6 +11,7 @@ from sglang.srt.layers.attention.triton_ops.extend_attention import ( ...@@ -11,6 +11,7 @@ from sglang.srt.layers.attention.triton_ops.extend_attention import (
extend_attention_fwd, extend_attention_fwd,
redundant_attention, redundant_attention,
) )
from sglang.srt.utils import should_use_tensor_core
flashinfer_prefill_wrapper = None flashinfer_prefill_wrapper = None
flashinfer_decode_wrapper = None flashinfer_decode_wrapper = None
...@@ -195,10 +196,9 @@ def test_batch_decode_with_paged_kv_cache( ...@@ -195,10 +196,9 @@ def test_batch_decode_with_paged_kv_cache(
def init_flashinfer(num_attention_heads, num_kv_heads): def init_flashinfer(num_attention_heads, num_kv_heads):
if not _grouped_size_compiled_for_decode_kernels(num_attention_heads, num_kv_heads): use_tensor_cores = should_use_tensor_core(
use_tensor_cores = True torch.half, num_attention_heads, num_kv_heads
else: )
use_tensor_cores = False
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device="cuda") workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device="cuda")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment