Unverified Commit cdc1fa12 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Remove unused kwargs from model definitions (#13555)

parent f61528d4
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from functools import partial from functools import partial
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union from typing import Any, Dict, Iterable, Optional, Set, Tuple, Type, Union
import torch import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
...@@ -175,13 +175,11 @@ class InternLM2Attention(nn.Module): ...@@ -175,13 +175,11 @@ class InternLM2Attention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.wqkv(hidden_states) qkv, _ = self.wqkv(hidden_states)
q, k, v = self.split_qkv(qkv) q, k, v = self.split_qkv(qkv)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
output, _ = self.wo(attn_output) output, _ = self.wo(attn_output)
return output return output
...@@ -227,8 +225,6 @@ class InternLMDecoderLayer(nn.Module): ...@@ -227,8 +225,6 @@ class InternLMDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
...@@ -241,8 +237,6 @@ class InternLMDecoderLayer(nn.Module): ...@@ -241,8 +237,6 @@ class InternLMDecoderLayer(nn.Module):
hidden_states = self.attention( hidden_states = self.attention(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
# Fully Connected # Fully Connected
...@@ -290,8 +284,6 @@ class InternLM2Model(nn.Module): ...@@ -290,8 +284,6 @@ class InternLM2Model(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
...@@ -305,15 +297,8 @@ class InternLM2Model(nn.Module): ...@@ -305,15 +297,8 @@ class InternLM2Model(nn.Module):
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer): for layer in self.layers[self.start_layer:self.end_layer]:
layer = self.layers[i] hidden_states, residual = layer(positions, hidden_states, residual)
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
residual,
)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({ return IntermediateTensors({
"hidden_states": hidden_states, "hidden_states": hidden_states,
...@@ -363,13 +348,10 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA): ...@@ -363,13 +348,10 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, intermediate_tensors,
attn_metadata, intermediate_tensors,
inputs_embeds) inputs_embeds)
return hidden_states return hidden_states
...@@ -466,13 +448,10 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM): ...@@ -466,13 +448,10 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, intermediate_tensors,
attn_metadata, intermediate_tensors,
inputs_embeds) inputs_embeds)
logits, _ = self.v_head(hidden_states) logits, _ = self.v_head(hidden_states)
return logits return logits
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import List, Optional, Tuple, Union from typing import Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group from vllm.distributed import get_pp_group
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
...@@ -65,8 +64,6 @@ class InternLM2VEDecoderLayer(nn.Module): ...@@ -65,8 +64,6 @@ class InternLM2VEDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
visual_token_mask: Optional[torch.Tensor] = None, visual_token_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
...@@ -80,8 +77,6 @@ class InternLM2VEDecoderLayer(nn.Module): ...@@ -80,8 +77,6 @@ class InternLM2VEDecoderLayer(nn.Module):
hidden_states = self.attention( hidden_states = self.attention(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
# Fully Connected # Fully Connected
...@@ -113,8 +108,6 @@ class InternLM2VEModel(InternLM2Model): ...@@ -113,8 +108,6 @@ class InternLM2VEModel(InternLM2Model):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
visual_token_mask: Optional[torch.Tensor] = None, visual_token_mask: Optional[torch.Tensor] = None,
...@@ -129,13 +122,10 @@ class InternLM2VEModel(InternLM2Model): ...@@ -129,13 +122,10 @@ class InternLM2VEModel(InternLM2Model):
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer): for layer in self.layers[self.start_layer:self.end_layer]:
layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
residual, residual,
visual_token_mask=visual_token_mask, visual_token_mask=visual_token_mask,
) )
......
...@@ -17,7 +17,6 @@ import torchvision.transforms as T ...@@ -17,7 +17,6 @@ import torchvision.transforms as T
from PIL import Image from PIL import Image
from transformers import BatchFeature, PretrainedConfig, TensorType from transformers import BatchFeature, PretrainedConfig, TensorType
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.quantization.awq import AWQConfig
...@@ -929,8 +928,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -929,8 +928,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object, **kwargs: object,
...@@ -951,8 +948,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -951,8 +948,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
forward_kwargs = { forward_kwargs = {
"input_ids": input_ids, "input_ids": input_ids,
"positions": positions, "positions": positions,
"kv_caches": kv_caches,
"attn_metadata": attn_metadata,
"intermediate_tensors": intermediate_tensors, "intermediate_tensors": intermediate_tensors,
"inputs_embeds": inputs_embeds, "inputs_embeds": inputs_embeds,
} }
......
...@@ -21,12 +21,12 @@ ...@@ -21,12 +21,12 @@
"""Inference-only Jais model compatible with HuggingFace weights.""" """Inference-only Jais model compatible with HuggingFace weights."""
import math import math
from typing import Iterable, List, Optional, Set, Tuple, Union from typing import Iterable, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
...@@ -123,12 +123,10 @@ class JAISAttention(nn.Module): ...@@ -123,12 +123,10 @@ class JAISAttention(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.c_attn(hidden_states) qkv, _ = self.c_attn(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
attn_output, _ = self.c_proj(attn_output) attn_output, _ = self.c_proj(attn_output)
return attn_output return attn_output
...@@ -200,16 +198,10 @@ class JAISBlock(nn.Module): ...@@ -200,16 +198,10 @@ class JAISBlock(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
residual = hidden_states residual = hidden_states
hidden_states = self.ln_1(hidden_states) hidden_states = self.ln_1(hidden_states)
attn_output = self.attn( attn_output = self.attn(hidden_states=hidden_states, )
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
# residual connection # residual connection
hidden_states = attn_output + residual hidden_states = attn_output + residual
...@@ -266,8 +258,6 @@ class JAISModel(nn.Module): ...@@ -266,8 +258,6 @@ class JAISModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[IntermediateTensors, torch.Tensor]: ) -> Union[IntermediateTensors, torch.Tensor]:
...@@ -285,11 +275,8 @@ class JAISModel(nn.Module): ...@@ -285,11 +275,8 @@ class JAISModel(nn.Module):
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer): for layer in self.h[self.start_layer:self.end_layer]:
layer = self.h[i] hidden_states = layer(hidden_states)
hidden_states = layer(hidden_states,
kv_caches[i - self.start_layer],
attn_metadata)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states}) return IntermediateTensors({"hidden_states": hidden_states})
...@@ -332,14 +319,11 @@ class JAISLMHeadModel(nn.Module, SupportsPP): ...@@ -332,14 +319,11 @@ class JAISLMHeadModel(nn.Module, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[IntermediateTensors, torch.Tensor]: ) -> Union[IntermediateTensors, torch.Tensor]:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions,
attn_metadata, intermediate_tensors, intermediate_tensors, inputs_embeds)
inputs_embeds)
return hidden_states return hidden_states
def compute_logits( def compute_logits(
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""Inference-only Jamba model.""" """Inference-only Jamba model."""
from typing import Iterable, List, Optional, Set, Tuple from typing import Iterable, Optional, Set, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers import JambaConfig from transformers import JambaConfig
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
...@@ -138,7 +137,6 @@ class JambaMambaDecoderLayer(nn.Module): ...@@ -138,7 +137,6 @@ class JambaMambaDecoderLayer(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
mamba_cache_params: MambaCacheParams, mamba_cache_params: MambaCacheParams,
**kwargs, **kwargs,
...@@ -150,8 +148,7 @@ class JambaMambaDecoderLayer(nn.Module): ...@@ -150,8 +148,7 @@ class JambaMambaDecoderLayer(nn.Module):
hidden_states, residual = self.input_layernorm( hidden_states, residual = self.input_layernorm(
hidden_states, residual) hidden_states, residual)
hidden_states = self.mamba(hidden_states, attn_metadata, hidden_states = self.mamba(hidden_states, mamba_cache_params)
mamba_cache_params)
# Fully Connected # Fully Connected
hidden_states, residual = self.pre_ff_layernorm( hidden_states, residual = self.pre_ff_layernorm(
hidden_states, residual) hidden_states, residual)
...@@ -223,13 +220,11 @@ class JambaAttentionDecoderLayer(nn.Module): ...@@ -223,13 +220,11 @@ class JambaAttentionDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
...@@ -237,8 +232,6 @@ class JambaAttentionDecoderLayer(nn.Module): ...@@ -237,8 +232,6 @@ class JambaAttentionDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
**kwargs, **kwargs,
): ):
...@@ -252,8 +245,6 @@ class JambaAttentionDecoderLayer(nn.Module): ...@@ -252,8 +245,6 @@ class JambaAttentionDecoderLayer(nn.Module):
hidden_states = self.self_attention( hidden_states = self.self_attention(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
# Fully Connected # Fully Connected
hidden_states, residual = self.pre_ff_layernorm( hidden_states, residual = self.pre_ff_layernorm(
...@@ -320,8 +311,6 @@ class JambaModel(nn.Module): ...@@ -320,8 +311,6 @@ class JambaModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
mamba_cache_params: MambaCacheParams, mamba_cache_params: MambaCacheParams,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
...@@ -339,12 +328,9 @@ class JambaModel(nn.Module): ...@@ -339,12 +328,9 @@ class JambaModel(nn.Module):
kv_cache_index = 0 kv_cache_index = 0
mamba_cache_index = 0 mamba_cache_index = 0
for i in range(self.start_layer, self.end_layer): for layer in self.layers[self.start_layer:self.end_layer]:
layer = self.layers[i]
kv_cache = None
layer_mamba_cache_params = None layer_mamba_cache_params = None
if isinstance(layer, JambaAttentionDecoderLayer): if isinstance(layer, JambaAttentionDecoderLayer):
kv_cache = kv_caches[kv_cache_index]
kv_cache_index += 1 kv_cache_index += 1
if isinstance(layer, JambaMambaDecoderLayer): if isinstance(layer, JambaMambaDecoderLayer):
current_state_layer = mamba_cache_index current_state_layer = mamba_cache_index
...@@ -355,8 +341,6 @@ class JambaModel(nn.Module): ...@@ -355,8 +341,6 @@ class JambaModel(nn.Module):
hidden_states, residual = layer( hidden_states, residual = layer(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
residual=residual, residual=residual,
mamba_cache_params=layer_mamba_cache_params) mamba_cache_params=layer_mamba_cache_params)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
...@@ -429,8 +413,6 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, ...@@ -429,8 +413,6 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
def forward(self, def forward(self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs): **kwargs):
...@@ -443,8 +425,7 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, ...@@ -443,8 +425,7 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, mamba_cache_params,
attn_metadata, mamba_cache_params,
intermediate_tensors, inputs_embeds) intermediate_tensors, inputs_embeds)
return hidden_states return hidden_states
......
...@@ -22,13 +22,13 @@ ...@@ -22,13 +22,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only LLaMA model compatible with HuggingFace weights.""" """Inference-only LLaMA model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union from typing import Any, Dict, Iterable, Optional, Set, Tuple, Type, Union
import torch import torch
from torch import nn from torch import nn
from transformers import LlamaConfig from transformers import LlamaConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
...@@ -197,13 +197,11 @@ class LlamaAttention(nn.Module): ...@@ -197,13 +197,11 @@ class LlamaAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
...@@ -268,8 +266,6 @@ class LlamaDecoderLayer(nn.Module): ...@@ -268,8 +266,6 @@ class LlamaDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
...@@ -280,9 +276,7 @@ class LlamaDecoderLayer(nn.Module): ...@@ -280,9 +276,7 @@ class LlamaDecoderLayer(nn.Module):
hidden_states, residual = self.input_layernorm( hidden_states, residual = self.input_layernorm(
hidden_states, residual) hidden_states, residual)
hidden_states = self.self_attn(positions=positions, hidden_states = self.self_attn(positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states)
kv_cache=kv_cache,
attn_metadata=attn_metadata)
# Fully Connected # Fully Connected
hidden_states, residual = self.post_attention_layernorm( hidden_states, residual = self.post_attention_layernorm(
...@@ -347,8 +341,6 @@ class LlamaModel(nn.Module): ...@@ -347,8 +341,6 @@ class LlamaModel(nn.Module):
self, self,
input_ids: Optional[torch.Tensor], input_ids: Optional[torch.Tensor],
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
...@@ -363,11 +355,8 @@ class LlamaModel(nn.Module): ...@@ -363,11 +355,8 @@ class LlamaModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer): for layer in self.layers[self.start_layer:self.end_layer]:
layer = self.layers[i] hidden_states, residual = layer(positions, hidden_states, residual)
hidden_states, residual = layer(positions, hidden_states,
kv_caches[i - self.start_layer],
attn_metadata, residual)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({ return IntermediateTensors({
...@@ -535,13 +524,10 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -535,13 +524,10 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
model_output = self.model(input_ids, positions, kv_caches, model_output = self.model(input_ids, positions, intermediate_tensors,
attn_metadata, intermediate_tensors,
inputs_embeds) inputs_embeds)
return model_output return model_output
......
...@@ -15,7 +15,6 @@ from transformers import __version__ as TRANSFORMERS_VERSION ...@@ -15,7 +15,6 @@ from transformers import __version__ as TRANSFORMERS_VERSION
from transformers.models.llava import LlavaProcessor from transformers.models.llava import LlavaProcessor
from transformers.models.pixtral import PixtralProcessor from transformers.models.pixtral import PixtralProcessor
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.inputs import InputProcessingContext from vllm.inputs import InputProcessingContext
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
...@@ -658,8 +657,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -658,8 +657,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object, **kwargs: object,
...@@ -712,8 +709,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -712,8 +709,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
hidden_states = self.language_model.model(input_ids, hidden_states = self.language_model.model(input_ids,
positions, positions,
kv_caches,
attn_metadata,
intermediate_tensors, intermediate_tensors,
inputs_embeds=inputs_embeds) inputs_embeds=inputs_embeds)
......
...@@ -12,7 +12,6 @@ from transformers.models.llava_next.modeling_llava_next import ( ...@@ -12,7 +12,6 @@ from transformers.models.llava_next.modeling_llava_next import (
get_anyres_image_grid_shape, unpad_image) get_anyres_image_grid_shape, unpad_image)
from typing_extensions import NotRequired from typing_extensions import NotRequired
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
...@@ -508,8 +507,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -508,8 +507,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object, **kwargs: object,
...@@ -571,8 +568,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -571,8 +568,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
hidden_states = self.language_model.model(input_ids, hidden_states = self.language_model.model(input_ids,
positions, positions,
kv_caches,
attn_metadata,
intermediate_tensors, intermediate_tensors,
inputs_embeds=inputs_embeds) inputs_embeds=inputs_embeds)
return hidden_states return hidden_states
......
...@@ -10,7 +10,6 @@ import torch.nn as nn ...@@ -10,7 +10,6 @@ import torch.nn as nn
from transformers import (BatchFeature, LlavaNextVideoConfig, from transformers import (BatchFeature, LlavaNextVideoConfig,
LlavaNextVideoProcessor) LlavaNextVideoProcessor)
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
...@@ -443,8 +442,6 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -443,8 +442,6 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object, **kwargs: object,
...@@ -468,8 +465,6 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -468,8 +465,6 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
hidden_states = self.language_model.model(input_ids, hidden_states = self.language_model.model(input_ids,
positions, positions,
kv_caches,
attn_metadata,
intermediate_tensors, intermediate_tensors,
inputs_embeds=inputs_embeds) inputs_embeds=inputs_embeds)
......
...@@ -13,7 +13,6 @@ from transformers.models.llava_onevision.modeling_llava_onevision import ( ...@@ -13,7 +13,6 @@ from transformers.models.llava_onevision.modeling_llava_onevision import (
get_anyres_image_grid_shape, unpad_image) get_anyres_image_grid_shape, unpad_image)
from typing_extensions import NotRequired from typing_extensions import NotRequired
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
...@@ -922,8 +921,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -922,8 +921,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object, **kwargs: object,
...@@ -955,8 +952,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -955,8 +952,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
hidden_states = self.language_model.model(input_ids, hidden_states = self.language_model.model(input_ids,
positions, positions,
kv_caches,
attn_metadata,
intermediate_tensors, intermediate_tensors,
inputs_embeds=inputs_embeds) inputs_embeds=inputs_embeds)
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""PyTorch MAMBA model.""" """PyTorch MAMBA model."""
from typing import Iterable, List, Optional, Set, Tuple from typing import Iterable, Optional, Set, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers import MambaConfig from transformers import MambaConfig
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import get_pp_group from vllm.distributed.parallel_state import get_pp_group
...@@ -64,7 +63,6 @@ class MambaDecoderLayer(nn.Module): ...@@ -64,7 +63,6 @@ class MambaDecoderLayer(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
mamba_cache_params: MambaCacheParams, mamba_cache_params: MambaCacheParams,
**kwargs, **kwargs,
...@@ -75,8 +73,7 @@ class MambaDecoderLayer(nn.Module): ...@@ -75,8 +73,7 @@ class MambaDecoderLayer(nn.Module):
else: else:
hidden_states, residual = self.norm(hidden_states, residual) hidden_states, residual = self.norm(hidden_states, residual)
hidden_states = self.mixer(hidden_states, attn_metadata, hidden_states = self.mixer(hidden_states, mamba_cache_params)
mamba_cache_params)
return hidden_states, residual return hidden_states, residual
...@@ -125,7 +122,6 @@ class MambaModel(nn.Module): ...@@ -125,7 +122,6 @@ class MambaModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
attn_metadata: AttentionMetadata,
mamba_cache_params: MambaCacheParams, mamba_cache_params: MambaCacheParams,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
...@@ -146,7 +142,6 @@ class MambaModel(nn.Module): ...@@ -146,7 +142,6 @@ class MambaModel(nn.Module):
hidden_states, residual = layer( hidden_states, residual = layer(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
attn_metadata=attn_metadata,
residual=residual, residual=residual,
mamba_cache_params=mamba_cache_params.at_layer_idx( mamba_cache_params=mamba_cache_params.at_layer_idx(
i - self.start_layer)) i - self.start_layer))
...@@ -208,8 +203,6 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP): ...@@ -208,8 +203,6 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
def forward(self, def forward(self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs): **kwargs):
...@@ -222,9 +215,8 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP): ...@@ -222,9 +215,8 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
hidden_states = self.backbone(input_ids, positions, attn_metadata, hidden_states = self.backbone(input_ids, positions, mamba_cache_params,
mamba_cache_params, intermediate_tensors, intermediate_tensors, inputs_embeds)
inputs_embeds)
return hidden_states return hidden_states
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""PyTorch MAMBA2 model.""" """PyTorch MAMBA2 model."""
from typing import Iterable, List, Optional, Set, Tuple from typing import Iterable, Optional, Set, Tuple
import torch import torch
from torch import nn from torch import nn
...@@ -10,6 +10,7 @@ from vllm.attention.backends.abstract import AttentionMetadata ...@@ -10,6 +10,7 @@ from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import get_pp_group from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_mixer2 import ( from vllm.model_executor.layers.mamba.mamba_mixer2 import (
...@@ -63,7 +64,6 @@ class Mamba2DecoderLayer(nn.Module): ...@@ -63,7 +64,6 @@ class Mamba2DecoderLayer(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
mamba_cache_params: MambaCacheParams, mamba_cache_params: MambaCacheParams,
sequence_idx: Optional[torch.Tensor], sequence_idx: Optional[torch.Tensor],
...@@ -75,8 +75,8 @@ class Mamba2DecoderLayer(nn.Module): ...@@ -75,8 +75,8 @@ class Mamba2DecoderLayer(nn.Module):
else: else:
hidden_states, residual = self.norm(hidden_states, residual) hidden_states, residual = self.norm(hidden_states, residual)
hidden_states = self.mixer(hidden_states, attn_metadata, hidden_states = self.mixer(hidden_states, mamba_cache_params,
mamba_cache_params, sequence_idx) sequence_idx)
return hidden_states, residual return hidden_states, residual
...@@ -122,7 +122,6 @@ class Mamba2Model(nn.Module): ...@@ -122,7 +122,6 @@ class Mamba2Model(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
attn_metadata: AttentionMetadata,
mamba_cache_params: MambaCacheParams, mamba_cache_params: MambaCacheParams,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
...@@ -142,6 +141,7 @@ class Mamba2Model(nn.Module): ...@@ -142,6 +141,7 @@ class Mamba2Model(nn.Module):
# proper continuous batching computation including # proper continuous batching computation including
# chunked prefill # chunked prefill
seq_idx = None seq_idx = None
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
if attn_metadata.num_prefills > 0: if attn_metadata.num_prefills > 0:
seq_idx = torch.zeros_like(input_ids, dtype=torch.int32) seq_idx = torch.zeros_like(input_ids, dtype=torch.int32)
for i, (srt, end) in enumerate( for i, (srt, end) in enumerate(
...@@ -158,7 +158,6 @@ class Mamba2Model(nn.Module): ...@@ -158,7 +158,6 @@ class Mamba2Model(nn.Module):
hidden_states, residual = layer( hidden_states, residual = layer(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
attn_metadata=attn_metadata,
residual=residual, residual=residual,
mamba_cache_params=mamba_cache_params.at_layer_idx( mamba_cache_params=mamba_cache_params.at_layer_idx(
i - self.start_layer), i - self.start_layer),
...@@ -224,8 +223,6 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree): ...@@ -224,8 +223,6 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
def forward(self, def forward(self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs): **kwargs):
...@@ -238,9 +235,8 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree): ...@@ -238,9 +235,8 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
hidden_states = self.backbone(input_ids, positions, attn_metadata, hidden_states = self.backbone(input_ids, positions, mamba_cache_params,
mamba_cache_params, intermediate_tensors, intermediate_tensors, inputs_embeds)
inputs_embeds)
return hidden_states return hidden_states
......
...@@ -23,13 +23,13 @@ ...@@ -23,13 +23,13 @@
# limitations under the License. # limitations under the License.
"""Inference-only MiniCPM model compatible with HuggingFace weights.""" """Inference-only MiniCPM model compatible with HuggingFace weights."""
import math import math
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
...@@ -257,8 +257,6 @@ class MiniCPMAttention(nn.Module): ...@@ -257,8 +257,6 @@ class MiniCPMAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
...@@ -266,7 +264,7 @@ class MiniCPMAttention(nn.Module): ...@@ -266,7 +264,7 @@ class MiniCPMAttention(nn.Module):
q, k = q.float(), k.float() q, k = q.float(), k.float()
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
q, k = q.to(orig_dtype), k.to(orig_dtype) q, k = q.to(orig_dtype), k.to(orig_dtype)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
...@@ -331,8 +329,6 @@ class MiniCPMDecoderLayer(nn.Module): ...@@ -331,8 +329,6 @@ class MiniCPMDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
...@@ -341,8 +337,6 @@ class MiniCPMDecoderLayer(nn.Module): ...@@ -341,8 +337,6 @@ class MiniCPMDecoderLayer(nn.Module):
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
hidden_states = residual + hidden_states * \ hidden_states = residual + hidden_states * \
(self.config.scale_depth / math.sqrt(self.config.num_hidden_layers)) (self.config.scale_depth / math.sqrt(self.config.num_hidden_layers))
...@@ -409,8 +403,6 @@ class MiniCPMModel(nn.Module): ...@@ -409,8 +403,6 @@ class MiniCPMModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
...@@ -424,13 +416,10 @@ class MiniCPMModel(nn.Module): ...@@ -424,13 +416,10 @@ class MiniCPMModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer): for layer in self.layers[self.start_layer:self.end_layer]:
layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
residual, residual,
) )
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
...@@ -579,13 +568,10 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -579,13 +568,10 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, intermediate_tensors,
attn_metadata, intermediate_tensors,
inputs_embeds) inputs_embeds)
return hidden_states return hidden_states
......
...@@ -29,7 +29,7 @@ import torch ...@@ -29,7 +29,7 @@ import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
...@@ -129,8 +129,6 @@ class MiniCPM3Attention(nn.Module): ...@@ -129,8 +129,6 @@ class MiniCPM3Attention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
q, _ = self.q_a_proj(hidden_states) q, _ = self.q_a_proj(hidden_states)
q = self.q_a_layernorm(q) q = self.q_a_layernorm(q)
...@@ -170,7 +168,7 @@ class MiniCPM3Attention(nn.Module): ...@@ -170,7 +168,7 @@ class MiniCPM3Attention(nn.Module):
v, [0, self.qk_head_dim - self.v_head_dim], v, [0, self.qk_head_dim - self.v_head_dim],
value=0).view(-1, self.num_local_heads * self.qk_head_dim) value=0).view(-1, self.num_local_heads * self.qk_head_dim)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
attn_output = attn_output.view( attn_output = attn_output.view(
-1, self.num_local_heads, -1, self.num_local_heads,
self.qk_head_dim)[..., :self.v_head_dim].reshape( self.qk_head_dim)[..., :self.v_head_dim].reshape(
......
...@@ -33,7 +33,6 @@ from transformers.modeling_outputs import BaseModelOutputWithPast ...@@ -33,7 +33,6 @@ from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.whisper.modeling_whisper import ( from transformers.models.whisper.modeling_whisper import (
ACT2FN, WHISPER_ATTENTION_CLASSES, WhisperConfig, WhisperEncoder) ACT2FN, WHISPER_ATTENTION_CLASSES, WhisperConfig, WhisperEncoder)
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.inputs import MultiModalFieldConfig from vllm.multimodal.inputs import MultiModalFieldConfig
...@@ -792,8 +791,6 @@ class MiniCPMO(MiniCPMV2_6): ...@@ -792,8 +791,6 @@ class MiniCPMO(MiniCPMV2_6):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: Any, **kwargs: Any,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -818,8 +815,6 @@ class MiniCPMO(MiniCPMV2_6): ...@@ -818,8 +815,6 @@ class MiniCPMO(MiniCPMV2_6):
output = self.llm.model( output = self.llm.model(
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
inputs_embeds=vlm_embeddings, inputs_embeds=vlm_embeddings,
) )
......
...@@ -37,7 +37,6 @@ from torch import nn ...@@ -37,7 +37,6 @@ from torch import nn
from transformers import BatchFeature, PretrainedConfig from transformers import BatchFeature, PretrainedConfig
from typing_extensions import TypeVar from typing_extensions import TypeVar
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.resampler import (BaseResampler, Resampler2, from vllm.model_executor.layers.resampler import (BaseResampler, Resampler2,
...@@ -1030,8 +1029,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -1030,8 +1029,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: Any, **kwargs: Any,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -1051,8 +1048,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -1051,8 +1048,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
output = self.llm.model( output = self.llm.model(
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
inputs_embeds=vlm_embeddings, inputs_embeds=vlm_embeddings,
) )
......
...@@ -22,13 +22,13 @@ ...@@ -22,13 +22,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only Mixtral model.""" """Inference-only Mixtral model."""
from typing import Iterable, List, Optional, Set, Tuple, Union from typing import Iterable, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from transformers import MixtralConfig from transformers import MixtralConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
...@@ -175,13 +175,11 @@ class MixtralAttention(nn.Module): ...@@ -175,13 +175,11 @@ class MixtralAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
...@@ -224,8 +222,6 @@ class MixtralDecoderLayer(nn.Module): ...@@ -224,8 +222,6 @@ class MixtralDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
# Self Attention # Self Attention
...@@ -238,8 +234,6 @@ class MixtralDecoderLayer(nn.Module): ...@@ -238,8 +234,6 @@ class MixtralDecoderLayer(nn.Module):
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
# Fully Connected # Fully Connected
...@@ -291,8 +285,6 @@ class MixtralModel(nn.Module): ...@@ -291,8 +285,6 @@ class MixtralModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
...@@ -306,11 +298,8 @@ class MixtralModel(nn.Module): ...@@ -306,11 +298,8 @@ class MixtralModel(nn.Module):
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer): for layer in self.layers[self.start_layer:self.end_layer]:
layer = self.layers[i] hidden_states, residual = layer(positions, hidden_states, residual)
hidden_states, residual = layer(positions, hidden_states,
kv_caches[i - self.start_layer],
attn_metadata, residual)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({ return IntermediateTensors({
"hidden_states": hidden_states, "hidden_states": hidden_states,
...@@ -377,13 +366,10 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -377,13 +366,10 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, intermediate_tensors,
attn_metadata, intermediate_tensors,
inputs_embeds) inputs_embeds)
return hidden_states return hidden_states
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only Mixtral model.""" """Inference-only Mixtral model."""
from typing import Iterable, List, Optional, Set, Tuple, Union from typing import Iterable, Optional, Set, Tuple, Union
import numpy as np import numpy as np
import torch import torch
...@@ -30,7 +30,7 @@ import torch.nn.functional as F ...@@ -30,7 +30,7 @@ import torch.nn.functional as F
from torch import nn from torch import nn
from transformers import MixtralConfig from transformers import MixtralConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
...@@ -229,13 +229,11 @@ class MixtralAttention(nn.Module): ...@@ -229,13 +229,11 @@ class MixtralAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
...@@ -274,8 +272,6 @@ class MixtralDecoderLayer(nn.Module): ...@@ -274,8 +272,6 @@ class MixtralDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
# Self Attention # Self Attention
...@@ -288,8 +284,6 @@ class MixtralDecoderLayer(nn.Module): ...@@ -288,8 +284,6 @@ class MixtralDecoderLayer(nn.Module):
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
# Fully Connected # Fully Connected
...@@ -333,8 +327,6 @@ class MixtralModel(nn.Module): ...@@ -333,8 +327,6 @@ class MixtralModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
...@@ -348,11 +340,8 @@ class MixtralModel(nn.Module): ...@@ -348,11 +340,8 @@ class MixtralModel(nn.Module):
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer): for layer in self.layers[self.start_layer:self.end_layer]:
layer = self.layers[i] hidden_states, residual = layer(positions, hidden_states, residual)
hidden_states, residual = layer(positions, hidden_states,
kv_caches[i - self.start_layer],
attn_metadata, residual)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({ return IntermediateTensors({
"hidden_states": hidden_states, "hidden_states": hidden_states,
...@@ -390,13 +379,10 @@ class MixtralForCausalLM(nn.Module, SupportsPP): ...@@ -390,13 +379,10 @@ class MixtralForCausalLM(nn.Module, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, intermediate_tensors,
attn_metadata, intermediate_tensors,
inputs_embeds) inputs_embeds)
return hidden_states return hidden_states
......
...@@ -38,7 +38,8 @@ from vllm.attention import Attention, AttentionMetadata, AttentionType ...@@ -38,7 +38,8 @@ from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.attention.ops.paged_attn import PagedAttention from vllm.attention.ops.paged_attn import PagedAttention
from vllm.attention.selector import _Backend from vllm.attention.selector import _Backend
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tp_group
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...@@ -416,11 +417,11 @@ class MllamaVisionSdpaAttention(nn.Module): ...@@ -416,11 +417,11 @@ class MllamaVisionSdpaAttention(nn.Module):
prefix: str = ""): prefix: str = ""):
super().__init__() super().__init__()
model_parallel_size = get_tensor_model_parallel_world_size() tensor_parallel_size = get_tp_group().world_size
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
self.num_heads = config.attention_heads self.num_heads = config.attention_heads
self.head_dim = config.hidden_size // config.attention_heads self.head_dim = config.hidden_size // config.attention_heads
self.num_local_heads = self.num_heads // model_parallel_size self.num_local_heads = self.num_heads // tensor_parallel_size
self.q_size = self.num_local_heads * self.head_dim self.q_size = self.num_local_heads * self.head_dim
self.kv_size = self.num_local_heads * self.head_dim self.kv_size = self.num_local_heads * self.head_dim
...@@ -771,12 +772,13 @@ class MllamaTextCrossAttention(nn.Module): ...@@ -771,12 +772,13 @@ class MllamaTextCrossAttention(nn.Module):
): ):
super().__init__() super().__init__()
self.config = config self.config = config
self.model_parallel_size = get_tensor_model_parallel_world_size() self.pipeline_parallel_rank = get_pp_group().rank_in_group
self.tensor_parallel_size = get_tp_group().world_size
self.num_heads = self.config.num_attention_heads self.num_heads = self.config.num_attention_heads
self.num_local_heads = self.num_heads // self.model_parallel_size self.num_local_heads = self.num_heads // self.tensor_parallel_size
self.num_key_value_heads = self.config.num_key_value_heads self.num_key_value_heads = self.config.num_key_value_heads
self.num_local_key_value_heads = \ self.num_local_key_value_heads = \
self.num_key_value_heads // self.model_parallel_size self.num_key_value_heads // self.tensor_parallel_size
self.dropout = config.dropout self.dropout = config.dropout
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.head_dim = config.hidden_size // self.num_heads self.head_dim = config.hidden_size // self.num_heads
...@@ -824,8 +826,6 @@ class MllamaTextCrossAttention(nn.Module): ...@@ -824,8 +826,6 @@ class MllamaTextCrossAttention(nn.Module):
attention_mask: Optional[torch.Tensor], attention_mask: Optional[torch.Tensor],
kv_range_for_decode: Optional[List[Tuple[int, int]]], kv_range_for_decode: Optional[List[Tuple[int, int]]],
cross_attention_states: Optional[torch.Tensor], cross_attention_states: Optional[torch.Tensor],
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv_dec, _ = self.qkv_proj(hidden_states) qkv_dec, _ = self.qkv_proj(hidden_states)
q, _, _ = qkv_dec.split( q, _, _ = qkv_dec.split(
...@@ -846,14 +846,11 @@ class MllamaTextCrossAttention(nn.Module): ...@@ -846,14 +846,11 @@ class MllamaTextCrossAttention(nn.Module):
q = self.q_norm(q) q = self.q_norm(q)
if attention_mask is not None: if attention_mask is not None:
output = self._attention_with_mask(q, k, v, kv_cache, output = self._attention_with_mask(q, k, v, attention_mask,
attention_mask, kv_range_for_decode)
kv_range_for_decode,
attn_metadata)
else: else:
output = self.attn( output = self.attn(
q.view(-1, self.num_local_heads * self.head_dim), k, v, q.view(-1, self.num_local_heads * self.head_dim), k, v)
kv_cache, attn_metadata)
out, _ = self.o_proj(output) out, _ = self.o_proj(output)
return out return out
...@@ -862,11 +859,11 @@ class MllamaTextCrossAttention(nn.Module): ...@@ -862,11 +859,11 @@ class MllamaTextCrossAttention(nn.Module):
q: torch.Tensor, q: torch.Tensor,
k: torch.Tensor, k: torch.Tensor,
v: torch.Tensor, v: torch.Tensor,
kv_cache: torch.Tensor,
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
kv_range_for_decode: List[Tuple[int, int]], kv_range_for_decode: List[Tuple[int, int]],
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
kv_cache = self.attn.kv_cache[self.pipeline_parallel_rank]
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
# Skip writing kv-cache for the initial profiling run. # Skip writing kv-cache for the initial profiling run.
if len(kv_cache.shape) > 1: if len(kv_cache.shape) > 1:
i = torch.ones(1, dtype=torch.float32) i = torch.ones(1, dtype=torch.float32)
...@@ -978,8 +975,6 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module): ...@@ -978,8 +975,6 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
cross_attention_mask: torch.Tensor, cross_attention_mask: torch.Tensor,
kv_range_for_decode: Optional[List[Tuple[int, int]]], kv_range_for_decode: Optional[List[Tuple[int, int]]],
full_text_row_masked_out_mask: torch.Tensor, full_text_row_masked_out_mask: torch.Tensor,
kv_cache: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
residual = hidden_states residual = hidden_states
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
...@@ -989,8 +984,6 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module): ...@@ -989,8 +984,6 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
attention_mask=cross_attention_mask, attention_mask=cross_attention_mask,
kv_range_for_decode=kv_range_for_decode, kv_range_for_decode=kv_range_for_decode,
cross_attention_states=cross_attention_states, cross_attention_states=cross_attention_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
hidden_states = full_text_row_masked_out_mask * hidden_states hidden_states = full_text_row_masked_out_mask * hidden_states
hidden_states = residual + self.cross_attn_attn_gate.tanh( hidden_states = residual + self.cross_attn_attn_gate.tanh(
...@@ -1054,14 +1047,12 @@ class MllamaTextModel(nn.Module): ...@@ -1054,14 +1047,12 @@ class MllamaTextModel(nn.Module):
kv_range_for_decode: Optional[List[Tuple[int, int]]], kv_range_for_decode: Optional[List[Tuple[int, int]]],
full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor,
torch.Tensor]], torch.Tensor]],
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
skip_cross_attention: bool, skip_cross_attention: bool,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds hidden_states = inputs_embeds
for idx, decoder_layer in enumerate(self.layers): for decoder_layer in self.layers:
if isinstance(decoder_layer, MllamaCrossAttentionDecoderLayer): if isinstance(decoder_layer, MllamaCrossAttentionDecoderLayer):
if not skip_cross_attention: if not skip_cross_attention:
hidden_states = decoder_layer( hidden_states = decoder_layer(
...@@ -1071,15 +1062,11 @@ class MllamaTextModel(nn.Module): ...@@ -1071,15 +1062,11 @@ class MllamaTextModel(nn.Module):
kv_range_for_decode=kv_range_for_decode, kv_range_for_decode=kv_range_for_decode,
full_text_row_masked_out_mask= full_text_row_masked_out_mask=
full_text_row_masked_out_mask, full_text_row_masked_out_mask,
kv_cache=kv_caches[idx],
attn_metadata=attn_metadata,
) )
elif isinstance(decoder_layer, LlamaDecoderLayer): elif isinstance(decoder_layer, LlamaDecoderLayer):
hidden_states, residual = decoder_layer( hidden_states, residual = decoder_layer(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_caches[idx],
attn_metadata=attn_metadata,
residual=None, residual=None,
) )
hidden_states = hidden_states + residual hidden_states = hidden_states + residual
...@@ -1124,8 +1111,6 @@ class MllamaForCausalLM(nn.Module): ...@@ -1124,8 +1111,6 @@ class MllamaForCausalLM(nn.Module):
kv_range_for_decode: Optional[List[Tuple[int, int]]], kv_range_for_decode: Optional[List[Tuple[int, int]]],
full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor,
torch.Tensor]], torch.Tensor]],
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
skip_cross_attention: bool, skip_cross_attention: bool,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model( hidden_states = self.model(
...@@ -1135,8 +1120,6 @@ class MllamaForCausalLM(nn.Module): ...@@ -1135,8 +1120,6 @@ class MllamaForCausalLM(nn.Module):
cross_attention_mask=cross_attention_mask, cross_attention_mask=cross_attention_mask,
kv_range_for_decode=kv_range_for_decode, kv_range_for_decode=kv_range_for_decode,
full_text_row_masked_out_mask=full_text_row_masked_out_mask, full_text_row_masked_out_mask=full_text_row_masked_out_mask,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
skip_cross_attention=skip_cross_attention, skip_cross_attention=skip_cross_attention,
) )
return hidden_states return hidden_states
...@@ -1353,10 +1336,9 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -1353,10 +1336,9 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
**kwargs: object, **kwargs: object,
) -> Union[Tuple, CausalLMOutputWithPast]: ) -> Union[Tuple, CausalLMOutputWithPast]:
attn_metadata = get_forward_context().attn_metadata
if attn_metadata.num_prefill_tokens > 0 and \ if attn_metadata.num_prefill_tokens > 0 and \
attn_metadata.num_decode_tokens > 0: attn_metadata.num_decode_tokens > 0:
raise ValueError("Chunk prefill not supported") raise ValueError("Chunk prefill not supported")
...@@ -1410,8 +1392,6 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -1410,8 +1392,6 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
cross_attention_mask=cross_attention_mask, cross_attention_mask=cross_attention_mask,
kv_range_for_decode=kv_range_for_decode, kv_range_for_decode=kv_range_for_decode,
full_text_row_masked_out_mask=full_text_row_masked_out_mask, full_text_row_masked_out_mask=full_text_row_masked_out_mask,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
skip_cross_attention=skip_cross_attention, skip_cross_attention=skip_cross_attention,
) )
......
...@@ -16,7 +16,7 @@ from transformers import (BatchFeature, PretrainedConfig, ProcessorMixin, ...@@ -16,7 +16,7 @@ from transformers import (BatchFeature, PretrainedConfig, ProcessorMixin,
from transformers.image_utils import ImageInput from transformers.image_utils import ImageInput
from transformers.tokenization_utils_base import TextInput from transformers.tokenization_utils_base import TextInput
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.attention.layer import MultiHeadAttention from vllm.attention.layer import MultiHeadAttention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
...@@ -460,15 +460,13 @@ class MolmoAttention(nn.Module): ...@@ -460,15 +460,13 @@ class MolmoAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
if self.q_norm is not None and self.k_norm is not None: if self.q_norm is not None and self.k_norm is not None:
q, k = self._apply_qk_norm(q, k) q, k = self._apply_qk_norm(q, k)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
...@@ -580,8 +578,6 @@ class MolmoDecoderLayer(nn.Module): ...@@ -580,8 +578,6 @@ class MolmoDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
# Self Attention # Self Attention
...@@ -594,8 +590,6 @@ class MolmoDecoderLayer(nn.Module): ...@@ -594,8 +590,6 @@ class MolmoDecoderLayer(nn.Module):
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
hidden_states, residual = self.post_attention_layernorm( hidden_states, residual = self.post_attention_layernorm(
...@@ -610,8 +604,6 @@ class MolmoDecoderNormAfterLayer(MolmoDecoderLayer): ...@@ -610,8 +604,6 @@ class MolmoDecoderNormAfterLayer(MolmoDecoderLayer):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
# Self Attention # Self Attention
...@@ -619,8 +611,6 @@ class MolmoDecoderNormAfterLayer(MolmoDecoderLayer): ...@@ -619,8 +611,6 @@ class MolmoDecoderNormAfterLayer(MolmoDecoderLayer):
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
...@@ -841,8 +831,6 @@ class MolmoModel(nn.Module, SupportsQuant): ...@@ -841,8 +831,6 @@ class MolmoModel(nn.Module, SupportsQuant):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -858,13 +846,10 @@ class MolmoModel(nn.Module, SupportsQuant): ...@@ -858,13 +846,10 @@ class MolmoModel(nn.Module, SupportsQuant):
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
# Apply blocks one-by-one. # Apply blocks one-by-one.
for i in range(self.start_layer, self.end_layer): for layer in self.layers[self.start_layer:self.end_layer]:
layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
residual, residual,
) )
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
...@@ -1643,8 +1628,6 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, ...@@ -1643,8 +1628,6 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
self, self,
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
positions: torch.LongTensor, positions: torch.LongTensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object, **kwargs: object,
...@@ -1663,8 +1646,6 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, ...@@ -1663,8 +1646,6 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
hidden_states = self.model(input_ids, hidden_states = self.model(input_ids,
positions, positions,
kv_caches,
attn_metadata,
intermediate_tensors, intermediate_tensors,
inputs_embeds=inputs_embeds) inputs_embeds=inputs_embeds)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment