Unverified Commit 0fca3cdc authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[Misc] Enhance attention selector (#4751)

parent e7c46b95
...@@ -30,6 +30,7 @@ from torch import nn ...@@ -30,6 +30,7 @@ from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import (get_tensor_model_parallel_rank, from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
...@@ -187,6 +188,7 @@ class Qwen2MoeAttention(nn.Module): ...@@ -187,6 +188,7 @@ class Qwen2MoeAttention(nn.Module):
rope_theta: float = 10000, rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None, rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -238,7 +240,8 @@ class Qwen2MoeAttention(nn.Module): ...@@ -238,7 +240,8 @@ class Qwen2MoeAttention(nn.Module):
self.attn = Attention(self.num_heads, self.attn = Attention(self.num_heads,
self.head_dim, self.head_dim,
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads) num_kv_heads=self.num_kv_heads,
cache_config=cache_config)
def forward( def forward(
self, self,
...@@ -261,6 +264,7 @@ class Qwen2MoeDecoderLayer(nn.Module): ...@@ -261,6 +264,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
layer_idx: int, layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -276,6 +280,7 @@ class Qwen2MoeDecoderLayer(nn.Module): ...@@ -276,6 +280,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
rope_theta=rope_theta, rope_theta=rope_theta,
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
) )
if (config.num_experts is not None if (config.num_experts is not None
...@@ -328,6 +333,7 @@ class Qwen2MoeModel(nn.Module): ...@@ -328,6 +333,7 @@ class Qwen2MoeModel(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -339,7 +345,10 @@ class Qwen2MoeModel(nn.Module): ...@@ -339,7 +345,10 @@ class Qwen2MoeModel(nn.Module):
config.hidden_size, config.hidden_size,
) )
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
Qwen2MoeDecoderLayer(config, layer_idx, quant_config=quant_config) Qwen2MoeDecoderLayer(config,
layer_idx,
cache_config,
quant_config=quant_config)
for layer_idx in range(config.num_hidden_layers) for layer_idx in range(config.num_hidden_layers)
]) ])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
...@@ -369,12 +378,13 @@ class Qwen2MoeForCausalLM(nn.Module): ...@@ -369,12 +378,13 @@ class Qwen2MoeForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = Qwen2MoeModel(config, quant_config) self.model = Qwen2MoeModel(config, cache_config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
......
...@@ -26,6 +26,7 @@ from torch import nn ...@@ -26,6 +26,7 @@ from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
...@@ -72,6 +73,7 @@ class StablelmAttention(nn.Module): ...@@ -72,6 +73,7 @@ class StablelmAttention(nn.Module):
def __init__(self, def __init__(self,
config: PretrainedConfig, config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None: quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -124,7 +126,8 @@ class StablelmAttention(nn.Module): ...@@ -124,7 +126,8 @@ class StablelmAttention(nn.Module):
self.attn = Attention(self.num_heads, self.attn = Attention(self.num_heads,
self.head_dim, self.head_dim,
self.scaling, self.scaling,
num_kv_heads=self.num_key_value_heads) num_kv_heads=self.num_key_value_heads,
cache_config=cache_config)
def forward( def forward(
self, self,
...@@ -146,10 +149,11 @@ class StablelmDecoderLayer(nn.Module): ...@@ -146,10 +149,11 @@ class StablelmDecoderLayer(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.self_attn = StablelmAttention(config) self.self_attn = StablelmAttention(config, cache_config, quant_config)
self.mlp = StablelmMLP(config, quant_config) self.mlp = StablelmMLP(config, quant_config)
norm_eps = getattr(config, "norm_eps", norm_eps = getattr(config, "norm_eps",
getattr(config, "layer_norm_eps", 1e-05)) getattr(config, "layer_norm_eps", 1e-05))
...@@ -188,6 +192,7 @@ class StableLMEpochModel(nn.Module): ...@@ -188,6 +192,7 @@ class StableLMEpochModel(nn.Module):
def __init__(self, def __init__(self,
config: PretrainedConfig, config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None: quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__() super().__init__()
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
...@@ -195,7 +200,7 @@ class StableLMEpochModel(nn.Module): ...@@ -195,7 +200,7 @@ class StableLMEpochModel(nn.Module):
config.hidden_size, config.hidden_size,
) )
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
StablelmDecoderLayer(config, quant_config) StablelmDecoderLayer(config, cache_config, quant_config)
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
]) ])
norm_eps = getattr(config, "norm_eps", norm_eps = getattr(config, "norm_eps",
...@@ -227,12 +232,13 @@ class StablelmForCausalLM(nn.Module): ...@@ -227,12 +232,13 @@ class StablelmForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = StableLMEpochModel(config, quant_config) self.model = StableLMEpochModel(config, cache_config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
......
...@@ -25,6 +25,7 @@ from torch import nn ...@@ -25,6 +25,7 @@ from torch import nn
from transformers import Starcoder2Config from transformers import Starcoder2Config
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...@@ -46,6 +47,7 @@ class Starcoder2Attention(nn.Module): ...@@ -46,6 +47,7 @@ class Starcoder2Attention(nn.Module):
def __init__(self, def __init__(self,
config: Starcoder2Config, config: Starcoder2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None): quant_config: Optional[QuantizationConfig] = None):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -101,6 +103,7 @@ class Starcoder2Attention(nn.Module): ...@@ -101,6 +103,7 @@ class Starcoder2Attention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
sliding_window=self.sliding_window, sliding_window=self.sliding_window,
cache_config=cache_config,
) )
def forward( def forward(
...@@ -150,10 +153,13 @@ class Starcoder2DecoderLayer(nn.Module): ...@@ -150,10 +153,13 @@ class Starcoder2DecoderLayer(nn.Module):
def __init__(self, def __init__(self,
config: Starcoder2Config, config: Starcoder2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None): quant_config: Optional[QuantizationConfig] = None):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.self_attn = Starcoder2Attention(config, quant_config=quant_config) self.self_attn = Starcoder2Attention(config,
cache_config,
quant_config=quant_config)
self.mlp = Starcoder2MLP(config, quant_config=quant_config) self.mlp = Starcoder2MLP(config, quant_config=quant_config)
self.input_layernorm = nn.LayerNorm(config.hidden_size, self.input_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.norm_epsilon) eps=config.norm_epsilon)
...@@ -191,6 +197,7 @@ class Starcoder2Model(nn.Module): ...@@ -191,6 +197,7 @@ class Starcoder2Model(nn.Module):
def __init__(self, def __init__(self,
config: Starcoder2Config, config: Starcoder2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None): quant_config: Optional[QuantizationConfig] = None):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -201,7 +208,9 @@ class Starcoder2Model(nn.Module): ...@@ -201,7 +208,9 @@ class Starcoder2Model(nn.Module):
self.embed_tokens = VocabParallelEmbedding(config.vocab_size, self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
config.hidden_size) config.hidden_size)
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
Starcoder2DecoderLayer(config, quant_config=quant_config) Starcoder2DecoderLayer(config,
cache_config,
quant_config=quant_config)
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
]) ])
self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
...@@ -226,10 +235,13 @@ class Starcoder2ForCausalLM(nn.Module): ...@@ -226,10 +235,13 @@ class Starcoder2ForCausalLM(nn.Module):
def __init__(self, def __init__(self,
config: Starcoder2Config, config: Starcoder2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None): quant_config: Optional[QuantizationConfig] = None):
super().__init__() super().__init__()
self.config = config self.config = config
self.model = Starcoder2Model(config, quant_config=quant_config) self.model = Starcoder2Model(config,
cache_config,
quant_config=quant_config)
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.unpadded_vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size
if config.tie_word_embeddings: if config.tie_word_embeddings:
......
...@@ -27,7 +27,7 @@ from torch import nn ...@@ -27,7 +27,7 @@ from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import LoRAConfig from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
...@@ -89,6 +89,7 @@ class XverseAttention(nn.Module): ...@@ -89,6 +89,7 @@ class XverseAttention(nn.Module):
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
bias: bool = False, bias: bool = False,
sliding_window: Optional[int] = None, sliding_window: Optional[int] = None,
cache_config: Optional[CacheConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -133,7 +134,8 @@ class XverseAttention(nn.Module): ...@@ -133,7 +134,8 @@ class XverseAttention(nn.Module):
self.head_dim, self.head_dim,
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
sliding_window=sliding_window) sliding_window=sliding_window,
cache_config=cache_config)
def forward( def forward(
self, self,
...@@ -155,6 +157,7 @@ class XverseDecoderLayer(nn.Module): ...@@ -155,6 +157,7 @@ class XverseDecoderLayer(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -175,6 +178,7 @@ class XverseDecoderLayer(nn.Module): ...@@ -175,6 +178,7 @@ class XverseDecoderLayer(nn.Module):
quant_config=quant_config, quant_config=quant_config,
bias=getattr(config, "bias", False), bias=getattr(config, "bias", False),
sliding_window=sliding_window, sliding_window=sliding_window,
cache_config=cache_config,
) )
self.mlp = XverseMLP( self.mlp = XverseMLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
...@@ -221,6 +225,7 @@ class XverseModel(nn.Module): ...@@ -221,6 +225,7 @@ class XverseModel(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
) -> None: ) -> None:
...@@ -237,7 +242,7 @@ class XverseModel(nn.Module): ...@@ -237,7 +242,7 @@ class XverseModel(nn.Module):
org_num_embeddings=config.vocab_size, org_num_embeddings=config.vocab_size,
) )
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
XverseDecoderLayer(config, quant_config) XverseDecoderLayer(config, cache_config, quant_config)
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
]) ])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
...@@ -295,13 +300,14 @@ class XverseForCausalLM(nn.Module): ...@@ -295,13 +300,14 @@ class XverseForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config=None, lora_config=None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = XverseModel(config, quant_config) self.model = XverseModel(config, cache_config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
......
...@@ -31,7 +31,7 @@ class CacheEngine: ...@@ -31,7 +31,7 @@ class CacheEngine:
self.head_size = model_config.get_head_size() self.head_size = model_config.get_head_size()
self.num_layers = model_config.get_num_layers(parallel_config) self.num_layers = model_config.get_num_layers(parallel_config)
self.num_heads = model_config.get_num_kv_heads(parallel_config) self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
self.block_size = cache_config.block_size self.block_size = cache_config.block_size
self.num_gpu_blocks = cache_config.num_gpu_blocks self.num_gpu_blocks = cache_config.num_gpu_blocks
...@@ -43,7 +43,15 @@ class CacheEngine: ...@@ -43,7 +43,15 @@ class CacheEngine:
self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
# Get attention backend. # Get attention backend.
self.attn_backend = get_attn_backend(model_config.dtype) self.attn_backend = get_attn_backend(
model_config.get_num_attention_heads(parallel_config),
self.head_size,
self.num_kv_heads,
model_config.get_sliding_window(),
model_config.dtype,
cache_config.cache_dtype,
self.block_size,
)
# Initialize the cache. # Initialize the cache.
self.gpu_cache = self._allocate_kv_cache(self.num_gpu_blocks, "cuda") self.gpu_cache = self._allocate_kv_cache(self.num_gpu_blocks, "cuda")
...@@ -56,7 +64,7 @@ class CacheEngine: ...@@ -56,7 +64,7 @@ class CacheEngine:
) -> List[torch.Tensor]: ) -> List[torch.Tensor]:
"""Allocates KV cache on the specified device.""" """Allocates KV cache on the specified device."""
kv_cache_shape = self.attn_backend.get_kv_cache_shape( kv_cache_shape = self.attn_backend.get_kv_cache_shape(
num_blocks, self.block_size, self.num_heads, self.head_size) num_blocks, self.block_size, self.num_kv_heads, self.head_size)
pin_memory = is_pin_memory_available() if device == "cpu" else False pin_memory = is_pin_memory_available() if device == "cpu" else False
kv_cache: List[torch.Tensor] = [] kv_cache: List[torch.Tensor] = []
for _ in range(self.num_layers): for _ in range(self.num_layers):
......
...@@ -53,7 +53,15 @@ class CPUModelRunner: ...@@ -53,7 +53,15 @@ class CPUModelRunner:
self.kv_cache_dtype = kv_cache_dtype self.kv_cache_dtype = kv_cache_dtype
self.sliding_window = model_config.get_sliding_window() self.sliding_window = model_config.get_sliding_window()
self.block_size = cache_config.block_size self.block_size = cache_config.block_size
self.attn_backend = get_attn_backend(self.model_config.dtype) self.attn_backend = get_attn_backend(
self.model_config.get_num_attention_heads(self.parallel_config),
self.model_config.get_head_size(),
self.model_config.get_num_kv_heads(self.parallel_config),
self.model_config.get_sliding_window(),
self.model_config.dtype,
self.kv_cache_dtype,
self.block_size,
)
# Lazy initialization. # Lazy initialization.
self.model: nn.Module # Set after init_Model self.model: nn.Module # Set after init_Model
...@@ -66,7 +74,8 @@ class CPUModelRunner: ...@@ -66,7 +74,8 @@ class CPUModelRunner:
vision_language_config=self.vision_language_config, vision_language_config=self.vision_language_config,
lora_config=self.lora_config, lora_config=self.lora_config,
parallel_config=self.parallel_config, parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config) scheduler_config=self.scheduler_config,
cache_config=self.cache_config)
def _prepare_prompt( def _prepare_prompt(
self, self,
...@@ -158,7 +167,6 @@ class CPUModelRunner: ...@@ -158,7 +167,6 @@ class CPUModelRunner:
decode_metadata=None, decode_metadata=None,
block_tables=torch.tensor([]), block_tables=torch.tensor([]),
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
kv_cache_dtype=self.kv_cache_dtype,
) )
return (input_tokens, input_positions, attn_metadata, seq_lens, return (input_tokens, input_positions, attn_metadata, seq_lens,
multi_modal_input) multi_modal_input)
...@@ -242,7 +250,6 @@ class CPUModelRunner: ...@@ -242,7 +250,6 @@ class CPUModelRunner:
prefill_metadata=None, prefill_metadata=None,
decode_metadata=None, decode_metadata=None,
block_tables=block_tables, block_tables=block_tables,
kv_cache_dtype=self.kv_cache_dtype,
) )
return ( return (
input_tokens, input_tokens,
......
...@@ -53,7 +53,15 @@ class CPUCacheEngine: ...@@ -53,7 +53,15 @@ class CPUCacheEngine:
self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
# Get attention backend. # Get attention backend.
self.attn_backend = get_attn_backend(model_config.dtype) self.attn_backend = get_attn_backend(
self.model_config.get_num_attention_heads(self.parallel_config),
self.model_config.get_head_size(),
self.model_config.get_num_kv_heads(self.parallel_config),
self.model_config.get_sliding_window(),
self.model_config.dtype,
cache_config.cache_dtype,
self.block_size,
)
# Initialize the cache. # Initialize the cache.
self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks) self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks)
......
...@@ -235,7 +235,6 @@ class EmbeddingModelRunner(ModelRunner): ...@@ -235,7 +235,6 @@ class EmbeddingModelRunner(ModelRunner):
num_decode_tokens=num_decode_tokens, num_decode_tokens=num_decode_tokens,
prefill_metadata=prefill_attn_metadata, prefill_metadata=prefill_attn_metadata,
decode_metadata=decode_attn_metadata, decode_metadata=decode_attn_metadata,
kv_cache_dtype=self.kv_cache_dtype,
) )
return (input_tokens, input_positions, attn_metadata, pooling_metadata, return (input_tokens, input_positions, attn_metadata, pooling_metadata,
......
...@@ -141,10 +141,18 @@ class ModelRunner: ...@@ -141,10 +141,18 @@ class ModelRunner:
self.graph_block_tables = np.zeros( self.graph_block_tables = np.zeros(
(max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()), (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()),
dtype=np.int32) dtype=np.int32)
self.attn_backend = get_attn_backend(self.model_config.dtype) self.attn_backend = get_attn_backend(
self.model_config.get_num_attention_heads(self.parallel_config),
self.model_config.get_head_size(),
self.model_config.get_num_kv_heads(self.parallel_config),
self.model_config.get_sliding_window(),
self.model_config.dtype,
self.kv_cache_dtype,
self.block_size,
)
# Lazy initialization # Lazy initialization
self.model: torch.nn.Module # Set after load_model self.model: nn.Module # Set after load_model
# Set if the backend is flashinfer. # Set if the backend is flashinfer.
self.flashinfer_workspace_buffer: torch.Tensor self.flashinfer_workspace_buffer: torch.Tensor
# Set after load_model. # Set after load_model.
...@@ -160,6 +168,7 @@ class ModelRunner: ...@@ -160,6 +168,7 @@ class ModelRunner:
vision_language_config=self.vision_language_config, vision_language_config=self.vision_language_config,
parallel_config=self.parallel_config, parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config, scheduler_config=self.scheduler_config,
cache_config=self.cache_config,
) )
self.model_memory_usage = m.consumed_memory self.model_memory_usage = m.consumed_memory
...@@ -753,7 +762,6 @@ class ModelRunner: ...@@ -753,7 +762,6 @@ class ModelRunner:
num_decode_tokens=num_decode_tokens, num_decode_tokens=num_decode_tokens,
prefill_metadata=prefill_attn_metadata, prefill_metadata=prefill_attn_metadata,
decode_metadata=decode_attn_metadata, decode_metadata=decode_attn_metadata,
kv_cache_dtype=self.kv_cache_dtype,
) )
return (input_tokens, input_positions, attn_metadata, return (input_tokens, input_positions, attn_metadata,
...@@ -965,7 +973,6 @@ class ModelRunner: ...@@ -965,7 +973,6 @@ class ModelRunner:
slot_mapping=slot_mapping[:batch_size], slot_mapping=slot_mapping[:batch_size],
prefill_metadata=None, prefill_metadata=None,
decode_metadata=decode_metadata, decode_metadata=decode_metadata,
kv_cache_dtype=self.kv_cache_dtype,
) )
if self.lora_config: if self.lora_config:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment