Unverified Commit cbdfb771 authored by Clay's avatar Clay Committed by GitHub
Browse files

Enable FlashInfer support encoder models and add head_dim padding workaround (#6230)

parent 282eb59f
...@@ -25,6 +25,7 @@ from sglang.global_config import global_config ...@@ -25,6 +25,7 @@ from sglang.global_config import global_config
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.layers.radix_attention import AttentionType
from sglang.srt.layers.utils import is_sm100_supported from sglang.srt.layers.utils import is_sm100_supported
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
...@@ -486,12 +487,20 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -486,12 +487,20 @@ class FlashInferAttnBackend(AttentionBackend):
v_scale=layer.v_scale, v_scale=layer.v_scale,
) )
else: else:
causal = True
if layer.attn_type == AttentionType.ENCODER_ONLY:
save_kv_cache = False
causal = False
if self.forward_metadata.extend_no_prefix: if self.forward_metadata.extend_no_prefix:
# NOTE: FlashInfer currently has limitations with head_dim = 32 or other dimensions
# The FlashInfer head_dim limitation itself is tracked here:
# https://github.com/flashinfer-ai/flashinfer/issues/1048
o = self.prefill_wrapper_ragged.forward( o = self.prefill_wrapper_ragged.forward(
q.view(-1, layer.tp_q_head_num, layer.head_dim), q.view(-1, layer.tp_q_head_num, layer.head_dim),
k.view(-1, layer.tp_k_head_num, layer.head_dim), k.view(-1, layer.tp_k_head_num, layer.head_dim),
v.view(-1, layer.tp_v_head_num, layer.head_dim), v.view(-1, layer.tp_v_head_num, layer.head_dim),
causal=True, causal=causal,
sm_scale=layer.scaling, sm_scale=layer.scaling,
logits_soft_cap=logits_soft_cap, logits_soft_cap=logits_soft_cap,
) )
......
...@@ -27,9 +27,9 @@ from sglang.test.test_utils import CustomTestCase, get_similarities, is_in_ci ...@@ -27,9 +27,9 @@ from sglang.test.test_utils import CustomTestCase, get_similarities, is_in_ci
MODELS = [("BAAI/bge-small-en", 1, 1e-5), ("BAAI/bge-m3", 1, 1e-5)] MODELS = [("BAAI/bge-small-en", 1, 1e-5), ("BAAI/bge-m3", 1, 1e-5)]
ATTENTION_BACKEND = ["torch_native", "triton"] ATTENTION_BACKEND = ["torch_native", "triton", "flashinfer"]
BATCH_SIZE = [1, 2] BATCH_SIZE = [1, 2]
TORCH_DTYPES = [torch.float32] TORCH_DTYPES = [torch.float32, torch.float16]
sgl_to_st_ratio = [] sgl_to_st_ratio = []
...@@ -126,6 +126,19 @@ class TestEncoderEmbeddingModels(CustomTestCase): ...@@ -126,6 +126,19 @@ class TestEncoderEmbeddingModels(CustomTestCase):
for attention_backend in ATTENTION_BACKEND: for attention_backend in ATTENTION_BACKEND:
for batch_size in BATCH_SIZE: for batch_size in BATCH_SIZE:
for torch_dtype in TORCH_DTYPES: for torch_dtype in TORCH_DTYPES:
# NOTE: FlashInfer currently has limitations with head_dim = 32 or
# other dimensions.
# The FlashInfer head_dim limitation itself is tracked here:
# https://github.com/flashinfer-ai/flashinfer/issues/1048
#
# Flashinfer does not support torch.float32 for dtype_q, so skip it
if attention_backend == "flashinfer":
if (
model == "BAAI/bge-small-en"
or torch_dtype == torch.float32
):
continue
self.assert_close_prefill_logits( self.assert_close_prefill_logits(
DEFAULT_PROMPTS, DEFAULT_PROMPTS,
model, model,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment