Unverified Commit 7101e085 authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[Models]: Use `MMEncoderAttention` for MoonViT (#31738)


Signed-off-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: default avatarRoger Wang <hey@rogerw.io>
Signed-off-by: default avatarh100 <h100@inferact.ai>
Co-authored-by: default avatarRoger Wang <hey@rogerw.io>
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: default avatarh100 <h100@inferact.ai>
parent e9717801
...@@ -325,7 +325,7 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -325,7 +325,7 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
self.hidden_size = config.text_config.hidden_size self.hidden_size = config.text_config.hidden_size
self.vision_tower = MoonVitPretrainedModel( self.vision_tower = MoonVitPretrainedModel(
config.vision_config, config.vision_config,
self.use_data_parallel, multimodal_config=model_config.multimodal_config,
prefix=maybe_prefix(prefix, "vision_tower"), prefix=maybe_prefix(prefix, "vision_tower"),
) )
......
...@@ -51,118 +51,20 @@ import torch.nn as nn ...@@ -51,118 +51,20 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from transformers.utils import is_flash_attn_2_available
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.config import MultiModalConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.conv import Conv2dLayer from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.models.utils import maybe_prefix from vllm.model_executor.models.utils import maybe_prefix
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.transformers_utils.configs.moonvit import MoonViTConfig from vllm.transformers_utils.configs.moonvit import MoonViTConfig
if is_flash_attn_2_available():
from flash_attn import flash_attn_varlen_func
elif current_platform.is_xpu():
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
else:
flash_attn_varlen_func = None
def multihead_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
q_cu_seqlens: torch.Tensor | None = None,
k_cu_seqlens: torch.Tensor | None = None,
) -> torch.Tensor:
"""Multi-head attention using flash attention 2.
Args:
q: Query tensor of shape (batch_size, seqlen, num_heads, head_dim),
or (tot_seqlens, num_heads, head_dim) if packing.
k: Key tensor of shape (batch_size, seqlen, num_heads, head_dim),
or (tot_seqlens, num_heads, head_dim) if packing.
v: Value tensor of shape (batch_size, seqlen, num_heads, head_dim),
or (tot_seqlens, num_heads, head_dim) if packing.
q_cu_seqlens (torch.Tensor): cumulative sequence lengths of q.
The first element should be 0 and the last element should be q.shape[0].
k_cu_seqlens (torch.Tensor): cumulative sequence lengths of k.
The first element should be 0 and the last element should be k.shape[0].
Returns:
output: shape (batch_size, seqlen, dim) or (tot_seqlens, dim) if packing,
where dim = num_heads * head_dim
"""
# Unified format legal check
assert q.dim() == k.dim() == v.dim() == 3, "q, k, v must have 3 dims"
assert q_cu_seqlens[-1] == q.shape[0], "q_cu_seqlens must sum to q.shape[0]"
assert k_cu_seqlens[-1] == k.shape[0] == v.shape[0], (
"k_cu_seqlens must sum to k.shape[0]"
)
assert q.dtype in [
torch.bfloat16,
torch.float16,
], f"unsupported dtype {q.dtype} for multihead attn"
max_seqlen_q = (q_cu_seqlens[1:] - q_cu_seqlens[:-1]).max().item()
max_seqlen_k = (k_cu_seqlens[1:] - k_cu_seqlens[:-1]).max().item()
attn_out = flash_attn_varlen_func(
q,
k,
v,
cu_seqlens_q=q_cu_seqlens,
cu_seqlens_k=k_cu_seqlens,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
causal=False,
)
attn_out = attn_out.flatten(start_dim=-2)
return attn_out
def sdpa_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
q_cu_seqlens: torch.Tensor | None = None,
k_cu_seqlens: torch.Tensor | None = None,
) -> torch.Tensor:
"""SDPA attention.
Args:
q: Query tensor of shape (batch_size, seqlen, num_heads, head_dim),
or (tot_seqlens, num_heads, head_dim) if packing.
k: Key tensor of shape (batch_size, seqlen, num_heads, head_dim),
or (tot_seqlens, num_heads, head_dim) if packing.
v: Value tensor of shape (batch_size, seqlen, num_heads, head_dim),
or (tot_seqlens, num_heads, head_dim) if packing.
q_cu_seqlens: Optional cumulative sequence lengths of q.
k_cu_seqlens: Optional cumulative sequence lengths of k.
"""
seq_length = q.shape[0]
attention_mask = torch.zeros(
[1, seq_length, seq_length], device=q.device, dtype=torch.bool
)
for i in range(1, len(q_cu_seqlens)):
attention_mask[
...,
q_cu_seqlens[i - 1] : q_cu_seqlens[i],
q_cu_seqlens[i - 1] : q_cu_seqlens[i],
] = True
q = q.transpose(0, 1)
k = k.transpose(0, 1)
v = v.transpose(0, 1)
attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0)
attn_output = attn_output.transpose(0, 1)
attn_output = attn_output.reshape(seq_length, -1)
return attn_output
VL_VISION_ATTENTION_FUNCTIONS = {
"flash_attention_2": multihead_attention,
"sdpa": sdpa_attention,
}
def _apply_rope_input_validation(x, freqs_cis): def _apply_rope_input_validation(x, freqs_cis):
assert x.ndim == freqs_cis.ndim + 1, (x.shape, freqs_cis.shape) assert x.ndim == freqs_cis.ndim + 1, (x.shape, freqs_cis.shape)
...@@ -411,11 +313,19 @@ class MLP2(nn.Module): ...@@ -411,11 +313,19 @@ class MLP2(nn.Module):
super().__init__() super().__init__()
assert len(dims) == 3 assert len(dims) == 3
self.use_data_parallel = use_data_parallel self.use_data_parallel = use_data_parallel
self.fc0 = ReplicatedLinear( self.fc0 = ColumnParallelLinear(
dims[0], dims[1], bias=bias, prefix=maybe_prefix(prefix, "fc0") dims[0],
dims[1],
bias=bias,
prefix=maybe_prefix(prefix, "fc0"),
disable_tp=self.use_data_parallel,
) )
self.fc1 = ReplicatedLinear( self.fc1 = RowParallelLinear(
dims[1], dims[2], bias=bias, prefix=maybe_prefix(prefix, "fc1") dims[1],
dims[2],
bias=bias,
prefix=maybe_prefix(prefix, "fc1"),
disable_tp=self.use_data_parallel,
) )
self.activation = activation self.activation = activation
...@@ -433,35 +343,55 @@ class MoonVitEncoderLayer(nn.Module): ...@@ -433,35 +343,55 @@ class MoonVitEncoderLayer(nn.Module):
hidden_dim: int, hidden_dim: int,
mlp_dim: int, mlp_dim: int,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False, multimodal_config: MultiModalConfig | None = None,
*, *,
attn_implementation: str = "sdpa",
activation=F.gelu, activation=F.gelu,
attn_bias: bool = False, attn_bias: bool = False,
): ):
super().__init__() super().__init__()
self.use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.num_heads = num_heads self.num_heads = num_heads
self.hidden_dim = hidden_dim self.hidden_dim = hidden_dim
self.hidden_size_per_attention_head = self.hidden_dim // self.num_heads self.hidden_size_per_attention_head = self.hidden_dim // self.num_heads
self.attn_implementation = attn_implementation self.tp_size = (
# use fa2 in vllm by default 1 if self.use_data_parallel else get_tensor_model_parallel_world_size()
if is_flash_attn_2_available() or current_platform.is_xpu(): )
self.attn_implementation = "flash_attention_2" self.num_attention_heads_per_partition = divide(num_heads, self.tp_size)
self.norm0 = nn.LayerNorm(hidden_dim) self.norm0 = nn.LayerNorm(hidden_dim)
self.norm1 = nn.LayerNorm(hidden_dim) self.norm1 = nn.LayerNorm(hidden_dim)
self.use_data_parallel = use_data_parallel
self.mlp = MLP2( self.mlp = MLP2(
[hidden_dim, mlp_dim, hidden_dim], [hidden_dim, mlp_dim, hidden_dim],
activation, activation,
prefix=f"{prefix}.mlp", prefix=f"{prefix}.mlp",
use_data_parallel=use_data_parallel, use_data_parallel=self.use_data_parallel,
) )
self.wqkv = ReplicatedLinear( self.wqkv = QKVParallelLinear(
hidden_dim, hidden_dim * 3, bias=attn_bias, prefix=f"{prefix}.wqkv" hidden_size=hidden_dim,
head_size=self.hidden_size_per_attention_head,
total_num_heads=num_heads,
total_num_kv_heads=num_heads,
bias=attn_bias,
prefix=f"{prefix}.wqkv",
disable_tp=self.use_data_parallel,
) )
self.wo = ReplicatedLinear( self.wo = RowParallelLinear(
hidden_dim, hidden_dim, bias=attn_bias, prefix=f"{prefix}.wo" hidden_dim,
hidden_dim,
bias=attn_bias,
prefix=f"{prefix}.wo",
disable_tp=self.use_data_parallel,
)
self.attn = MMEncoderAttention(
num_heads=self.num_attention_heads_per_partition,
head_size=self.hidden_size_per_attention_head,
multimodal_config=multimodal_config,
prefix=f"{prefix}.attn",
) )
def attention_qkvpacked( def attention_qkvpacked(
...@@ -472,14 +402,15 @@ class MoonVitEncoderLayer(nn.Module): ...@@ -472,14 +402,15 @@ class MoonVitEncoderLayer(nn.Module):
): ):
""" """
Args: Args:
x (torch.Tensor): (batch_size, seqlen, hidden_dim) x (torch.Tensor): (seqlen, hidden_dim)
cu_seqlens (torch.Tensor): cu_seqlens (torch.Tensor):
""" """
seq_length = x.size(0)
xqkv, _ = self.wqkv(x) xqkv, _ = self.wqkv(x)
qkv_shape = xqkv.size()[:-1] + ( qkv_shape = xqkv.size()[:-1] + (
3, 3,
self.num_heads, self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head, self.hidden_size_per_attention_head,
) )
# xqkv: (batch_size, seqlen, 3, nheads, headdim) # xqkv: (batch_size, seqlen, 3, nheads, headdim)
...@@ -488,9 +419,18 @@ class MoonVitEncoderLayer(nn.Module): ...@@ -488,9 +419,18 @@ class MoonVitEncoderLayer(nn.Module):
xq, xk = apply_rope(xq, xk, rope_freqs_cis) xq, xk = apply_rope(xq, xk, rope_freqs_cis)
attn_func = VL_VISION_ATTENTION_FUNCTIONS[self.attn_implementation] max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
attn_out = attn_func( attn_out = self.attn(
xq, xk, xv, q_cu_seqlens=cu_seqlens, k_cu_seqlens=cu_seqlens xq.unsqueeze(0),
xk.unsqueeze(0),
xv.unsqueeze(0),
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
attn_out = attn_out.reshape(
seq_length,
self.num_attention_heads_per_partition
* self.hidden_size_per_attention_head,
) )
attn_out, _ = self.wo(attn_out) attn_out, _ = self.wo(attn_out)
return attn_out return attn_out
...@@ -528,7 +468,7 @@ class MoonVitEncoder(nn.Module): ...@@ -528,7 +468,7 @@ class MoonVitEncoder(nn.Module):
num_layers: int, num_layers: int,
block_cfg: dict, block_cfg: dict,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False, multimodal_config: MultiModalConfig | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -538,7 +478,7 @@ class MoonVitEncoder(nn.Module): ...@@ -538,7 +478,7 @@ class MoonVitEncoder(nn.Module):
self.blocks = nn.ModuleList( self.blocks = nn.ModuleList(
[ [
MoonVitEncoderLayer( MoonVitEncoderLayer(
use_data_parallel=use_data_parallel, multimodal_config=multimodal_config,
prefix=f"{prefix}.blocks.{layer_idx}", prefix=f"{prefix}.blocks.{layer_idx}",
**block_cfg, **block_cfg,
) )
...@@ -599,31 +539,6 @@ def patch_merger( ...@@ -599,31 +539,6 @@ def patch_merger(
return outputs return outputs
class MoonVitVLProjector(nn.Module):
def __init__(
self,
in_channels: int,
merge_kernel_size: list[int, int],
hidden_act: str = "gelu",
ln_eps: float = 1e-5,
out_dim: int = 4096,
):
super().__init__()
self.hidden_size = in_channels * merge_kernel_size[0] * merge_kernel_size[1]
self.pre_norm = nn.nn.LayerNorm(in_channels, eps=ln_eps)
self.linear_1 = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
self.act = ACT2FN[hidden_act]
self.linear_2 = nn.Linear(self.hidden_size, out_dim, bias=True)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.pre_norm(hidden_states).view(-1, self.hidden_size)
hidden_states = self.linear_1(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
class MoonVitPretrainedModel(PreTrainedModel): class MoonVitPretrainedModel(PreTrainedModel):
config_class = MoonViTConfig config_class = MoonViTConfig
model_type = "moonvit" model_type = "moonvit"
...@@ -634,14 +549,13 @@ class MoonVitPretrainedModel(PreTrainedModel): ...@@ -634,14 +549,13 @@ class MoonVitPretrainedModel(PreTrainedModel):
def __init__( def __init__(
self, self,
config: MoonViTConfig, config: MoonViTConfig,
use_data_parallel: bool = False, multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
*inputs, *inputs,
**kwargs, **kwargs,
): ):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
config = deepcopy(config) config = deepcopy(config)
self.use_data_parallel = use_data_parallel
self.merge_kernel_size = config.merge_kernel_size self.merge_kernel_size = config.merge_kernel_size
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.patch_size = config.patch_size self.patch_size = config.patch_size
...@@ -662,9 +576,9 @@ class MoonVitPretrainedModel(PreTrainedModel): ...@@ -662,9 +576,9 @@ class MoonVitPretrainedModel(PreTrainedModel):
"mlp_dim": config.intermediate_size, "mlp_dim": config.intermediate_size,
"activation": ACT2FN["gelu_pytorch_tanh"], "activation": ACT2FN["gelu_pytorch_tanh"],
"attn_bias": True, "attn_bias": True,
"attn_implementation": config._attn_implementation,
}, },
prefix=f"{prefix}.encoder", prefix=f"{prefix}.encoder",
multimodal_config=multimodal_config,
) )
def forward( def forward(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment