Unverified Commit cdc1fa12 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Remove unused kwargs from model definitions (#13555)

parent f61528d4
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Iterable, List, Optional, Set, Tuple from typing import Iterable, Optional, Set, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
...@@ -69,8 +68,6 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module): ...@@ -69,8 +68,6 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
previous_hidden_states: torch.Tensor, previous_hidden_states: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
spec_step_index: int = 0, spec_step_index: int = 0,
...@@ -88,8 +85,6 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module): ...@@ -88,8 +85,6 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
hidden_states, residual = self.mtp_block(positions=positions, hidden_states, residual = self.mtp_block(positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
residual=None) residual=None)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
return self.shared_head(hidden_states) return self.shared_head(hidden_states)
...@@ -122,8 +117,6 @@ class DeepSeekMultiTokenPredictor(nn.Module): ...@@ -122,8 +117,6 @@ class DeepSeekMultiTokenPredictor(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
previous_hidden_states: torch.Tensor, previous_hidden_states: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
spec_step_idx: int = 0, spec_step_idx: int = 0,
...@@ -131,8 +124,6 @@ class DeepSeekMultiTokenPredictor(nn.Module): ...@@ -131,8 +124,6 @@ class DeepSeekMultiTokenPredictor(nn.Module):
return self.layers[str(self.mtp_start_layer_idx + spec_step_idx)]( return self.layers[str(self.mtp_start_layer_idx + spec_step_idx)](
input_ids, input_ids,
positions, positions,
kv_caches[spec_step_idx],
attn_metadata,
previous_hidden_states, previous_hidden_states,
inputs_embeds, inputs_embeds,
spec_step_idx, spec_step_idx,
...@@ -165,16 +156,14 @@ class DeepSeekMTP(nn.Module): ...@@ -165,16 +156,14 @@ class DeepSeekMTP(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
previous_hidden_states: torch.Tensor, previous_hidden_states: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
spec_step_idx: int = 0, spec_step_idx: int = 0,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions,
attn_metadata, previous_hidden_states, previous_hidden_states, inputs_embeds,
inputs_embeds, spec_step_idx) spec_step_idx)
return hidden_states return hidden_states
def compute_logits( def compute_logits(
......
...@@ -22,13 +22,13 @@ ...@@ -22,13 +22,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only DeepseekV2/DeepseekV3 model.""" """Inference-only DeepseekV2/DeepseekV3 model."""
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed import (get_pp_group, from vllm.distributed import (get_pp_group,
...@@ -279,8 +279,6 @@ class DeepseekV2Attention(nn.Module): ...@@ -279,8 +279,6 @@ class DeepseekV2Attention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
if self.q_lora_rank is not None: if self.q_lora_rank is not None:
q = self.q_a_proj(hidden_states)[0] q = self.q_a_proj(hidden_states)[0]
...@@ -313,7 +311,7 @@ class DeepseekV2Attention(nn.Module): ...@@ -313,7 +311,7 @@ class DeepseekV2Attention(nn.Module):
v = torch.nn.functional.pad( v = torch.nn.functional.pad(
v, [0, self.qk_head_dim - self.v_head_dim], v, [0, self.qk_head_dim - self.v_head_dim],
value=0).view(-1, self.num_local_heads * self.qk_head_dim) value=0).view(-1, self.num_local_heads * self.qk_head_dim)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
attn_output = attn_output.view( attn_output = attn_output.view(
-1, self.num_local_heads, -1, self.num_local_heads,
self.qk_head_dim)[..., :self.v_head_dim].reshape( self.qk_head_dim)[..., :self.v_head_dim].reshape(
...@@ -451,8 +449,6 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -451,8 +449,6 @@ class DeepseekV2MLAAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
if self.q_lora_rank is not None: if self.q_lora_rank is not None:
ckq = self.q_a_proj(hidden_states)[0] ckq = self.q_a_proj(hidden_states)[0]
...@@ -462,8 +458,7 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -462,8 +458,7 @@ class DeepseekV2MLAAttention(nn.Module):
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split( kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
return self.mla_attn(hidden_states_or_q_c, kv_c_normed, k_pe, kv_cache, return self.mla_attn(hidden_states_or_q_c, kv_c_normed, k_pe)
attn_metadata)
class DeepseekV2DecoderLayer(nn.Module): class DeepseekV2DecoderLayer(nn.Module):
...@@ -532,8 +527,6 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -532,8 +527,6 @@ class DeepseekV2DecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
# Self Attention # Self Attention
...@@ -546,8 +539,6 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -546,8 +539,6 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
# Fully Connected # Fully Connected
...@@ -608,8 +599,6 @@ class DeepseekV2Model(nn.Module): ...@@ -608,8 +599,6 @@ class DeepseekV2Model(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
...@@ -624,11 +613,8 @@ class DeepseekV2Model(nn.Module): ...@@ -624,11 +613,8 @@ class DeepseekV2Model(nn.Module):
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer): for layer in self.layers[self.start_layer:self.end_layer]:
layer = self.layers[i] hidden_states, residual = layer(positions, hidden_states)
hidden_states, residual = layer(positions, hidden_states,
kv_caches[i - self.start_layer],
attn_metadata, residual)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({ return IntermediateTensors({
...@@ -665,13 +651,10 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP): ...@@ -665,13 +651,10 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, intermediate_tensors,
attn_metadata, intermediate_tensors,
inputs_embeds) inputs_embeds)
return hidden_states return hidden_states
......
...@@ -13,7 +13,6 @@ import torch.nn.functional as F ...@@ -13,7 +13,6 @@ import torch.nn.functional as F
from einops import rearrange, repeat from einops import rearrange, repeat
from transformers import BatchFeature from transformers import BatchFeature
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
...@@ -595,8 +594,6 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -595,8 +594,6 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def forward(self, def forward(self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object): **kwargs: object):
...@@ -614,8 +611,6 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -614,8 +611,6 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
hidden_states = self.language_model(input_ids, hidden_states = self.language_model(input_ids,
positions, positions,
kv_caches,
attn_metadata,
intermediate_tensors, intermediate_tensors,
inputs_embeds=inputs_embeds) inputs_embeds=inputs_embeds)
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Iterable, List, Optional, Tuple from typing import Iterable, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
...@@ -121,8 +120,6 @@ class EAGLE(nn.Module): ...@@ -121,8 +120,6 @@ class EAGLE(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
previous_hidden_states: torch.Tensor, previous_hidden_states: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
...@@ -140,8 +137,6 @@ class EAGLE(nn.Module): ...@@ -140,8 +137,6 @@ class EAGLE(nn.Module):
input_ids=None, input_ids=None,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
positions=positions, positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
) )
return hidden_states return hidden_states
......
...@@ -24,12 +24,12 @@ ...@@ -24,12 +24,12 @@
# limitations under the License. # limitations under the License.
"""Inference-only Exaone model compatible with HuggingFace weights.""" """Inference-only Exaone model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
...@@ -179,13 +179,11 @@ class ExaoneAttention(nn.Module): ...@@ -179,13 +179,11 @@ class ExaoneAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
output, _ = self.out_proj(attn_output) output, _ = self.out_proj(attn_output)
return output return output
...@@ -225,14 +223,10 @@ class ExaoneBlockAttention(nn.Module): ...@@ -225,14 +223,10 @@ class ExaoneBlockAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
return self.attention( return self.attention(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
...@@ -288,8 +282,6 @@ class ExaoneDecoderLayer(nn.Module): ...@@ -288,8 +282,6 @@ class ExaoneDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
...@@ -301,8 +293,6 @@ class ExaoneDecoderLayer(nn.Module): ...@@ -301,8 +293,6 @@ class ExaoneDecoderLayer(nn.Module):
hidden_states = self.attn( hidden_states = self.attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
# Fully Connected # Fully Connected
...@@ -365,8 +355,6 @@ class ExaoneModel(nn.Module): ...@@ -365,8 +355,6 @@ class ExaoneModel(nn.Module):
self, self,
input_ids: Optional[torch.Tensor], input_ids: Optional[torch.Tensor],
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
...@@ -381,13 +369,10 @@ class ExaoneModel(nn.Module): ...@@ -381,13 +369,10 @@ class ExaoneModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer): for layer in self.h[self.start_layer:self.end_layer]:
layer = self.h[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
residual, residual,
) )
...@@ -471,14 +456,11 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -471,14 +456,11 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
model_output = self.transformer(input_ids, positions, kv_caches, model_output = self.transformer(input_ids, positions,
attn_metadata, intermediate_tensors, intermediate_tensors, inputs_embeds)
inputs_embeds)
return model_output return model_output
def compute_logits( def compute_logits(
......
...@@ -20,14 +20,14 @@ ...@@ -20,14 +20,14 @@
"""PyTorch Falcon model.""" """PyTorch Falcon model."""
import math import math
from typing import Iterable, List, Optional, Set, Tuple, Union from typing import Iterable, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from torch.nn import LayerNorm from torch.nn import LayerNorm
from transformers import FalconConfig as HF_FalconConfig from transformers import FalconConfig as HF_FalconConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
...@@ -190,8 +190,6 @@ class FalconAttention(nn.Module): ...@@ -190,8 +190,6 @@ class FalconAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, bias = self.query_key_value(hidden_states) qkv, bias = self.query_key_value(hidden_states)
if bias is not None: if bias is not None:
...@@ -199,7 +197,7 @@ class FalconAttention(nn.Module): ...@@ -199,7 +197,7 @@ class FalconAttention(nn.Module):
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
if self.use_rotary: if self.use_rotary:
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
attn_output, bias = self.dense(attn_output) attn_output, bias = self.dense(attn_output)
return attn_output, bias return attn_output, bias
...@@ -291,8 +289,6 @@ class FalconDecoderLayer(nn.Module): ...@@ -291,8 +289,6 @@ class FalconDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
residual = hidden_states residual = hidden_states
...@@ -306,8 +302,6 @@ class FalconDecoderLayer(nn.Module): ...@@ -306,8 +302,6 @@ class FalconDecoderLayer(nn.Module):
attention_output, attention_bias = self.self_attention( attention_output, attention_bias = self.self_attention(
positions=positions, positions=positions,
hidden_states=attention_layernorm_out, hidden_states=attention_layernorm_out,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
if self.reduce_row_parallel_results and attention_bias is not None: if self.reduce_row_parallel_results and attention_bias is not None:
attention_output += attention_bias attention_output += attention_bias
...@@ -384,8 +378,6 @@ class FalconModel(nn.Module): ...@@ -384,8 +378,6 @@ class FalconModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
...@@ -396,14 +388,8 @@ class FalconModel(nn.Module): ...@@ -396,14 +388,8 @@ class FalconModel(nn.Module):
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.get_input_embeddings(input_ids)
else: else:
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer): for layer in self.h[self.start_layer:self.end_layer]:
layer = self.h[i] hidden_states = layer(positions, hidden_states)
hidden_states = layer(
positions,
hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states}) return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
...@@ -450,14 +436,11 @@ class FalconForCausalLM(nn.Module, SupportsPP): ...@@ -450,14 +436,11 @@ class FalconForCausalLM(nn.Module, SupportsPP):
self, self,
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions,
attn_metadata, intermediate_tensors, intermediate_tensors, inputs_embeds)
inputs_embeds)
return hidden_states return hidden_states
def compute_logits( def compute_logits(
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import math import math
from typing import Iterable, List, Optional, Set, Tuple from typing import Iterable, Optional, Set, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
...@@ -50,8 +49,7 @@ class Florence2LanguageModel(nn.Module): ...@@ -50,8 +49,7 @@ class Florence2LanguageModel(nn.Module):
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
encoder_input_ids: torch.Tensor, encoder_input_ids: torch.Tensor,
encoder_positions: torch.Tensor, kv_caches: List[torch.Tensor], encoder_positions: torch.Tensor) -> torch.Tensor:
attn_metadata: AttentionMetadata) -> torch.Tensor:
r""" r"""
Args: Args:
input_ids input_ids
...@@ -64,10 +62,6 @@ class Florence2LanguageModel(nn.Module): ...@@ -64,10 +62,6 @@ class Florence2LanguageModel(nn.Module):
Indices of *encoder* input sequence tokens in the vocabulary. Indices of *encoder* input sequence tokens in the vocabulary.
encoder_positions: encoder_positions:
Positions of *encoder* input sequence tokens. Positions of *encoder* input sequence tokens.
kv_caches:
Layer-wise list of KV cache tensors
attn_metadata:
vLLM Attention metadata structure
Returns: Returns:
Model output torch.Tensor Model output torch.Tensor
""" """
...@@ -78,18 +72,14 @@ class Florence2LanguageModel(nn.Module): ...@@ -78,18 +72,14 @@ class Florence2LanguageModel(nn.Module):
# Run encoder attention if a non-zero number of encoder tokens # Run encoder attention if a non-zero number of encoder tokens
# are provided as input # are provided as input
encoder_hidden_states = self.encoder(input_ids=encoder_input_ids, encoder_hidden_states = self.encoder(input_ids=encoder_input_ids,
positions=encoder_positions, positions=encoder_positions)
kv_caches=kv_caches,
attn_metadata=attn_metadata)
# decoder outputs consists of # decoder outputs consists of
# (dec_features, past_key_value, dec_hidden, dec_attn) # (dec_features, past_key_value, dec_hidden, dec_attn)
decoder_outputs = self.decoder( decoder_outputs = self.decoder(
decoder_input_ids=input_ids, decoder_input_ids=input_ids,
decoder_positions=positions, decoder_positions=positions,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states)
kv_caches=kv_caches,
attn_metadata=attn_metadata)
return decoder_outputs return decoder_outputs
...@@ -122,8 +112,6 @@ class Florence2LanguageForConditionalGeneration(nn.Module): ...@@ -122,8 +112,6 @@ class Florence2LanguageForConditionalGeneration(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
encoder_input_ids: torch.Tensor, encoder_input_ids: torch.Tensor,
encoder_positions: torch.Tensor, encoder_positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
r""" r"""
...@@ -136,15 +124,11 @@ class Florence2LanguageForConditionalGeneration(nn.Module): ...@@ -136,15 +124,11 @@ class Florence2LanguageForConditionalGeneration(nn.Module):
torch.Tensor of *encoder* input token ids. torch.Tensor of *encoder* input token ids.
encoder_positions encoder_positions
torch.Tensor of *encoder* position indices torch.Tensor of *encoder* position indices
kv_caches:
Layer-wise list of KV cache tensors
attn_metadata:
vLLM Attention metadata structure
Returns: Returns:
Output torch.Tensor Output torch.Tensor
""" """
return self.model(input_ids, positions, encoder_input_ids, return self.model(input_ids, positions, encoder_input_ids,
encoder_positions, kv_caches, attn_metadata) encoder_positions)
def compute_logits( def compute_logits(
self, self,
...@@ -213,8 +197,6 @@ class Florence2ForConditionalGeneration(nn.Module): ...@@ -213,8 +197,6 @@ class Florence2ForConditionalGeneration(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
*, *,
encoder_input_ids: torch.Tensor, encoder_input_ids: torch.Tensor,
...@@ -231,15 +213,11 @@ class Florence2ForConditionalGeneration(nn.Module): ...@@ -231,15 +213,11 @@ class Florence2ForConditionalGeneration(nn.Module):
torch.Tensor of *encoder* input token ids. torch.Tensor of *encoder* input token ids.
encoder_positions encoder_positions
torch.Tensor of *encoder* position indices torch.Tensor of *encoder* position indices
kv_caches:
Layer-wise list of KV cache tensors
attn_metadata:
vLLM Attention metadata structure
Returns: Returns:
Output torch.Tensor Output torch.Tensor
""" """
return self.language_model(input_ids, positions, encoder_input_ids, return self.language_model(input_ids, positions, encoder_input_ids,
encoder_positions, kv_caches, attn_metadata) encoder_positions)
def compute_logits( def compute_logits(
self, self,
......
...@@ -25,7 +25,6 @@ import torch.nn as nn ...@@ -25,7 +25,6 @@ import torch.nn as nn
from transformers import (BatchFeature, FuyuConfig, FuyuImageProcessor, from transformers import (BatchFeature, FuyuConfig, FuyuImageProcessor,
FuyuProcessor) FuyuProcessor)
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
...@@ -351,8 +350,6 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -351,8 +350,6 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object, **kwargs: object,
...@@ -371,8 +368,6 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -371,8 +368,6 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
hidden_states = self.language_model( hidden_states = self.language_model(
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
) )
......
...@@ -16,13 +16,13 @@ ...@@ -16,13 +16,13 @@
# limitations under the License. # limitations under the License.
"""Inference-only Gemma model compatible with HuggingFace weights.""" """Inference-only Gemma model compatible with HuggingFace weights."""
from functools import cache from functools import cache
from typing import Iterable, List, Optional, Set, Tuple, Union from typing import Iterable, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from transformers import GemmaConfig from transformers import GemmaConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
...@@ -183,13 +183,11 @@ class GemmaAttention(nn.Module): ...@@ -183,13 +183,11 @@ class GemmaAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
...@@ -233,8 +231,6 @@ class GemmaDecoderLayer(nn.Module): ...@@ -233,8 +231,6 @@ class GemmaDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
...@@ -247,8 +243,6 @@ class GemmaDecoderLayer(nn.Module): ...@@ -247,8 +243,6 @@ class GemmaDecoderLayer(nn.Module):
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
# Fully Connected # Fully Connected
...@@ -298,8 +292,6 @@ class GemmaModel(nn.Module): ...@@ -298,8 +292,6 @@ class GemmaModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
...@@ -313,13 +305,10 @@ class GemmaModel(nn.Module): ...@@ -313,13 +305,10 @@ class GemmaModel(nn.Module):
else: else:
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer): for layer in self.layers[self.start_layer:self.end_layer]:
layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
residual, residual,
) )
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
...@@ -370,13 +359,10 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -370,13 +359,10 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, intermediate_tensors,
attn_metadata, intermediate_tensors,
inputs_embeds) inputs_embeds)
return hidden_states return hidden_states
......
...@@ -15,13 +15,13 @@ ...@@ -15,13 +15,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Iterable, List, Optional, Set, Tuple, Union from typing import Iterable, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from transformers import Gemma2Config from transformers import Gemma2Config
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
...@@ -164,13 +164,11 @@ class Gemma2Attention(nn.Module): ...@@ -164,13 +164,11 @@ class Gemma2Attention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
...@@ -220,8 +218,6 @@ class Gemma2DecoderLayer(nn.Module): ...@@ -220,8 +218,6 @@ class Gemma2DecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
if residual is None: if residual is None:
...@@ -233,8 +229,6 @@ class Gemma2DecoderLayer(nn.Module): ...@@ -233,8 +229,6 @@ class Gemma2DecoderLayer(nn.Module):
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.post_attention_layernorm(hidden_states)
...@@ -284,8 +278,6 @@ class Gemma2Model(nn.Module): ...@@ -284,8 +278,6 @@ class Gemma2Model(nn.Module):
self, self,
input_ids: Optional[torch.Tensor], input_ids: Optional[torch.Tensor],
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
...@@ -300,13 +292,10 @@ class Gemma2Model(nn.Module): ...@@ -300,13 +292,10 @@ class Gemma2Model(nn.Module):
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer): for layer in self.layers[self.start_layer:self.end_layer]:
layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
residual, residual,
) )
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
...@@ -415,13 +404,10 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -415,13 +404,10 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, intermediate_tensors,
attn_metadata, intermediate_tensors,
inputs_embeds) inputs_embeds)
return hidden_states return hidden_states
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
# https://github.com/THUDM/CogAgent # https://github.com/THUDM/CogAgent
"""Inference-only CogAgent model compatible with THUDM weights.""" """Inference-only CogAgent model compatible with THUDM weights."""
from argparse import Namespace from argparse import Namespace
from typing import List, Literal, Mapping, Optional, TypedDict, Union from typing import Literal, Mapping, Optional, TypedDict, Union
import torch import torch
from torch import nn from torch import nn
...@@ -15,7 +15,6 @@ from transformers import PreTrainedTokenizer, TensorType ...@@ -15,7 +15,6 @@ from transformers import PreTrainedTokenizer, TensorType
from transformers.image_utils import ImageInput from transformers.image_utils import ImageInput
from transformers.tokenization_utils_base import TextInput from transformers.tokenization_utils_base import TextInput
from vllm.attention import AttentionMetadata
from vllm.attention.layer import MultiHeadAttention from vllm.attention.layer import MultiHeadAttention
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
...@@ -628,8 +627,6 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP, ...@@ -628,8 +627,6 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object, **kwargs: object,
...@@ -645,8 +642,7 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP, ...@@ -645,8 +642,7 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
vision_embeddings) vision_embeddings)
input_ids = None input_ids = None
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions,
attn_metadata, intermediate_tensors, intermediate_tensors, inputs_embeds)
inputs_embeds)
return hidden_states return hidden_states
...@@ -18,13 +18,13 @@ ...@@ -18,13 +18,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only GPT-2 model compatible with HuggingFace weights.""" """Inference-only GPT-2 model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Set, Tuple, Union from typing import Iterable, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from transformers import GPT2Config from transformers import GPT2Config
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed.parallel_state import ( from vllm.distributed.parallel_state import (
...@@ -92,12 +92,10 @@ class GPT2Attention(nn.Module): ...@@ -92,12 +92,10 @@ class GPT2Attention(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.c_attn(hidden_states) qkv, _ = self.c_attn(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
attn_output, _ = self.c_proj(attn_output) attn_output, _ = self.c_proj(attn_output)
return attn_output return attn_output
...@@ -164,16 +162,10 @@ class GPT2Block(nn.Module): ...@@ -164,16 +162,10 @@ class GPT2Block(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
residual = hidden_states residual = hidden_states
hidden_states = self.ln_1(hidden_states) hidden_states = self.ln_1(hidden_states)
attn_output = self.attn( attn_output = self.attn(hidden_states=hidden_states)
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
# residual connection # residual connection
hidden_states = attn_output + residual hidden_states = attn_output + residual
...@@ -222,8 +214,6 @@ class GPT2Model(nn.Module): ...@@ -222,8 +214,6 @@ class GPT2Model(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor], inputs_embeds: Optional[torch.Tensor],
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
...@@ -236,11 +226,8 @@ class GPT2Model(nn.Module): ...@@ -236,11 +226,8 @@ class GPT2Model(nn.Module):
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer): for layer in self.h[self.start_layer:self.end_layer]:
layer = self.h[i] hidden_states = layer(hidden_states)
hidden_states = layer(hidden_states,
kv_caches[i - self.start_layer],
attn_metadata)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states}) return IntermediateTensors({"hidden_states": hidden_states})
...@@ -279,14 +266,11 @@ class GPT2LMHeadModel(nn.Module, SupportsPP): ...@@ -279,14 +266,11 @@ class GPT2LMHeadModel(nn.Module, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions,
attn_metadata, intermediate_tensors, intermediate_tensors, inputs_embeds)
inputs_embeds)
return hidden_states return hidden_states
def compute_logits( def compute_logits(
......
...@@ -19,13 +19,13 @@ ...@@ -19,13 +19,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only GPTBigCode model compatible with HuggingFace weights.""" """Inference-only GPTBigCode model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Set, Tuple, Union from typing import Iterable, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from transformers import GPTBigCodeConfig from transformers import GPTBigCodeConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
...@@ -101,8 +101,6 @@ class GPTBigCodeAttention(nn.Module): ...@@ -101,8 +101,6 @@ class GPTBigCodeAttention(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.c_attn(hidden_states) qkv, _ = self.c_attn(hidden_states)
q, k, v = qkv.split( q, k, v = qkv.split(
...@@ -112,7 +110,7 @@ class GPTBigCodeAttention(nn.Module): ...@@ -112,7 +110,7 @@ class GPTBigCodeAttention(nn.Module):
], ],
dim=-1, dim=-1,
) )
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
attn_output, _ = self.c_proj(attn_output) attn_output, _ = self.c_proj(attn_output)
return attn_output return attn_output
...@@ -173,16 +171,10 @@ class GPTBigCodeBlock(nn.Module): ...@@ -173,16 +171,10 @@ class GPTBigCodeBlock(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
residual = hidden_states residual = hidden_states
hidden_states = self.ln_1(hidden_states) hidden_states = self.ln_1(hidden_states)
attn_output = self.attn( attn_output = self.attn(hidden_states=hidden_states, )
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
# residual connection # residual connection
hidden_states = attn_output + residual hidden_states = attn_output + residual
...@@ -234,8 +226,6 @@ class GPTBigCodeModel(nn.Module): ...@@ -234,8 +226,6 @@ class GPTBigCodeModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
...@@ -246,11 +236,8 @@ class GPTBigCodeModel(nn.Module): ...@@ -246,11 +236,8 @@ class GPTBigCodeModel(nn.Module):
else: else:
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer): for layer in self.h[self.start_layer:self.end_layer]:
layer = self.h[i] hidden_states = layer(hidden_states)
hidden_states = layer(hidden_states,
kv_caches[i - self.start_layer],
attn_metadata)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states}) return IntermediateTensors({"hidden_states": hidden_states})
...@@ -302,14 +289,11 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -302,14 +289,11 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions,
attn_metadata, intermediate_tensors, intermediate_tensors, inputs_embeds)
inputs_embeds)
return hidden_states return hidden_states
def compute_logits( def compute_logits(
......
...@@ -17,13 +17,13 @@ ...@@ -17,13 +17,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only GPT-J model compatible with HuggingFace weights.""" """Inference-only GPT-J model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Set, Tuple, Union from typing import Iterable, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from transformers import GPTJConfig from transformers import GPTJConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
...@@ -104,13 +104,11 @@ class GPTJAttention(nn.Module): ...@@ -104,13 +104,11 @@ class GPTJAttention(nn.Module):
self, self,
position_ids: torch.Tensor, position_ids: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k = self.rotary_emb(position_ids, q, k) q, k = self.rotary_emb(position_ids, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
attn_output, _ = self.out_proj(attn_output) attn_output, _ = self.out_proj(attn_output)
return attn_output return attn_output
...@@ -167,16 +165,12 @@ class GPTJBlock(nn.Module): ...@@ -167,16 +165,12 @@ class GPTJBlock(nn.Module):
self, self,
position_ids: torch.Tensor, position_ids: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
residual = hidden_states residual = hidden_states
hidden_states = self.ln_1(hidden_states) hidden_states = self.ln_1(hidden_states)
attn_output = self.attn( attn_output = self.attn(
position_ids=position_ids, position_ids=position_ids,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
mlp_output = self.mlp(hidden_states) mlp_output = self.mlp(hidden_states)
hidden_states = attn_output + mlp_output + residual hidden_states = attn_output + mlp_output + residual
...@@ -217,8 +211,6 @@ class GPTJModel(nn.Module): ...@@ -217,8 +211,6 @@ class GPTJModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
...@@ -229,14 +221,8 @@ class GPTJModel(nn.Module): ...@@ -229,14 +221,8 @@ class GPTJModel(nn.Module):
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.get_input_embeddings(input_ids)
else: else:
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer): for layer in self.h[self.start_layer:self.end_layer]:
layer = self.h[i] hidden_states = layer(position_ids, hidden_states)
hidden_states = layer(
position_ids,
hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states}) return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
...@@ -273,14 +259,11 @@ class GPTJForCausalLM(nn.Module, SupportsPP): ...@@ -273,14 +259,11 @@ class GPTJForCausalLM(nn.Module, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions,
attn_metadata, intermediate_tensors, intermediate_tensors, inputs_embeds)
inputs_embeds)
return hidden_states return hidden_states
def compute_logits( def compute_logits(
......
...@@ -17,13 +17,13 @@ ...@@ -17,13 +17,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only GPT-NeoX model compatible with HuggingFace weights.""" """Inference-only GPT-NeoX model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Set, Tuple, Union from typing import Iterable, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from transformers import GPTNeoXConfig from transformers import GPTNeoXConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
...@@ -104,13 +104,11 @@ class GPTNeoXAttention(nn.Module): ...@@ -104,13 +104,11 @@ class GPTNeoXAttention(nn.Module):
self, self,
position_ids: torch.Tensor, position_ids: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.query_key_value(hidden_states) qkv, _ = self.query_key_value(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k = self.rotary_emb(position_ids, q, k) q, k = self.rotary_emb(position_ids, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
output, _ = self.dense(attn_output) output, _ = self.dense(attn_output)
return output return output
...@@ -167,15 +165,11 @@ class GPTNeoXLayer(nn.Module): ...@@ -167,15 +165,11 @@ class GPTNeoXLayer(nn.Module):
self, self,
position_ids: torch.Tensor, position_ids: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
attn_input = self.input_layernorm(hidden_states) attn_input = self.input_layernorm(hidden_states)
attn_output = self.attention( attn_output = self.attention(
position_ids=position_ids, position_ids=position_ids,
hidden_states=attn_input, hidden_states=attn_input,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
if self.use_parallel_residual: if self.use_parallel_residual:
...@@ -230,8 +224,6 @@ class GPTNeoXModel(nn.Module): ...@@ -230,8 +224,6 @@ class GPTNeoXModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
...@@ -242,14 +234,8 @@ class GPTNeoXModel(nn.Module): ...@@ -242,14 +234,8 @@ class GPTNeoXModel(nn.Module):
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.get_input_embeddings(input_ids)
else: else:
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer): for layer in self.layers[self.start_layer:self.end_layer]:
layer = self.layers[i] hidden_states = layer(position_ids, hidden_states)
hidden_states = layer(
position_ids,
hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states}) return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.final_layer_norm(hidden_states)
...@@ -285,14 +271,11 @@ class GPTNeoXForCausalLM(nn.Module, SupportsPP): ...@@ -285,14 +271,11 @@ class GPTNeoXForCausalLM(nn.Module, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.gpt_neox(input_ids, positions, kv_caches, hidden_states = self.gpt_neox(input_ids, positions,
attn_metadata, intermediate_tensors, intermediate_tensors, inputs_embeds)
inputs_embeds)
return hidden_states return hidden_states
def compute_logits( def compute_logits(
......
...@@ -22,13 +22,13 @@ ...@@ -22,13 +22,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only IBM Granite model compatible with HuggingFace weights.""" """Inference-only IBM Granite model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from transformers import GraniteConfig from transformers import GraniteConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
...@@ -166,13 +166,11 @@ class GraniteAttention(nn.Module): ...@@ -166,13 +166,11 @@ class GraniteAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
...@@ -233,8 +231,6 @@ class GraniteDecoderLayer(nn.Module): ...@@ -233,8 +231,6 @@ class GraniteDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
residual = hidden_states residual = hidden_states
...@@ -242,8 +238,6 @@ class GraniteDecoderLayer(nn.Module): ...@@ -242,8 +238,6 @@ class GraniteDecoderLayer(nn.Module):
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
hidden_states = residual + hidden_states * self.residual_multiplier hidden_states = residual + hidden_states * self.residual_multiplier
# Fully Connected # Fully Connected
...@@ -300,8 +294,6 @@ class GraniteModel(nn.Module): ...@@ -300,8 +294,6 @@ class GraniteModel(nn.Module):
self, self,
input_ids: Optional[torch.Tensor], input_ids: Optional[torch.Tensor],
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
...@@ -318,14 +310,8 @@ class GraniteModel(nn.Module): ...@@ -318,14 +310,8 @@ class GraniteModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer): for layer in self.layers[self.start_layer:self.end_layer]:
layer = self.layers[i] hidden_states = layer(positions, hidden_states)
hidden_states = layer(
positions,
hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({ return IntermediateTensors({
...@@ -405,13 +391,10 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -405,13 +391,10 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
model_output = self.model(input_ids, positions, kv_caches, model_output = self.model(input_ids, positions, intermediate_tensors,
attn_metadata, intermediate_tensors,
inputs_embeds) inputs_embeds)
return model_output return model_output
......
...@@ -22,13 +22,13 @@ ...@@ -22,13 +22,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only GraniteMoe model.""" """Inference-only GraniteMoe model."""
from typing import Iterable, List, Optional, Set, Tuple from typing import Iterable, Optional, Set, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers.models.granitemoe import GraniteMoeConfig from transformers.models.granitemoe import GraniteMoeConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
...@@ -173,13 +173,11 @@ class GraniteMoeAttention(nn.Module): ...@@ -173,13 +173,11 @@ class GraniteMoeAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
...@@ -226,8 +224,6 @@ class GraniteMoeDecoderLayer(nn.Module): ...@@ -226,8 +224,6 @@ class GraniteMoeDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
# Self Attention # Self Attention
residual = hidden_states residual = hidden_states
...@@ -235,8 +231,6 @@ class GraniteMoeDecoderLayer(nn.Module): ...@@ -235,8 +231,6 @@ class GraniteMoeDecoderLayer(nn.Module):
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
hidden_states = residual + hidden_states * self.residual_multiplier hidden_states = residual + hidden_states * self.residual_multiplier
residual = hidden_states residual = hidden_states
...@@ -287,8 +281,6 @@ class GraniteMoeModel(nn.Module): ...@@ -287,8 +281,6 @@ class GraniteMoeModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -303,11 +295,8 @@ class GraniteMoeModel(nn.Module): ...@@ -303,11 +295,8 @@ class GraniteMoeModel(nn.Module):
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer): for layer in self.layers[self.start_layer:self.end_layer]:
layer = self.layers[i] hidden_states = layer(positions, hidden_states)
hidden_states = layer(positions, hidden_states,
kv_caches[i - self.start_layer],
attn_metadata)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({ return IntermediateTensors({
"hidden_states": hidden_states, "hidden_states": hidden_states,
...@@ -377,13 +366,10 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -377,13 +366,10 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, intermediate_tensors,
attn_metadata, intermediate_tensors,
inputs_embeds) inputs_embeds)
return hidden_states return hidden_states
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from array import array from array import array
from typing import List, Optional, Union from typing import Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from xformers.ops.fmha.attn_bias import BlockDiagonalMask from xformers.ops.fmha.attn_bias import BlockDiagonalMask
from vllm.attention import AttentionMetadata
from vllm.attention.backends.xformers import XFormersImpl from vllm.attention.backends.xformers import XFormersImpl
from vllm.config import ModelConfig, VllmConfig from vllm.config import ModelConfig, VllmConfig
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.pooler import PoolerHead from vllm.model_executor.layers.pooler import PoolerHead
from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.models.llama import LlamaForCausalLM
...@@ -217,13 +217,12 @@ class GritLM(LlamaForCausalLM): ...@@ -217,13 +217,12 @@ class GritLM(LlamaForCausalLM):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
**kwargs, **kwargs,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
# Change attention to non-causal for pooling tasks. # Change attention to non-causal for pooling tasks.
if self.runner_type == "pooling": if self.runner_type == "pooling":
attn_metadata = get_forward_context().attn_metadata
assert attn_metadata.prefill_metadata.attn_bias is None assert attn_metadata.prefill_metadata.attn_bias is None
attn_metadata.prefill_metadata.attn_bias = [ attn_metadata.prefill_metadata.attn_bias = [
BlockDiagonalMask.from_seqlens(attn_metadata.seq_lens) BlockDiagonalMask.from_seqlens(attn_metadata.seq_lens)
...@@ -232,8 +231,6 @@ class GritLM(LlamaForCausalLM): ...@@ -232,8 +231,6 @@ class GritLM(LlamaForCausalLM):
return super().forward( return super().forward(
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
**kwargs, **kwargs,
) )
......
...@@ -25,7 +25,6 @@ from torch import nn ...@@ -25,7 +25,6 @@ from torch import nn
from transformers import (BatchFeature, Idefics3Config, Idefics3ImageProcessor, from transformers import (BatchFeature, Idefics3Config, Idefics3ImageProcessor,
Idefics3Processor) Idefics3Processor)
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.linear import ReplicatedLinear
...@@ -563,8 +562,6 @@ class Idefics3Model(nn.Module): ...@@ -563,8 +562,6 @@ class Idefics3Model(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
...@@ -572,8 +569,6 @@ class Idefics3Model(nn.Module): ...@@ -572,8 +569,6 @@ class Idefics3Model(nn.Module):
hidden_states = self.text_model( hidden_states = self.text_model(
input_ids, input_ids,
positions, positions,
kv_caches,
attn_metadata,
intermediate_tensors, intermediate_tensors,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
) )
...@@ -645,8 +640,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -645,8 +640,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object, **kwargs: object,
...@@ -664,8 +657,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -664,8 +657,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
hidden_states = self.model.text_model(input_ids, hidden_states = self.model.text_model(input_ids,
positions, positions,
kv_caches,
attn_metadata,
intermediate_tensors, intermediate_tensors,
inputs_embeds=inputs_embeds) inputs_embeds=inputs_embeds)
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import (TYPE_CHECKING, List, Optional, Protocol, Type, Union, from typing import (TYPE_CHECKING, Optional, Protocol, Type, Union, overload,
overload, runtime_checkable) runtime_checkable)
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -11,7 +11,6 @@ from vllm.logger import init_logger ...@@ -11,7 +11,6 @@ from vllm.logger import init_logger
from vllm.utils import supports_kw from vllm.utils import supports_kw
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import PoolerOutput from vllm.model_executor.layers.pooler import PoolerOutput
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
...@@ -46,8 +45,6 @@ class VllmModel(Protocol[T_co]): ...@@ -46,8 +45,6 @@ class VllmModel(Protocol[T_co]):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: "AttentionMetadata",
) -> T_co: ) -> T_co:
... ...
...@@ -62,7 +59,7 @@ def _check_vllm_model_forward(model: Union[Type[object], object]) -> bool: ...@@ -62,7 +59,7 @@ def _check_vllm_model_forward(model: Union[Type[object], object]) -> bool:
if not callable(model_forward): if not callable(model_forward):
return False return False
vllm_kws = ("input_ids", "positions", "kv_caches", "attn_metadata") vllm_kws = ("input_ids", "positions")
missing_kws = tuple(kw for kw in vllm_kws missing_kws = tuple(kw for kw in vllm_kws
if not supports_kw(model_forward, kw)) if not supports_kw(model_forward, kw))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment