Unverified Commit cdc1fa12 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Remove unused kwargs from model definitions (#13555)

parent f61528d4
...@@ -74,8 +74,6 @@ def forward( ...@@ -74,8 +74,6 @@ def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
... ...
``` ```
......
...@@ -16,8 +16,6 @@ Further update the model as follows: ...@@ -16,8 +16,6 @@ Further update the model as follows:
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
+ pixel_values: torch.Tensor, + pixel_values: torch.Tensor,
) -> SamplerOutput: ) -> SamplerOutput:
``` ```
......
...@@ -644,11 +644,7 @@ def _run_encoder_attention_test( ...@@ -644,11 +644,7 @@ def _run_encoder_attention_test(
# is shaped as [num_tokens, hidden_size] and we can skip the reshape. # is shaped as [num_tokens, hidden_size] and we can skip the reshape.
reshaped_query = packed_qkv.query.view( reshaped_query = packed_qkv.query.view(
-1, test_pt.num_heads * test_pt.head_size) -1, test_pt.num_heads * test_pt.head_size)
return attn.forward( return attn.forward(reshaped_query, packed_qkv.key, packed_qkv.value)
reshaped_query, packed_qkv.key, packed_qkv.value,
torch.tensor([],
dtype=torch.float32,
device=packed_qkv.query.device), attn_metadata)
def _run_decoder_self_attention_test( def _run_decoder_self_attention_test(
...@@ -682,7 +678,6 @@ def _run_decoder_self_attention_test( ...@@ -682,7 +678,6 @@ def _run_decoder_self_attention_test(
& attn_metadata & attn_metadata
''' '''
attn = test_rsrcs.attn attn = test_rsrcs.attn
kv_cache = test_rsrcs.kv_cache
packed_qkv = decoder_test_params.packed_qkvo.packed_qkv packed_qkv = decoder_test_params.packed_qkvo.packed_qkv
assert packed_qkv is not None assert packed_qkv is not None
with set_forward_context(attn_metadata, vllm_config): with set_forward_context(attn_metadata, vllm_config):
...@@ -695,8 +690,7 @@ def _run_decoder_self_attention_test( ...@@ -695,8 +690,7 @@ def _run_decoder_self_attention_test(
# is shaped as [num_tokens, hidden_size] and we can skip the reshape. # is shaped as [num_tokens, hidden_size] and we can skip the reshape.
reshaped_query = packed_qkv.query.view( reshaped_query = packed_qkv.query.view(
-1, test_pt.num_heads * test_pt.head_size) -1, test_pt.num_heads * test_pt.head_size)
return attn.forward(reshaped_query, packed_qkv.key, packed_qkv.value, return attn.forward(reshaped_query, packed_qkv.key, packed_qkv.value)
kv_cache, attn_metadata)
def _run_encoder_decoder_cross_attention_test( def _run_encoder_decoder_cross_attention_test(
...@@ -744,7 +738,6 @@ def _run_encoder_decoder_cross_attention_test( ...@@ -744,7 +738,6 @@ def _run_encoder_decoder_cross_attention_test(
assert decoder_test_params.packed_qkvo.packed_qkv is not None assert decoder_test_params.packed_qkvo.packed_qkv is not None
attn = test_rsrcs.attn attn = test_rsrcs.attn
kv_cache = test_rsrcs.kv_cache
if cross_test_params is None: if cross_test_params is None:
key = None key = None
value = None value = None
...@@ -762,8 +755,7 @@ def _run_encoder_decoder_cross_attention_test( ...@@ -762,8 +755,7 @@ def _run_encoder_decoder_cross_attention_test(
# is shaped as [num_tokens, hidden_size] and we can skip the reshape. # is shaped as [num_tokens, hidden_size] and we can skip the reshape.
reshaped_query = decoder_test_params.packed_qkvo.packed_qkv.query.view( reshaped_query = decoder_test_params.packed_qkvo.packed_qkv.query.view(
-1, test_pt.num_heads * test_pt.head_size) -1, test_pt.num_heads * test_pt.head_size)
return attn.forward(reshaped_query, key, value, kv_cache, return attn.forward(reshaped_query, key, value)
attn_metadata)
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
......
...@@ -7,7 +7,7 @@ import torch.nn as nn ...@@ -7,7 +7,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import vllm.envs as envs import vllm.envs as envs
from vllm.attention import AttentionMetadata, AttentionType from vllm.attention import AttentionType
from vllm.attention.selector import backend_name_to_enum, get_attn_backend from vllm.attention.selector import backend_name_to_enum, get_attn_backend
from vllm.config import CacheConfig, get_current_vllm_config from vllm.config import CacheConfig, get_current_vllm_config
from vllm.forward_context import ForwardContext, get_forward_context from vllm.forward_context import ForwardContext, get_forward_context
...@@ -153,15 +153,10 @@ class Attention(nn.Module): ...@@ -153,15 +153,10 @@ class Attention(nn.Module):
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
# NOTE: please avoid accessing `kv_cache` and `attn_metadata` arguments
# directly, use `self.kv_cache` and
# `get_forward_context().attn_metadata` instead.
if self.calculate_kv_scales: if self.calculate_kv_scales:
ctx_attn_metadata = get_forward_context().attn_metadata attn_metadata = get_forward_context().attn_metadata
if ctx_attn_metadata.enable_kv_scales_calculation: if attn_metadata.enable_kv_scales_calculation:
self.calc_kv_scales(key, value) self.calc_kv_scales(key, value)
if self.use_output: if self.use_output:
output = torch.empty_like(query) output = torch.empty_like(query)
...@@ -177,14 +172,14 @@ class Attention(nn.Module): ...@@ -177,14 +172,14 @@ class Attention(nn.Module):
value = value.view(-1, self.num_kv_heads, self.head_size) value = value.view(-1, self.num_kv_heads, self.head_size)
if self.use_direct_call: if self.use_direct_call:
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
ctx_attn_metadata = forward_context.attn_metadata attn_metadata = forward_context.attn_metadata
self_kv_cache = self.kv_cache[forward_context.virtual_engine] self_kv_cache = self.kv_cache[forward_context.virtual_engine]
self.impl.forward(self, self.impl.forward(self,
query, query,
key, key,
value, value,
self_kv_cache, self_kv_cache,
ctx_attn_metadata, attn_metadata,
output=output) output=output)
else: else:
torch.ops.vllm.unified_attention_with_output( torch.ops.vllm.unified_attention_with_output(
...@@ -193,10 +188,10 @@ class Attention(nn.Module): ...@@ -193,10 +188,10 @@ class Attention(nn.Module):
else: else:
if self.use_direct_call: if self.use_direct_call:
forward_context = get_forward_context() forward_context = get_forward_context()
ctx_attn_metadata = forward_context.attn_metadata attn_metadata = forward_context.attn_metadata
self_kv_cache = self.kv_cache[forward_context.virtual_engine] self_kv_cache = self.kv_cache[forward_context.virtual_engine]
return self.impl.forward(self, query, key, value, return self.impl.forward(self, query, key, value,
self_kv_cache, ctx_attn_metadata) self_kv_cache, attn_metadata)
else: else:
return torch.ops.vllm.unified_attention( return torch.ops.vllm.unified_attention(
query, key, value, self.layer_name) query, key, value, self.layer_name)
......
...@@ -7,6 +7,7 @@ from torch.nn.parameter import Parameter ...@@ -7,6 +7,7 @@ from torch.nn.parameter import Parameter
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
from vllm.distributed.parallel_state import ( from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.forward_context import get_forward_context
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...@@ -130,14 +131,14 @@ class MambaMixer(CustomOp): ...@@ -130,14 +131,14 @@ class MambaMixer(CustomOp):
) if use_rms_norm else None ) if use_rms_norm else None
def forward_native(self, hidden_states: torch.Tensor, def forward_native(self, hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
conv_state: torch.Tensor, ssm_state: torch.Tensor): conv_state: torch.Tensor, ssm_state: torch.Tensor):
pass pass
def forward_cuda(self, hidden_states: torch.Tensor, def forward_cuda(self, hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
mamba_cache_params: MambaCacheParams): mamba_cache_params: MambaCacheParams):
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
# 1. Gated MLP's linear projection # 1. Gated MLP's linear projection
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1) projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
hidden_states, gate = projected_states.chunk(2, dim=-2) hidden_states, gate = projected_states.chunk(2, dim=-2)
......
...@@ -14,6 +14,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank, ...@@ -14,6 +14,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather, tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
from vllm.forward_context import get_forward_context
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear) RowParallelLinear)
...@@ -376,17 +377,16 @@ class MambaMixer2(CustomOp): ...@@ -376,17 +377,16 @@ class MambaMixer2(CustomOp):
eps=rms_norm_eps) eps=rms_norm_eps)
def forward_native(self, hidden_states: torch.Tensor, def forward_native(self, hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
conv_state: torch.Tensor, ssm_state: torch.Tensor): conv_state: torch.Tensor, ssm_state: torch.Tensor):
pass pass
def forward_cuda( def forward_cuda(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
mamba_cache_params: MambaCacheParams, mamba_cache_params: MambaCacheParams,
sequence_idx: Optional[torch.Tensor] = None, sequence_idx: Optional[torch.Tensor] = None,
): ):
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
seq_len, _ = hidden_states.shape seq_len, _ = hidden_states.shape
groups_time_state_size = self.n_groups * self.ssm_state_size groups_time_state_size = self.n_groups * self.ssm_state_size
......
...@@ -160,7 +160,6 @@ def as_classification_model(cls: _T) -> _T: ...@@ -160,7 +160,6 @@ def as_classification_model(cls: _T) -> _T:
return cls return cls
# Lazy import # Lazy import
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.linear import RowParallelLinear from vllm.model_executor.layers.linear import RowParallelLinear
from vllm.model_executor.layers.pooler import PoolingType from vllm.model_executor.layers.pooler import PoolingType
...@@ -201,13 +200,10 @@ def as_classification_model(cls: _T) -> _T: ...@@ -201,13 +200,10 @@ def as_classification_model(cls: _T) -> _T:
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: list[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = super().forward(input_ids, positions, kv_caches, hidden_states = super().forward(input_ids, positions,
attn_metadata,
intermediate_tensors, intermediate_tensors,
inputs_embeds) inputs_embeds)
logits, _ = self.score(hidden_states) logits, _ = self.score(hidden_states)
......
...@@ -5,7 +5,7 @@ from typing import Iterable, List, Optional, Set, Tuple, Union ...@@ -5,7 +5,7 @@ from typing import Iterable, List, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
...@@ -283,13 +283,11 @@ class ArcticAttention(nn.Module): ...@@ -283,13 +283,11 @@ class ArcticAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
...@@ -336,16 +334,12 @@ class ArcticDecoderLayer(nn.Module): ...@@ -336,16 +334,12 @@ class ArcticDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
residual_input = hidden_states residual_input = hidden_states
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
hidden_states = residual_input + hidden_states hidden_states = residual_input + hidden_states
...@@ -400,8 +394,6 @@ class ArcticModel(nn.Module): ...@@ -400,8 +394,6 @@ class ArcticModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
...@@ -413,11 +405,8 @@ class ArcticModel(nn.Module): ...@@ -413,11 +405,8 @@ class ArcticModel(nn.Module):
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer): for layer in self.layers[self.start_layer:self.end_layer]:
layer = self.layers[i] hidden_states = layer(positions, hidden_states)
hidden_states = layer(positions, hidden_states,
kv_caches[i - self.start_layer],
attn_metadata)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states}) return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
...@@ -458,13 +447,10 @@ class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant): ...@@ -458,13 +447,10 @@ class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, intermediate_tensors,
attn_metadata, intermediate_tensors,
inputs_embeds) inputs_embeds)
return hidden_states return hidden_states
......
...@@ -9,7 +9,6 @@ from transformers import AriaConfig, AriaTextConfig, BatchFeature ...@@ -9,7 +9,6 @@ from transformers import AriaConfig, AriaTextConfig, BatchFeature
from transformers.models.aria.modeling_aria import AriaCrossAttention from transformers.models.aria.modeling_aria import AriaCrossAttention
from transformers.models.aria.processing_aria import AriaProcessor from transformers.models.aria.processing_aria import AriaProcessor
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, QuantizationConfig, VllmConfig from vllm.config import CacheConfig, QuantizationConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_rank from vllm.distributed import get_tensor_model_parallel_rank
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
...@@ -626,8 +625,6 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -626,8 +625,6 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object, **kwargs: object,
...@@ -643,8 +640,6 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -643,8 +640,6 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
hidden_states = self.language_model( hidden_states = self.language_model(
input_ids, input_ids,
positions, positions,
kv_caches,
attn_metadata,
intermediate_tensors, intermediate_tensors,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
) )
......
...@@ -20,13 +20,13 @@ ...@@ -20,13 +20,13 @@
# limitations under the License. # limitations under the License.
"""Inference-only BaiChuan model compatible with HuggingFace weights.""" """Inference-only BaiChuan model compatible with HuggingFace weights."""
import math import math
from typing import Iterable, List, Optional, Set, Tuple, Union from typing import Iterable, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
...@@ -182,14 +182,12 @@ class BaiChuanAttention(nn.Module): ...@@ -182,14 +182,12 @@ class BaiChuanAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.W_pack(hidden_states) qkv, _ = self.W_pack(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
if self.postion_embedding != "ALIBI": if self.postion_embedding != "ALIBI":
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
...@@ -232,8 +230,6 @@ class BaiChuanDecoderLayer(nn.Module): ...@@ -232,8 +230,6 @@ class BaiChuanDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
...@@ -246,8 +242,6 @@ class BaiChuanDecoderLayer(nn.Module): ...@@ -246,8 +242,6 @@ class BaiChuanDecoderLayer(nn.Module):
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
# Fully Connected # Fully Connected
...@@ -301,8 +295,6 @@ class BaiChuanModel(nn.Module): ...@@ -301,8 +295,6 @@ class BaiChuanModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
...@@ -316,13 +308,10 @@ class BaiChuanModel(nn.Module): ...@@ -316,13 +308,10 @@ class BaiChuanModel(nn.Module):
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer): for layer in self.layers[self.start_layer:self.end_layer]:
layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
residual, residual,
) )
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
...@@ -379,13 +368,10 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP, ...@@ -379,13 +368,10 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, intermediate_tensors,
attn_metadata, intermediate_tensors,
inputs_embeds) inputs_embeds)
return hidden_states return hidden_states
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""Inference-only Bamba model.""" """Inference-only Bamba model."""
# Added by the IBM Team, 2024 # Added by the IBM Team, 2024
from typing import Iterable, List, Optional, Set, Tuple from typing import Iterable, Optional, Set, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers import BambaConfig from transformers import BambaConfig
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import get_pp_group from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
...@@ -107,7 +107,6 @@ class BambaMixerDecoderLayer(nn.Module): ...@@ -107,7 +107,6 @@ class BambaMixerDecoderLayer(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
mamba_cache_params: MambaCacheParams, mamba_cache_params: MambaCacheParams,
sequence_idx: Optional[torch.Tensor] = None, sequence_idx: Optional[torch.Tensor] = None,
...@@ -120,8 +119,8 @@ class BambaMixerDecoderLayer(nn.Module): ...@@ -120,8 +119,8 @@ class BambaMixerDecoderLayer(nn.Module):
hidden_states, residual = self.input_layernorm( hidden_states, residual = self.input_layernorm(
hidden_states, residual) hidden_states, residual)
hidden_states = self.mamba(hidden_states, attn_metadata, hidden_states = self.mamba(hidden_states, mamba_cache_params,
mamba_cache_params, sequence_idx) sequence_idx)
# Fully Connected # Fully Connected
hidden_states, residual = self.pre_ff_layernorm( hidden_states, residual = self.pre_ff_layernorm(
hidden_states, residual) hidden_states, residual)
...@@ -215,15 +214,13 @@ class BambaAttentionDecoderLayer(nn.Module): ...@@ -215,15 +214,13 @@ class BambaAttentionDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
...@@ -231,8 +228,6 @@ class BambaAttentionDecoderLayer(nn.Module): ...@@ -231,8 +228,6 @@ class BambaAttentionDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
**kwargs, **kwargs,
): ):
...@@ -246,8 +241,6 @@ class BambaAttentionDecoderLayer(nn.Module): ...@@ -246,8 +241,6 @@ class BambaAttentionDecoderLayer(nn.Module):
hidden_states = self.self_attention( hidden_states = self.self_attention(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
# Fully Connected # Fully Connected
hidden_states, residual = self.pre_ff_layernorm( hidden_states, residual = self.pre_ff_layernorm(
...@@ -312,8 +305,6 @@ class BambaModel(nn.Module): ...@@ -312,8 +305,6 @@ class BambaModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
mamba_cache_params: MambaCacheParams, mamba_cache_params: MambaCacheParams,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
...@@ -323,6 +314,7 @@ class BambaModel(nn.Module): ...@@ -323,6 +314,7 @@ class BambaModel(nn.Module):
# proper continuous batching computation including # proper continuous batching computation including
# chunked prefill # chunked prefill
seq_idx = None seq_idx = None
attn_metadata = get_forward_context().attn_metadata
if attn_metadata.num_prefills > 0: if attn_metadata.num_prefills > 0:
seq_idx = torch.zeros_like(input_ids, dtype=torch.int32) seq_idx = torch.zeros_like(input_ids, dtype=torch.int32)
for i, (srt, end) in enumerate( for i, (srt, end) in enumerate(
...@@ -348,9 +340,7 @@ class BambaModel(nn.Module): ...@@ -348,9 +340,7 @@ class BambaModel(nn.Module):
num_attn = 0 num_attn = 0
for i in range(len(self.layers)): for i in range(len(self.layers)):
layer = self.layers[i] layer = self.layers[i]
kv_cache = None
if isinstance(layer, BambaAttentionDecoderLayer): if isinstance(layer, BambaAttentionDecoderLayer):
kv_cache = kv_caches[num_attn]
num_attn += 1 num_attn += 1
layer_mamba_cache_params = None layer_mamba_cache_params = None
...@@ -361,8 +351,6 @@ class BambaModel(nn.Module): ...@@ -361,8 +351,6 @@ class BambaModel(nn.Module):
hidden_states, residual = layer( hidden_states, residual = layer(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
residual=residual, residual=residual,
mamba_cache_params=layer_mamba_cache_params, mamba_cache_params=layer_mamba_cache_params,
sequence_idx=seq_idx, sequence_idx=seq_idx,
...@@ -440,8 +428,6 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, ...@@ -440,8 +428,6 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
def forward(self, def forward(self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs): **kwargs):
...@@ -454,8 +440,7 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, ...@@ -454,8 +440,7 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers, self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers,
*self._get_mamba_cache_shape()) *self._get_mamba_cache_shape())
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, mamba_cache_params,
attn_metadata, mamba_cache_params,
intermediate_tensors, inputs_embeds) intermediate_tensors, inputs_embeds)
return hidden_states return hidden_states
......
...@@ -19,14 +19,14 @@ ...@@ -19,14 +19,14 @@
# limitations under the License. # limitations under the License.
"""PyTorch BART model.""" """PyTorch BART model."""
import math import math
from typing import Iterable, List, Optional, Tuple from typing import Iterable, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers import BartConfig from transformers import BartConfig
from transformers.utils import logging from transformers.utils import logging
from vllm.attention import Attention, AttentionMetadata, AttentionType from vllm.attention import Attention, AttentionType
from vllm.config import CacheConfig, LoRAConfig, VllmConfig from vllm.config import CacheConfig, LoRAConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
...@@ -181,14 +181,13 @@ class BartEncoderAttention(nn.Module): ...@@ -181,14 +181,13 @@ class BartEncoderAttention(nn.Module):
prefix=f"{prefix}.attn", prefix=f"{prefix}.attn",
attn_type=AttentionType.ENCODER) attn_type=AttentionType.ENCODER)
def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor, def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
attn_metadata: AttentionMetadata) -> torch.Tensor:
"""Input shape: Batch x Time x Channel""" """Input shape: Batch x Time x Channel"""
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
output, _ = self.out_proj(attn_output) output, _ = self.out_proj(attn_output)
return output return output
...@@ -261,14 +260,13 @@ class BartDecoderSelfAttention(nn.Module): ...@@ -261,14 +260,13 @@ class BartDecoderSelfAttention(nn.Module):
prefix=f"{prefix}.attn", prefix=f"{prefix}.attn",
attn_type=AttentionType.DECODER) attn_type=AttentionType.DECODER)
def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor, def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
attn_metadata: AttentionMetadata) -> torch.Tensor:
"""Input shape: Batch x Time x Channel""" """Input shape: Batch x Time x Channel"""
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
output, _ = self.out_proj(attn_output) output, _ = self.out_proj(attn_output)
return output return output
...@@ -344,8 +342,6 @@ class BartCrossAttention(nn.Module): ...@@ -344,8 +342,6 @@ class BartCrossAttention(nn.Module):
def forward( def forward(
self, self,
decoder_hidden_states: torch.Tensor, decoder_hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
encoder_hidden_states: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Input shape: Batch x Time x Channel""" """Input shape: Batch x Time x Channel"""
...@@ -363,7 +359,7 @@ class BartCrossAttention(nn.Module): ...@@ -363,7 +359,7 @@ class BartCrossAttention(nn.Module):
_, k, v = qkv_enc.split([self.q_size, self.kv_size, self.kv_size], _, k, v = qkv_enc.split([self.q_size, self.kv_size, self.kv_size],
dim=-1) dim=-1)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
output, _ = self.out_proj(attn_output) output, _ = self.out_proj(attn_output)
return output return output
...@@ -411,23 +407,16 @@ class BartEncoderLayer(nn.Module): ...@@ -411,23 +407,16 @@ class BartEncoderLayer(nn.Module):
self.final_layer_norm = nn.LayerNorm(self.embed_dim) self.final_layer_norm = nn.LayerNorm(self.embed_dim)
def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor, def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
attn_metadata: AttentionMetadata) -> torch.Tensor:
r""" r"""
Args: Args:
hidden_states hidden_states
torch.Tensor of *encoder* input embeddings. torch.Tensor of *encoder* input embeddings.
kv_cache:
Layer-wise list of KV cache tensors
attn_metadata:
vLLM Attention metadata structure
Returns: Returns:
Encoder layer output torch.Tensor Encoder layer output torch.Tensor
""" """
residual = hidden_states residual = hidden_states
hidden_states = self.self_attn(hidden_states=hidden_states, hidden_states = self.self_attn(hidden_states=hidden_states)
kv_cache=kv_cache,
attn_metadata=attn_metadata)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states = self.self_attn_layer_norm(hidden_states)
...@@ -509,18 +498,12 @@ class BartDecoderLayer(nn.Module): ...@@ -509,18 +498,12 @@ class BartDecoderLayer(nn.Module):
def forward( def forward(
self, self,
decoder_hidden_states: torch.Tensor, decoder_hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
encoder_hidden_states: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
r""" r"""
Args: Args:
decoder_hidden_states decoder_hidden_states
torch.Tensor of *decoder* input embeddings. torch.Tensor of *decoder* input embeddings.
kv_cache:
KV cache tensor
attn_metadata:
vLLM Attention metadata structure
encoder_hidden_states encoder_hidden_states
torch.Tensor of *encoder* input embeddings. torch.Tensor of *encoder* input embeddings.
Returns: Returns:
...@@ -529,9 +512,7 @@ class BartDecoderLayer(nn.Module): ...@@ -529,9 +512,7 @@ class BartDecoderLayer(nn.Module):
residual = decoder_hidden_states residual = decoder_hidden_states
# Self Attention # Self Attention
hidden_states = self.self_attn(hidden_states=decoder_hidden_states, hidden_states = self.self_attn(hidden_states=decoder_hidden_states)
kv_cache=kv_cache,
attn_metadata=attn_metadata)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states = self.self_attn_layer_norm(hidden_states)
...@@ -542,8 +523,6 @@ class BartDecoderLayer(nn.Module): ...@@ -542,8 +523,6 @@ class BartDecoderLayer(nn.Module):
hidden_states = self.encoder_attn( hidden_states = self.encoder_attn(
decoder_hidden_states=hidden_states, decoder_hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
) )
...@@ -609,9 +588,8 @@ class BartEncoder(nn.Module): ...@@ -609,9 +588,8 @@ class BartEncoder(nn.Module):
self.layernorm_embedding = nn.LayerNorm(embed_dim) self.layernorm_embedding = nn.LayerNorm(embed_dim)
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, def forward(self, input_ids: torch.Tensor,
kv_caches: List[torch.Tensor], positions: torch.Tensor) -> torch.Tensor:
attn_metadata: AttentionMetadata) -> torch.Tensor:
r""" r"""
Args: Args:
input_ids input_ids
...@@ -620,10 +598,6 @@ class BartEncoder(nn.Module): ...@@ -620,10 +598,6 @@ class BartEncoder(nn.Module):
provide it. provide it.
positions positions
Positions of *encoder* input sequence tokens. Positions of *encoder* input sequence tokens.
kv_caches:
Layer-wise list of KV cache tensors
attn_metadata:
vLLM Attention metadata structure
Returns: Returns:
Decoder output torch.Tensor Decoder output torch.Tensor
""" """
...@@ -636,12 +610,8 @@ class BartEncoder(nn.Module): ...@@ -636,12 +610,8 @@ class BartEncoder(nn.Module):
hidden_states = inputs_embeds + embed_pos hidden_states = inputs_embeds + embed_pos
hidden_states = self.layernorm_embedding(hidden_states) hidden_states = self.layernorm_embedding(hidden_states)
for idx, encoder_layer in enumerate(self.layers): for encoder_layer in self.layers:
hidden_states = encoder_layer( hidden_states = encoder_layer(hidden_states=hidden_states)
hidden_states=hidden_states,
kv_cache=kv_caches[idx],
attn_metadata=attn_metadata,
)
return hidden_states return hidden_states
...@@ -693,9 +663,7 @@ class BartDecoder(nn.Module): ...@@ -693,9 +663,7 @@ class BartDecoder(nn.Module):
def forward(self, decoder_input_ids: torch.Tensor, def forward(self, decoder_input_ids: torch.Tensor,
decoder_positions: torch.Tensor, decoder_positions: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor], encoder_hidden_states: Optional[torch.Tensor]) -> torch.Tensor:
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata) -> torch.Tensor:
r""" r"""
Args: Args:
decoder_input_ids decoder_input_ids
...@@ -706,10 +674,6 @@ class BartDecoder(nn.Module): ...@@ -706,10 +674,6 @@ class BartDecoder(nn.Module):
Positions of *decoder* input sequence tokens. Positions of *decoder* input sequence tokens.
encoder_hidden_states: encoder_hidden_states:
Tensor of encoder output embeddings Tensor of encoder output embeddings
kv_caches:
Layer-wise list of KV cache tensors
attn_metadata:
vLLM Attention metadata structure
Returns: Returns:
Decoder output torch.Tensor Decoder output torch.Tensor
""" """
...@@ -725,11 +689,9 @@ class BartDecoder(nn.Module): ...@@ -725,11 +689,9 @@ class BartDecoder(nn.Module):
# decoder layers # decoder layers
for idx, decoder_layer in enumerate(self.layers): for decoder_layer in self.layers:
hidden_states = decoder_layer( hidden_states = decoder_layer(
decoder_hidden_states=hidden_states, decoder_hidden_states=hidden_states,
kv_cache=kv_caches[idx],
attn_metadata=attn_metadata,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
) )
...@@ -768,8 +730,7 @@ class BartModel(nn.Module): ...@@ -768,8 +730,7 @@ class BartModel(nn.Module):
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
encoder_input_ids: torch.Tensor, encoder_input_ids: torch.Tensor,
encoder_positions: torch.Tensor, kv_caches: List[torch.Tensor], encoder_positions: torch.Tensor) -> torch.Tensor:
attn_metadata: AttentionMetadata) -> torch.Tensor:
r""" r"""
Args: Args:
input_ids input_ids
...@@ -782,10 +743,6 @@ class BartModel(nn.Module): ...@@ -782,10 +743,6 @@ class BartModel(nn.Module):
Indices of *encoder* input sequence tokens in the vocabulary. Indices of *encoder* input sequence tokens in the vocabulary.
encoder_positions: encoder_positions:
Positions of *encoder* input sequence tokens. Positions of *encoder* input sequence tokens.
kv_caches:
Layer-wise list of KV cache tensors
attn_metadata:
vLLM Attention metadata structure
Returns: Returns:
Model output torch.Tensor Model output torch.Tensor
""" """
...@@ -796,18 +753,14 @@ class BartModel(nn.Module): ...@@ -796,18 +753,14 @@ class BartModel(nn.Module):
# Run encoder attention if a non-zero number of encoder tokens # Run encoder attention if a non-zero number of encoder tokens
# are provided as input # are provided as input
encoder_hidden_states = self.encoder(input_ids=encoder_input_ids, encoder_hidden_states = self.encoder(input_ids=encoder_input_ids,
positions=encoder_positions, positions=encoder_positions)
kv_caches=kv_caches,
attn_metadata=attn_metadata)
# decoder outputs consists of # decoder outputs consists of
# (dec_features, past_key_value, dec_hidden, dec_attn) # (dec_features, past_key_value, dec_hidden, dec_attn)
decoder_outputs = self.decoder( decoder_outputs = self.decoder(
decoder_input_ids=input_ids, decoder_input_ids=input_ids,
decoder_positions=positions, decoder_positions=positions,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states)
kv_caches=kv_caches,
attn_metadata=attn_metadata)
return decoder_outputs return decoder_outputs
...@@ -845,8 +798,6 @@ class BartForConditionalGeneration(nn.Module): ...@@ -845,8 +798,6 @@ class BartForConditionalGeneration(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
*, *,
encoder_input_ids: torch.Tensor, encoder_input_ids: torch.Tensor,
...@@ -863,15 +814,11 @@ class BartForConditionalGeneration(nn.Module): ...@@ -863,15 +814,11 @@ class BartForConditionalGeneration(nn.Module):
torch.Tensor of *encoder* input token ids. torch.Tensor of *encoder* input token ids.
encoder_positions encoder_positions
torch.Tensor of *encoder* position indices torch.Tensor of *encoder* position indices
kv_caches:
Layer-wise list of KV cache tensors
attn_metadata:
vLLM Attention metadata structure
Returns: Returns:
Output torch.Tensor Output torch.Tensor
""" """
return self.model(input_ids, positions, encoder_input_ids, return self.model(input_ids, positions, encoder_input_ids,
encoder_positions, kv_caches, attn_metadata) encoder_positions)
def compute_logits( def compute_logits(
self, self,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Iterable, List, Optional, Set, Tuple from typing import Iterable, Optional, Set, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers import BertConfig from transformers import BertConfig
from vllm.attention import Attention, AttentionMetadata, AttentionType from vllm.attention import Attention, AttentionType
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, PoolerConfig, VllmConfig from vllm.config import CacheConfig, PoolerConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
...@@ -113,12 +114,9 @@ class BertEncoder(nn.Module): ...@@ -113,12 +114,9 @@ class BertEncoder(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
for i in range(len(self.layer)): for layer in self.layer:
layer = self.layer[i] hidden_states = layer(hidden_states)
hidden_states = layer(hidden_states, kv_caches[i], attn_metadata)
return hidden_states return hidden_states
...@@ -152,13 +150,8 @@ class BertLayer(nn.Module): ...@@ -152,13 +150,8 @@ class BertLayer(nn.Module):
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.output") prefix=f"{prefix}.output")
def forward( def forward(self, hidden_states: torch.Tensor):
self, attn_output = self.attention(hidden_states)
hidden_states: torch.Tensor,
kv_cache: Optional[torch.Tensor],
attn_metadata: AttentionMetadata,
):
attn_output = self.attention(hidden_states, kv_cache, attn_metadata)
intermediate_output = self.intermediate(attn_output) intermediate_output = self.intermediate(attn_output)
output = self.output(intermediate_output, attn_output) output = self.output(intermediate_output, attn_output)
return output return output
...@@ -191,10 +184,8 @@ class BertAttention(nn.Module): ...@@ -191,10 +184,8 @@ class BertAttention(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
self_output = self.self(hidden_states, kv_cache, attn_metadata) self_output = self.self(hidden_states)
return self.output(self_output, hidden_states) return self.output(self_output, hidden_states)
...@@ -246,12 +237,10 @@ class BertSelfAttention(nn.Module): ...@@ -246,12 +237,10 @@ class BertSelfAttention(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
output = self.attn(q, k, v, kv_cache, attn_metadata) output = self.attn(q, k, v)
return output return output
...@@ -343,8 +332,6 @@ class BertModel(nn.Module): ...@@ -343,8 +332,6 @@ class BertModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None,
...@@ -352,13 +339,14 @@ class BertModel(nn.Module): ...@@ -352,13 +339,14 @@ class BertModel(nn.Module):
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
attn_metadata = get_forward_context().attn_metadata
assert hasattr(attn_metadata, "seq_lens_tensor") assert hasattr(attn_metadata, "seq_lens_tensor")
hidden_states = self.embeddings( hidden_states = self.embeddings(
input_ids=input_ids, input_ids=input_ids,
seq_lens=attn_metadata.seq_lens_tensor, seq_lens=attn_metadata.seq_lens_tensor,
position_ids=position_ids, position_ids=position_ids,
token_type_ids=token_type_ids) token_type_ids=token_type_ids)
return self.encoder(hidden_states, kv_caches, attn_metadata) return self.encoder(hidden_states)
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
...@@ -420,17 +408,13 @@ class BertEmbeddingModel(nn.Module): ...@@ -420,17 +408,13 @@ class BertEmbeddingModel(nn.Module):
self, self,
input_ids: Optional[torch.Tensor], input_ids: Optional[torch.Tensor],
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
return self.model(input_ids=input_ids, return self.model(input_ids=input_ids,
position_ids=positions, position_ids=positions,
kv_caches=kv_caches,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors)
attn_metadata=attn_metadata)
def pooler( def pooler(
self, self,
...@@ -519,16 +503,12 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding): ...@@ -519,16 +503,12 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding):
self, self,
input_ids: Optional[torch.Tensor], input_ids: Optional[torch.Tensor],
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
return self.bert(input_ids=input_ids, return self.bert(input_ids=input_ids,
position_ids=positions, position_ids=positions,
kv_caches=kv_caches,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
attn_metadata=attn_metadata,
token_type_ids=token_type_ids) token_type_ids=token_type_ids)
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from functools import cached_property from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, from typing import (Iterable, Literal, Mapping, Optional, Set, Tuple,
TypedDict, Union) TypedDict, Union)
import torch import torch
...@@ -9,7 +9,6 @@ import torch.nn as nn ...@@ -9,7 +9,6 @@ import torch.nn as nn
from transformers import (BatchFeature, Blip2Config, Blip2QFormerConfig, from transformers import (BatchFeature, Blip2Config, Blip2QFormerConfig,
apply_chunking_to_forward) apply_chunking_to_forward)
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
...@@ -658,8 +657,6 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -658,8 +657,6 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object, **kwargs: object,
...@@ -708,8 +705,6 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -708,8 +705,6 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
hidden_states = self.language_model.model(input_ids, hidden_states = self.language_model.model(input_ids,
positions, positions,
kv_caches,
attn_metadata,
intermediate_tensors, intermediate_tensors,
inputs_embeds=inputs_embeds) inputs_embeds=inputs_embeds)
......
...@@ -18,13 +18,13 @@ ...@@ -18,13 +18,13 @@
# limitations under the License. # limitations under the License.
"""Inference-only BLOOM model compatible with HuggingFace weights.""" """Inference-only BLOOM model compatible with HuggingFace weights."""
import math import math
from typing import Iterable, List, Optional, Set, Tuple, Union from typing import Iterable, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from transformers import BloomConfig from transformers import BloomConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
...@@ -126,13 +126,11 @@ class BloomAttention(nn.Module): ...@@ -126,13 +126,11 @@ class BloomAttention(nn.Module):
self, self,
position_ids: torch.Tensor, position_ids: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
del position_ids # Unused. del position_ids # Unused.
qkv, _ = self.query_key_value(hidden_states) qkv, _ = self.query_key_value(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
output, _ = self.dense(attn_output) output, _ = self.dense(attn_output)
return output return output
...@@ -193,8 +191,6 @@ class BloomBlock(nn.Module): ...@@ -193,8 +191,6 @@ class BloomBlock(nn.Module):
self, self,
position_ids: torch.Tensor, position_ids: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
# Layer norm at the beginning of the transformer layer. # Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states) layernorm_output = self.input_layernorm(hidden_states)
...@@ -209,8 +205,6 @@ class BloomBlock(nn.Module): ...@@ -209,8 +205,6 @@ class BloomBlock(nn.Module):
attention_output = self.self_attention( attention_output = self.self_attention(
position_ids=position_ids, position_ids=position_ids,
hidden_states=layernorm_output, hidden_states=layernorm_output,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
attention_output = attention_output + residual attention_output = attention_output + residual
layernorm_output = self.post_attention_layernorm(attention_output) layernorm_output = self.post_attention_layernorm(attention_output)
...@@ -266,8 +260,6 @@ class BloomModel(nn.Module): ...@@ -266,8 +260,6 @@ class BloomModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
...@@ -279,14 +271,8 @@ class BloomModel(nn.Module): ...@@ -279,14 +271,8 @@ class BloomModel(nn.Module):
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer): for layer in self.h[self.start_layer:self.end_layer]:
layer = self.h[i] hidden_states = layer(position_ids, hidden_states)
hidden_states = layer(
position_ids,
hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states}) return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
...@@ -322,14 +308,11 @@ class BloomForCausalLM(nn.Module, SupportsPP): ...@@ -322,14 +308,11 @@ class BloomForCausalLM(nn.Module, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions,
attn_metadata, intermediate_tensors, intermediate_tensors, inputs_embeds)
inputs_embeds)
return hidden_states return hidden_states
def compute_logits( def compute_logits(
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from functools import cached_property from functools import cached_property
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Set, from typing import (Any, Dict, Iterable, Literal, Mapping, Optional, Set,
Tuple, TypedDict, Union) Tuple, TypedDict, Union)
import torch import torch
...@@ -10,7 +10,7 @@ import torch.nn.functional as F ...@@ -10,7 +10,7 @@ import torch.nn.functional as F
from transformers import (BatchFeature, ChameleonConfig, ChameleonProcessor, from transformers import (BatchFeature, ChameleonConfig, ChameleonProcessor,
ChameleonVQVAEConfig) ChameleonVQVAEConfig)
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -310,15 +310,13 @@ class ChameleonAttention(nn.Module): ...@@ -310,15 +310,13 @@ class ChameleonAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self._apply_qk_norm(q, k) q, k = self._apply_qk_norm(q, k)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
...@@ -372,8 +370,6 @@ class ChameleonDecoderLayer(nn.Module): ...@@ -372,8 +370,6 @@ class ChameleonDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
...@@ -386,8 +382,6 @@ class ChameleonDecoderLayer(nn.Module): ...@@ -386,8 +382,6 @@ class ChameleonDecoderLayer(nn.Module):
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
# Fully Connected # Fully Connected
...@@ -447,8 +441,6 @@ class ChameleonSwinDecoderLayer(nn.Module): ...@@ -447,8 +441,6 @@ class ChameleonSwinDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
...@@ -456,8 +448,6 @@ class ChameleonSwinDecoderLayer(nn.Module): ...@@ -456,8 +448,6 @@ class ChameleonSwinDecoderLayer(nn.Module):
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
...@@ -906,8 +896,6 @@ class ChameleonModel(nn.Module): ...@@ -906,8 +896,6 @@ class ChameleonModel(nn.Module):
self, self,
input_ids: Optional[torch.Tensor], input_ids: Optional[torch.Tensor],
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
...@@ -921,13 +909,10 @@ class ChameleonModel(nn.Module): ...@@ -921,13 +909,10 @@ class ChameleonModel(nn.Module):
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer): for layer in self.layers[self.start_layer:self.end_layer]:
layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
residual, residual,
) )
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
...@@ -1028,8 +1013,6 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1028,8 +1013,6 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs, **kwargs,
...@@ -1048,8 +1031,6 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1048,8 +1031,6 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
hidden_states = self.model(input_ids, hidden_states = self.model(input_ids,
positions, positions,
kv_caches,
attn_metadata,
intermediate_tensors, intermediate_tensors,
inputs_embeds=inputs_embeds) inputs_embeds=inputs_embeds)
return hidden_states return hidden_states
......
...@@ -2,13 +2,13 @@ ...@@ -2,13 +2,13 @@
# Adapted from # Adapted from
# https://github.com/THUDM/ChatGLM2-6B # https://github.com/THUDM/ChatGLM2-6B
"""Inference-only ChatGLM model compatible with THUDM weights.""" """Inference-only ChatGLM model compatible with THUDM weights."""
from typing import Iterable, List, Optional, Set, Tuple, Union from typing import Iterable, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from torch.nn import LayerNorm from torch.nn import LayerNorm
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
...@@ -108,19 +108,11 @@ class GLMAttention(nn.Module): ...@@ -108,19 +108,11 @@ class GLMAttention(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.query_key_value(hidden_states) qkv, _ = self.query_key_value(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(position_ids, q, k) q, k = self.rotary_emb(position_ids, q, k)
context_layer = self.attn( context_layer = self.attn(q, k, v)
q,
k,
v,
kv_cache,
attn_metadata,
)
attn_output, _ = self.dense(context_layer) attn_output, _ = self.dense(context_layer)
return attn_output return attn_output
...@@ -215,8 +207,6 @@ class GLMBlock(nn.Module): ...@@ -215,8 +207,6 @@ class GLMBlock(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
# hidden_states: [num_tokens, h] # hidden_states: [num_tokens, h]
# Layer norm at the beginning of the transformer layer. # Layer norm at the beginning of the transformer layer.
...@@ -225,8 +215,6 @@ class GLMBlock(nn.Module): ...@@ -225,8 +215,6 @@ class GLMBlock(nn.Module):
attention_output = self.self_attention( attention_output = self.self_attention(
hidden_states=layernorm_output, hidden_states=layernorm_output,
position_ids=position_ids, position_ids=position_ids,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
# Residual connection. # Residual connection.
...@@ -289,17 +277,10 @@ class GLMTransformer(nn.Module): ...@@ -289,17 +277,10 @@ class GLMTransformer(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
for i in range(self.start_layer, self.end_layer): for layer in self.layers[self.start_layer:self.end_layer]:
layer = self.layers[i] hidden_states = layer(hidden_states=hidden_states,
hidden_states = layer( position_ids=position_ids)
hidden_states=hidden_states,
position_ids=position_ids,
kv_cache=kv_caches[i - self.start_layer],
attn_metadata=attn_metadata,
)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states}) return IntermediateTensors({"hidden_states": hidden_states})
...@@ -350,8 +331,6 @@ class ChatGLMModel(nn.Module): ...@@ -350,8 +331,6 @@ class ChatGLMModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object, **kwargs: object,
...@@ -369,8 +348,6 @@ class ChatGLMModel(nn.Module): ...@@ -369,8 +348,6 @@ class ChatGLMModel(nn.Module):
hidden_states = self.encoder( hidden_states = self.encoder(
hidden_states=hidden_states, hidden_states=hidden_states,
position_ids=positions, position_ids=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
) )
return hidden_states return hidden_states
...@@ -494,12 +471,9 @@ class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP): ...@@ -494,12 +471,9 @@ class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions,
attn_metadata, intermediate_tensors, intermediate_tensors, inputs_embeds)
inputs_embeds)
return hidden_states return hidden_states
...@@ -21,14 +21,14 @@ ...@@ -21,14 +21,14 @@
# This file is based on the LLama model definition file in transformers # This file is based on the LLama model definition file in transformers
"""PyTorch Cohere model.""" """PyTorch Cohere model."""
from typing import Iterable, List, Optional, Set, Tuple, Union from typing import Iterable, Optional, Set, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from transformers import CohereConfig from transformers import CohereConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
...@@ -218,8 +218,6 @@ class CohereAttention(nn.Module): ...@@ -218,8 +218,6 @@ class CohereAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
...@@ -227,7 +225,7 @@ class CohereAttention(nn.Module): ...@@ -227,7 +225,7 @@ class CohereAttention(nn.Module):
q, k = self._apply_qk_norm(q, k) q, k = self._apply_qk_norm(q, k)
if self.v1 or self.sliding_window: if self.v1 or self.sliding_window:
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
...@@ -255,8 +253,6 @@ class CohereDecoderLayer(nn.Module): ...@@ -255,8 +253,6 @@ class CohereDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
...@@ -265,8 +261,6 @@ class CohereDecoderLayer(nn.Module): ...@@ -265,8 +261,6 @@ class CohereDecoderLayer(nn.Module):
hidden_states_attention = self.self_attn( hidden_states_attention = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
hidden_states_mlp = self.mlp(hidden_states) hidden_states_mlp = self.mlp(hidden_states)
# Add everything together # Add everything together
...@@ -311,8 +305,6 @@ class CohereModel(nn.Module): ...@@ -311,8 +305,6 @@ class CohereModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
...@@ -326,13 +318,10 @@ class CohereModel(nn.Module): ...@@ -326,13 +318,10 @@ class CohereModel(nn.Module):
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer): for layer in self.layers[self.start_layer:self.end_layer]:
layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
residual, residual,
) )
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
...@@ -389,13 +378,10 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -389,13 +378,10 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, intermediate_tensors,
attn_metadata, intermediate_tensors,
inputs_embeds) inputs_embeds)
return hidden_states return hidden_states
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Iterable, List, Optional, Set, Tuple, Union from typing import Iterable, Optional, Set, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
...@@ -230,15 +230,13 @@ class DbrxAttention(nn.Module): ...@@ -230,15 +230,13 @@ class DbrxAttention(nn.Module):
self, self,
position_ids: torch.Tensor, position_ids: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.Wqkv(hidden_states) qkv, _ = self.Wqkv(hidden_states)
if self.clip_qkv is not None: if self.clip_qkv is not None:
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(position_ids, q, k) q, k = self.rotary_emb(position_ids, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
hidden_states, _ = self.out_proj(attn_output) hidden_states, _ = self.out_proj(attn_output)
return hidden_states return hidden_states
...@@ -265,16 +263,12 @@ class DbrxFusedNormAttention(nn.Module): ...@@ -265,16 +263,12 @@ class DbrxFusedNormAttention(nn.Module):
self, self,
position_ids: torch.Tensor, position_ids: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
residual = hidden_states residual = hidden_states
hidden_states = self.norm_1(hidden_states) hidden_states = self.norm_1(hidden_states)
x = self.attn( x = self.attn(
position_ids=position_ids, position_ids=position_ids,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
hidden_states = residual + x hidden_states = residual + x
residual = hidden_states residual = hidden_states
...@@ -303,14 +297,10 @@ class DbrxBlock(nn.Module): ...@@ -303,14 +297,10 @@ class DbrxBlock(nn.Module):
self, self,
position_ids: torch.Tensor, position_ids: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states, residual = self.norm_attn_norm( hidden_states, residual = self.norm_attn_norm(
position_ids=position_ids, position_ids=position_ids,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
hidden_states = self.ffn(hidden_states) hidden_states = self.ffn(hidden_states)
hidden_states = hidden_states + residual hidden_states = hidden_states + residual
...@@ -353,8 +343,6 @@ class DbrxModel(nn.Module): ...@@ -353,8 +343,6 @@ class DbrxModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
...@@ -366,14 +354,8 @@ class DbrxModel(nn.Module): ...@@ -366,14 +354,8 @@ class DbrxModel(nn.Module):
else: else:
assert intermediate_tensors assert intermediate_tensors
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer): for block in self.blocks[self.start_layer:self.end_layer]:
block = self.blocks[i] hidden_states = block(position_ids, hidden_states)
hidden_states = block(
position_ids,
hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states}) return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.norm_f(hidden_states) hidden_states = self.norm_f(hidden_states)
...@@ -415,14 +397,11 @@ class DbrxForCausalLM(nn.Module, SupportsPP): ...@@ -415,14 +397,11 @@ class DbrxForCausalLM(nn.Module, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions,
attn_metadata, intermediate_tensors, intermediate_tensors, inputs_embeds)
inputs_embeds)
return hidden_states return hidden_states
def compute_logits( def compute_logits(
......
...@@ -22,13 +22,13 @@ ...@@ -22,13 +22,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only Deepseek model.""" """Inference-only Deepseek model."""
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
...@@ -248,13 +248,11 @@ class DeepseekAttention(nn.Module): ...@@ -248,13 +248,11 @@ class DeepseekAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
...@@ -309,8 +307,6 @@ class DeepseekDecoderLayer(nn.Module): ...@@ -309,8 +307,6 @@ class DeepseekDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
# Self Attention # Self Attention
...@@ -323,8 +319,6 @@ class DeepseekDecoderLayer(nn.Module): ...@@ -323,8 +319,6 @@ class DeepseekDecoderLayer(nn.Module):
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
# Fully Connected # Fully Connected
...@@ -370,8 +364,6 @@ class DeepseekModel(nn.Module): ...@@ -370,8 +364,6 @@ class DeepseekModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
...@@ -384,11 +376,8 @@ class DeepseekModel(nn.Module): ...@@ -384,11 +376,8 @@ class DeepseekModel(nn.Module):
else: else:
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer): for layer in self.layers[self.start_layer:self.end_layer]:
layer = self.layers[i] hidden_states, residual = layer(positions, hidden_states, residual)
hidden_states, residual = layer(positions, hidden_states,
kv_caches[i - self.start_layer],
attn_metadata, residual)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({ return IntermediateTensors({
"hidden_states": hidden_states, "hidden_states": hidden_states,
...@@ -425,13 +414,10 @@ class DeepseekForCausalLM(nn.Module, SupportsPP): ...@@ -425,13 +414,10 @@ class DeepseekForCausalLM(nn.Module, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, intermediate_tensors,
attn_metadata, intermediate_tensors,
inputs_embeds) inputs_embeds)
return hidden_states return hidden_states
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment