Unverified Commit 118b6af3 authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

feat: add should_use_tensor_core (#2179)

parent 9449a954
......@@ -18,7 +18,11 @@ import triton.language as tl
from sglang.global_config import global_config
from sglang.srt.layers.attention import AttentionBackend
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import get_bool_env_var, is_flashinfer_available
from sglang.srt.utils import (
get_bool_env_var,
is_flashinfer_available,
should_use_tensor_core,
)
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
......@@ -31,7 +35,6 @@ if is_flashinfer_available():
BatchPrefillWithRaggedKVCacheWrapper,
)
from flashinfer.cascade import merge_state
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
class WrapperDispatch(Enum):
......@@ -45,19 +48,14 @@ class FlashInferAttnBackend(AttentionBackend):
def __init__(self, model_runner: ModelRunner):
super().__init__()
# Parse constants
if "SGLANG_FLASHINFER_USE_TENSOR_CORE" in os.environ:
self.decode_use_tensor_cores = get_bool_env_var(
"SGLANG_FLASHINFER_USE_TENSOR_CORE"
)
else:
if not _grouped_size_compiled_for_decode_kernels(
model_runner.model_config.num_attention_heads // model_runner.tp_size,
model_runner.model_config.get_num_kv_heads(model_runner.tp_size),
):
self.decode_use_tensor_cores = True
else:
self.decode_use_tensor_cores = False
self.decode_use_tensor_cores = should_use_tensor_core(
kv_cache_dtype=model_runner.kv_cache_dtype,
num_attention_heads=model_runner.model_config.num_attention_heads
// model_runner.tp_size,
num_kv_heads=model_runner.model_config.get_num_kv_heads(
model_runner.tp_size
),
)
self.max_context_len = model_runner.model_config.context_len
......
......@@ -1108,3 +1108,51 @@ def cuda_device_count_stateless() -> int:
# This can be removed and simply replaced with torch.cuda.get_device_count
# after https://github.com/pytorch/pytorch/pull/122815 is released.
return _cuda_device_count_stateless(os.environ.get("CUDA_VISIBLE_DEVICES", None))
def should_use_tensor_core(
kv_cache_dtype: torch.dtype,
num_attention_heads: int,
num_kv_heads: int,
) -> bool:
"""
Determine whether to use tensor cores for attention computation.
Args:
kv_cache_dtype: Data type of the KV cache
num_attention_heads: Number of attention heads
num_kv_heads: Number of key/value heads
Returns:
bool: Whether to use tensor cores
"""
# Try to use environment variable first
env_override = os.environ.get("SGLANG_FLASHINFER_USE_TENSOR_CORE")
if env_override is not None:
return env_override.lower() == "true"
# Try to use _grouped_size_compiled_for_decode_kernels if available
# This is for flashinfer <=0.1.6. Otherwise, there is an accuracy bug
try:
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
if not _grouped_size_compiled_for_decode_kernels(
num_attention_heads,
num_kv_heads,
):
return True
else:
return False
except (ImportError, AttributeError):
pass
# Calculate GQA group size
gqa_group_size = num_attention_heads // num_kv_heads
# Determine based on dtype and GQA group size
if kv_cache_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
return True
elif kv_cache_dtype in (torch.float16, torch.half, torch.bfloat16):
return gqa_group_size > 4
else:
return False
......@@ -11,6 +11,7 @@ from sglang.srt.layers.attention.triton_ops.extend_attention import (
extend_attention_fwd,
redundant_attention,
)
from sglang.srt.utils import should_use_tensor_core
flashinfer_prefill_wrapper = None
flashinfer_decode_wrapper = None
......@@ -195,10 +196,9 @@ def test_batch_decode_with_paged_kv_cache(
def init_flashinfer(num_attention_heads, num_kv_heads):
if not _grouped_size_compiled_for_decode_kernels(num_attention_heads, num_kv_heads):
use_tensor_cores = True
else:
use_tensor_cores = False
use_tensor_cores = should_use_tensor_core(
torch.half, num_attention_heads, num_kv_heads
)
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device="cuda")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment