Unverified Commit e7ebb662 authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[Model] Remove transformers attention porting in VITs (#10414)


Signed-off-by: default avatarIsotr0py <2037008807@qq.com>
parent 5be4e52b
...@@ -4,10 +4,11 @@ from typing import Iterable, Optional, Set, Tuple, Union ...@@ -4,10 +4,11 @@ from typing import Iterable, Optional, Set, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from PIL import Image from PIL import Image
from transformers import Blip2VisionConfig, BlipVisionConfig from transformers import Blip2VisionConfig, BlipVisionConfig
from transformers.models.blip.modeling_blip import BlipAttention
from vllm.attention.selector import _Backend
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.inputs import DecoderOnlyInputs, token_inputs from vllm.inputs import DecoderOnlyInputs, token_inputs
...@@ -21,11 +22,7 @@ from vllm.multimodal.utils import (cached_get_tokenizer, ...@@ -21,11 +22,7 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens) repeat_and_pad_placeholder_tokens)
from vllm.sequence import SequenceData from vllm.sequence import SequenceData
try: from .utils import get_vit_attn_backend
from xformers import ops as xops
USE_XFORMERS_OPS = True
except ImportError:
USE_XFORMERS_OPS = False
def get_blip_patch_grid_length(*, image_size: int, patch_size: int) -> int: def get_blip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
...@@ -168,7 +165,7 @@ class BlipVisionEmbeddings(nn.Module): ...@@ -168,7 +165,7 @@ class BlipVisionEmbeddings(nn.Module):
return embeddings return embeddings
class BlipParallelAttention(nn.Module): class BlipAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper""" """Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__( def __init__(
...@@ -208,6 +205,12 @@ class BlipParallelAttention(nn.Module): ...@@ -208,6 +205,12 @@ class BlipParallelAttention(nn.Module):
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size) self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
# Detect attention implementation.
self.attn_backend = get_vit_attn_backend(support_fa=False)
if self.attn_backend not in {_Backend.TORCH_SDPA, _Backend.XFORMERS}:
raise RuntimeError(
f"BLIP does not support {self.attn_backend} backend now.")
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, return tensor.view(bsz, seq_len, self.num_heads,
self.head_dim).transpose(1, 2).contiguous() self.head_dim).transpose(1, 2).contiguous()
...@@ -231,11 +234,26 @@ class BlipParallelAttention(nn.Module): ...@@ -231,11 +234,26 @@ class BlipParallelAttention(nn.Module):
self.num_heads_per_partition, self.num_heads_per_partition,
self.head_dim) self.head_dim)
if self.attn_backend == _Backend.XFORMERS:
from xformers import ops as xops
out = xops.memory_efficient_attention_forward(query_states, out = xops.memory_efficient_attention_forward(query_states,
key_states, key_states,
value_states, value_states,
p=self.dropout, p=self.dropout,
scale=self.scale) scale=self.scale)
elif self.attn_backend == _Backend.TORCH_SDPA:
query_states, key_states, value_states = (x.transpose(1, 2)
for x in (query_states,
key_states,
value_states))
out = F.scaled_dot_product_attention(query_states,
key_states,
value_states,
dropout_p=self.dropout,
scale=self.scale)
out = out.transpose(1, 2)
out = out.view(bsz, tgt_len, -1) out = out.view(bsz, tgt_len, -1)
attn_output, _ = self.projection(out) attn_output, _ = self.projection(out)
...@@ -285,18 +303,11 @@ class BlipEncoderLayer(nn.Module): ...@@ -285,18 +303,11 @@ class BlipEncoderLayer(nn.Module):
super().__init__() super().__init__()
# fallback to sdpa attention if tp unavailable # fallback to sdpa attention if tp unavailable
num_heads = config.num_attention_heads self.self_attn = BlipAttention(
tp_size = get_tensor_model_parallel_world_size()
if USE_XFORMERS_OPS and num_heads % tp_size == 0:
self.self_attn = BlipParallelAttention(
config, config,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.self_attn", prefix=f"{prefix}.self_attn",
) )
else:
# Blip doesn't have SDPA attention implemented in transformers
# use eager attention instead for cpu backend
self.self_attn = BlipAttention(config)
self.layer_norm1 = nn.LayerNorm(config.hidden_size, self.layer_norm1 = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
self.mlp = BlipMLP(config, self.mlp = BlipMLP(config,
...@@ -374,11 +385,6 @@ class BlipVisionModel(nn.Module): ...@@ -374,11 +385,6 @@ class BlipVisionModel(nn.Module):
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
tp_size = get_tensor_model_parallel_world_size()
num_heads = config.num_attention_heads
self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0
self.config = config self.config = config
self.embeddings = BlipVisionEmbeddings(config) self.embeddings = BlipVisionEmbeddings(config)
...@@ -422,7 +428,7 @@ class BlipVisionModel(nn.Module): ...@@ -422,7 +428,7 @@ class BlipVisionModel(nn.Module):
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"), ("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"), ("qkv_proj", "v_proj", "v"),
] if self.shard_weight else [] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: Set[str] = set()
layer_count = len(self.encoder.layers) layer_count = len(self.encoder.layers)
......
...@@ -5,10 +5,11 @@ from typing import Iterable, List, Optional, Set, Tuple, Union ...@@ -5,10 +5,11 @@ from typing import Iterable, List, Optional, Set, Tuple, Union
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from PIL import Image from PIL import Image
from transformers import CLIPVisionConfig from transformers import CLIPVisionConfig
from transformers.models.clip.modeling_clip import CLIPSdpaAttention
from vllm.attention.selector import _Backend
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.inputs import DecoderOnlyInputs, token_inputs from vllm.inputs import DecoderOnlyInputs, token_inputs
...@@ -23,11 +24,7 @@ from vllm.multimodal.utils import (cached_get_tokenizer, ...@@ -23,11 +24,7 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens) repeat_and_pad_placeholder_tokens)
from vllm.sequence import SequenceData from vllm.sequence import SequenceData
try: from .utils import get_vit_attn_backend
from xformers import ops as xops
USE_XFORMERS_OPS = True
except ImportError:
USE_XFORMERS_OPS = False
def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int: def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
...@@ -197,7 +194,7 @@ class CLIPVisionEmbeddings(nn.Module): ...@@ -197,7 +194,7 @@ class CLIPVisionEmbeddings(nn.Module):
return embeddings return embeddings
class CLIPParallelAttention(nn.Module): class CLIPAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper""" """Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__( def __init__(
...@@ -237,6 +234,12 @@ class CLIPParallelAttention(nn.Module): ...@@ -237,6 +234,12 @@ class CLIPParallelAttention(nn.Module):
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size) self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
# Detect attention implementation.
self.attn_backend = get_vit_attn_backend(support_fa=False)
if self.attn_backend not in {_Backend.TORCH_SDPA, _Backend.XFORMERS}:
raise RuntimeError(
f"CLIP does not support {self.attn_backend} backend now.")
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, return tensor.view(bsz, seq_len, self.num_heads,
self.head_dim).transpose(1, 2).contiguous() self.head_dim).transpose(1, 2).contiguous()
...@@ -261,11 +264,26 @@ class CLIPParallelAttention(nn.Module): ...@@ -261,11 +264,26 @@ class CLIPParallelAttention(nn.Module):
self.num_heads_per_partition, self.num_heads_per_partition,
self.head_dim) self.head_dim)
if self.attn_backend == _Backend.XFORMERS:
from xformers import ops as xops
out = xops.memory_efficient_attention_forward(query_states, out = xops.memory_efficient_attention_forward(query_states,
key_states, key_states,
value_states, value_states,
p=self.dropout, p=self.dropout,
scale=self.scale) scale=self.scale)
elif self.attn_backend == _Backend.TORCH_SDPA:
query_states, key_states, value_states = (x.transpose(1, 2)
for x in (query_states,
key_states,
value_states))
out = F.scaled_dot_product_attention(query_states,
key_states,
value_states,
dropout_p=self.dropout,
scale=self.scale)
out = out.transpose(1, 2)
out = out.view(bsz, tgt_len, -1) out = out.view(bsz, tgt_len, -1)
attn_output, _ = self.out_proj(out) attn_output, _ = self.out_proj(out)
...@@ -311,17 +329,11 @@ class CLIPEncoderLayer(nn.Module): ...@@ -311,17 +329,11 @@ class CLIPEncoderLayer(nn.Module):
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.self_attn = CLIPAttention(
num_heads = config.num_attention_heads
tp_size = get_tensor_model_parallel_world_size()
if USE_XFORMERS_OPS and num_heads % tp_size == 0:
self.self_attn = CLIPParallelAttention(
config, config,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.self_attn", prefix=f"{prefix}.self_attn",
) )
else:
self.self_attn = CLIPSdpaAttention(config)
self.layer_norm1 = nn.LayerNorm(config.hidden_size, self.layer_norm1 = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
self.mlp = CLIPMLP(config, self.mlp = CLIPMLP(config,
...@@ -461,11 +473,6 @@ class CLIPVisionModel(nn.Module): ...@@ -461,11 +473,6 @@ class CLIPVisionModel(nn.Module):
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
tp_size = get_tensor_model_parallel_world_size()
num_heads = config.num_attention_heads
self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0
self.vision_model = CLIPVisionTransformer( self.vision_model = CLIPVisionTransformer(
config=config, config=config,
quant_config=quant_config, quant_config=quant_config,
...@@ -490,7 +497,7 @@ class CLIPVisionModel(nn.Module): ...@@ -490,7 +497,7 @@ class CLIPVisionModel(nn.Module):
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"), ("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"), ("qkv_proj", "v_proj", "v"),
] if self.shard_weight else [] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: Set[str] = set()
layer_count = len(self.vision_model.encoder.layers) layer_count = len(self.vision_model.encoder.layers)
......
...@@ -12,6 +12,7 @@ import torch.nn as nn ...@@ -12,6 +12,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention.selector import _Backend
from vllm.distributed import (divide, get_tensor_model_parallel_rank, from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
split_tensor_along_last_dim, split_tensor_along_last_dim,
...@@ -24,11 +25,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ...@@ -24,11 +25,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
try: from .utils import get_vit_attn_backend
from xformers import ops as xops
USE_XFORMERS_OPS = True
except ImportError:
USE_XFORMERS_OPS = False
NORM2FN = { NORM2FN = {
'rms_norm': RMSNorm, 'rms_norm': RMSNorm,
...@@ -186,6 +183,11 @@ class InternParallelAttention(nn.Module): ...@@ -186,6 +183,11 @@ class InternParallelAttention(nn.Module):
prefix=f"{prefix}.proj", prefix=f"{prefix}.proj",
) )
self.attn_backend = get_vit_attn_backend(support_fa=False)
if self.attn_backend not in {_Backend.TORCH_SDPA, _Backend.XFORMERS}:
raise RuntimeError(
f"InternViT does not support {self.attn_backend} backend now.")
def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor): def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor):
if self.tp_size > 1: if self.tp_size > 1:
q = tensor_model_parallel_all_gather(q.contiguous()) q = tensor_model_parallel_all_gather(q.contiguous())
...@@ -211,11 +213,21 @@ class InternParallelAttention(nn.Module): ...@@ -211,11 +213,21 @@ class InternParallelAttention(nn.Module):
k = k.view(B, N, self.num_heads_per_partition, self.head_dim) k = k.view(B, N, self.num_heads_per_partition, self.head_dim)
v = v.view(B, N, self.num_heads_per_partition, self.head_dim) v = v.view(B, N, self.num_heads_per_partition, self.head_dim)
x = xops.memory_efficient_attention_forward(q, k, v, scale=self.scale) if self.attn_backend == _Backend.XFORMERS:
x = x.view(B, N, -1) from xformers import ops as xops
x, _ = self.proj(x) out = xops.memory_efficient_attention_forward(q,
return x k,
v,
scale=self.scale)
elif self.attn_backend == _Backend.TORCH_SDPA:
q, k, v = (x.transpose(1, 2) for x in (q, k, v))
out = F.scaled_dot_product_attention(q, k, v, scale=self.scale)
out = out.transpose(1, 2)
out = out.view(B, N, -1)
out, _ = self.proj(out)
return out
class InternSdpaAttention(nn.Module): class InternSdpaAttention(nn.Module):
...@@ -362,7 +374,7 @@ class InternVisionEncoderLayer(nn.Module): ...@@ -362,7 +374,7 @@ class InternVisionEncoderLayer(nn.Module):
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
num_heads = config.num_attention_heads num_heads = config.num_attention_heads
if USE_XFORMERS_OPS and (num_heads + num_dummy_heads) % tp_size == 0: if (num_heads + num_dummy_heads) % tp_size == 0:
return InternParallelAttention(config, return InternParallelAttention(config,
quant_config=quant_config, quant_config=quant_config,
num_dummy_heads=num_dummy_heads, num_dummy_heads=num_dummy_heads,
......
...@@ -187,7 +187,7 @@ class MultiHeadDotProductAttention(nn.Module): ...@@ -187,7 +187,7 @@ class MultiHeadDotProductAttention(nn.Module):
) )
# Detect attention implementation. # Detect attention implementation.
self.attn_backend: _Backend = get_vit_attn_backend() self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
if self.attn_backend not in { if self.attn_backend not in {
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS
}: }:
......
...@@ -260,7 +260,7 @@ class Qwen2VisionAttention(nn.Module): ...@@ -260,7 +260,7 @@ class Qwen2VisionAttention(nn.Module):
prefix=f"{prefix}.proj") prefix=f"{prefix}.proj")
# Detect attention implementation. # Detect attention implementation.
self.attn_backend: _Backend = get_vit_attn_backend() self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
if self.attn_backend not in { if self.attn_backend not in {
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS
}: }:
......
...@@ -6,11 +6,12 @@ from typing import Iterable, List, Optional, Set, Tuple, Union ...@@ -6,11 +6,12 @@ from typing import Iterable, List, Optional, Set, Tuple, Union
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F
from PIL import Image from PIL import Image
from torch import nn from torch import nn
from transformers import SiglipVisionConfig from transformers import SiglipVisionConfig
from transformers.models.siglip.modeling_siglip import SiglipSdpaAttention
from vllm.attention.selector import _Backend
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.inputs import DecoderOnlyInputs, token_inputs from vllm.inputs import DecoderOnlyInputs, token_inputs
...@@ -27,11 +28,7 @@ from vllm.multimodal.utils import (cached_get_tokenizer, ...@@ -27,11 +28,7 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens) repeat_and_pad_placeholder_tokens)
from vllm.sequence import SequenceData from vllm.sequence import SequenceData
try: from .utils import get_vit_attn_backend
from xformers import ops as xops
USE_XFORMERS_OPS = True
except ImportError:
USE_XFORMERS_OPS = False
def get_siglip_patch_grid_length(*, image_size: int, patch_size: int) -> int: def get_siglip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
...@@ -254,7 +251,7 @@ class SiglipVisionEmbeddings(nn.Module): ...@@ -254,7 +251,7 @@ class SiglipVisionEmbeddings(nn.Module):
return embeddings return embeddings
class SiglipParallelAttention(nn.Module): class SiglipAttention(nn.Module):
def __init__( def __init__(
self, self,
...@@ -293,6 +290,11 @@ class SiglipParallelAttention(nn.Module): ...@@ -293,6 +290,11 @@ class SiglipParallelAttention(nn.Module):
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size) self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
self.attn_backend = get_vit_attn_backend(support_fa=False)
if self.attn_backend not in {_Backend.TORCH_SDPA, _Backend.XFORMERS}:
raise RuntimeError(
f"SIGLIP does not support {self.attn_backend} backend now.")
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -313,11 +315,26 @@ class SiglipParallelAttention(nn.Module): ...@@ -313,11 +315,26 @@ class SiglipParallelAttention(nn.Module):
self.num_heads_per_partition, self.num_heads_per_partition,
self.head_dim) self.head_dim)
if self.attn_backend == _Backend.XFORMERS:
from xformers import ops as xops
out = xops.memory_efficient_attention_forward(query_states, out = xops.memory_efficient_attention_forward(query_states,
key_states, key_states,
value_states, value_states,
p=self.dropout, p=self.dropout,
scale=self.scale) scale=self.scale)
elif self.attn_backend == _Backend.TORCH_SDPA:
query_states, key_states, value_states = (x.transpose(1, 2)
for x in (query_states,
key_states,
value_states))
out = F.scaled_dot_product_attention(query_states,
key_states,
value_states,
dropout_p=self.dropout,
scale=self.scale)
out = out.transpose(1, 2)
out = out.view(batch_size, q_len, -1) out = out.view(batch_size, q_len, -1)
attn_output, _ = self.out_proj(out) attn_output, _ = self.out_proj(out)
...@@ -372,17 +389,11 @@ class SiglipEncoderLayer(nn.Module): ...@@ -372,17 +389,11 @@ class SiglipEncoderLayer(nn.Module):
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
num_heads = config.num_attention_heads self.self_attn = SiglipAttention(
tp_size = get_tensor_model_parallel_world_size()
if USE_XFORMERS_OPS and num_heads % tp_size == 0:
self.self_attn = SiglipParallelAttention(
config, config,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.self_attn", prefix=f"{prefix}.self_attn",
) )
else:
self.self_attn = SiglipSdpaAttention(config)
self.layer_norm1 = nn.LayerNorm(self.embed_dim, self.layer_norm1 = nn.LayerNorm(self.embed_dim,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
self.mlp = SiglipMLP( self.mlp = SiglipMLP(
...@@ -569,10 +580,6 @@ class SiglipVisionModel(nn.Module): ...@@ -569,10 +580,6 @@ class SiglipVisionModel(nn.Module):
) -> None: ) -> None:
super().__init__() super().__init__()
num_heads = config.num_attention_heads
tp_size = get_tensor_model_parallel_world_size()
self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0
self.vision_model = SiglipVisionTransformer( self.vision_model = SiglipVisionTransformer(
config, config,
quant_config, quant_config,
...@@ -601,7 +608,7 @@ class SiglipVisionModel(nn.Module): ...@@ -601,7 +608,7 @@ class SiglipVisionModel(nn.Module):
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"), ("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"), ("qkv_proj", "v_proj", "v"),
] if self.shard_weight else [] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: Set[str] = set()
layer_count = len(self.vision_model.encoder.layers) layer_count = len(self.vision_model.encoder.layers)
......
...@@ -587,7 +587,11 @@ class LLMWrapper(nn.Module): ...@@ -587,7 +587,11 @@ class LLMWrapper(nn.Module):
return llm(*args, **kwargs) return llm(*args, **kwargs)
def get_vit_attn_backend() -> _Backend: def get_vit_attn_backend(support_fa: bool = False) -> _Backend:
"""
Get the available attention backend for Vision Transformer.
"""
# TODO(Isotr0py): Remove `support_fa` after support FA for all ViTs attn.
selected_backend: Optional[_Backend] = get_global_forced_attn_backend() selected_backend: Optional[_Backend] = get_global_forced_attn_backend()
if selected_backend is None: if selected_backend is None:
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
...@@ -596,7 +600,7 @@ def get_vit_attn_backend() -> _Backend: ...@@ -596,7 +600,7 @@ def get_vit_attn_backend() -> _Backend:
if selected_backend is None: if selected_backend is None:
# For Volta and Turing GPUs, use xformers instead. # For Volta and Turing GPUs, use xformers instead.
device_available = current_platform.has_device_capability(80) device_available = current_platform.has_device_capability(80)
if device_available: if device_available and support_fa:
from transformers.utils import is_flash_attn_2_available from transformers.utils import is_flash_attn_2_available
if is_flash_attn_2_available(): if is_flash_attn_2_available():
selected_backend = _Backend.FLASH_ATTN selected_backend = _Backend.FLASH_ATTN
...@@ -606,7 +610,8 @@ def get_vit_attn_backend() -> _Backend: ...@@ -606,7 +610,8 @@ def get_vit_attn_backend() -> _Backend:
"so we use xformers backend instead. You can run " "so we use xformers backend instead. You can run "
"`pip install flash-attn` to use flash-attention backend.") "`pip install flash-attn` to use flash-attention backend.")
selected_backend = _Backend.XFORMERS selected_backend = _Backend.XFORMERS
elif current_platform.is_cpu(): elif current_platform.is_cpu() or current_platform.is_rocm():
# ROCM doesn't support xformers
selected_backend = _Backend.TORCH_SDPA selected_backend = _Backend.TORCH_SDPA
else: else:
selected_backend = _Backend.XFORMERS selected_backend = _Backend.XFORMERS
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment