Unverified Commit 0fca3cdc authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[Misc] Enhance attention selector (#4751)

parent e7c46b95
......@@ -307,7 +307,6 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
assert len(attn_metadata.slot_mapping) == len(input_tokens)
assert len(input_positions) == len(input_tokens)
assert attn_metadata.kv_cache_dtype == "auto"
assert attn_metadata.num_prefills == prefill_batch_size
if enforce_eager:
assert attn_metadata.num_decode_tokens == decode_batch_size
......
......@@ -5,9 +5,9 @@ from vllm.attention.layer import Attention
from vllm.attention.selector import get_attn_backend
__all__ = [
"Attention",
"AttentionBackend",
"AttentionMetadata",
"Attention",
"get_attn_backend",
"AttentionMetadataPerStage",
"get_attn_backend",
]
......@@ -94,8 +94,6 @@ class AttentionMetadata(Generic[T]):
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
# in block 0, and 1st slot in block 1, respectively.
slot_mapping: torch.Tensor
# The kv cache's data type.
kv_cache_dtype: str
def __post_init__(self):
if self.num_prefill_tokens > 0:
......@@ -116,6 +114,7 @@ class AttentionImpl(ABC):
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None,
kv_cache_dtype: str = "auto",
) -> None:
raise NotImplementedError
......@@ -127,6 +126,6 @@ class AttentionImpl(ABC):
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
kv_scale: float,
kv_scale: float = 1.0,
) -> torch.Tensor:
raise NotImplementedError
......@@ -140,16 +140,18 @@ class FlashAttentionImpl(AttentionImpl):
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None,
kv_cache_dtype: str = "auto",
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.sliding_window = ((sliding_window, sliding_window)
if sliding_window is not None else (-1, -1))
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
self.sliding_window = ((sliding_window, sliding_window)
if sliding_window is not None else (-1, -1))
self.kv_cache_dtype = kv_cache_dtype
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
......@@ -167,7 +169,7 @@ class FlashAttentionImpl(AttentionImpl):
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata[FlashAttentionMetadata],
kv_scale: float,
kv_scale: float = 1.0,
) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention.
......@@ -196,8 +198,7 @@ class FlashAttentionImpl(AttentionImpl):
PagedAttention.write_to_paged_cache(key, value, key_cache,
value_cache,
attn_metadata.slot_mapping,
attn_metadata.kv_cache_dtype,
kv_scale)
self.kv_cache_dtype, kv_scale)
num_prefill_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
......@@ -264,7 +265,7 @@ class FlashAttentionImpl(AttentionImpl):
decode_meta.block_tables,
decode_meta.seq_lens_tensor,
decode_meta.max_seq_len,
attn_metadata.kv_cache_dtype,
self.kv_cache_dtype,
self.num_kv_heads,
self.scale,
self.alibi_slopes,
......
......@@ -149,20 +149,33 @@ class FlashInferImpl(AttentionImpl):
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None,
kv_cache_dtype: str = "auto",
) -> None:
if sliding_window is not None:
raise ValueError("Sliding window is not supported in FlashInfer.")
self.sliding_window = (-1, -1)
self.alibi_slopes = alibi_slopes
self.scale = scale
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
if sliding_window is not None:
raise ValueError("Sliding window is not supported in FlashInfer.")
self.sliding_window = (-1, -1)
self.kv_cache_dtype = kv_cache_dtype
def forward(self, query: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, kv_cache: Optional[torch.Tensor],
attn_metadata: AttentionMetadata[FlashInferMetadata],
kv_scale: float):
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: Optional[torch.Tensor],
attn_metadata: AttentionMetadata[FlashInferMetadata],
kv_scale: float = 1.0,
) -> torch.Tensor:
assert kv_scale == 1.0
num_tokens, hidden_size = query.shape
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
......@@ -183,7 +196,7 @@ class FlashInferImpl(AttentionImpl):
kv_cache[:, 0],
kv_cache[:, 1],
attn_metadata.slot_mapping.flatten(),
attn_metadata.kv_cache_dtype,
self.kv_cache_dtype,
)
if prefill_meta := attn_metadata.prefill_metadata:
......
......@@ -138,25 +138,27 @@ class ROCmFlashAttentionImpl(AttentionImpl):
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None,
kv_cache_dtype: str = "auto",
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.sliding_window = ((sliding_window, sliding_window)
if sliding_window is not None else (-1, -1))
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
self.sliding_window = ((sliding_window, sliding_window)
if sliding_window is not None else (-1, -1))
self.kv_cache_dtype = kv_cache_dtype
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
suppored_head_sizes = PagedAttention.get_supported_head_sizes()
if head_size not in suppored_head_sizes:
supported_head_sizes = PagedAttention.get_supported_head_sizes()
if head_size not in supported_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by PagedAttention. "
f"Supported head sizes are: {suppored_head_sizes}.")
f"Supported head sizes are: {supported_head_sizes}.")
self.use_naive_attn = False
# NOTE: Allow for switching between Triton and CK. Defaulting to triton.
......@@ -229,7 +231,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
key_cache,
value_cache,
attn_metadata.slot_mapping,
attn_metadata.kv_cache_dtype,
self.kv_cache_dtype,
kv_scale,
)
......@@ -323,7 +325,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
decode_meta.block_tables,
decode_meta.seq_lens_tensor,
decode_meta.max_seq_len,
attn_metadata.kv_cache_dtype,
self.kv_cache_dtype,
self.num_kv_heads,
self.scale,
self.alibi_slopes,
......
......@@ -83,26 +83,32 @@ class TorchSDPABackendImpl(AttentionImpl):
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None,
kv_cache_dtype: str = "auto",
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.sliding_window = sliding_window
if alibi_slopes is not None:
assert len(alibi_slopes) == num_heads
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
self.need_mask = (self.alibi_slopes is not None
or self.sliding_window is not None)
self.sliding_window = sliding_window
self.kv_cache_dtype = kv_cache_dtype
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
suppored_head_sizes = PagedAttention.get_supported_head_sizes()
if head_size not in suppored_head_sizes:
self.need_mask = (self.alibi_slopes is not None
or self.sliding_window is not None)
supported_head_sizes = PagedAttention.get_supported_head_sizes()
if head_size not in supported_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by PagedAttention. "
f"Supported head sizes are: {suppored_head_sizes}.")
f"Supported head sizes are: {supported_head_sizes}.")
if kv_cache_dtype != "auto":
raise NotImplementedError(
"Torch SDPA backend does not support FP8 KV cache. "
"Please use xFormers backend instead.")
def forward(
self,
......@@ -111,7 +117,7 @@ class TorchSDPABackendImpl(AttentionImpl):
value: torch.Tensor,
kv_cache: Optional[torch.Tensor],
attn_metadata: TorchSDPAMetadata, # type: ignore
kv_scale: float,
kv_scale: float = 1.0,
) -> torch.Tensor:
"""Forward pass with torch SDPA and PagedAttention.
......@@ -124,6 +130,7 @@ class TorchSDPABackendImpl(AttentionImpl):
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert kv_scale == 1.0
num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)
......@@ -136,8 +143,7 @@ class TorchSDPABackendImpl(AttentionImpl):
PagedAttention.write_to_paged_cache(key, value, key_cache,
value_cache,
attn_metadata.slot_mapping,
attn_metadata.kv_cache_dtype,
kv_scale)
self.kv_cache_dtype, kv_scale)
if attn_metadata.is_prompt:
assert attn_metadata.seq_lens is not None
......@@ -195,7 +201,7 @@ class TorchSDPABackendImpl(AttentionImpl):
attn_metadata.block_tables,
attn_metadata.seq_lens_tensor,
attn_metadata.max_seq_len,
attn_metadata.kv_cache_dtype,
self.kv_cache_dtype,
self.num_kv_heads,
self.scale,
self.alibi_slopes,
......
......@@ -149,15 +149,17 @@ class XFormersImpl(AttentionImpl):
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None,
kv_cache_dtype: str = "auto",
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.sliding_window = sliding_window
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
self.sliding_window = sliding_window
self.kv_cache_dtype = kv_cache_dtype
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
......@@ -175,7 +177,7 @@ class XFormersImpl(AttentionImpl):
value: torch.Tensor,
kv_cache: Optional[torch.Tensor],
attn_metadata: AttentionMetadata[XFormersMetadata],
kv_scale: float,
kv_scale: float = 1.0,
) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention.
......@@ -188,7 +190,6 @@ class XFormersImpl(AttentionImpl):
Returns:
shape = [num_tokens, num_heads * head_size]
"""
num_tokens, hidden_size = query.shape
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
......@@ -203,8 +204,7 @@ class XFormersImpl(AttentionImpl):
PagedAttention.write_to_paged_cache(key, value, key_cache,
value_cache,
attn_metadata.slot_mapping,
attn_metadata.kv_cache_dtype,
kv_scale)
self.kv_cache_dtype, kv_scale)
num_prefill_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
......@@ -262,7 +262,7 @@ class XFormersImpl(AttentionImpl):
decode_meta.block_tables,
decode_meta.seq_lens_tensor,
decode_meta.max_seq_len,
attn_metadata.kv_cache_dtype,
self.kv_cache_dtype,
self.num_kv_heads,
self.scale,
self.alibi_slopes,
......
......@@ -7,6 +7,7 @@ import torch.nn as nn
from vllm.attention.backends.abstract import (AttentionMetadata,
AttentionMetadataPerStage)
from vllm.attention.selector import get_attn_backend
from vllm.config import CacheConfig
class Attention(nn.Module):
......@@ -29,10 +30,24 @@ class Attention(nn.Module):
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None,
cache_config: Optional[CacheConfig] = None,
) -> None:
super().__init__()
self.backend = get_attn_backend(torch.get_default_dtype())
impl_cls = self.backend.get_impl_cls()
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
else:
kv_cache_dtype = "auto"
block_size = 16
if num_kv_heads is None:
num_kv_heads = num_heads
# During model initialization, the default dtype is set as the model
# weight and activation dtype.
dtype = torch.get_default_dtype()
attn_backend = get_attn_backend(num_heads, head_size, num_kv_heads,
sliding_window, dtype, kv_cache_dtype,
block_size)
impl_cls = attn_backend.get_impl_cls()
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window)
......
import enum
from functools import lru_cache
from typing import Type
from typing import Optional, Type
import torch
......@@ -21,8 +21,18 @@ class _Backend(enum.Enum):
@lru_cache(maxsize=None)
def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]:
backend = _which_attn_to_use(dtype)
def get_attn_backend(
num_heads: int,
head_size: int,
num_kv_heads: int,
sliding_window: Optional[int],
dtype: torch.dtype,
kv_cache_dtype: Optional[str],
block_size: int,
) -> Type[AttentionBackend]:
backend = _which_attn_to_use(num_heads, head_size, num_kv_heads,
sliding_window, dtype, kv_cache_dtype,
block_size)
if backend == _Backend.FLASH_ATTN:
logger.info("Using FlashAttention-2 backend.")
from vllm.attention.backends.flash_attn import ( # noqa: F401
......@@ -44,14 +54,22 @@ def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]:
return TorchSDPABackend
elif backend == _Backend.FLASHINFER:
logger.info("Using Flashinfer backend.")
logger.warning("Eager mode is enforced for the Flashinfer backend. ")
logger.warning("Eager mode is enforced for the Flashinfer backend.")
from vllm.attention.backends.flashinfer import FlashInferBackend
return FlashInferBackend
else:
raise ValueError("Invalid attention backend.")
def _which_attn_to_use(dtype: torch.dtype) -> _Backend:
def _which_attn_to_use(
num_heads: int,
head_size: int,
num_kv_heads: int,
sliding_window: Optional[int],
dtype: torch.dtype,
kv_cache_dtype: Optional[str],
block_size: int,
) -> _Backend:
"""Returns which flash attention backend to use."""
if is_cpu():
return _Backend.TORCH_SDPA
......
......@@ -2,26 +2,29 @@ from typing import Optional
from torch import nn
from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig,
VisionLanguageConfig)
from vllm.model_executor.model_loader.loader import (BaseModelLoader,
get_model_loader)
from vllm.model_executor.model_loader.utils import (
get_architecture_class_name, get_model_architecture)
def get_model(
*, model_config: ModelConfig, load_config: LoadConfig,
device_config: DeviceConfig, parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig, lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig]) -> nn.Module:
def get_model(*, model_config: ModelConfig, load_config: LoadConfig,
device_config: DeviceConfig, parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
cache_config: CacheConfig) -> nn.Module:
loader = get_model_loader(load_config)
return loader.load_model(model_config=model_config,
device_config=device_config,
lora_config=lora_config,
vision_language_config=vision_language_config,
parallel_config=parallel_config,
scheduler_config=scheduler_config)
scheduler_config=scheduler_config,
cache_config=cache_config)
__all__ = [
......
......@@ -9,9 +9,9 @@ import huggingface_hub
import torch
from torch import nn
from vllm.config import (DeviceConfig, LoadConfig, LoadFormat, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig,
VisionLanguageConfig)
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoadFormat,
LoRAConfig, ModelConfig, ParallelConfig,
SchedulerConfig, VisionLanguageConfig)
from vllm.envs import VLLM_USE_MODELSCOPE
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
......@@ -77,15 +77,16 @@ def _get_model_initialization_kwargs(
return extra_kwargs
def _initialize_model(
model_config: ModelConfig, load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig]) -> nn.Module:
def _initialize_model(model_config: ModelConfig, load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
cache_config: CacheConfig) -> nn.Module:
"""Initialize a model with the given configurations."""
model_class = get_model_architecture(model_config)[0]
quant_config = _get_quantization_config(model_config, load_config)
return model_class(config=model_config.hf_config,
cache_config=cache_config,
quant_config=quant_config,
**_get_model_initialization_kwargs(
model_class, lora_config, vision_language_config))
......@@ -103,7 +104,8 @@ class BaseModelLoader(ABC):
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig) -> nn.Module:
scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module:
"""Load a model with the given configurations."""
...
......@@ -216,11 +218,13 @@ class DefaultModelLoader(BaseModelLoader):
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig) -> nn.Module:
scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module:
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config,
lora_config, vision_language_config)
lora_config, vision_language_config,
cache_config)
model.load_weights(
self._get_weights_iterator(model_config.model,
model_config.revision,
......@@ -253,11 +257,13 @@ class DummyModelLoader(BaseModelLoader):
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig) -> nn.Module:
scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module:
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config,
lora_config, vision_language_config)
lora_config, vision_language_config,
cache_config)
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
initialize_dummy_weights(model)
......@@ -286,9 +292,12 @@ class TensorizerLoader(BaseModelLoader):
return tensorizer_weights_iterator(tensorizer_args)
def _load_model_unserialized(
self, model_config: ModelConfig, device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig]
self,
model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
cache_config: CacheConfig,
) -> nn.Module:
"""Load an unserialized model with tensorizer.
......@@ -299,15 +308,19 @@ class TensorizerLoader(BaseModelLoader):
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config,
lora_config, vision_language_config)
lora_config, vision_language_config,
cache_config)
model.load_weights(self._get_weights_iterator())
return model.eval()
def _load_model_serialized(
self, model_config: ModelConfig, device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig]
self,
model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
cache_config: CacheConfig,
) -> nn.Module:
"""Load a serialized model with tensorizer.
......@@ -321,6 +334,7 @@ class TensorizerLoader(BaseModelLoader):
extra_kwargs = _get_model_initialization_kwargs(
model_class, lora_config, vision_language_config)
extra_kwargs["quant_config"] = quant_config
extra_kwargs["cache_config"] = cache_config
tensorizer_config = copy.copy(self.tensorizer_config)
tensorizer_config.model_class = model_class
......@@ -335,16 +349,19 @@ class TensorizerLoader(BaseModelLoader):
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig) -> nn.Module:
scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module:
self._verify_config(model_config, parallel_config)
if is_vllm_serialized_tensorizer(self.tensorizer_config):
return self._load_model_serialized(model_config, device_config,
lora_config,
vision_language_config)
vision_language_config,
cache_config)
return self._load_model_unserialized(model_config, device_config,
lora_config,
vision_language_config)
vision_language_config,
cache_config)
def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
......
......@@ -5,6 +5,7 @@ import torch
from torch import nn
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
......@@ -215,6 +216,7 @@ class ArcticAttention(nn.Module):
self,
config: ArcticConfig,
layer_idx: Optional[int] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
......@@ -265,7 +267,8 @@ class ArcticAttention(nn.Module):
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads)
num_kv_heads=self.num_kv_heads,
cache_config=cache_config)
def forward(
self,
......@@ -288,6 +291,7 @@ class ArcticDecoderLayer(nn.Module):
self,
config: ArcticConfig,
layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
......@@ -297,6 +301,7 @@ class ArcticDecoderLayer(nn.Module):
self.use_residual = config.use_residual and is_moe_layer
self.self_attn = ArcticAttention(config,
layer_idx,
cache_config,
quant_config=quant_config)
self.block_sparse_moe = ArcticMoE(
config,
......@@ -356,6 +361,7 @@ class ArcticModel(nn.Module):
def __init__(
self,
config: ArcticConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
......@@ -366,7 +372,10 @@ class ArcticModel(nn.Module):
config.hidden_size,
org_num_embeddings=self.vocab_size)
self.layers = nn.ModuleList([
ArcticDecoderLayer(config, layer_idx, quant_config=quant_config)
ArcticDecoderLayer(config,
layer_idx,
cache_config,
quant_config=quant_config)
for layer_idx in range(config.num_hidden_layers)
])
self._attn_implementation = config._attn_implementation
......@@ -392,11 +401,12 @@ class ArcticForCausalLM(nn.Module):
def __init__(self,
config: ArcticConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
**kwargs) -> None:
super().__init__()
self.config = config
self.model = ArcticModel(config, quant_config)
self.model = ArcticModel(config, cache_config, quant_config)
self.vocab_size = config.vocab_size
self.lm_head = ParallelLMHead(
self.vocab_size,
......
......@@ -26,7 +26,7 @@ from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import LoRAConfig
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul
......@@ -111,6 +111,7 @@ class BaiChuanAttention(nn.Module):
position_embedding: str,
rope_theta: float = 10000,
max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
......@@ -162,7 +163,10 @@ class BaiChuanAttention(nn.Module):
base=self.rope_theta,
)
self.scaling = self.head_dim**-0.5
self.attn = Attention(self.num_heads, self.head_dim, self.scaling)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
cache_config=cache_config)
def forward(
self,
......@@ -185,6 +189,7 @@ class BaiChuanDecoderLayer(nn.Module):
def __init__(self,
config: PretrainedConfig,
position_embedding: str,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.hidden_size = config.hidden_size
......@@ -197,6 +202,7 @@ class BaiChuanDecoderLayer(nn.Module):
position_embedding=position_embedding,
rope_theta=rope_theta,
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
)
self.mlp = BaiChuanMLP(
......@@ -244,6 +250,7 @@ class BaiChuanModel(nn.Module):
def __init__(self,
config: PretrainedConfig,
position_embedding: str,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.config = config
......@@ -255,7 +262,8 @@ class BaiChuanModel(nn.Module):
config.hidden_size,
)
self.layers = nn.ModuleList([
BaiChuanDecoderLayer(config, position_embedding, quant_config)
BaiChuanDecoderLayer(config, position_embedding, cache_config,
quant_config)
for _ in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
......@@ -304,13 +312,15 @@ class BaiChuanBaseForCausalLM(nn.Module):
self,
config,
position_embedding: str,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
):
super().__init__()
self.config = config
self.quant_config = quant_config
self.model = BaiChuanModel(config, position_embedding, quant_config)
self.model = BaiChuanModel(config, position_embedding, cache_config,
quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
......@@ -389,13 +399,16 @@ class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
def __init__(
self,
config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
):
if config.hidden_size == 4096: # baichuan2 7b
super().__init__(config, "ROPE", quant_config, lora_config)
super().__init__(config, "ROPE", cache_config, quant_config,
lora_config)
else: # baichuan 13b, baichuan2 13b
super().__init__(config, "ALIBI", quant_config, lora_config)
super().__init__(config, "ALIBI", cache_config, quant_config,
lora_config)
class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
......@@ -404,7 +417,9 @@ class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
def __init__(
self,
config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
):
super().__init__(config, "ROPE", quant_config, lora_config)
super().__init__(config, "ROPE", cache_config, quant_config,
lora_config)
......@@ -24,6 +24,7 @@ from torch import nn
from transformers import BloomConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import get_act_fn
......@@ -71,6 +72,7 @@ class BloomAttention(nn.Module):
def __init__(
self,
config: BloomConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
......@@ -108,7 +110,8 @@ class BloomAttention(nn.Module):
self.attn = Attention(self.num_heads,
self.head_dim,
scaling,
alibi_slopes=alibi_slopes)
alibi_slopes=alibi_slopes,
cache_config=cache_config)
def forward(
self,
......@@ -158,6 +161,7 @@ class BloomBlock(nn.Module):
def __init__(
self,
config: BloomConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
......@@ -165,7 +169,8 @@ class BloomBlock(nn.Module):
self.input_layernorm = nn.LayerNorm(hidden_size,
eps=config.layer_norm_epsilon)
self.self_attention = BloomAttention(config, quant_config)
self.self_attention = BloomAttention(config, cache_config,
quant_config)
self.post_attention_layernorm = nn.LayerNorm(
hidden_size, eps=config.layer_norm_epsilon)
self.mlp = BloomMLP(config, quant_config)
......@@ -214,6 +219,7 @@ class BloomModel(nn.Module):
def __init__(
self,
config: BloomConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
......@@ -229,7 +235,7 @@ class BloomModel(nn.Module):
# Transformer blocks
self.h = nn.ModuleList([
BloomBlock(config, quant_config)
BloomBlock(config, cache_config, quant_config)
for _ in range(config.num_hidden_layers)
])
......@@ -262,12 +268,13 @@ class BloomForCausalLM(nn.Module):
def __init__(
self,
config: BloomConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.quant_config = quant_config
self.transformer = BloomModel(config, quant_config)
self.transformer = BloomModel(config, cache_config, quant_config)
self.lm_head_weight = self.transformer.word_embeddings.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
......
......@@ -9,7 +9,7 @@ from torch import nn
from torch.nn import LayerNorm
from vllm.attention import Attention, AttentionMetadata
from vllm.config import LoRAConfig
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
......@@ -34,6 +34,7 @@ class GLMAttention(nn.Module):
def __init__(
self,
config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
......@@ -90,6 +91,7 @@ class GLMAttention(nn.Module):
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
)
def forward(
......@@ -167,6 +169,7 @@ class GLMBlock(nn.Module):
def __init__(
self,
config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
......@@ -181,7 +184,7 @@ class GLMBlock(nn.Module):
eps=config.layernorm_epsilon)
# Self attention.
self.self_attention = GLMAttention(config, quant_config)
self.self_attention = GLMAttention(config, cache_config, quant_config)
self.hidden_dropout = config.hidden_dropout
# Layernorm on the attention output
......@@ -237,6 +240,7 @@ class GLMTransformer(nn.Module):
def __init__(
self,
config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
......@@ -246,8 +250,10 @@ class GLMTransformer(nn.Module):
self.num_layers = config.num_layers
# Transformer layers.
self.layers = nn.ModuleList(
[GLMBlock(config, quant_config) for i in range(self.num_layers)])
self.layers = nn.ModuleList([
GLMBlock(config, cache_config, quant_config)
for i in range(self.num_layers)
])
if self.post_layer_norm:
layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
......@@ -282,6 +288,7 @@ class ChatGLMModel(nn.Module):
def __init__(
self,
config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
......@@ -292,7 +299,7 @@ class ChatGLMModel(nn.Module):
self.num_layers = config.num_layers
self.multi_query_group_num = config.multi_query_group_num
self.kv_channels = config.kv_channels
self.encoder = GLMTransformer(config, quant_config)
self.encoder = GLMTransformer(config, cache_config, quant_config)
self.output_layer = ParallelLMHead(config.padded_vocab_size,
config.hidden_size)
......@@ -334,13 +341,14 @@ class ChatGLMForCausalLM(nn.Module):
def __init__(
self,
config: ChatGLMConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
):
super().__init__()
self.config: ChatGLMConfig = config
self.quant_config = quant_config
self.transformer = ChatGLMModel(config, quant_config)
self.transformer = ChatGLMModel(config, cache_config, quant_config)
self.lm_head_weight = self.transformer.output_layer.weight
self.logits_processor = LogitsProcessor(config.padded_vocab_size)
self.sampler = Sampler()
......
......@@ -29,6 +29,7 @@ from torch.nn.parameter import Parameter
from transformers import CohereConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul
......@@ -124,6 +125,7 @@ class CohereAttention(nn.Module):
def __init__(
self,
config: CohereConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
......@@ -180,6 +182,7 @@ class CohereAttention(nn.Module):
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
)
if self.use_qk_norm:
self.q_norm = LayerNorm(param_shape=(self.num_heads,
......@@ -219,11 +222,14 @@ class CohereDecoderLayer(nn.Module):
def __init__(self,
config: CohereConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = CohereAttention(config, quant_config=quant_config)
self.self_attn = CohereAttention(config,
cache_config,
quant_config=quant_config)
self.mlp = CohereMLP(config, quant_config=quant_config)
self.input_layernorm = LayerNorm(param_shape=(config.hidden_size),
......@@ -258,6 +264,7 @@ class CohereModel(nn.Module):
def __init__(
self,
config: CohereConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
......@@ -266,7 +273,7 @@ class CohereModel(nn.Module):
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
config.hidden_size)
self.layers = nn.ModuleList([
CohereDecoderLayer(config, quant_config=quant_config)
CohereDecoderLayer(config, cache_config, quant_config=quant_config)
for _ in range(config.num_hidden_layers)
])
self.norm = LayerNorm(param_shape=(config.hidden_size),
......@@ -299,6 +306,7 @@ class CohereForCausalLM(nn.Module):
def __init__(
self,
config: CohereConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
......@@ -306,7 +314,7 @@ class CohereForCausalLM(nn.Module):
self.quant_config = quant_config
self.logits_processor = LogitsProcessor(config.vocab_size,
scale=config.logit_scale)
self.model = CohereModel(config, quant_config)
self.model = CohereModel(config, cache_config, quant_config)
self.sampler = Sampler()
@torch.no_grad()
......
......@@ -5,6 +5,7 @@ import torch
import torch.nn as nn
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
......@@ -166,6 +167,7 @@ class DbrxAttention(nn.Module):
def __init__(
self,
config: DbrxConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
......@@ -221,6 +223,7 @@ class DbrxAttention(nn.Module):
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
)
def forward(
......@@ -279,10 +282,12 @@ class DbrxBlock(nn.Module):
def __init__(
self,
config: DbrxConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.norm_attn_norm = DbrxFusedNormAttention(config, quant_config)
self.norm_attn_norm = DbrxFusedNormAttention(config, cache_config,
quant_config)
self.ffn = DbrxExperts(config, quant_config)
def forward(
......@@ -308,6 +313,7 @@ class DbrxModel(nn.Module):
def __init__(
self,
config: DbrxConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
......@@ -315,8 +321,10 @@ class DbrxModel(nn.Module):
config.vocab_size,
config.d_model,
)
self.blocks = nn.ModuleList(
[DbrxBlock(config, quant_config) for _ in range(config.n_layers)])
self.blocks = nn.ModuleList([
DbrxBlock(config, cache_config, quant_config)
for _ in range(config.n_layers)
])
self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5)
for module in self.modules():
if hasattr(module, "bias") and isinstance(module.bias,
......@@ -349,13 +357,14 @@ class DbrxForCausalLM(nn.Module):
def __init__(
self,
config: DbrxConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.quant_config = quant_config
self.unpadded_vocab_size = config.vocab_size
self.transformer = DbrxModel(config, quant_config)
self.transformer = DbrxModel(config, cache_config, quant_config)
self.lm_head = ParallelLMHead(
config.vocab_size,
config.d_model,
......
......@@ -28,7 +28,7 @@ from typing import Iterable, Optional, Tuple
import torch
from transformers import PretrainedConfig
from vllm.config import LoRAConfig
from vllm.config import CacheConfig, LoRAConfig
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
......@@ -56,12 +56,14 @@ class DeciLMForCausalLM(LlamaForCausalLM):
def __init__(
self,
config: Optional[PretrainedConfig] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
config.num_key_value_heads = max(config.num_key_value_heads_per_layer)
delattr(config, "num_key_value_heads_per_layer")
super().__init__(config=config,
cache_config=cache_config,
quant_config=quant_config,
lora_config=lora_config)
......
......@@ -28,6 +28,7 @@ from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
......@@ -178,6 +179,7 @@ class DeepseekAttention(nn.Module):
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
......@@ -229,7 +231,8 @@ class DeepseekAttention(nn.Module):
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads)
num_kv_heads=self.num_kv_heads,
cache_config=cache_config)
def forward(
self,
......@@ -252,6 +255,7 @@ class DeepseekDecoderLayer(nn.Module):
self,
config: PretrainedConfig,
layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
......@@ -267,6 +271,7 @@ class DeepseekDecoderLayer(nn.Module):
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
)
if (config.n_routed_experts is not None
......@@ -321,6 +326,7 @@ class DeepseekModel(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
......@@ -332,7 +338,10 @@ class DeepseekModel(nn.Module):
config.hidden_size,
)
self.layers = nn.ModuleList([
DeepseekDecoderLayer(config, layer_idx, quant_config=quant_config)
DeepseekDecoderLayer(config,
layer_idx,
cache_config,
quant_config=quant_config)
for layer_idx in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
......@@ -360,12 +369,13 @@ class DeepseekForCausalLM(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
self.quant_config = quant_config
self.model = DeepseekModel(config, quant_config)
self.model = DeepseekModel(config, cache_config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment