Unverified Commit ec266536 authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[Bugfix][VLM] Add fallback to SDPA for ViT model running on CPU backend (#8061)

parent 0fbc6696
...@@ -7,7 +7,7 @@ import torch ...@@ -7,7 +7,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from PIL import Image from PIL import Image
from transformers import Blip2VisionConfig, BlipVisionConfig from transformers import Blip2VisionConfig, BlipVisionConfig
from xformers import ops as xops from transformers.models.blip.modeling_blip import BlipAttention
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.distributed import divide, get_tensor_model_parallel_world_size
...@@ -21,6 +21,12 @@ from vllm.multimodal.utils import (cached_get_tokenizer, ...@@ -21,6 +21,12 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens) repeat_and_pad_placeholder_tokens)
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
try:
from xformers import ops as xops
USE_XFORMERS_OPS = True
except ImportError:
USE_XFORMERS_OPS = False
def get_blip_patch_grid_length(*, image_size: int, patch_size: int) -> int: def get_blip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
assert image_size % patch_size == 0 assert image_size % patch_size == 0
...@@ -156,7 +162,7 @@ class BlipVisionEmbeddings(nn.Module): ...@@ -156,7 +162,7 @@ class BlipVisionEmbeddings(nn.Module):
return embeddings return embeddings
class BlipAttention(nn.Module): class BlipParallelAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper""" """Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__( def __init__(
...@@ -224,7 +230,7 @@ class BlipAttention(nn.Module): ...@@ -224,7 +230,7 @@ class BlipAttention(nn.Module):
out = out.view(bsz, tgt_len, -1) out = out.view(bsz, tgt_len, -1)
attn_output, _ = self.projection(out) attn_output, _ = self.projection(out)
return attn_output return attn_output, None
class BlipMLP(nn.Module): class BlipMLP(nn.Module):
...@@ -261,7 +267,16 @@ class BlipEncoderLayer(nn.Module): ...@@ -261,7 +267,16 @@ class BlipEncoderLayer(nn.Module):
quant_config: Optional[QuantizationConfig] = None): quant_config: Optional[QuantizationConfig] = None):
super().__init__() super().__init__()
self.self_attn = BlipAttention(config, quant_config=quant_config) # fallback to sdpa attention if tp unavailable
num_heads = config.num_attention_heads
tp_size = get_tensor_model_parallel_world_size()
if USE_XFORMERS_OPS and num_heads % tp_size == 0:
self.self_attn = BlipParallelAttention(config,
quant_config=quant_config)
else:
# Blip doesn't have SDPA attention implemented in transformers
# use eager attention instead for cpu backend
self.self_attn = BlipAttention(config)
self.layer_norm1 = nn.LayerNorm(config.hidden_size, self.layer_norm1 = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
self.mlp = BlipMLP(config, quant_config=quant_config) self.mlp = BlipMLP(config, quant_config=quant_config)
...@@ -272,7 +287,7 @@ class BlipEncoderLayer(nn.Module): ...@@ -272,7 +287,7 @@ class BlipEncoderLayer(nn.Module):
residual = hidden_states residual = hidden_states
hidden_states = self.layer_norm1(hidden_states) hidden_states = self.layer_norm1(hidden_states)
hidden_states = self.self_attn(hidden_states=hidden_states) hidden_states, _ = self.self_attn(hidden_states=hidden_states)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
residual = hidden_states residual = hidden_states
......
...@@ -7,7 +7,7 @@ import torch ...@@ -7,7 +7,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from PIL import Image from PIL import Image
from transformers import CLIPVisionConfig from transformers import CLIPVisionConfig
from xformers import ops as xops from transformers.models.clip.modeling_clip import CLIPSdpaAttention
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.distributed import divide, get_tensor_model_parallel_world_size
...@@ -22,6 +22,12 @@ from vllm.multimodal.utils import (cached_get_tokenizer, ...@@ -22,6 +22,12 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens) repeat_and_pad_placeholder_tokens)
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
try:
from xformers import ops as xops
USE_XFORMERS_OPS = True
except ImportError:
USE_XFORMERS_OPS = False
def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int: def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
assert image_size % patch_size == 0 assert image_size % patch_size == 0
...@@ -162,7 +168,7 @@ class CLIPVisionEmbeddings(nn.Module): ...@@ -162,7 +168,7 @@ class CLIPVisionEmbeddings(nn.Module):
return embeddings return embeddings
class CLIPAttention(nn.Module): class CLIPParallelAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper""" """Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__( def __init__(
...@@ -231,7 +237,7 @@ class CLIPAttention(nn.Module): ...@@ -231,7 +237,7 @@ class CLIPAttention(nn.Module):
out = out.view(bsz, tgt_len, -1) out = out.view(bsz, tgt_len, -1)
attn_output, _ = self.out_proj(out) attn_output, _ = self.out_proj(out)
return attn_output return attn_output, None
class CLIPMLP(nn.Module): class CLIPMLP(nn.Module):
...@@ -266,7 +272,13 @@ class CLIPEncoderLayer(nn.Module): ...@@ -266,7 +272,13 @@ class CLIPEncoderLayer(nn.Module):
quant_config: Optional[QuantizationConfig] = None): quant_config: Optional[QuantizationConfig] = None):
super().__init__() super().__init__()
self.self_attn = CLIPAttention(config, quant_config=quant_config) num_heads = config.num_attention_heads
tp_size = get_tensor_model_parallel_world_size()
if USE_XFORMERS_OPS and num_heads % tp_size == 0:
self.self_attn = CLIPParallelAttention(config,
quant_config=quant_config)
else:
self.self_attn = CLIPSdpaAttention(config)
self.layer_norm1 = nn.LayerNorm(config.hidden_size, self.layer_norm1 = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
self.mlp = CLIPMLP(config, quant_config=quant_config) self.mlp = CLIPMLP(config, quant_config=quant_config)
...@@ -278,7 +290,7 @@ class CLIPEncoderLayer(nn.Module): ...@@ -278,7 +290,7 @@ class CLIPEncoderLayer(nn.Module):
residual = hidden_states residual = hidden_states
hidden_states = self.layer_norm1(hidden_states) hidden_states = self.layer_norm1(hidden_states)
hidden_states = self.self_attn(hidden_states=hidden_states) hidden_states, _ = self.self_attn(hidden_states=hidden_states)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
residual = hidden_states residual = hidden_states
...@@ -365,6 +377,10 @@ class CLIPVisionModel(nn.Module): ...@@ -365,6 +377,10 @@ class CLIPVisionModel(nn.Module):
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
num_hidden_layers_override: Optional[int] = None): num_hidden_layers_override: Optional[int] = None):
super().__init__() super().__init__()
tp_size = get_tensor_model_parallel_world_size()
num_heads = config.num_attention_heads
self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0
self.vision_model = CLIPVisionTransformer( self.vision_model = CLIPVisionTransformer(
config=config, config=config,
quant_config=quant_config, quant_config=quant_config,
...@@ -386,7 +402,7 @@ class CLIPVisionModel(nn.Module): ...@@ -386,7 +402,7 @@ class CLIPVisionModel(nn.Module):
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"), ("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"), ("qkv_proj", "v_proj", "v"),
] ] if self.shard_weight else []
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
layer_count = len(self.vision_model.encoder.layers) layer_count = len(self.vision_model.encoder.layers)
......
...@@ -10,7 +10,6 @@ import torch ...@@ -10,7 +10,6 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from transformers import PretrainedConfig from transformers import PretrainedConfig
from xformers import ops as xops
from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
...@@ -21,6 +20,12 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ...@@ -21,6 +20,12 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
try:
from xformers import ops as xops
USE_XFORMERS_OPS = True
except ImportError:
USE_XFORMERS_OPS = False
NORM2FN = { NORM2FN = {
'rms_norm': RMSNorm, 'rms_norm': RMSNorm,
'layer_norm': nn.LayerNorm, 'layer_norm': nn.LayerNorm,
...@@ -81,7 +86,7 @@ class InternVisionEmbeddings(nn.Module): ...@@ -81,7 +86,7 @@ class InternVisionEmbeddings(nn.Module):
return embeddings return embeddings
class InternAttention(nn.Module): class InternParallelAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper""" """Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__( def __init__(
...@@ -140,18 +145,67 @@ class InternAttention(nn.Module): ...@@ -140,18 +145,67 @@ class InternAttention(nn.Module):
k = self.k_norm.forward_native(k.flatten(-2, k = self.k_norm.forward_native(k.flatten(-2,
-1)).view(B_, N_, H_, D_) -1)).view(B_, N_, H_, D_)
x = xops.memory_efficient_attention_forward( x = xops.memory_efficient_attention_forward(q, k, v, scale=self.scale)
q,
k,
v,
scale=self.scale,
)
x = x.view(B, N, -1) x = x.view(B, N, -1)
x, _ = self.proj(x) x, _ = self.proj(x)
return x return x
class InternSdpaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: PretrainedConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f'embed_dim must be divisible by num_heads '
f'(got `embed_dim`: {self.embed_dim} and `num_heads`:'
f' {self.num_heads}).')
self.scale = self.head_dim**-0.5
self.qkv = nn.Linear(self.embed_dim,
3 * self.embed_dim,
bias=config.qkv_bias)
self.qk_normalization = config.qk_normalization
if self.qk_normalization:
self.q_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
self.k_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
self.proj = nn.Linear(self.embed_dim, self.embed_dim)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x)
q, k, v = qkv.chunk(3, dim=-1)
q = q.view(B, N, self.num_heads, self.head_dim)
k = k.view(B, N, self.num_heads, self.head_dim)
v = v.view(B, N, self.num_heads, self.head_dim)
if self.qk_normalization:
B_, N_, H_, D_ = q.shape
q = self.q_norm.forward_native(q.flatten(-2,
-1)).view(B_, N_, H_, D_)
k = self.k_norm.forward_native(k.flatten(-2,
-1)).view(B_, N_, H_, D_)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
x = F.scaled_dot_product_attention(q, k, v, scale=self.scale)
x = x.transpose(1, 2).view(B, N, -1)
x = self.proj(x)
return x
class InternMLP(nn.Module): class InternMLP(nn.Module):
def __init__(self, def __init__(self,
...@@ -187,7 +241,14 @@ class InternVisionEncoderLayer(nn.Module): ...@@ -187,7 +241,14 @@ class InternVisionEncoderLayer(nn.Module):
self.intermediate_size = config.intermediate_size self.intermediate_size = config.intermediate_size
self.norm_type = config.norm_type self.norm_type = config.norm_type
self.attn = InternAttention(config, quant_config=quant_config) # fallback to sdpa attention if tp unavailable
tp_size = get_tensor_model_parallel_world_size()
num_heads = config.num_attention_heads
if USE_XFORMERS_OPS and num_heads % tp_size == 0:
self.attn = InternParallelAttention(config,
quant_config=quant_config)
else:
self.attn = InternSdpaAttention(config)
self.mlp = InternMLP(config, quant_config=quant_config) self.mlp = InternMLP(config, quant_config=quant_config)
self.norm1 = NORM2FN[self.norm_type](self.embed_dim, self.norm1 = NORM2FN[self.norm_type](self.embed_dim,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
......
...@@ -307,26 +307,30 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -307,26 +307,30 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal):
if key_to_modify in name: if key_to_modify in name:
name = name.replace(key_to_modify, new_key) name = name.replace(key_to_modify, new_key)
use_default_weight_loading = False use_default_weight_loading = False
for (param_name, shard_name, shard_id) in stacked_params_mapping: if "vision" not in name or self.vision_tower.shard_weight:
if shard_name not in name: for (param_name, shard_name,
continue shard_id) in stacked_params_mapping:
name = name.replace(shard_name, param_name) if shard_name not in name:
# Skip loading extra bias for GPTQ models. continue
if name.endswith(".bias") and name not in params_dict: name = name.replace(shard_name, param_name)
continue # Skip loading extra bias for GPTQ models.
param = params_dict[name] if name.endswith(".bias") and name not in params_dict:
weight_loader = param.weight_loader continue
weight_loader(param, loaded_weight, shard_id) param = params_dict[name]
break weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# lm_head is not used in vllm as it is tied with
# embed_token. To prevent errors, skip loading
# lm_head.weight.
if "lm_head.weight" in name:
continue
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
use_default_weight_loading = True
else: else:
# lm_head is not used in vllm as it is tied with
# embed_token. To prevent errors, skip loading
# lm_head.weight.
if "lm_head.weight" in name:
continue
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
use_default_weight_loading = True use_default_weight_loading = True
if use_default_weight_loading: if use_default_weight_loading:
......
...@@ -9,7 +9,7 @@ import torch ...@@ -9,7 +9,7 @@ import torch
from PIL import Image from PIL import Image
from torch import nn from torch import nn
from transformers import SiglipVisionConfig from transformers import SiglipVisionConfig
from xformers import ops as xops from transformers.models.siglip.modeling_siglip import SiglipSdpaAttention
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.distributed import divide, get_tensor_model_parallel_world_size
...@@ -26,6 +26,12 @@ from vllm.multimodal.utils import (cached_get_tokenizer, ...@@ -26,6 +26,12 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens) repeat_and_pad_placeholder_tokens)
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
try:
from xformers import ops as xops
USE_XFORMERS_OPS = True
except ImportError:
USE_XFORMERS_OPS = False
def get_siglip_patch_grid_length(*, image_size: int, patch_size: int) -> int: def get_siglip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
# Since interpolation is applied, the image size need not be divisible # Since interpolation is applied, the image size need not be divisible
...@@ -219,7 +225,7 @@ class SiglipVisionEmbeddings(nn.Module): ...@@ -219,7 +225,7 @@ class SiglipVisionEmbeddings(nn.Module):
return embeddings return embeddings
class SiglipAttention(nn.Module): class SiglipParallelAttention(nn.Module):
def __init__( def __init__(
self, self,
...@@ -282,7 +288,7 @@ class SiglipAttention(nn.Module): ...@@ -282,7 +288,7 @@ class SiglipAttention(nn.Module):
out = out.view(batch_size, q_len, -1) out = out.view(batch_size, q_len, -1)
attn_output, _ = self.out_proj(out) attn_output, _ = self.out_proj(out)
return attn_output return attn_output, None
class SiglipMLP(nn.Module): class SiglipMLP(nn.Module):
...@@ -327,7 +333,14 @@ class SiglipEncoderLayer(nn.Module): ...@@ -327,7 +333,14 @@ class SiglipEncoderLayer(nn.Module):
super().__init__() super().__init__()
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
self.self_attn = SiglipAttention(config, quant_config=quant_config) num_heads = config.num_attention_heads
tp_size = get_tensor_model_parallel_world_size()
if USE_XFORMERS_OPS and num_heads % tp_size == 0:
self.self_attn = SiglipParallelAttention(config,
quant_config=quant_config)
else:
self.self_attn = SiglipSdpaAttention(config)
self.layer_norm1 = nn.LayerNorm(self.embed_dim, self.layer_norm1 = nn.LayerNorm(self.embed_dim,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
self.mlp = SiglipMLP( self.mlp = SiglipMLP(
...@@ -344,7 +357,7 @@ class SiglipEncoderLayer(nn.Module): ...@@ -344,7 +357,7 @@ class SiglipEncoderLayer(nn.Module):
residual = hidden_states residual = hidden_states
hidden_states = self.layer_norm1(hidden_states) hidden_states = self.layer_norm1(hidden_states)
hidden_states = self.self_attn(hidden_states=hidden_states) hidden_states, _ = self.self_attn(hidden_states=hidden_states)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
residual = hidden_states residual = hidden_states
...@@ -476,6 +489,10 @@ class SiglipVisionModel(nn.Module): ...@@ -476,6 +489,10 @@ class SiglipVisionModel(nn.Module):
num_hidden_layers_override: Optional[int] = None, num_hidden_layers_override: Optional[int] = None,
): ):
super().__init__() super().__init__()
num_heads = config.num_attention_heads
tp_size = get_tensor_model_parallel_world_size()
self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0
self.vision_model = SiglipVisionTransformer( self.vision_model = SiglipVisionTransformer(
config, config,
quant_config, quant_config,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment