Unverified Commit a1f011d0 authored by Mick's avatar Mick Committed by GitHub
Browse files

minor: determine mm attn backend based on platforms (#9303)

parent 9ec314c6
......@@ -12,7 +12,12 @@ import torch.nn.functional as F
from einops import rearrange
from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
from sglang.srt.utils import is_cuda, print_info_once
from sglang.srt.utils import (
get_device_capability,
is_blackwell,
is_cuda,
print_info_once,
)
_is_cuda = is_cuda()
......@@ -20,7 +25,6 @@ if _is_cuda:
from sgl_kernel.flash_attn import flash_attn_varlen_func
from sglang.srt.distributed import (
parallel_state,
split_tensor_along_last_dim,
tensor_model_parallel_all_gather,
)
......@@ -402,18 +406,14 @@ class VisionAttention(nn.Module):
self.dummy_dim, eps=layer_norm_eps, var_hidden_size=embed_dim
)
# priority: server_args > passed qkv_backend > sdpa
if global_server_args_dict["mm_attention_backend"] is None:
if qkv_backend is None:
if is_cuda():
# Double prefill throughput by setting attn backend to Triton on CUDA
qkv_backend = "triton_attn"
else:
qkv_backend = "sdpa"
# Select attention backend via a unified method
_passed_backend = qkv_backend
qkv_backend = self._determine_attention_backend(_passed_backend)
if (
global_server_args_dict["mm_attention_backend"] is None
and _passed_backend is None
):
print_info_once(f"Multimodal attention backend not set. Use {qkv_backend}.")
else:
qkv_backend = global_server_args_dict["mm_attention_backend"]
print_info_once(f"Using {qkv_backend} as multimodal attention backend.")
self.customized_position_embedding_applier = (
......@@ -461,6 +461,33 @@ class VisionAttention(nn.Module):
prefix=add_prefix("proj", prefix),
)
def _determine_attention_backend(self, passed_backend: Optional[str]) -> str:
"""Decide the multimodal attention backend string.
Priority: server args override > constructor arg > platform default.
Platform defaults:
- CUDA: "triton_attn"
- Non-CUDA: "sdpa"
"""
override_backend = global_server_args_dict["mm_attention_backend"]
if override_backend is not None:
backend = override_backend
elif passed_backend is not None:
backend = passed_backend
elif is_cuda():
major, minor = get_device_capability()
if major == 9:
backend = "fa3"
else:
backend = "triton_attn"
else:
backend = "sdpa"
if backend == "fa3" and is_blackwell():
raise ValueError("The 'fa3' backend is not supported on Blackwell GPUs")
return backend
def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor):
"""apply qk norm for internvl vit attn"""
q = q.flatten(1, 2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment