Unverified Commit 83d87685 authored by Mick's avatar Mick Committed by GitHub
Browse files

vlm: adapt internvl to VisionAttention (#6870)

parent 2a5f0100
from __future__ import annotations from __future__ import annotations
import dataclasses
import functools
import math import math
from functools import lru_cache, wraps from functools import lru_cache
from typing import Optional, Tuple from typing import Any, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from sglang.srt.utils import is_cuda from sglang.srt.utils import is_cuda, print_info_once
_is_cuda = is_cuda() _is_cuda = is_cuda()
...@@ -29,29 +31,42 @@ from sglang.srt.layers.linear import ( ...@@ -29,29 +31,42 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.quantization import QuantizationConfig from sglang.srt.layers.quantization import QuantizationConfig
from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.utils import add_prefix, logger from sglang.srt.utils import add_prefix
ROTARY_EMBED_CLASSES = { ROTARY_EMBED_CLASSES = {
"normal": apply_rotary_pos_emb, "normal": apply_rotary_pos_emb,
} }
def execute_once(func): @dataclasses.dataclass
has_run = None class SingletonCache:
data: Any = None
@wraps(func) def set_data(self, value: Any) -> None:
def wrapper(*args, **kwargs): self.data = value
nonlocal has_run
if not has_run:
func(*args, **kwargs)
has_run = True
return wrapper def get_data(self) -> Optional[Any]:
return self.data
def empty(self) -> bool:
return self.get_data() is None
@execute_once
def info_once(message: str): # TODO: requires real seqlens from images
logger.info(message) @functools.lru_cache(maxsize=128)
def _get_cu_seqlens_for_shape(batch_size: int, seqlen: int, device) -> torch.Tensor:
"""
Generates cumulative sequence lengths (cu_seqlens) for a given batch_size, seqlen, and device.
Caches the result based on these parameters.
"""
cu_seqlens = torch.arange(
0,
(batch_size + 1) * seqlen,
step=seqlen,
dtype=torch.int32,
device=device,
)
return cu_seqlens
class VisionSdpaAttention(nn.Module): class VisionSdpaAttention(nn.Module):
...@@ -265,8 +280,9 @@ class VisionFlash3Attention(nn.Module): ...@@ -265,8 +280,9 @@ class VisionFlash3Attention(nn.Module):
q: torch.Tensor, q: torch.Tensor,
k: torch.Tensor, k: torch.Tensor,
v: torch.Tensor, v: torch.Tensor,
cu_seqlens: Optional[torch.Tensor], cu_seqlens: Optional[Union[SingletonCache, torch.Tensor]],
attention_mask: Optional[torch.Tensor] = None, bsz: int,
seq_len: int,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
r""" r"""
...@@ -275,7 +291,16 @@ class VisionFlash3Attention(nn.Module): ...@@ -275,7 +291,16 @@ class VisionFlash3Attention(nn.Module):
Returns: Returns:
[b * s, h, head_size] [b * s, h, head_size]
""" """
cu_seqlens = cu_seqlens.to(dtype=torch.int32).cuda() if cu_seqlens is None:
cu_seqlens = _get_cu_seqlens_for_shape(bsz, seq_len, device=q.device)
elif isinstance(cu_seqlens, SingletonCache):
if cu_seqlens.empty():
cu_seqlens.set_data(
_get_cu_seqlens_for_shape(bsz, seq_len, device=q.device)
)
cu_seqlens = cu_seqlens.get_data()
cu_seqlens = cu_seqlens.to(dtype=torch.int32).to(q.device)
seq_lens = cu_seqlens[1:] - cu_seqlens[:-1] seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]
max_seqlen = seq_lens.max().item() max_seqlen = seq_lens.max().item()
output = flash_attn_varlen_func( output = flash_attn_varlen_func(
...@@ -346,11 +371,11 @@ class VisionAttention(nn.Module): ...@@ -346,11 +371,11 @@ class VisionAttention(nn.Module):
if global_server_args_dict["mm_attention_backend"] is None: if global_server_args_dict["mm_attention_backend"] is None:
if qkv_backend is None: if qkv_backend is None:
qkv_backend = "sdpa" qkv_backend = "sdpa"
info_once(f"Multimodal attention backend not set. Use {qkv_backend}.") print_info_once(f"Multimodal attention backend not set. Use {qkv_backend}.")
else: else:
qkv_backend = global_server_args_dict["mm_attention_backend"] qkv_backend = global_server_args_dict["mm_attention_backend"]
info_once(f"Using {qkv_backend} as multimodal attention backend.") print_info_once(f"Using {qkv_backend} as multimodal attention backend.")
self.qkv_backend = QKV_BACKEND_IMPL[qkv_backend]( self.qkv_backend = QKV_BACKEND_IMPL[qkv_backend](
head_dim=self.head_size, head_dim=self.head_size,
...@@ -423,15 +448,16 @@ class VisionAttention(nn.Module): ...@@ -423,15 +448,16 @@ class VisionAttention(nn.Module):
# [s, b, embed_dim] --> [s, b, head * 3 * head_size] # [s, b, embed_dim] --> [s, b, head * 3 * head_size]
qkv, _ = self.qkv_proj(x) qkv, _ = self.qkv_proj(x)
# [s, b, head * 3 * head_size] --> [s, b, head, 3 * head_size] # [s, b, head, head_dim_sum]
new_x_shape = qkv.size()[:-1] + ( new_x_shape = qkv.size()[:-1] + (
head, head,
3 * self.hidden_size_per_attention_head, self.q_size + 2 * self.kv_size,
) )
qkv = qkv.view(*new_x_shape) qkv = qkv.view(*new_x_shape)
# [s, b, head, 3 * head_size] --> 3 [s, b, head, head_size] # [s, b, head, 3 * head_size] --> 3 [s, b, head, head_size]
q, k, v = dist_utils.split_tensor_along_last_dim(qkv, 3) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
# [s, b, head, head_size] --> [b, s, head, head_size] # [s, b, head, head_size] --> [b, s, head, head_size]
q, k, v = [ q, k, v = [
rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v) rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)
...@@ -468,6 +494,7 @@ class VisionAttention(nn.Module): ...@@ -468,6 +494,7 @@ class VisionAttention(nn.Module):
k=k, k=k,
v=v, v=v,
bsz=bsz, bsz=bsz,
seq_len=s,
cu_seqlens=cu_seqlens, cu_seqlens=cu_seqlens,
attention_mask=attention_mask, attention_mask=attention_mask,
) )
......
...@@ -11,21 +11,19 @@ ...@@ -11,21 +11,19 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==========================582==================================================== # ==========================582====================================================
from typing import Iterable, List, Optional, Set, Tuple, Union
from typing import Iterable, List, Optional, Tuple, Union
import torch import torch
# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/7f62077af5159c625fe3ad1c812e6c1a2b93ba3b/vllm/model_executor/models/internlm2.py # Adapted from https://raw.githubusercontent.com/vllm-project/vllm/7f62077af5159c625fe3ad1c812e6c1a2b93ba3b/vllm/model_executor/models/internlm2.py
# Adapted from https://raw.githubusercontent.com/hehesangsj/sglang/refs/heads/internvl/python/sglang/srt/models/internvl.py # Adapted from https://raw.githubusercontent.com/hehesangsj/sglang/refs/heads/internvl/python/sglang/srt/models/internvl.py
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange, repeat
from sgl_kernel.flash_attn import flash_attn_varlen_func
from torch import nn from torch import nn
from transformers import PretrainedConfig, PreTrainedModel from transformers import PretrainedConfig, PreTrainedModel
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from sglang.srt.layers.attention.vision import SingletonCache, VisionAttention
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.mm_utils import ( from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs, MultiModalityDataPaddingPatternTokenPairs,
...@@ -40,75 +38,12 @@ from sglang.srt.models.qwen2 import Qwen2ForCausalLM ...@@ -40,75 +38,12 @@ from sglang.srt.models.qwen2 import Qwen2ForCausalLM
from sglang.utils import logger from sglang.utils import logger
class FlashAttention(nn.Module): class InternAttention(nn.Module):
"""Implement the scaled dot product attention with softmax.
Arguments
---------
softmax_scale: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
attention_dropout: The dropout rate to apply to the attention
(default: 0.0)
"""
def __init__( def __init__(
self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None
):
super().__init__()
self.softmax_scale = softmax_scale
self.dropout_p = attention_dropout
def forward(
self, self,
qkv, config,
causal=False, quant_config: QuantizationConfig = None,
max_s=None,
): ):
"""Implements the multihead softmax attention.
Arguments
---------
qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
if unpadded: (nnz, 3, h, d)
"""
assert qkv.dtype in [torch.float16, torch.bfloat16]
assert qkv.is_cuda
batch_size, seqlen, _, nheads, d = qkv.shape
if batch_size == 0 or seqlen == 0:
output_shape = (batch_size, seqlen, nheads, d)
return (
torch.zeros(output_shape, dtype=qkv.dtype, device=qkv.device),
None,
)
qkv_reshaped = rearrange(qkv, "b s three h d -> (b s) three h d", three=3)
q, k, v = qkv_reshaped.unbind(1)
max_s = seqlen
cu_seqlens = torch.arange(
0,
(batch_size + 1) * seqlen,
step=seqlen,
dtype=torch.int32,
device=qkv.device,
)
output_reshaped = flash_attn_varlen_func(
q,
k,
v,
cu_seqlens,
cu_seqlens,
max_s,
max_s,
softmax_scale=self.softmax_scale,
causal=causal,
)
output = rearrange(output_reshaped, "(b s) h d -> b s h d", b=batch_size)
return output, None
class InternAttention(nn.Module):
def __init__(self, config):
super().__init__() super().__init__()
self.config = config self.config = config
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
...@@ -116,7 +51,19 @@ class InternAttention(nn.Module): ...@@ -116,7 +51,19 @@ class InternAttention(nn.Module):
self.head_dim = self.embed_dim // self.num_heads self.head_dim = self.embed_dim // self.num_heads
self.scale = self.head_dim**-0.5 self.scale = self.head_dim**-0.5
self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=config.qkv_bias)
self.attn = VisionAttention(
qkv_backend="fa3",
embed_dim=self.embed_dim,
num_heads=self.num_heads,
projection_size=self.embed_dim,
use_qkv_parallel=True,
quant_config=quant_config,
dropout=getattr(config, "dropout", 0.0),
proj_bias=getattr(config, "qkv_bias", True),
flatten_batch=False,
)
self.proj_drop = nn.Dropout(config.dropout) self.proj_drop = nn.Dropout(config.dropout)
self.qk_normalization = config.qk_normalization self.qk_normalization = config.qk_normalization
...@@ -125,36 +72,15 @@ class InternAttention(nn.Module): ...@@ -125,36 +72,15 @@ class InternAttention(nn.Module):
self.q_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps) self.q_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
self.k_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps) self.k_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
self.inner_attn = FlashAttention(softmax_scale=self.scale) def forward(
self.proj = nn.Linear(self.embed_dim, self.embed_dim)
def _flash_attn(
self, self,
x, hidden_states: torch.Tensor,
): cu_seqlens: torch.Tensor,
qkv = self.qkv(x) ) -> torch.Tensor:
qkv = rearrange( out = self.attn(hidden_states, cu_seqlens=cu_seqlens)
qkv, "b s (three h d) -> b s three h d", three=3, h=self.num_heads outs = self.proj_drop(out)
)
if self.qk_normalization:
q, k, v = qkv.unbind(2)
q = self.q_norm(q.flatten(-2, -1)).view(q.shape)
k = self.k_norm(k.flatten(-2, -1)).view(k.shape)
qkv = torch.stack([q, k, v], dim=2)
context, _ = self.inner_attn(
qkv,
)
outs = self.proj(rearrange(context, "b s h d -> b s (h d)"))
outs = self.proj_drop(outs)
return outs return outs
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
x = self._flash_attn(hidden_states)
return x
class InternVisionEmbeddings(nn.Module): class InternVisionEmbeddings(nn.Module):
def __init__(self, config: PretrainedConfig): def __init__(self, config: PretrainedConfig):
...@@ -286,6 +212,7 @@ class InternVisionEncoderLayer(nn.Module): ...@@ -286,6 +212,7 @@ class InternVisionEncoderLayer(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
) -> Tuple[ ) -> Tuple[
torch.FloatTensor, torch.FloatTensor,
Optional[torch.FloatTensor], Optional[torch.FloatTensor],
...@@ -295,8 +222,12 @@ class InternVisionEncoderLayer(nn.Module): ...@@ -295,8 +222,12 @@ class InternVisionEncoderLayer(nn.Module):
Args: Args:
hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)` hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)`
""" """
hidden_states = hidden_states + self.drop_path1( hidden_states = hidden_states + self.drop_path1(
self.attn(self.norm1(hidden_states).to(hidden_states.dtype)) * self.ls1 self.attn(
self.norm1(hidden_states).to(hidden_states.dtype), cu_seqlens=cu_seqlens
)
* self.ls1
) )
hidden_states = hidden_states + self.drop_path2( hidden_states = hidden_states + self.drop_path2(
...@@ -363,12 +294,12 @@ class InternVisionEncoder(nn.Module): ...@@ -363,12 +294,12 @@ class InternVisionEncoder(nn.Module):
encoder_states = () if output_hidden_states else None encoder_states = () if output_hidden_states else None
hidden_states = inputs_embeds hidden_states = inputs_embeds
cu_seqlens = SingletonCache()
for idx, encoder_layer in enumerate(self.layers): for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states: if output_hidden_states:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
layer_outputs = encoder_layer( layer_outputs = encoder_layer(hidden_states, cu_seqlens=cu_seqlens)
hidden_states,
)
hidden_states = layer_outputs hidden_states = layer_outputs
if output_hidden_states: if output_hidden_states:
...@@ -625,6 +556,7 @@ class InternVLChatModel(nn.Module): ...@@ -625,6 +556,7 @@ class InternVLChatModel(nn.Module):
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
...@@ -641,6 +573,11 @@ class InternVLChatModel(nn.Module): ...@@ -641,6 +573,11 @@ class InternVLChatModel(nn.Module):
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
break break
else: else:
if "vision_model" in name:
# adapt to VisionAttention
name = name.replace(r"attn.", r"attn.attn.")
name = name.replace(r"qkv.", r"qkv_proj.")
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
...@@ -665,6 +602,13 @@ class InternVLChatModel(nn.Module): ...@@ -665,6 +602,13 @@ class InternVLChatModel(nn.Module):
param, "weight_loader", default_weight_loader param, "weight_loader", default_weight_loader
) )
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name)
unloaded_params = params_dict.keys() - loaded_params
if unloaded_params:
raise RuntimeError(
f"Some weights are not initialized from checkpoints: {unloaded_params}"
)
return loaded_params
EntryClass = InternVLChatModel EntryClass = InternVLChatModel
...@@ -17,6 +17,7 @@ import base64 ...@@ -17,6 +17,7 @@ import base64
import builtins import builtins
import ctypes import ctypes
import dataclasses import dataclasses
import functools
import importlib import importlib
import io import io
import ipaddress import ipaddress
...@@ -1386,6 +1387,11 @@ def print_warning_once(msg: str) -> None: ...@@ -1386,6 +1387,11 @@ def print_warning_once(msg: str) -> None:
logger.warning(msg, stacklevel=2) logger.warning(msg, stacklevel=2)
@functools.lru_cache(None)
def print_info_once(msg: str) -> None:
logger.info(msg)
def get_device_name(device_id: int = 0) -> str: def get_device_name(device_id: int = 0) -> str:
if hasattr(torch, "cuda") and torch.cuda.is_available(): if hasattr(torch, "cuda") and torch.cuda.is_available():
return torch.cuda.get_device_name(device_id) return torch.cuda.get_device_name(device_id)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment