Unverified Commit cdc1fa12 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Remove unused kwargs from model definitions (#13555)

parent f61528d4
# SPDX-License-Identifier: Apache-2.0
from functools import partial
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Type, Union
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
......@@ -175,13 +175,11 @@ class InternLM2Attention(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.wqkv(hidden_states)
q, k, v = self.split_qkv(qkv)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output = self.attn(q, k, v)
output, _ = self.wo(attn_output)
return output
......@@ -227,8 +225,6 @@ class InternLMDecoderLayer(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
......@@ -241,8 +237,6 @@ class InternLMDecoderLayer(nn.Module):
hidden_states = self.attention(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
# Fully Connected
......@@ -290,8 +284,6 @@ class InternLM2Model(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
......@@ -305,15 +297,8 @@ class InternLM2Model(nn.Module):
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
residual,
)
for layer in self.layers[self.start_layer:self.end_layer]:
hidden_states, residual = layer(positions, hidden_states, residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
......@@ -363,13 +348,10 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors,
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
return hidden_states
......@@ -466,13 +448,10 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors,
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
logits, _ = self.v_head(hidden_states)
return logits
......
# SPDX-License-Identifier: Apache-2.0
from typing import List, Optional, Tuple, Union
from typing import Optional, Tuple, Union
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group
from vllm.model_executor.layers.layernorm import RMSNorm
......@@ -65,8 +64,6 @@ class InternLM2VEDecoderLayer(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
visual_token_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
......@@ -80,8 +77,6 @@ class InternLM2VEDecoderLayer(nn.Module):
hidden_states = self.attention(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
# Fully Connected
......@@ -113,8 +108,6 @@ class InternLM2VEModel(InternLM2Model):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
visual_token_mask: Optional[torch.Tensor] = None,
......@@ -129,13 +122,10 @@ class InternLM2VEModel(InternLM2Model):
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
for layer in self.layers[self.start_layer:self.end_layer]:
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
residual,
visual_token_mask=visual_token_mask,
)
......
......@@ -17,7 +17,6 @@ import torchvision.transforms as T
from PIL import Image
from transformers import BatchFeature, PretrainedConfig, TensorType
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig
......@@ -929,8 +928,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
......@@ -951,8 +948,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
forward_kwargs = {
"input_ids": input_ids,
"positions": positions,
"kv_caches": kv_caches,
"attn_metadata": attn_metadata,
"intermediate_tensors": intermediate_tensors,
"inputs_embeds": inputs_embeds,
}
......
......@@ -21,12 +21,12 @@
"""Inference-only Jais model compatible with HuggingFace weights."""
import math
from typing import Iterable, List, Optional, Set, Tuple, Union
from typing import Iterable, Optional, Set, Tuple, Union
import torch
from torch import nn
from vllm.attention import Attention, AttentionMetadata
from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
......@@ -123,12 +123,10 @@ class JAISAttention(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.c_attn(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output = self.attn(q, k, v)
attn_output, _ = self.c_proj(attn_output)
return attn_output
......@@ -200,16 +198,10 @@ class JAISBlock(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
attn_output = self.attn(
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
attn_output = self.attn(hidden_states=hidden_states, )
# residual connection
hidden_states = attn_output + residual
......@@ -266,8 +258,6 @@ class JAISModel(nn.Module):
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[IntermediateTensors, torch.Tensor]:
......@@ -285,11 +275,8 @@ class JAISModel(nn.Module):
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer):
layer = self.h[i]
hidden_states = layer(hidden_states,
kv_caches[i - self.start_layer],
attn_metadata)
for layer in self.h[self.start_layer:self.end_layer]:
hidden_states = layer(hidden_states)
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
......@@ -332,14 +319,11 @@ class JAISLMHeadModel(nn.Module, SupportsPP):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[IntermediateTensors, torch.Tensor]:
hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors,
inputs_embeds)
hidden_states = self.transformer(input_ids, positions,
intermediate_tensors, inputs_embeds)
return hidden_states
def compute_logits(
......
# SPDX-License-Identifier: Apache-2.0
"""Inference-only Jamba model."""
from typing import Iterable, List, Optional, Set, Tuple
from typing import Iterable, Optional, Set, Tuple
import torch
from torch import nn
from transformers import JambaConfig
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.layer import Attention
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
......@@ -138,7 +137,6 @@ class JambaMambaDecoderLayer(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
mamba_cache_params: MambaCacheParams,
**kwargs,
......@@ -150,8 +148,7 @@ class JambaMambaDecoderLayer(nn.Module):
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.mamba(hidden_states, attn_metadata,
mamba_cache_params)
hidden_states = self.mamba(hidden_states, mamba_cache_params)
# Fully Connected
hidden_states, residual = self.pre_ff_layernorm(
hidden_states, residual)
......@@ -223,13 +220,11 @@ class JambaAttentionDecoderLayer(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
**kwargs,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
......@@ -237,8 +232,6 @@ class JambaAttentionDecoderLayer(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
**kwargs,
):
......@@ -252,8 +245,6 @@ class JambaAttentionDecoderLayer(nn.Module):
hidden_states = self.self_attention(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
# Fully Connected
hidden_states, residual = self.pre_ff_layernorm(
......@@ -320,8 +311,6 @@ class JambaModel(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
mamba_cache_params: MambaCacheParams,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
......@@ -339,12 +328,9 @@ class JambaModel(nn.Module):
kv_cache_index = 0
mamba_cache_index = 0
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
kv_cache = None
for layer in self.layers[self.start_layer:self.end_layer]:
layer_mamba_cache_params = None
if isinstance(layer, JambaAttentionDecoderLayer):
kv_cache = kv_caches[kv_cache_index]
kv_cache_index += 1
if isinstance(layer, JambaMambaDecoderLayer):
current_state_layer = mamba_cache_index
......@@ -355,8 +341,6 @@ class JambaModel(nn.Module):
hidden_states, residual = layer(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
residual=residual,
mamba_cache_params=layer_mamba_cache_params)
if not get_pp_group().is_last_rank:
......@@ -429,8 +413,6 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
def forward(self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs):
......@@ -443,8 +425,7 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, mamba_cache_params,
hidden_states = self.model(input_ids, positions, mamba_cache_params,
intermediate_tensors, inputs_embeds)
return hidden_states
......
......@@ -22,13 +22,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only LLaMA model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Type, Union
import torch
from torch import nn
from transformers import LlamaConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
......@@ -197,13 +197,11 @@ class LlamaAttention(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
......@@ -268,8 +266,6 @@ class LlamaDecoderLayer(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
......@@ -280,9 +276,7 @@ class LlamaDecoderLayer(nn.Module):
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata)
hidden_states=hidden_states)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
......@@ -347,8 +341,6 @@ class LlamaModel(nn.Module):
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
......@@ -363,11 +355,8 @@ class LlamaModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states,
kv_caches[i - self.start_layer],
attn_metadata, residual)
for layer in self.layers[self.start_layer:self.end_layer]:
hidden_states, residual = layer(positions, hidden_states, residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
......@@ -535,13 +524,10 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
model_output = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors,
model_output = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
return model_output
......
......@@ -15,7 +15,6 @@ from transformers import __version__ as TRANSFORMERS_VERSION
from transformers.models.llava import LlavaProcessor
from transformers.models.pixtral import PixtralProcessor
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.inputs import InputProcessingContext
from vllm.model_executor.layers.activation import get_act_fn
......@@ -658,8 +657,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
......@@ -712,8 +709,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
hidden_states = self.language_model.model(input_ids,
positions,
kv_caches,
attn_metadata,
intermediate_tensors,
inputs_embeds=inputs_embeds)
......
......@@ -12,7 +12,6 @@ from transformers.models.llava_next.modeling_llava_next import (
get_anyres_image_grid_shape, unpad_image)
from typing_extensions import NotRequired
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
......@@ -508,8 +507,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
......@@ -571,8 +568,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
hidden_states = self.language_model.model(input_ids,
positions,
kv_caches,
attn_metadata,
intermediate_tensors,
inputs_embeds=inputs_embeds)
return hidden_states
......
......@@ -10,7 +10,6 @@ import torch.nn as nn
from transformers import (BatchFeature, LlavaNextVideoConfig,
LlavaNextVideoProcessor)
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
......@@ -443,8 +442,6 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
......@@ -468,8 +465,6 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
hidden_states = self.language_model.model(input_ids,
positions,
kv_caches,
attn_metadata,
intermediate_tensors,
inputs_embeds=inputs_embeds)
......
......@@ -13,7 +13,6 @@ from transformers.models.llava_onevision.modeling_llava_onevision import (
get_anyres_image_grid_shape, unpad_image)
from typing_extensions import NotRequired
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
......@@ -922,8 +921,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
......@@ -955,8 +952,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
hidden_states = self.language_model.model(input_ids,
positions,
kv_caches,
attn_metadata,
intermediate_tensors,
inputs_embeds=inputs_embeds)
......
# SPDX-License-Identifier: Apache-2.0
"""PyTorch MAMBA model."""
from typing import Iterable, List, Optional, Set, Tuple
from typing import Iterable, Optional, Set, Tuple
import torch
from torch import nn
from transformers import MambaConfig
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import get_pp_group
......@@ -64,7 +63,6 @@ class MambaDecoderLayer(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
mamba_cache_params: MambaCacheParams,
**kwargs,
......@@ -75,8 +73,7 @@ class MambaDecoderLayer(nn.Module):
else:
hidden_states, residual = self.norm(hidden_states, residual)
hidden_states = self.mixer(hidden_states, attn_metadata,
mamba_cache_params)
hidden_states = self.mixer(hidden_states, mamba_cache_params)
return hidden_states, residual
......@@ -125,7 +122,6 @@ class MambaModel(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
attn_metadata: AttentionMetadata,
mamba_cache_params: MambaCacheParams,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
......@@ -146,7 +142,6 @@ class MambaModel(nn.Module):
hidden_states, residual = layer(
positions=positions,
hidden_states=hidden_states,
attn_metadata=attn_metadata,
residual=residual,
mamba_cache_params=mamba_cache_params.at_layer_idx(
i - self.start_layer))
......@@ -208,8 +203,6 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
def forward(self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs):
......@@ -222,9 +215,8 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
hidden_states = self.backbone(input_ids, positions, attn_metadata,
mamba_cache_params, intermediate_tensors,
inputs_embeds)
hidden_states = self.backbone(input_ids, positions, mamba_cache_params,
intermediate_tensors, inputs_embeds)
return hidden_states
......
# SPDX-License-Identifier: Apache-2.0
"""PyTorch MAMBA2 model."""
from typing import Iterable, List, Optional, Set, Tuple
from typing import Iterable, Optional, Set, Tuple
import torch
from torch import nn
......@@ -10,6 +10,7 @@ from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_mixer2 import (
......@@ -63,7 +64,6 @@ class Mamba2DecoderLayer(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
mamba_cache_params: MambaCacheParams,
sequence_idx: Optional[torch.Tensor],
......@@ -75,8 +75,8 @@ class Mamba2DecoderLayer(nn.Module):
else:
hidden_states, residual = self.norm(hidden_states, residual)
hidden_states = self.mixer(hidden_states, attn_metadata,
mamba_cache_params, sequence_idx)
hidden_states = self.mixer(hidden_states, mamba_cache_params,
sequence_idx)
return hidden_states, residual
......@@ -122,7 +122,6 @@ class Mamba2Model(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
attn_metadata: AttentionMetadata,
mamba_cache_params: MambaCacheParams,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
......@@ -142,6 +141,7 @@ class Mamba2Model(nn.Module):
# proper continuous batching computation including
# chunked prefill
seq_idx = None
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
if attn_metadata.num_prefills > 0:
seq_idx = torch.zeros_like(input_ids, dtype=torch.int32)
for i, (srt, end) in enumerate(
......@@ -158,7 +158,6 @@ class Mamba2Model(nn.Module):
hidden_states, residual = layer(
positions=positions,
hidden_states=hidden_states,
attn_metadata=attn_metadata,
residual=residual,
mamba_cache_params=mamba_cache_params.at_layer_idx(
i - self.start_layer),
......@@ -224,8 +223,6 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
def forward(self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs):
......@@ -238,9 +235,8 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
hidden_states = self.backbone(input_ids, positions, attn_metadata,
mamba_cache_params, intermediate_tensors,
inputs_embeds)
hidden_states = self.backbone(input_ids, positions, mamba_cache_params,
intermediate_tensors, inputs_embeds)
return hidden_states
......
......@@ -23,13 +23,13 @@
# limitations under the License.
"""Inference-only MiniCPM model compatible with HuggingFace weights."""
import math
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
......@@ -257,8 +257,6 @@ class MiniCPMAttention(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
......@@ -266,7 +264,7 @@ class MiniCPMAttention(nn.Module):
q, k = q.float(), k.float()
q, k = self.rotary_emb(positions, q, k)
q, k = q.to(orig_dtype), k.to(orig_dtype)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
......@@ -331,8 +329,6 @@ class MiniCPMDecoderLayer(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
......@@ -341,8 +337,6 @@ class MiniCPMDecoderLayer(nn.Module):
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
hidden_states = residual + hidden_states * \
(self.config.scale_depth / math.sqrt(self.config.num_hidden_layers))
......@@ -409,8 +403,6 @@ class MiniCPMModel(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
......@@ -424,13 +416,10 @@ class MiniCPMModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
for layer in self.layers[self.start_layer:self.end_layer]:
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
residual,
)
if not get_pp_group().is_last_rank:
......@@ -579,13 +568,10 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors,
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
return hidden_states
......
......@@ -29,7 +29,7 @@ import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.attention import Attention
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.layernorm import RMSNorm
......@@ -129,8 +129,6 @@ class MiniCPM3Attention(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
q, _ = self.q_a_proj(hidden_states)
q = self.q_a_layernorm(q)
......@@ -170,7 +168,7 @@ class MiniCPM3Attention(nn.Module):
v, [0, self.qk_head_dim - self.v_head_dim],
value=0).view(-1, self.num_local_heads * self.qk_head_dim)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output = self.attn(q, k, v)
attn_output = attn_output.view(
-1, self.num_local_heads,
self.qk_head_dim)[..., :self.v_head_dim].reshape(
......
......@@ -33,7 +33,6 @@ from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.whisper.modeling_whisper import (
ACT2FN, WHISPER_ATTENTION_CLASSES, WhisperConfig, WhisperEncoder)
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.inputs import MultiModalFieldConfig
......@@ -792,8 +791,6 @@ class MiniCPMO(MiniCPMV2_6):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: Any,
) -> torch.Tensor:
......@@ -818,8 +815,6 @@ class MiniCPMO(MiniCPMV2_6):
output = self.llm.model(
input_ids=input_ids,
positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors,
inputs_embeds=vlm_embeddings,
)
......
......@@ -37,7 +37,6 @@ from torch import nn
from transformers import BatchFeature, PretrainedConfig
from typing_extensions import TypeVar
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.resampler import (BaseResampler, Resampler2,
......@@ -1030,8 +1029,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: Any,
) -> torch.Tensor:
......@@ -1051,8 +1048,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
output = self.llm.model(
input_ids=input_ids,
positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors,
inputs_embeds=vlm_embeddings,
)
......
......@@ -22,13 +22,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Mixtral model."""
from typing import Iterable, List, Optional, Set, Tuple, Union
from typing import Iterable, Optional, Set, Tuple, Union
import torch
from torch import nn
from transformers import MixtralConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
......@@ -175,13 +175,11 @@ class MixtralAttention(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
......@@ -224,8 +222,6 @@ class MixtralDecoderLayer(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
) -> torch.Tensor:
# Self Attention
......@@ -238,8 +234,6 @@ class MixtralDecoderLayer(nn.Module):
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
# Fully Connected
......@@ -291,8 +285,6 @@ class MixtralModel(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
......@@ -306,11 +298,8 @@ class MixtralModel(nn.Module):
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states,
kv_caches[i - self.start_layer],
attn_metadata, residual)
for layer in self.layers[self.start_layer:self.end_layer]:
hidden_states, residual = layer(positions, hidden_states, residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
......@@ -377,13 +366,10 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors,
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
return hidden_states
......
......@@ -22,7 +22,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Mixtral model."""
from typing import Iterable, List, Optional, Set, Tuple, Union
from typing import Iterable, Optional, Set, Tuple, Union
import numpy as np
import torch
......@@ -30,7 +30,7 @@ import torch.nn.functional as F
from torch import nn
from transformers import MixtralConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.attention import Attention
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
......@@ -229,13 +229,11 @@ class MixtralAttention(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
......@@ -274,8 +272,6 @@ class MixtralDecoderLayer(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
) -> torch.Tensor:
# Self Attention
......@@ -288,8 +284,6 @@ class MixtralDecoderLayer(nn.Module):
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
# Fully Connected
......@@ -333,8 +327,6 @@ class MixtralModel(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
......@@ -348,11 +340,8 @@ class MixtralModel(nn.Module):
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states,
kv_caches[i - self.start_layer],
attn_metadata, residual)
for layer in self.layers[self.start_layer:self.end_layer]:
hidden_states, residual = layer(positions, hidden_states, residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
......@@ -390,13 +379,10 @@ class MixtralForCausalLM(nn.Module, SupportsPP):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors,
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
return hidden_states
......
......@@ -38,7 +38,8 @@ from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.attention.selector import _Backend
from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed import get_pp_group, get_tp_group
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
......@@ -416,11 +417,11 @@ class MllamaVisionSdpaAttention(nn.Module):
prefix: str = ""):
super().__init__()
model_parallel_size = get_tensor_model_parallel_world_size()
tensor_parallel_size = get_tp_group().world_size
self.embed_dim = config.hidden_size
self.num_heads = config.attention_heads
self.head_dim = config.hidden_size // config.attention_heads
self.num_local_heads = self.num_heads // model_parallel_size
self.num_local_heads = self.num_heads // tensor_parallel_size
self.q_size = self.num_local_heads * self.head_dim
self.kv_size = self.num_local_heads * self.head_dim
......@@ -771,12 +772,13 @@ class MllamaTextCrossAttention(nn.Module):
):
super().__init__()
self.config = config
self.model_parallel_size = get_tensor_model_parallel_world_size()
self.pipeline_parallel_rank = get_pp_group().rank_in_group
self.tensor_parallel_size = get_tp_group().world_size
self.num_heads = self.config.num_attention_heads
self.num_local_heads = self.num_heads // self.model_parallel_size
self.num_local_heads = self.num_heads // self.tensor_parallel_size
self.num_key_value_heads = self.config.num_key_value_heads
self.num_local_key_value_heads = \
self.num_key_value_heads // self.model_parallel_size
self.num_key_value_heads // self.tensor_parallel_size
self.dropout = config.dropout
self.hidden_size = config.hidden_size
self.head_dim = config.hidden_size // self.num_heads
......@@ -824,8 +826,6 @@ class MllamaTextCrossAttention(nn.Module):
attention_mask: Optional[torch.Tensor],
kv_range_for_decode: Optional[List[Tuple[int, int]]],
cross_attention_states: Optional[torch.Tensor],
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv_dec, _ = self.qkv_proj(hidden_states)
q, _, _ = qkv_dec.split(
......@@ -846,14 +846,11 @@ class MllamaTextCrossAttention(nn.Module):
q = self.q_norm(q)
if attention_mask is not None:
output = self._attention_with_mask(q, k, v, kv_cache,
attention_mask,
kv_range_for_decode,
attn_metadata)
output = self._attention_with_mask(q, k, v, attention_mask,
kv_range_for_decode)
else:
output = self.attn(
q.view(-1, self.num_local_heads * self.head_dim), k, v,
kv_cache, attn_metadata)
q.view(-1, self.num_local_heads * self.head_dim), k, v)
out, _ = self.o_proj(output)
return out
......@@ -862,11 +859,11 @@ class MllamaTextCrossAttention(nn.Module):
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
kv_cache: torch.Tensor,
attention_mask: torch.Tensor,
kv_range_for_decode: List[Tuple[int, int]],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
kv_cache = self.attn.kv_cache[self.pipeline_parallel_rank]
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
# Skip writing kv-cache for the initial profiling run.
if len(kv_cache.shape) > 1:
i = torch.ones(1, dtype=torch.float32)
......@@ -978,8 +975,6 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
cross_attention_mask: torch.Tensor,
kv_range_for_decode: Optional[List[Tuple[int, int]]],
full_text_row_masked_out_mask: torch.Tensor,
kv_cache: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
......@@ -989,8 +984,6 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
attention_mask=cross_attention_mask,
kv_range_for_decode=kv_range_for_decode,
cross_attention_states=cross_attention_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
hidden_states = full_text_row_masked_out_mask * hidden_states
hidden_states = residual + self.cross_attn_attn_gate.tanh(
......@@ -1054,14 +1047,12 @@ class MllamaTextModel(nn.Module):
kv_range_for_decode: Optional[List[Tuple[int, int]]],
full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor,
torch.Tensor]],
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
skip_cross_attention: bool,
) -> torch.Tensor:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds
for idx, decoder_layer in enumerate(self.layers):
for decoder_layer in self.layers:
if isinstance(decoder_layer, MllamaCrossAttentionDecoderLayer):
if not skip_cross_attention:
hidden_states = decoder_layer(
......@@ -1071,15 +1062,11 @@ class MllamaTextModel(nn.Module):
kv_range_for_decode=kv_range_for_decode,
full_text_row_masked_out_mask=
full_text_row_masked_out_mask,
kv_cache=kv_caches[idx],
attn_metadata=attn_metadata,
)
elif isinstance(decoder_layer, LlamaDecoderLayer):
hidden_states, residual = decoder_layer(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_caches[idx],
attn_metadata=attn_metadata,
residual=None,
)
hidden_states = hidden_states + residual
......@@ -1124,8 +1111,6 @@ class MllamaForCausalLM(nn.Module):
kv_range_for_decode: Optional[List[Tuple[int, int]]],
full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor,
torch.Tensor]],
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
skip_cross_attention: bool,
) -> torch.Tensor:
hidden_states = self.model(
......@@ -1135,8 +1120,6 @@ class MllamaForCausalLM(nn.Module):
cross_attention_mask=cross_attention_mask,
kv_range_for_decode=kv_range_for_decode,
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
skip_cross_attention=skip_cross_attention,
)
return hidden_states
......@@ -1353,10 +1336,9 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
**kwargs: object,
) -> Union[Tuple, CausalLMOutputWithPast]:
attn_metadata = get_forward_context().attn_metadata
if attn_metadata.num_prefill_tokens > 0 and \
attn_metadata.num_decode_tokens > 0:
raise ValueError("Chunk prefill not supported")
......@@ -1410,8 +1392,6 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
cross_attention_mask=cross_attention_mask,
kv_range_for_decode=kv_range_for_decode,
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
skip_cross_attention=skip_cross_attention,
)
......
......@@ -16,7 +16,7 @@ from transformers import (BatchFeature, PretrainedConfig, ProcessorMixin,
from transformers.image_utils import ImageInput
from transformers.tokenization_utils_base import TextInput
from vllm.attention import Attention, AttentionMetadata
from vllm.attention import Attention
from vllm.attention.layer import MultiHeadAttention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
......@@ -460,15 +460,13 @@ class MolmoAttention(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
if self.q_norm is not None and self.k_norm is not None:
q, k = self._apply_qk_norm(q, k)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
......@@ -580,8 +578,6 @@ class MolmoDecoderLayer(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
# Self Attention
......@@ -594,8 +590,6 @@ class MolmoDecoderLayer(nn.Module):
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
hidden_states, residual = self.post_attention_layernorm(
......@@ -610,8 +604,6 @@ class MolmoDecoderNormAfterLayer(MolmoDecoderLayer):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
# Self Attention
......@@ -619,8 +611,6 @@ class MolmoDecoderNormAfterLayer(MolmoDecoderLayer):
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
hidden_states = self.input_layernorm(hidden_states)
......@@ -841,8 +831,6 @@ class MolmoModel(nn.Module, SupportsQuant):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
......@@ -858,13 +846,10 @@ class MolmoModel(nn.Module, SupportsQuant):
residual = intermediate_tensors["residual"]
# Apply blocks one-by-one.
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
for layer in self.layers[self.start_layer:self.end_layer]:
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
residual,
)
if not get_pp_group().is_last_rank:
......@@ -1643,8 +1628,6 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
self,
input_ids: torch.LongTensor,
positions: torch.LongTensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
......@@ -1663,8 +1646,6 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
hidden_states = self.model(input_ids,
positions,
kv_caches,
attn_metadata,
intermediate_tensors,
inputs_embeds=inputs_embeds)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment