Unverified Commit 0fca3cdc authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[Misc] Enhance attention selector (#4751)

parent e7c46b95
......@@ -30,6 +30,7 @@ from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
......@@ -187,6 +188,7 @@ class Qwen2MoeAttention(nn.Module):
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
......@@ -238,7 +240,8 @@ class Qwen2MoeAttention(nn.Module):
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads)
num_kv_heads=self.num_kv_heads,
cache_config=cache_config)
def forward(
self,
......@@ -261,6 +264,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
self,
config: PretrainedConfig,
layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
......@@ -276,6 +280,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
)
if (config.num_experts is not None
......@@ -328,6 +333,7 @@ class Qwen2MoeModel(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
......@@ -339,7 +345,10 @@ class Qwen2MoeModel(nn.Module):
config.hidden_size,
)
self.layers = nn.ModuleList([
Qwen2MoeDecoderLayer(config, layer_idx, quant_config=quant_config)
Qwen2MoeDecoderLayer(config,
layer_idx,
cache_config,
quant_config=quant_config)
for layer_idx in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
......@@ -369,12 +378,13 @@ class Qwen2MoeForCausalLM(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
self.quant_config = quant_config
self.model = Qwen2MoeModel(config, quant_config)
self.model = Qwen2MoeModel(config, cache_config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
......
......@@ -26,6 +26,7 @@ from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
......@@ -72,6 +73,7 @@ class StablelmAttention(nn.Module):
def __init__(self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__()
self.config = config
......@@ -124,7 +126,8 @@ class StablelmAttention(nn.Module):
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_key_value_heads)
num_kv_heads=self.num_key_value_heads,
cache_config=cache_config)
def forward(
self,
......@@ -146,10 +149,11 @@ class StablelmDecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.self_attn = StablelmAttention(config)
self.self_attn = StablelmAttention(config, cache_config, quant_config)
self.mlp = StablelmMLP(config, quant_config)
norm_eps = getattr(config, "norm_eps",
getattr(config, "layer_norm_eps", 1e-05))
......@@ -188,6 +192,7 @@ class StableLMEpochModel(nn.Module):
def __init__(self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__()
self.embed_tokens = VocabParallelEmbedding(
......@@ -195,7 +200,7 @@ class StableLMEpochModel(nn.Module):
config.hidden_size,
)
self.layers = nn.ModuleList([
StablelmDecoderLayer(config, quant_config)
StablelmDecoderLayer(config, cache_config, quant_config)
for _ in range(config.num_hidden_layers)
])
norm_eps = getattr(config, "norm_eps",
......@@ -227,12 +232,13 @@ class StablelmForCausalLM(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
self.quant_config = quant_config
self.model = StableLMEpochModel(config, quant_config)
self.model = StableLMEpochModel(config, cache_config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
......
......@@ -25,6 +25,7 @@ from torch import nn
from transformers import Starcoder2Config
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
......@@ -46,6 +47,7 @@ class Starcoder2Attention(nn.Module):
def __init__(self,
config: Starcoder2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.config = config
......@@ -101,6 +103,7 @@ class Starcoder2Attention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
sliding_window=self.sliding_window,
cache_config=cache_config,
)
def forward(
......@@ -150,10 +153,13 @@ class Starcoder2DecoderLayer(nn.Module):
def __init__(self,
config: Starcoder2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = Starcoder2Attention(config, quant_config=quant_config)
self.self_attn = Starcoder2Attention(config,
cache_config,
quant_config=quant_config)
self.mlp = Starcoder2MLP(config, quant_config=quant_config)
self.input_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.norm_epsilon)
......@@ -191,6 +197,7 @@ class Starcoder2Model(nn.Module):
def __init__(self,
config: Starcoder2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.config = config
......@@ -201,7 +208,9 @@ class Starcoder2Model(nn.Module):
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
config.hidden_size)
self.layers = nn.ModuleList([
Starcoder2DecoderLayer(config, quant_config=quant_config)
Starcoder2DecoderLayer(config,
cache_config,
quant_config=quant_config)
for _ in range(config.num_hidden_layers)
])
self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
......@@ -226,10 +235,13 @@ class Starcoder2ForCausalLM(nn.Module):
def __init__(self,
config: Starcoder2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.config = config
self.model = Starcoder2Model(config, quant_config=quant_config)
self.model = Starcoder2Model(config,
cache_config,
quant_config=quant_config)
self.vocab_size = config.vocab_size
self.unpadded_vocab_size = config.vocab_size
if config.tie_word_embeddings:
......
......@@ -27,7 +27,7 @@ from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import LoRAConfig
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
......@@ -89,6 +89,7 @@ class XverseAttention(nn.Module):
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
sliding_window: Optional[int] = None,
cache_config: Optional[CacheConfig] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
......@@ -133,7 +134,8 @@ class XverseAttention(nn.Module):
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
sliding_window=sliding_window)
sliding_window=sliding_window,
cache_config=cache_config)
def forward(
self,
......@@ -155,6 +157,7 @@ class XverseDecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
......@@ -175,6 +178,7 @@ class XverseDecoderLayer(nn.Module):
quant_config=quant_config,
bias=getattr(config, "bias", False),
sliding_window=sliding_window,
cache_config=cache_config,
)
self.mlp = XverseMLP(
hidden_size=self.hidden_size,
......@@ -221,6 +225,7 @@ class XverseModel(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
......@@ -237,7 +242,7 @@ class XverseModel(nn.Module):
org_num_embeddings=config.vocab_size,
)
self.layers = nn.ModuleList([
XverseDecoderLayer(config, quant_config)
XverseDecoderLayer(config, cache_config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
......@@ -295,13 +300,14 @@ class XverseForCausalLM(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config=None,
) -> None:
super().__init__()
self.config = config
self.quant_config = quant_config
self.model = XverseModel(config, quant_config)
self.model = XverseModel(config, cache_config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
......
......@@ -31,7 +31,7 @@ class CacheEngine:
self.head_size = model_config.get_head_size()
self.num_layers = model_config.get_num_layers(parallel_config)
self.num_heads = model_config.get_num_kv_heads(parallel_config)
self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
self.block_size = cache_config.block_size
self.num_gpu_blocks = cache_config.num_gpu_blocks
......@@ -43,7 +43,15 @@ class CacheEngine:
self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
# Get attention backend.
self.attn_backend = get_attn_backend(model_config.dtype)
self.attn_backend = get_attn_backend(
model_config.get_num_attention_heads(parallel_config),
self.head_size,
self.num_kv_heads,
model_config.get_sliding_window(),
model_config.dtype,
cache_config.cache_dtype,
self.block_size,
)
# Initialize the cache.
self.gpu_cache = self._allocate_kv_cache(self.num_gpu_blocks, "cuda")
......@@ -56,7 +64,7 @@ class CacheEngine:
) -> List[torch.Tensor]:
"""Allocates KV cache on the specified device."""
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
num_blocks, self.block_size, self.num_heads, self.head_size)
num_blocks, self.block_size, self.num_kv_heads, self.head_size)
pin_memory = is_pin_memory_available() if device == "cpu" else False
kv_cache: List[torch.Tensor] = []
for _ in range(self.num_layers):
......
......@@ -53,7 +53,15 @@ class CPUModelRunner:
self.kv_cache_dtype = kv_cache_dtype
self.sliding_window = model_config.get_sliding_window()
self.block_size = cache_config.block_size
self.attn_backend = get_attn_backend(self.model_config.dtype)
self.attn_backend = get_attn_backend(
self.model_config.get_num_attention_heads(self.parallel_config),
self.model_config.get_head_size(),
self.model_config.get_num_kv_heads(self.parallel_config),
self.model_config.get_sliding_window(),
self.model_config.dtype,
self.kv_cache_dtype,
self.block_size,
)
# Lazy initialization.
self.model: nn.Module # Set after init_Model
......@@ -66,7 +74,8 @@ class CPUModelRunner:
vision_language_config=self.vision_language_config,
lora_config=self.lora_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config)
scheduler_config=self.scheduler_config,
cache_config=self.cache_config)
def _prepare_prompt(
self,
......@@ -158,7 +167,6 @@ class CPUModelRunner:
decode_metadata=None,
block_tables=torch.tensor([]),
slot_mapping=slot_mapping,
kv_cache_dtype=self.kv_cache_dtype,
)
return (input_tokens, input_positions, attn_metadata, seq_lens,
multi_modal_input)
......@@ -242,7 +250,6 @@ class CPUModelRunner:
prefill_metadata=None,
decode_metadata=None,
block_tables=block_tables,
kv_cache_dtype=self.kv_cache_dtype,
)
return (
input_tokens,
......
......@@ -53,7 +53,15 @@ class CPUCacheEngine:
self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
# Get attention backend.
self.attn_backend = get_attn_backend(model_config.dtype)
self.attn_backend = get_attn_backend(
self.model_config.get_num_attention_heads(self.parallel_config),
self.model_config.get_head_size(),
self.model_config.get_num_kv_heads(self.parallel_config),
self.model_config.get_sliding_window(),
self.model_config.dtype,
cache_config.cache_dtype,
self.block_size,
)
# Initialize the cache.
self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks)
......
......@@ -235,7 +235,6 @@ class EmbeddingModelRunner(ModelRunner):
num_decode_tokens=num_decode_tokens,
prefill_metadata=prefill_attn_metadata,
decode_metadata=decode_attn_metadata,
kv_cache_dtype=self.kv_cache_dtype,
)
return (input_tokens, input_positions, attn_metadata, pooling_metadata,
......
......@@ -141,10 +141,18 @@ class ModelRunner:
self.graph_block_tables = np.zeros(
(max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()),
dtype=np.int32)
self.attn_backend = get_attn_backend(self.model_config.dtype)
self.attn_backend = get_attn_backend(
self.model_config.get_num_attention_heads(self.parallel_config),
self.model_config.get_head_size(),
self.model_config.get_num_kv_heads(self.parallel_config),
self.model_config.get_sliding_window(),
self.model_config.dtype,
self.kv_cache_dtype,
self.block_size,
)
# Lazy initialization
self.model: torch.nn.Module # Set after load_model
self.model: nn.Module # Set after load_model
# Set if the backend is flashinfer.
self.flashinfer_workspace_buffer: torch.Tensor
# Set after load_model.
......@@ -160,6 +168,7 @@ class ModelRunner:
vision_language_config=self.vision_language_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
cache_config=self.cache_config,
)
self.model_memory_usage = m.consumed_memory
......@@ -753,7 +762,6 @@ class ModelRunner:
num_decode_tokens=num_decode_tokens,
prefill_metadata=prefill_attn_metadata,
decode_metadata=decode_attn_metadata,
kv_cache_dtype=self.kv_cache_dtype,
)
return (input_tokens, input_positions, attn_metadata,
......@@ -965,7 +973,6 @@ class ModelRunner:
slot_mapping=slot_mapping[:batch_size],
prefill_metadata=None,
decode_metadata=decode_metadata,
kv_cache_dtype=self.kv_cache_dtype,
)
if self.lora_config:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment