Unverified Commit a1f011d0 authored by Mick's avatar Mick Committed by GitHub
Browse files

minor: determine mm attn backend based on platforms (#9303)

parent 9ec314c6
...@@ -12,7 +12,12 @@ import torch.nn.functional as F ...@@ -12,7 +12,12 @@ import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
from sglang.srt.utils import is_cuda, print_info_once from sglang.srt.utils import (
get_device_capability,
is_blackwell,
is_cuda,
print_info_once,
)
_is_cuda = is_cuda() _is_cuda = is_cuda()
...@@ -20,7 +25,6 @@ if _is_cuda: ...@@ -20,7 +25,6 @@ if _is_cuda:
from sgl_kernel.flash_attn import flash_attn_varlen_func from sgl_kernel.flash_attn import flash_attn_varlen_func
from sglang.srt.distributed import ( from sglang.srt.distributed import (
parallel_state,
split_tensor_along_last_dim, split_tensor_along_last_dim,
tensor_model_parallel_all_gather, tensor_model_parallel_all_gather,
) )
...@@ -402,18 +406,14 @@ class VisionAttention(nn.Module): ...@@ -402,18 +406,14 @@ class VisionAttention(nn.Module):
self.dummy_dim, eps=layer_norm_eps, var_hidden_size=embed_dim self.dummy_dim, eps=layer_norm_eps, var_hidden_size=embed_dim
) )
# priority: server_args > passed qkv_backend > sdpa # Select attention backend via a unified method
if global_server_args_dict["mm_attention_backend"] is None: _passed_backend = qkv_backend
if qkv_backend is None: qkv_backend = self._determine_attention_backend(_passed_backend)
if is_cuda(): if (
# Double prefill throughput by setting attn backend to Triton on CUDA global_server_args_dict["mm_attention_backend"] is None
qkv_backend = "triton_attn" and _passed_backend is None
else: ):
qkv_backend = "sdpa"
print_info_once(f"Multimodal attention backend not set. Use {qkv_backend}.") print_info_once(f"Multimodal attention backend not set. Use {qkv_backend}.")
else:
qkv_backend = global_server_args_dict["mm_attention_backend"]
print_info_once(f"Using {qkv_backend} as multimodal attention backend.") print_info_once(f"Using {qkv_backend} as multimodal attention backend.")
self.customized_position_embedding_applier = ( self.customized_position_embedding_applier = (
...@@ -461,6 +461,33 @@ class VisionAttention(nn.Module): ...@@ -461,6 +461,33 @@ class VisionAttention(nn.Module):
prefix=add_prefix("proj", prefix), prefix=add_prefix("proj", prefix),
) )
def _determine_attention_backend(self, passed_backend: Optional[str]) -> str:
"""Decide the multimodal attention backend string.
Priority: server args override > constructor arg > platform default.
Platform defaults:
- CUDA: "triton_attn"
- Non-CUDA: "sdpa"
"""
override_backend = global_server_args_dict["mm_attention_backend"]
if override_backend is not None:
backend = override_backend
elif passed_backend is not None:
backend = passed_backend
elif is_cuda():
major, minor = get_device_capability()
if major == 9:
backend = "fa3"
else:
backend = "triton_attn"
else:
backend = "sdpa"
if backend == "fa3" and is_blackwell():
raise ValueError("The 'fa3' backend is not supported on Blackwell GPUs")
return backend
def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor): def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor):
"""apply qk norm for internvl vit attn""" """apply qk norm for internvl vit attn"""
q = q.flatten(1, 2) q = q.flatten(1, 2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment