Unverified Commit 0fca3cdc authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[Misc] Enhance attention selector (#4751)

parent e7c46b95
...@@ -307,7 +307,6 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): ...@@ -307,7 +307,6 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
assert len(attn_metadata.slot_mapping) == len(input_tokens) assert len(attn_metadata.slot_mapping) == len(input_tokens)
assert len(input_positions) == len(input_tokens) assert len(input_positions) == len(input_tokens)
assert attn_metadata.kv_cache_dtype == "auto"
assert attn_metadata.num_prefills == prefill_batch_size assert attn_metadata.num_prefills == prefill_batch_size
if enforce_eager: if enforce_eager:
assert attn_metadata.num_decode_tokens == decode_batch_size assert attn_metadata.num_decode_tokens == decode_batch_size
......
...@@ -5,9 +5,9 @@ from vllm.attention.layer import Attention ...@@ -5,9 +5,9 @@ from vllm.attention.layer import Attention
from vllm.attention.selector import get_attn_backend from vllm.attention.selector import get_attn_backend
__all__ = [ __all__ = [
"Attention",
"AttentionBackend", "AttentionBackend",
"AttentionMetadata", "AttentionMetadata",
"Attention",
"get_attn_backend",
"AttentionMetadataPerStage", "AttentionMetadataPerStage",
"get_attn_backend",
] ]
...@@ -94,8 +94,6 @@ class AttentionMetadata(Generic[T]): ...@@ -94,8 +94,6 @@ class AttentionMetadata(Generic[T]):
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
# in block 0, and 1st slot in block 1, respectively. # in block 0, and 1st slot in block 1, respectively.
slot_mapping: torch.Tensor slot_mapping: torch.Tensor
# The kv cache's data type.
kv_cache_dtype: str
def __post_init__(self): def __post_init__(self):
if self.num_prefill_tokens > 0: if self.num_prefill_tokens > 0:
...@@ -116,6 +114,7 @@ class AttentionImpl(ABC): ...@@ -116,6 +114,7 @@ class AttentionImpl(ABC):
num_kv_heads: Optional[int] = None, num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None, alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None, sliding_window: Optional[int] = None,
kv_cache_dtype: str = "auto",
) -> None: ) -> None:
raise NotImplementedError raise NotImplementedError
...@@ -127,6 +126,6 @@ class AttentionImpl(ABC): ...@@ -127,6 +126,6 @@ class AttentionImpl(ABC):
value: torch.Tensor, value: torch.Tensor,
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
kv_scale: float, kv_scale: float = 1.0,
) -> torch.Tensor: ) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError
...@@ -140,16 +140,18 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -140,16 +140,18 @@ class FlashAttentionImpl(AttentionImpl):
num_kv_heads: Optional[int] = None, num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None, alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None, sliding_window: Optional[int] = None,
kv_cache_dtype: str = "auto",
) -> None: ) -> None:
self.num_heads = num_heads self.num_heads = num_heads
self.head_size = head_size self.head_size = head_size
self.scale = float(scale) self.scale = float(scale)
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.sliding_window = ((sliding_window, sliding_window)
if sliding_window is not None else (-1, -1))
if alibi_slopes is not None: if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes self.alibi_slopes = alibi_slopes
self.sliding_window = ((sliding_window, sliding_window)
if sliding_window is not None else (-1, -1))
self.kv_cache_dtype = kv_cache_dtype
assert self.num_heads % self.num_kv_heads == 0 assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.num_queries_per_kv = self.num_heads // self.num_kv_heads
...@@ -167,7 +169,7 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -167,7 +169,7 @@ class FlashAttentionImpl(AttentionImpl):
value: torch.Tensor, value: torch.Tensor,
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata[FlashAttentionMetadata], attn_metadata: AttentionMetadata[FlashAttentionMetadata],
kv_scale: float, kv_scale: float = 1.0,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention. """Forward pass with FlashAttention and PagedAttention.
...@@ -196,8 +198,7 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -196,8 +198,7 @@ class FlashAttentionImpl(AttentionImpl):
PagedAttention.write_to_paged_cache(key, value, key_cache, PagedAttention.write_to_paged_cache(key, value, key_cache,
value_cache, value_cache,
attn_metadata.slot_mapping, attn_metadata.slot_mapping,
attn_metadata.kv_cache_dtype, self.kv_cache_dtype, kv_scale)
kv_scale)
num_prefill_tokens = attn_metadata.num_prefill_tokens num_prefill_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens num_decode_tokens = attn_metadata.num_decode_tokens
...@@ -264,7 +265,7 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -264,7 +265,7 @@ class FlashAttentionImpl(AttentionImpl):
decode_meta.block_tables, decode_meta.block_tables,
decode_meta.seq_lens_tensor, decode_meta.seq_lens_tensor,
decode_meta.max_seq_len, decode_meta.max_seq_len,
attn_metadata.kv_cache_dtype, self.kv_cache_dtype,
self.num_kv_heads, self.num_kv_heads,
self.scale, self.scale,
self.alibi_slopes, self.alibi_slopes,
......
...@@ -149,20 +149,33 @@ class FlashInferImpl(AttentionImpl): ...@@ -149,20 +149,33 @@ class FlashInferImpl(AttentionImpl):
num_kv_heads: Optional[int] = None, num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None, alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None, sliding_window: Optional[int] = None,
kv_cache_dtype: str = "auto",
) -> None: ) -> None:
if sliding_window is not None:
raise ValueError("Sliding window is not supported in FlashInfer.")
self.sliding_window = (-1, -1)
self.alibi_slopes = alibi_slopes
self.scale = scale
self.num_heads = num_heads self.num_heads = num_heads
self.head_size = head_size self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
if sliding_window is not None:
raise ValueError("Sliding window is not supported in FlashInfer.")
self.sliding_window = (-1, -1)
self.kv_cache_dtype = kv_cache_dtype
def forward(self, query: torch.Tensor, key: torch.Tensor, assert self.num_heads % self.num_kv_heads == 0
value: torch.Tensor, kv_cache: Optional[torch.Tensor], self.num_queries_per_kv = self.num_heads // self.num_kv_heads
attn_metadata: AttentionMetadata[FlashInferMetadata],
kv_scale: float): def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: Optional[torch.Tensor],
attn_metadata: AttentionMetadata[FlashInferMetadata],
kv_scale: float = 1.0,
) -> torch.Tensor:
assert kv_scale == 1.0
num_tokens, hidden_size = query.shape num_tokens, hidden_size = query.shape
query = query.view(-1, self.num_heads, self.head_size) query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size) key = key.view(-1, self.num_kv_heads, self.head_size)
...@@ -183,7 +196,7 @@ class FlashInferImpl(AttentionImpl): ...@@ -183,7 +196,7 @@ class FlashInferImpl(AttentionImpl):
kv_cache[:, 0], kv_cache[:, 0],
kv_cache[:, 1], kv_cache[:, 1],
attn_metadata.slot_mapping.flatten(), attn_metadata.slot_mapping.flatten(),
attn_metadata.kv_cache_dtype, self.kv_cache_dtype,
) )
if prefill_meta := attn_metadata.prefill_metadata: if prefill_meta := attn_metadata.prefill_metadata:
......
...@@ -138,25 +138,27 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -138,25 +138,27 @@ class ROCmFlashAttentionImpl(AttentionImpl):
num_kv_heads: Optional[int] = None, num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None, alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None, sliding_window: Optional[int] = None,
kv_cache_dtype: str = "auto",
) -> None: ) -> None:
self.num_heads = num_heads self.num_heads = num_heads
self.head_size = head_size self.head_size = head_size
self.scale = float(scale) self.scale = float(scale)
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.sliding_window = ((sliding_window, sliding_window)
if sliding_window is not None else (-1, -1))
if alibi_slopes is not None: if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes self.alibi_slopes = alibi_slopes
self.sliding_window = ((sliding_window, sliding_window)
if sliding_window is not None else (-1, -1))
self.kv_cache_dtype = kv_cache_dtype
assert self.num_heads % self.num_kv_heads == 0 assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.num_queries_per_kv = self.num_heads // self.num_kv_heads
suppored_head_sizes = PagedAttention.get_supported_head_sizes() supported_head_sizes = PagedAttention.get_supported_head_sizes()
if head_size not in suppored_head_sizes: if head_size not in supported_head_sizes:
raise ValueError( raise ValueError(
f"Head size {head_size} is not supported by PagedAttention. " f"Head size {head_size} is not supported by PagedAttention. "
f"Supported head sizes are: {suppored_head_sizes}.") f"Supported head sizes are: {supported_head_sizes}.")
self.use_naive_attn = False self.use_naive_attn = False
# NOTE: Allow for switching between Triton and CK. Defaulting to triton. # NOTE: Allow for switching between Triton and CK. Defaulting to triton.
...@@ -229,7 +231,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -229,7 +231,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
key_cache, key_cache,
value_cache, value_cache,
attn_metadata.slot_mapping, attn_metadata.slot_mapping,
attn_metadata.kv_cache_dtype, self.kv_cache_dtype,
kv_scale, kv_scale,
) )
...@@ -323,7 +325,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -323,7 +325,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
decode_meta.block_tables, decode_meta.block_tables,
decode_meta.seq_lens_tensor, decode_meta.seq_lens_tensor,
decode_meta.max_seq_len, decode_meta.max_seq_len,
attn_metadata.kv_cache_dtype, self.kv_cache_dtype,
self.num_kv_heads, self.num_kv_heads,
self.scale, self.scale,
self.alibi_slopes, self.alibi_slopes,
......
...@@ -83,26 +83,32 @@ class TorchSDPABackendImpl(AttentionImpl): ...@@ -83,26 +83,32 @@ class TorchSDPABackendImpl(AttentionImpl):
num_kv_heads: Optional[int] = None, num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None, alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None, sliding_window: Optional[int] = None,
kv_cache_dtype: str = "auto",
) -> None: ) -> None:
self.num_heads = num_heads self.num_heads = num_heads
self.head_size = head_size self.head_size = head_size
self.scale = float(scale) self.scale = float(scale)
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.sliding_window = sliding_window
if alibi_slopes is not None: if alibi_slopes is not None:
assert len(alibi_slopes) == num_heads
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes self.alibi_slopes = alibi_slopes
self.need_mask = (self.alibi_slopes is not None self.sliding_window = sliding_window
or self.sliding_window is not None) self.kv_cache_dtype = kv_cache_dtype
assert self.num_heads % self.num_kv_heads == 0 assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.num_queries_per_kv = self.num_heads // self.num_kv_heads
suppored_head_sizes = PagedAttention.get_supported_head_sizes() self.need_mask = (self.alibi_slopes is not None
if head_size not in suppored_head_sizes: or self.sliding_window is not None)
supported_head_sizes = PagedAttention.get_supported_head_sizes()
if head_size not in supported_head_sizes:
raise ValueError( raise ValueError(
f"Head size {head_size} is not supported by PagedAttention. " f"Head size {head_size} is not supported by PagedAttention. "
f"Supported head sizes are: {suppored_head_sizes}.") f"Supported head sizes are: {supported_head_sizes}.")
if kv_cache_dtype != "auto":
raise NotImplementedError(
"Torch SDPA backend does not support FP8 KV cache. "
"Please use xFormers backend instead.")
def forward( def forward(
self, self,
...@@ -111,7 +117,7 @@ class TorchSDPABackendImpl(AttentionImpl): ...@@ -111,7 +117,7 @@ class TorchSDPABackendImpl(AttentionImpl):
value: torch.Tensor, value: torch.Tensor,
kv_cache: Optional[torch.Tensor], kv_cache: Optional[torch.Tensor],
attn_metadata: TorchSDPAMetadata, # type: ignore attn_metadata: TorchSDPAMetadata, # type: ignore
kv_scale: float, kv_scale: float = 1.0,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with torch SDPA and PagedAttention. """Forward pass with torch SDPA and PagedAttention.
...@@ -124,6 +130,7 @@ class TorchSDPABackendImpl(AttentionImpl): ...@@ -124,6 +130,7 @@ class TorchSDPABackendImpl(AttentionImpl):
Returns: Returns:
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
""" """
assert kv_scale == 1.0
num_tokens, hidden_size = query.shape num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors. # Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size) query = query.view(-1, self.num_heads, self.head_size)
...@@ -136,8 +143,7 @@ class TorchSDPABackendImpl(AttentionImpl): ...@@ -136,8 +143,7 @@ class TorchSDPABackendImpl(AttentionImpl):
PagedAttention.write_to_paged_cache(key, value, key_cache, PagedAttention.write_to_paged_cache(key, value, key_cache,
value_cache, value_cache,
attn_metadata.slot_mapping, attn_metadata.slot_mapping,
attn_metadata.kv_cache_dtype, self.kv_cache_dtype, kv_scale)
kv_scale)
if attn_metadata.is_prompt: if attn_metadata.is_prompt:
assert attn_metadata.seq_lens is not None assert attn_metadata.seq_lens is not None
...@@ -195,7 +201,7 @@ class TorchSDPABackendImpl(AttentionImpl): ...@@ -195,7 +201,7 @@ class TorchSDPABackendImpl(AttentionImpl):
attn_metadata.block_tables, attn_metadata.block_tables,
attn_metadata.seq_lens_tensor, attn_metadata.seq_lens_tensor,
attn_metadata.max_seq_len, attn_metadata.max_seq_len,
attn_metadata.kv_cache_dtype, self.kv_cache_dtype,
self.num_kv_heads, self.num_kv_heads,
self.scale, self.scale,
self.alibi_slopes, self.alibi_slopes,
......
...@@ -149,15 +149,17 @@ class XFormersImpl(AttentionImpl): ...@@ -149,15 +149,17 @@ class XFormersImpl(AttentionImpl):
num_kv_heads: Optional[int] = None, num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None, alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None, sliding_window: Optional[int] = None,
kv_cache_dtype: str = "auto",
) -> None: ) -> None:
self.num_heads = num_heads self.num_heads = num_heads
self.head_size = head_size self.head_size = head_size
self.scale = float(scale) self.scale = float(scale)
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.sliding_window = sliding_window
if alibi_slopes is not None: if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes self.alibi_slopes = alibi_slopes
self.sliding_window = sliding_window
self.kv_cache_dtype = kv_cache_dtype
assert self.num_heads % self.num_kv_heads == 0 assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.num_queries_per_kv = self.num_heads // self.num_kv_heads
...@@ -175,7 +177,7 @@ class XFormersImpl(AttentionImpl): ...@@ -175,7 +177,7 @@ class XFormersImpl(AttentionImpl):
value: torch.Tensor, value: torch.Tensor,
kv_cache: Optional[torch.Tensor], kv_cache: Optional[torch.Tensor],
attn_metadata: AttentionMetadata[XFormersMetadata], attn_metadata: AttentionMetadata[XFormersMetadata],
kv_scale: float, kv_scale: float = 1.0,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention. """Forward pass with xFormers and PagedAttention.
...@@ -188,7 +190,6 @@ class XFormersImpl(AttentionImpl): ...@@ -188,7 +190,6 @@ class XFormersImpl(AttentionImpl):
Returns: Returns:
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
""" """
num_tokens, hidden_size = query.shape
query = query.view(-1, self.num_heads, self.head_size) query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size) key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size) value = value.view(-1, self.num_kv_heads, self.head_size)
...@@ -203,8 +204,7 @@ class XFormersImpl(AttentionImpl): ...@@ -203,8 +204,7 @@ class XFormersImpl(AttentionImpl):
PagedAttention.write_to_paged_cache(key, value, key_cache, PagedAttention.write_to_paged_cache(key, value, key_cache,
value_cache, value_cache,
attn_metadata.slot_mapping, attn_metadata.slot_mapping,
attn_metadata.kv_cache_dtype, self.kv_cache_dtype, kv_scale)
kv_scale)
num_prefill_tokens = attn_metadata.num_prefill_tokens num_prefill_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens num_decode_tokens = attn_metadata.num_decode_tokens
...@@ -262,7 +262,7 @@ class XFormersImpl(AttentionImpl): ...@@ -262,7 +262,7 @@ class XFormersImpl(AttentionImpl):
decode_meta.block_tables, decode_meta.block_tables,
decode_meta.seq_lens_tensor, decode_meta.seq_lens_tensor,
decode_meta.max_seq_len, decode_meta.max_seq_len,
attn_metadata.kv_cache_dtype, self.kv_cache_dtype,
self.num_kv_heads, self.num_kv_heads,
self.scale, self.scale,
self.alibi_slopes, self.alibi_slopes,
......
...@@ -7,6 +7,7 @@ import torch.nn as nn ...@@ -7,6 +7,7 @@ import torch.nn as nn
from vllm.attention.backends.abstract import (AttentionMetadata, from vllm.attention.backends.abstract import (AttentionMetadata,
AttentionMetadataPerStage) AttentionMetadataPerStage)
from vllm.attention.selector import get_attn_backend from vllm.attention.selector import get_attn_backend
from vllm.config import CacheConfig
class Attention(nn.Module): class Attention(nn.Module):
...@@ -29,10 +30,24 @@ class Attention(nn.Module): ...@@ -29,10 +30,24 @@ class Attention(nn.Module):
num_kv_heads: Optional[int] = None, num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None, alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None, sliding_window: Optional[int] = None,
cache_config: Optional[CacheConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.backend = get_attn_backend(torch.get_default_dtype()) if cache_config is not None:
impl_cls = self.backend.get_impl_cls() kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
else:
kv_cache_dtype = "auto"
block_size = 16
if num_kv_heads is None:
num_kv_heads = num_heads
# During model initialization, the default dtype is set as the model
# weight and activation dtype.
dtype = torch.get_default_dtype()
attn_backend = get_attn_backend(num_heads, head_size, num_kv_heads,
sliding_window, dtype, kv_cache_dtype,
block_size)
impl_cls = attn_backend.get_impl_cls()
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window) alibi_slopes, sliding_window)
......
import enum import enum
from functools import lru_cache from functools import lru_cache
from typing import Type from typing import Optional, Type
import torch import torch
...@@ -21,8 +21,18 @@ class _Backend(enum.Enum): ...@@ -21,8 +21,18 @@ class _Backend(enum.Enum):
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]: def get_attn_backend(
backend = _which_attn_to_use(dtype) num_heads: int,
head_size: int,
num_kv_heads: int,
sliding_window: Optional[int],
dtype: torch.dtype,
kv_cache_dtype: Optional[str],
block_size: int,
) -> Type[AttentionBackend]:
backend = _which_attn_to_use(num_heads, head_size, num_kv_heads,
sliding_window, dtype, kv_cache_dtype,
block_size)
if backend == _Backend.FLASH_ATTN: if backend == _Backend.FLASH_ATTN:
logger.info("Using FlashAttention-2 backend.") logger.info("Using FlashAttention-2 backend.")
from vllm.attention.backends.flash_attn import ( # noqa: F401 from vllm.attention.backends.flash_attn import ( # noqa: F401
...@@ -44,14 +54,22 @@ def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]: ...@@ -44,14 +54,22 @@ def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]:
return TorchSDPABackend return TorchSDPABackend
elif backend == _Backend.FLASHINFER: elif backend == _Backend.FLASHINFER:
logger.info("Using Flashinfer backend.") logger.info("Using Flashinfer backend.")
logger.warning("Eager mode is enforced for the Flashinfer backend. ") logger.warning("Eager mode is enforced for the Flashinfer backend.")
from vllm.attention.backends.flashinfer import FlashInferBackend from vllm.attention.backends.flashinfer import FlashInferBackend
return FlashInferBackend return FlashInferBackend
else: else:
raise ValueError("Invalid attention backend.") raise ValueError("Invalid attention backend.")
def _which_attn_to_use(dtype: torch.dtype) -> _Backend: def _which_attn_to_use(
num_heads: int,
head_size: int,
num_kv_heads: int,
sliding_window: Optional[int],
dtype: torch.dtype,
kv_cache_dtype: Optional[str],
block_size: int,
) -> _Backend:
"""Returns which flash attention backend to use.""" """Returns which flash attention backend to use."""
if is_cpu(): if is_cpu():
return _Backend.TORCH_SDPA return _Backend.TORCH_SDPA
......
...@@ -2,26 +2,29 @@ from typing import Optional ...@@ -2,26 +2,29 @@ from typing import Optional
from torch import nn from torch import nn
from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ParallelConfig, SchedulerConfig, VisionLanguageConfig) ModelConfig, ParallelConfig, SchedulerConfig,
VisionLanguageConfig)
from vllm.model_executor.model_loader.loader import (BaseModelLoader, from vllm.model_executor.model_loader.loader import (BaseModelLoader,
get_model_loader) get_model_loader)
from vllm.model_executor.model_loader.utils import ( from vllm.model_executor.model_loader.utils import (
get_architecture_class_name, get_model_architecture) get_architecture_class_name, get_model_architecture)
def get_model( def get_model(*, model_config: ModelConfig, load_config: LoadConfig,
*, model_config: ModelConfig, load_config: LoadConfig, device_config: DeviceConfig, parallel_config: ParallelConfig,
device_config: DeviceConfig, parallel_config: ParallelConfig, scheduler_config: SchedulerConfig,
scheduler_config: SchedulerConfig, lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig]) -> nn.Module: vision_language_config: Optional[VisionLanguageConfig],
cache_config: CacheConfig) -> nn.Module:
loader = get_model_loader(load_config) loader = get_model_loader(load_config)
return loader.load_model(model_config=model_config, return loader.load_model(model_config=model_config,
device_config=device_config, device_config=device_config,
lora_config=lora_config, lora_config=lora_config,
vision_language_config=vision_language_config, vision_language_config=vision_language_config,
parallel_config=parallel_config, parallel_config=parallel_config,
scheduler_config=scheduler_config) scheduler_config=scheduler_config,
cache_config=cache_config)
__all__ = [ __all__ = [
......
...@@ -9,9 +9,9 @@ import huggingface_hub ...@@ -9,9 +9,9 @@ import huggingface_hub
import torch import torch
from torch import nn from torch import nn
from vllm.config import (DeviceConfig, LoadConfig, LoadFormat, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoadFormat,
ModelConfig, ParallelConfig, SchedulerConfig, LoRAConfig, ModelConfig, ParallelConfig,
VisionLanguageConfig) SchedulerConfig, VisionLanguageConfig)
from vllm.envs import VLLM_USE_MODELSCOPE from vllm.envs import VLLM_USE_MODELSCOPE
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
...@@ -77,15 +77,16 @@ def _get_model_initialization_kwargs( ...@@ -77,15 +77,16 @@ def _get_model_initialization_kwargs(
return extra_kwargs return extra_kwargs
def _initialize_model( def _initialize_model(model_config: ModelConfig, load_config: LoadConfig,
model_config: ModelConfig, load_config: LoadConfig, lora_config: Optional[LoRAConfig],
lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig],
vision_language_config: Optional[VisionLanguageConfig]) -> nn.Module: cache_config: CacheConfig) -> nn.Module:
"""Initialize a model with the given configurations.""" """Initialize a model with the given configurations."""
model_class = get_model_architecture(model_config)[0] model_class = get_model_architecture(model_config)[0]
quant_config = _get_quantization_config(model_config, load_config) quant_config = _get_quantization_config(model_config, load_config)
return model_class(config=model_config.hf_config, return model_class(config=model_config.hf_config,
cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
**_get_model_initialization_kwargs( **_get_model_initialization_kwargs(
model_class, lora_config, vision_language_config)) model_class, lora_config, vision_language_config))
...@@ -103,7 +104,8 @@ class BaseModelLoader(ABC): ...@@ -103,7 +104,8 @@ class BaseModelLoader(ABC):
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig], vision_language_config: Optional[VisionLanguageConfig],
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig) -> nn.Module: scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module:
"""Load a model with the given configurations.""" """Load a model with the given configurations."""
... ...
...@@ -216,11 +218,13 @@ class DefaultModelLoader(BaseModelLoader): ...@@ -216,11 +218,13 @@ class DefaultModelLoader(BaseModelLoader):
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig], vision_language_config: Optional[VisionLanguageConfig],
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig) -> nn.Module: scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module:
with set_default_torch_dtype(model_config.dtype): with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device): with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config, model = _initialize_model(model_config, self.load_config,
lora_config, vision_language_config) lora_config, vision_language_config,
cache_config)
model.load_weights( model.load_weights(
self._get_weights_iterator(model_config.model, self._get_weights_iterator(model_config.model,
model_config.revision, model_config.revision,
...@@ -253,11 +257,13 @@ class DummyModelLoader(BaseModelLoader): ...@@ -253,11 +257,13 @@ class DummyModelLoader(BaseModelLoader):
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig], vision_language_config: Optional[VisionLanguageConfig],
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig) -> nn.Module: scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module:
with set_default_torch_dtype(model_config.dtype): with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device): with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config, model = _initialize_model(model_config, self.load_config,
lora_config, vision_language_config) lora_config, vision_language_config,
cache_config)
# NOTE(woosuk): For accurate performance evaluation, we assign # NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights. # random values to the weights.
initialize_dummy_weights(model) initialize_dummy_weights(model)
...@@ -286,9 +292,12 @@ class TensorizerLoader(BaseModelLoader): ...@@ -286,9 +292,12 @@ class TensorizerLoader(BaseModelLoader):
return tensorizer_weights_iterator(tensorizer_args) return tensorizer_weights_iterator(tensorizer_args)
def _load_model_unserialized( def _load_model_unserialized(
self, model_config: ModelConfig, device_config: DeviceConfig, self,
lora_config: Optional[LoRAConfig], model_config: ModelConfig,
vision_language_config: Optional[VisionLanguageConfig] device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
cache_config: CacheConfig,
) -> nn.Module: ) -> nn.Module:
"""Load an unserialized model with tensorizer. """Load an unserialized model with tensorizer.
...@@ -299,15 +308,19 @@ class TensorizerLoader(BaseModelLoader): ...@@ -299,15 +308,19 @@ class TensorizerLoader(BaseModelLoader):
with set_default_torch_dtype(model_config.dtype): with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device): with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config, model = _initialize_model(model_config, self.load_config,
lora_config, vision_language_config) lora_config, vision_language_config,
cache_config)
model.load_weights(self._get_weights_iterator()) model.load_weights(self._get_weights_iterator())
return model.eval() return model.eval()
def _load_model_serialized( def _load_model_serialized(
self, model_config: ModelConfig, device_config: DeviceConfig, self,
lora_config: Optional[LoRAConfig], model_config: ModelConfig,
vision_language_config: Optional[VisionLanguageConfig] device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
cache_config: CacheConfig,
) -> nn.Module: ) -> nn.Module:
"""Load a serialized model with tensorizer. """Load a serialized model with tensorizer.
...@@ -321,6 +334,7 @@ class TensorizerLoader(BaseModelLoader): ...@@ -321,6 +334,7 @@ class TensorizerLoader(BaseModelLoader):
extra_kwargs = _get_model_initialization_kwargs( extra_kwargs = _get_model_initialization_kwargs(
model_class, lora_config, vision_language_config) model_class, lora_config, vision_language_config)
extra_kwargs["quant_config"] = quant_config extra_kwargs["quant_config"] = quant_config
extra_kwargs["cache_config"] = cache_config
tensorizer_config = copy.copy(self.tensorizer_config) tensorizer_config = copy.copy(self.tensorizer_config)
tensorizer_config.model_class = model_class tensorizer_config.model_class = model_class
...@@ -335,16 +349,19 @@ class TensorizerLoader(BaseModelLoader): ...@@ -335,16 +349,19 @@ class TensorizerLoader(BaseModelLoader):
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig], vision_language_config: Optional[VisionLanguageConfig],
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig) -> nn.Module: scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module:
self._verify_config(model_config, parallel_config) self._verify_config(model_config, parallel_config)
if is_vllm_serialized_tensorizer(self.tensorizer_config): if is_vllm_serialized_tensorizer(self.tensorizer_config):
return self._load_model_serialized(model_config, device_config, return self._load_model_serialized(model_config, device_config,
lora_config, lora_config,
vision_language_config) vision_language_config,
cache_config)
return self._load_model_unserialized(model_config, device_config, return self._load_model_unserialized(model_config, device_config,
lora_config, lora_config,
vision_language_config) vision_language_config,
cache_config)
def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
......
...@@ -5,6 +5,7 @@ import torch ...@@ -5,6 +5,7 @@ import torch
from torch import nn from torch import nn
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import (get_tensor_model_parallel_rank, from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
...@@ -215,6 +216,7 @@ class ArcticAttention(nn.Module): ...@@ -215,6 +216,7 @@ class ArcticAttention(nn.Module):
self, self,
config: ArcticConfig, config: ArcticConfig,
layer_idx: Optional[int] = None, layer_idx: Optional[int] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
...@@ -265,7 +267,8 @@ class ArcticAttention(nn.Module): ...@@ -265,7 +267,8 @@ class ArcticAttention(nn.Module):
self.attn = Attention(self.num_heads, self.attn = Attention(self.num_heads,
self.head_dim, self.head_dim,
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads) num_kv_heads=self.num_kv_heads,
cache_config=cache_config)
def forward( def forward(
self, self,
...@@ -288,6 +291,7 @@ class ArcticDecoderLayer(nn.Module): ...@@ -288,6 +291,7 @@ class ArcticDecoderLayer(nn.Module):
self, self,
config: ArcticConfig, config: ArcticConfig,
layer_idx: int, layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -297,6 +301,7 @@ class ArcticDecoderLayer(nn.Module): ...@@ -297,6 +301,7 @@ class ArcticDecoderLayer(nn.Module):
self.use_residual = config.use_residual and is_moe_layer self.use_residual = config.use_residual and is_moe_layer
self.self_attn = ArcticAttention(config, self.self_attn = ArcticAttention(config,
layer_idx, layer_idx,
cache_config,
quant_config=quant_config) quant_config=quant_config)
self.block_sparse_moe = ArcticMoE( self.block_sparse_moe = ArcticMoE(
config, config,
...@@ -356,6 +361,7 @@ class ArcticModel(nn.Module): ...@@ -356,6 +361,7 @@ class ArcticModel(nn.Module):
def __init__( def __init__(
self, self,
config: ArcticConfig, config: ArcticConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -366,7 +372,10 @@ class ArcticModel(nn.Module): ...@@ -366,7 +372,10 @@ class ArcticModel(nn.Module):
config.hidden_size, config.hidden_size,
org_num_embeddings=self.vocab_size) org_num_embeddings=self.vocab_size)
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
ArcticDecoderLayer(config, layer_idx, quant_config=quant_config) ArcticDecoderLayer(config,
layer_idx,
cache_config,
quant_config=quant_config)
for layer_idx in range(config.num_hidden_layers) for layer_idx in range(config.num_hidden_layers)
]) ])
self._attn_implementation = config._attn_implementation self._attn_implementation = config._attn_implementation
...@@ -392,11 +401,12 @@ class ArcticForCausalLM(nn.Module): ...@@ -392,11 +401,12 @@ class ArcticForCausalLM(nn.Module):
def __init__(self, def __init__(self,
config: ArcticConfig, config: ArcticConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
**kwargs) -> None: **kwargs) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.model = ArcticModel(config, quant_config) self.model = ArcticModel(config, cache_config, quant_config)
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
self.vocab_size, self.vocab_size,
......
...@@ -26,7 +26,7 @@ from torch import nn ...@@ -26,7 +26,7 @@ from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import LoRAConfig from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import (get_tensor_model_parallel_rank, from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
...@@ -111,6 +111,7 @@ class BaiChuanAttention(nn.Module): ...@@ -111,6 +111,7 @@ class BaiChuanAttention(nn.Module):
position_embedding: str, position_embedding: str,
rope_theta: float = 10000, rope_theta: float = 10000,
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
...@@ -162,7 +163,10 @@ class BaiChuanAttention(nn.Module): ...@@ -162,7 +163,10 @@ class BaiChuanAttention(nn.Module):
base=self.rope_theta, base=self.rope_theta,
) )
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.attn = Attention(self.num_heads, self.head_dim, self.scaling) self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
cache_config=cache_config)
def forward( def forward(
self, self,
...@@ -185,6 +189,7 @@ class BaiChuanDecoderLayer(nn.Module): ...@@ -185,6 +189,7 @@ class BaiChuanDecoderLayer(nn.Module):
def __init__(self, def __init__(self,
config: PretrainedConfig, config: PretrainedConfig,
position_embedding: str, position_embedding: str,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None): quant_config: Optional[QuantizationConfig] = None):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -197,6 +202,7 @@ class BaiChuanDecoderLayer(nn.Module): ...@@ -197,6 +202,7 @@ class BaiChuanDecoderLayer(nn.Module):
position_embedding=position_embedding, position_embedding=position_embedding,
rope_theta=rope_theta, rope_theta=rope_theta,
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
) )
self.mlp = BaiChuanMLP( self.mlp = BaiChuanMLP(
...@@ -244,6 +250,7 @@ class BaiChuanModel(nn.Module): ...@@ -244,6 +250,7 @@ class BaiChuanModel(nn.Module):
def __init__(self, def __init__(self,
config: PretrainedConfig, config: PretrainedConfig,
position_embedding: str, position_embedding: str,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None): quant_config: Optional[QuantizationConfig] = None):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -255,7 +262,8 @@ class BaiChuanModel(nn.Module): ...@@ -255,7 +262,8 @@ class BaiChuanModel(nn.Module):
config.hidden_size, config.hidden_size,
) )
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
BaiChuanDecoderLayer(config, position_embedding, quant_config) BaiChuanDecoderLayer(config, position_embedding, cache_config,
quant_config)
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
]) ])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
...@@ -304,13 +312,15 @@ class BaiChuanBaseForCausalLM(nn.Module): ...@@ -304,13 +312,15 @@ class BaiChuanBaseForCausalLM(nn.Module):
self, self,
config, config,
position_embedding: str, position_embedding: str,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = BaiChuanModel(config, position_embedding, quant_config) self.model = BaiChuanModel(config, position_embedding, cache_config,
quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
...@@ -389,13 +399,16 @@ class BaichuanForCausalLM(BaiChuanBaseForCausalLM): ...@@ -389,13 +399,16 @@ class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
def __init__( def __init__(
self, self,
config, config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
): ):
if config.hidden_size == 4096: # baichuan2 7b if config.hidden_size == 4096: # baichuan2 7b
super().__init__(config, "ROPE", quant_config, lora_config) super().__init__(config, "ROPE", cache_config, quant_config,
lora_config)
else: # baichuan 13b, baichuan2 13b else: # baichuan 13b, baichuan2 13b
super().__init__(config, "ALIBI", quant_config, lora_config) super().__init__(config, "ALIBI", cache_config, quant_config,
lora_config)
class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
...@@ -404,7 +417,9 @@ class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): ...@@ -404,7 +417,9 @@ class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
def __init__( def __init__(
self, self,
config, config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
): ):
super().__init__(config, "ROPE", quant_config, lora_config) super().__init__(config, "ROPE", cache_config, quant_config,
lora_config)
...@@ -24,6 +24,7 @@ from torch import nn ...@@ -24,6 +24,7 @@ from torch import nn
from transformers import BloomConfig from transformers import BloomConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import (get_tensor_model_parallel_rank, from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
...@@ -71,6 +72,7 @@ class BloomAttention(nn.Module): ...@@ -71,6 +72,7 @@ class BloomAttention(nn.Module):
def __init__( def __init__(
self, self,
config: BloomConfig, config: BloomConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
...@@ -108,7 +110,8 @@ class BloomAttention(nn.Module): ...@@ -108,7 +110,8 @@ class BloomAttention(nn.Module):
self.attn = Attention(self.num_heads, self.attn = Attention(self.num_heads,
self.head_dim, self.head_dim,
scaling, scaling,
alibi_slopes=alibi_slopes) alibi_slopes=alibi_slopes,
cache_config=cache_config)
def forward( def forward(
self, self,
...@@ -158,6 +161,7 @@ class BloomBlock(nn.Module): ...@@ -158,6 +161,7 @@ class BloomBlock(nn.Module):
def __init__( def __init__(
self, self,
config: BloomConfig, config: BloomConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
...@@ -165,7 +169,8 @@ class BloomBlock(nn.Module): ...@@ -165,7 +169,8 @@ class BloomBlock(nn.Module):
self.input_layernorm = nn.LayerNorm(hidden_size, self.input_layernorm = nn.LayerNorm(hidden_size,
eps=config.layer_norm_epsilon) eps=config.layer_norm_epsilon)
self.self_attention = BloomAttention(config, quant_config) self.self_attention = BloomAttention(config, cache_config,
quant_config)
self.post_attention_layernorm = nn.LayerNorm( self.post_attention_layernorm = nn.LayerNorm(
hidden_size, eps=config.layer_norm_epsilon) hidden_size, eps=config.layer_norm_epsilon)
self.mlp = BloomMLP(config, quant_config) self.mlp = BloomMLP(config, quant_config)
...@@ -214,6 +219,7 @@ class BloomModel(nn.Module): ...@@ -214,6 +219,7 @@ class BloomModel(nn.Module):
def __init__( def __init__(
self, self,
config: BloomConfig, config: BloomConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
...@@ -229,7 +235,7 @@ class BloomModel(nn.Module): ...@@ -229,7 +235,7 @@ class BloomModel(nn.Module):
# Transformer blocks # Transformer blocks
self.h = nn.ModuleList([ self.h = nn.ModuleList([
BloomBlock(config, quant_config) BloomBlock(config, cache_config, quant_config)
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
]) ])
...@@ -262,12 +268,13 @@ class BloomForCausalLM(nn.Module): ...@@ -262,12 +268,13 @@ class BloomForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config: BloomConfig, config: BloomConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.transformer = BloomModel(config, quant_config) self.transformer = BloomModel(config, cache_config, quant_config)
self.lm_head_weight = self.transformer.word_embeddings.weight self.lm_head_weight = self.transformer.word_embeddings.weight
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
......
...@@ -9,7 +9,7 @@ from torch import nn ...@@ -9,7 +9,7 @@ from torch import nn
from torch.nn import LayerNorm from torch.nn import LayerNorm
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import LoRAConfig from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
...@@ -34,6 +34,7 @@ class GLMAttention(nn.Module): ...@@ -34,6 +34,7 @@ class GLMAttention(nn.Module):
def __init__( def __init__(
self, self,
config, config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
...@@ -90,6 +91,7 @@ class GLMAttention(nn.Module): ...@@ -90,6 +91,7 @@ class GLMAttention(nn.Module):
self.head_dim, self.head_dim,
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
) )
def forward( def forward(
...@@ -167,6 +169,7 @@ class GLMBlock(nn.Module): ...@@ -167,6 +169,7 @@ class GLMBlock(nn.Module):
def __init__( def __init__(
self, self,
config, config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
...@@ -181,7 +184,7 @@ class GLMBlock(nn.Module): ...@@ -181,7 +184,7 @@ class GLMBlock(nn.Module):
eps=config.layernorm_epsilon) eps=config.layernorm_epsilon)
# Self attention. # Self attention.
self.self_attention = GLMAttention(config, quant_config) self.self_attention = GLMAttention(config, cache_config, quant_config)
self.hidden_dropout = config.hidden_dropout self.hidden_dropout = config.hidden_dropout
# Layernorm on the attention output # Layernorm on the attention output
...@@ -237,6 +240,7 @@ class GLMTransformer(nn.Module): ...@@ -237,6 +240,7 @@ class GLMTransformer(nn.Module):
def __init__( def __init__(
self, self,
config, config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
...@@ -246,8 +250,10 @@ class GLMTransformer(nn.Module): ...@@ -246,8 +250,10 @@ class GLMTransformer(nn.Module):
self.num_layers = config.num_layers self.num_layers = config.num_layers
# Transformer layers. # Transformer layers.
self.layers = nn.ModuleList( self.layers = nn.ModuleList([
[GLMBlock(config, quant_config) for i in range(self.num_layers)]) GLMBlock(config, cache_config, quant_config)
for i in range(self.num_layers)
])
if self.post_layer_norm: if self.post_layer_norm:
layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
...@@ -282,6 +288,7 @@ class ChatGLMModel(nn.Module): ...@@ -282,6 +288,7 @@ class ChatGLMModel(nn.Module):
def __init__( def __init__(
self, self,
config, config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
...@@ -292,7 +299,7 @@ class ChatGLMModel(nn.Module): ...@@ -292,7 +299,7 @@ class ChatGLMModel(nn.Module):
self.num_layers = config.num_layers self.num_layers = config.num_layers
self.multi_query_group_num = config.multi_query_group_num self.multi_query_group_num = config.multi_query_group_num
self.kv_channels = config.kv_channels self.kv_channels = config.kv_channels
self.encoder = GLMTransformer(config, quant_config) self.encoder = GLMTransformer(config, cache_config, quant_config)
self.output_layer = ParallelLMHead(config.padded_vocab_size, self.output_layer = ParallelLMHead(config.padded_vocab_size,
config.hidden_size) config.hidden_size)
...@@ -334,13 +341,14 @@ class ChatGLMForCausalLM(nn.Module): ...@@ -334,13 +341,14 @@ class ChatGLMForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config: ChatGLMConfig, config: ChatGLMConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
): ):
super().__init__() super().__init__()
self.config: ChatGLMConfig = config self.config: ChatGLMConfig = config
self.quant_config = quant_config self.quant_config = quant_config
self.transformer = ChatGLMModel(config, quant_config) self.transformer = ChatGLMModel(config, cache_config, quant_config)
self.lm_head_weight = self.transformer.output_layer.weight self.lm_head_weight = self.transformer.output_layer.weight
self.logits_processor = LogitsProcessor(config.padded_vocab_size) self.logits_processor = LogitsProcessor(config.padded_vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
......
...@@ -29,6 +29,7 @@ from torch.nn.parameter import Parameter ...@@ -29,6 +29,7 @@ from torch.nn.parameter import Parameter
from transformers import CohereConfig from transformers import CohereConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import (get_tensor_model_parallel_rank, from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
...@@ -124,6 +125,7 @@ class CohereAttention(nn.Module): ...@@ -124,6 +125,7 @@ class CohereAttention(nn.Module):
def __init__( def __init__(
self, self,
config: CohereConfig, config: CohereConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
...@@ -180,6 +182,7 @@ class CohereAttention(nn.Module): ...@@ -180,6 +182,7 @@ class CohereAttention(nn.Module):
self.head_dim, self.head_dim,
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
) )
if self.use_qk_norm: if self.use_qk_norm:
self.q_norm = LayerNorm(param_shape=(self.num_heads, self.q_norm = LayerNorm(param_shape=(self.num_heads,
...@@ -219,11 +222,14 @@ class CohereDecoderLayer(nn.Module): ...@@ -219,11 +222,14 @@ class CohereDecoderLayer(nn.Module):
def __init__(self, def __init__(self,
config: CohereConfig, config: CohereConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None): quant_config: Optional[QuantizationConfig] = None):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.self_attn = CohereAttention(config, quant_config=quant_config) self.self_attn = CohereAttention(config,
cache_config,
quant_config=quant_config)
self.mlp = CohereMLP(config, quant_config=quant_config) self.mlp = CohereMLP(config, quant_config=quant_config)
self.input_layernorm = LayerNorm(param_shape=(config.hidden_size), self.input_layernorm = LayerNorm(param_shape=(config.hidden_size),
...@@ -258,6 +264,7 @@ class CohereModel(nn.Module): ...@@ -258,6 +264,7 @@ class CohereModel(nn.Module):
def __init__( def __init__(
self, self,
config: CohereConfig, config: CohereConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
...@@ -266,7 +273,7 @@ class CohereModel(nn.Module): ...@@ -266,7 +273,7 @@ class CohereModel(nn.Module):
self.embed_tokens = VocabParallelEmbedding(config.vocab_size, self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
config.hidden_size) config.hidden_size)
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
CohereDecoderLayer(config, quant_config=quant_config) CohereDecoderLayer(config, cache_config, quant_config=quant_config)
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
]) ])
self.norm = LayerNorm(param_shape=(config.hidden_size), self.norm = LayerNorm(param_shape=(config.hidden_size),
...@@ -299,6 +306,7 @@ class CohereForCausalLM(nn.Module): ...@@ -299,6 +306,7 @@ class CohereForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config: CohereConfig, config: CohereConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -306,7 +314,7 @@ class CohereForCausalLM(nn.Module): ...@@ -306,7 +314,7 @@ class CohereForCausalLM(nn.Module):
self.quant_config = quant_config self.quant_config = quant_config
self.logits_processor = LogitsProcessor(config.vocab_size, self.logits_processor = LogitsProcessor(config.vocab_size,
scale=config.logit_scale) scale=config.logit_scale)
self.model = CohereModel(config, quant_config) self.model = CohereModel(config, cache_config, quant_config)
self.sampler = Sampler() self.sampler = Sampler()
@torch.no_grad() @torch.no_grad()
......
...@@ -5,6 +5,7 @@ import torch ...@@ -5,6 +5,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import (get_tensor_model_parallel_rank, from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
...@@ -166,6 +167,7 @@ class DbrxAttention(nn.Module): ...@@ -166,6 +167,7 @@ class DbrxAttention(nn.Module):
def __init__( def __init__(
self, self,
config: DbrxConfig, config: DbrxConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
...@@ -221,6 +223,7 @@ class DbrxAttention(nn.Module): ...@@ -221,6 +223,7 @@ class DbrxAttention(nn.Module):
self.head_dim, self.head_dim,
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
) )
def forward( def forward(
...@@ -279,10 +282,12 @@ class DbrxBlock(nn.Module): ...@@ -279,10 +282,12 @@ class DbrxBlock(nn.Module):
def __init__( def __init__(
self, self,
config: DbrxConfig, config: DbrxConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.norm_attn_norm = DbrxFusedNormAttention(config, quant_config) self.norm_attn_norm = DbrxFusedNormAttention(config, cache_config,
quant_config)
self.ffn = DbrxExperts(config, quant_config) self.ffn = DbrxExperts(config, quant_config)
def forward( def forward(
...@@ -308,6 +313,7 @@ class DbrxModel(nn.Module): ...@@ -308,6 +313,7 @@ class DbrxModel(nn.Module):
def __init__( def __init__(
self, self,
config: DbrxConfig, config: DbrxConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
...@@ -315,8 +321,10 @@ class DbrxModel(nn.Module): ...@@ -315,8 +321,10 @@ class DbrxModel(nn.Module):
config.vocab_size, config.vocab_size,
config.d_model, config.d_model,
) )
self.blocks = nn.ModuleList( self.blocks = nn.ModuleList([
[DbrxBlock(config, quant_config) for _ in range(config.n_layers)]) DbrxBlock(config, cache_config, quant_config)
for _ in range(config.n_layers)
])
self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5) self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5)
for module in self.modules(): for module in self.modules():
if hasattr(module, "bias") and isinstance(module.bias, if hasattr(module, "bias") and isinstance(module.bias,
...@@ -349,13 +357,14 @@ class DbrxForCausalLM(nn.Module): ...@@ -349,13 +357,14 @@ class DbrxForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config: DbrxConfig, config: DbrxConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.unpadded_vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size
self.transformer = DbrxModel(config, quant_config) self.transformer = DbrxModel(config, cache_config, quant_config)
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
config.vocab_size, config.vocab_size,
config.d_model, config.d_model,
......
...@@ -28,7 +28,7 @@ from typing import Iterable, Optional, Tuple ...@@ -28,7 +28,7 @@ from typing import Iterable, Optional, Tuple
import torch import torch
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.config import LoRAConfig from vllm.config import CacheConfig, LoRAConfig
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...@@ -56,12 +56,14 @@ class DeciLMForCausalLM(LlamaForCausalLM): ...@@ -56,12 +56,14 @@ class DeciLMForCausalLM(LlamaForCausalLM):
def __init__( def __init__(
self, self,
config: Optional[PretrainedConfig] = None, config: Optional[PretrainedConfig] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
) -> None: ) -> None:
config.num_key_value_heads = max(config.num_key_value_heads_per_layer) config.num_key_value_heads = max(config.num_key_value_heads_per_layer)
delattr(config, "num_key_value_heads_per_layer") delattr(config, "num_key_value_heads_per_layer")
super().__init__(config=config, super().__init__(config=config,
cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
lora_config=lora_config) lora_config=lora_config)
......
...@@ -28,6 +28,7 @@ from torch import nn ...@@ -28,6 +28,7 @@ from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import (get_tensor_model_parallel_rank, from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
...@@ -178,6 +179,7 @@ class DeepseekAttention(nn.Module): ...@@ -178,6 +179,7 @@ class DeepseekAttention(nn.Module):
rope_theta: float = 10000, rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None, rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -229,7 +231,8 @@ class DeepseekAttention(nn.Module): ...@@ -229,7 +231,8 @@ class DeepseekAttention(nn.Module):
self.attn = Attention(self.num_heads, self.attn = Attention(self.num_heads,
self.head_dim, self.head_dim,
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads) num_kv_heads=self.num_kv_heads,
cache_config=cache_config)
def forward( def forward(
self, self,
...@@ -252,6 +255,7 @@ class DeepseekDecoderLayer(nn.Module): ...@@ -252,6 +255,7 @@ class DeepseekDecoderLayer(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
layer_idx: int, layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -267,6 +271,7 @@ class DeepseekDecoderLayer(nn.Module): ...@@ -267,6 +271,7 @@ class DeepseekDecoderLayer(nn.Module):
rope_theta=rope_theta, rope_theta=rope_theta,
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
) )
if (config.n_routed_experts is not None if (config.n_routed_experts is not None
...@@ -321,6 +326,7 @@ class DeepseekModel(nn.Module): ...@@ -321,6 +326,7 @@ class DeepseekModel(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -332,7 +338,10 @@ class DeepseekModel(nn.Module): ...@@ -332,7 +338,10 @@ class DeepseekModel(nn.Module):
config.hidden_size, config.hidden_size,
) )
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
DeepseekDecoderLayer(config, layer_idx, quant_config=quant_config) DeepseekDecoderLayer(config,
layer_idx,
cache_config,
quant_config=quant_config)
for layer_idx in range(config.num_hidden_layers) for layer_idx in range(config.num_hidden_layers)
]) ])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
...@@ -360,12 +369,13 @@ class DeepseekForCausalLM(nn.Module): ...@@ -360,12 +369,13 @@ class DeepseekForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = DeepseekModel(config, quant_config) self.model = DeepseekModel(config, cache_config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment