Unverified Commit 925f3332 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[Core] Refactor Attention Take 2 (#3462)

parent b0dfa91d
...@@ -19,16 +19,15 @@ ...@@ -19,16 +19,15 @@
"""PyTorch Falcon model.""" """PyTorch Falcon model."""
import math import math
from typing import List, Optional, Tuple, Union from typing import List, Optional, Union
import torch import torch
from torch import nn from torch import nn
from torch.nn import LayerNorm from torch.nn import LayerNorm
from transformers import FalconConfig as HF_FalconConfig from transformers import FalconConfig as HF_FalconConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase, LinearMethodBase,
QKVParallelLinear, QKVParallelLinear,
...@@ -48,7 +47,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader, ...@@ -48,7 +47,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs import RWConfig from vllm.transformers_utils.configs import RWConfig
KVCache = Tuple[torch.Tensor, torch.Tensor]
FalconConfig = Union[HF_FalconConfig, RWConfig] FalconConfig = Union[HF_FalconConfig, RWConfig]
...@@ -177,8 +175,8 @@ class FalconAttention(nn.Module): ...@@ -177,8 +175,8 @@ class FalconAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: torch.Tensor,
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, bias = self.query_key_value(hidden_states) qkv, bias = self.query_key_value(hidden_states)
if bias is not None: if bias is not None:
...@@ -186,8 +184,7 @@ class FalconAttention(nn.Module): ...@@ -186,8 +184,7 @@ class FalconAttention(nn.Module):
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
if self.use_rotary: if self.use_rotary:
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
attn_output, bias = self.dense(attn_output) attn_output, bias = self.dense(attn_output)
return attn_output, bias return attn_output, bias
...@@ -263,8 +260,8 @@ class FalconDecoderLayer(nn.Module): ...@@ -263,8 +260,8 @@ class FalconDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: torch.Tensor,
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
residual = hidden_states residual = hidden_states
...@@ -279,7 +276,7 @@ class FalconDecoderLayer(nn.Module): ...@@ -279,7 +276,7 @@ class FalconDecoderLayer(nn.Module):
positions=positions, positions=positions,
hidden_states=attention_layernorm_out, hidden_states=attention_layernorm_out,
kv_cache=kv_cache, kv_cache=kv_cache,
input_metadata=input_metadata, attn_metadata=attn_metadata,
) )
if self.reduce_row_parallel_results and attention_bias is not None: if self.reduce_row_parallel_results and attention_bias is not None:
attention_output += attention_bias attention_output += attention_bias
...@@ -343,8 +340,8 @@ class FalconModel(nn.Module): ...@@ -343,8 +340,8 @@ class FalconModel(nn.Module):
self, self,
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[torch.Tensor],
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.word_embeddings(input_ids) hidden_states = self.word_embeddings(input_ids)
for i in range(len(self.h)): for i in range(len(self.h)):
...@@ -353,7 +350,7 @@ class FalconModel(nn.Module): ...@@ -353,7 +350,7 @@ class FalconModel(nn.Module):
positions, positions,
hidden_states, hidden_states,
kv_caches[i], kv_caches[i],
input_metadata, attn_metadata,
) )
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
return hidden_states return hidden_states
...@@ -378,14 +375,14 @@ class FalconForCausalLM(nn.Module): ...@@ -378,14 +375,14 @@ class FalconForCausalLM(nn.Module):
self, self,
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[torch.Tensor],
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer( hidden_states = self.transformer(
input_ids, input_ids,
positions, positions,
kv_caches, kv_caches,
input_metadata, attn_metadata,
) )
return hidden_states return hidden_states
......
...@@ -20,10 +20,9 @@ import torch ...@@ -20,10 +20,9 @@ import torch
from torch import nn from torch import nn
from transformers import GemmaConfig from transformers import GemmaConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import LoRAConfig from vllm.config import LoRAConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import GeluAndMul from vllm.model_executor.layers.activation import GeluAndMul
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear, MergedColumnParallelLinear,
...@@ -41,8 +40,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader, ...@@ -41,8 +40,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator) hf_model_weights_iterator)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
class GemmaMLP(nn.Module): class GemmaMLP(nn.Module):
...@@ -133,14 +130,13 @@ class GemmaAttention(nn.Module): ...@@ -133,14 +130,13 @@ class GemmaAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: torch.Tensor,
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
...@@ -177,8 +173,8 @@ class GemmaDecoderLayer(nn.Module): ...@@ -177,8 +173,8 @@ class GemmaDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: torch.Tensor,
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
...@@ -192,7 +188,7 @@ class GemmaDecoderLayer(nn.Module): ...@@ -192,7 +188,7 @@ class GemmaDecoderLayer(nn.Module):
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache, kv_cache=kv_cache,
input_metadata=input_metadata, attn_metadata=attn_metadata,
) )
# Fully Connected # Fully Connected
...@@ -226,8 +222,8 @@ class GemmaModel(nn.Module): ...@@ -226,8 +222,8 @@ class GemmaModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[torch.Tensor],
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
# Normalize the embedding by sqrt(hidden_size) # Normalize the embedding by sqrt(hidden_size)
...@@ -240,7 +236,7 @@ class GemmaModel(nn.Module): ...@@ -240,7 +236,7 @@ class GemmaModel(nn.Module):
positions, positions,
hidden_states, hidden_states,
kv_caches[i], kv_caches[i],
input_metadata, attn_metadata,
residual, residual,
) )
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
...@@ -290,11 +286,11 @@ class GemmaForCausalLM(nn.Module): ...@@ -290,11 +286,11 @@ class GemmaForCausalLM(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[torch.Tensor],
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata) attn_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(self, hidden_states: torch.Tensor,
......
...@@ -17,15 +17,14 @@ ...@@ -17,15 +17,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only GPT-2 model compatible with HuggingFace weights.""" """Inference-only GPT-2 model compatible with HuggingFace weights."""
from typing import List, Optional, Tuple from typing import List, Optional
import torch import torch
from torch import nn from torch import nn
from transformers import GPT2Config from transformers import GPT2Config
from vllm.model_executor.input_metadata import InputMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase, LinearMethodBase,
QKVParallelLinear, QKVParallelLinear,
...@@ -41,8 +40,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader, ...@@ -41,8 +40,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator) hf_model_weights_iterator)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
class GPT2Attention(nn.Module): class GPT2Attention(nn.Module):
...@@ -79,14 +76,12 @@ class GPT2Attention(nn.Module): ...@@ -79,14 +76,12 @@ class GPT2Attention(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: torch.Tensor,
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.c_attn(hidden_states) qkv, _ = self.c_attn(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
key_cache, value_cache = kv_cache attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output = self.attn(q, k, v, key_cache, value_cache,
input_metadata)
attn_output, _ = self.c_proj(attn_output) attn_output, _ = self.c_proj(attn_output)
return attn_output return attn_output
...@@ -144,15 +139,15 @@ class GPT2Block(nn.Module): ...@@ -144,15 +139,15 @@ class GPT2Block(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: torch.Tensor,
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
residual = hidden_states residual = hidden_states
hidden_states = self.ln_1(hidden_states) hidden_states = self.ln_1(hidden_states)
attn_output = self.attn( attn_output = self.attn(
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache, kv_cache=kv_cache,
input_metadata=input_metadata, attn_metadata=attn_metadata,
) )
# residual connection # residual connection
hidden_states = attn_output + residual hidden_states = attn_output + residual
...@@ -190,8 +185,8 @@ class GPT2Model(nn.Module): ...@@ -190,8 +185,8 @@ class GPT2Model(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[torch.Tensor],
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.wte(input_ids) inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids) position_embeds = self.wpe(position_ids)
...@@ -199,7 +194,7 @@ class GPT2Model(nn.Module): ...@@ -199,7 +194,7 @@ class GPT2Model(nn.Module):
for i in range(len(self.h)): for i in range(len(self.h)):
layer = self.h[i] layer = self.h[i]
hidden_states = layer(hidden_states, kv_caches[i], input_metadata) hidden_states = layer(hidden_states, kv_caches[i], attn_metadata)
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
return hidden_states return hidden_states
...@@ -224,11 +219,11 @@ class GPT2LMHeadModel(nn.Module): ...@@ -224,11 +219,11 @@ class GPT2LMHeadModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[torch.Tensor],
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata) attn_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(self, hidden_states: torch.Tensor,
......
...@@ -18,15 +18,14 @@ ...@@ -18,15 +18,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only GPTBigCode model compatible with HuggingFace weights.""" """Inference-only GPTBigCode model compatible with HuggingFace weights."""
from typing import List, Optional, Tuple from typing import List, Optional
import torch import torch
from torch import nn from torch import nn
from transformers import GPTBigCodeConfig from transformers import GPTBigCodeConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase, LinearMethodBase,
QKVParallelLinear, QKVParallelLinear,
...@@ -42,8 +41,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader, ...@@ -42,8 +41,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator) hf_model_weights_iterator)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
class GPTBigCodeAttention(nn.Module): class GPTBigCodeAttention(nn.Module):
...@@ -94,8 +91,8 @@ class GPTBigCodeAttention(nn.Module): ...@@ -94,8 +91,8 @@ class GPTBigCodeAttention(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: torch.Tensor,
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.c_attn(hidden_states) qkv, _ = self.c_attn(hidden_states)
q, k, v = qkv.split( q, k, v = qkv.split(
...@@ -105,9 +102,7 @@ class GPTBigCodeAttention(nn.Module): ...@@ -105,9 +102,7 @@ class GPTBigCodeAttention(nn.Module):
], ],
dim=-1, dim=-1,
) )
key_cache, value_cache = kv_cache attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output = self.attn(q, k, v, key_cache, value_cache,
input_metadata)
attn_output, _ = self.c_proj(attn_output) attn_output, _ = self.c_proj(attn_output)
return attn_output return attn_output
...@@ -165,15 +160,15 @@ class GPTBigCodeBlock(nn.Module): ...@@ -165,15 +160,15 @@ class GPTBigCodeBlock(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: torch.Tensor,
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
residual = hidden_states residual = hidden_states
hidden_states = self.ln_1(hidden_states) hidden_states = self.ln_1(hidden_states)
attn_output = self.attn( attn_output = self.attn(
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache, kv_cache=kv_cache,
input_metadata=input_metadata, attn_metadata=attn_metadata,
) )
# residual connection # residual connection
hidden_states = attn_output + residual hidden_states = attn_output + residual
...@@ -211,8 +206,8 @@ class GPTBigCodeModel(nn.Module): ...@@ -211,8 +206,8 @@ class GPTBigCodeModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[torch.Tensor],
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.wte(input_ids) inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids) position_embeds = self.wpe(position_ids)
...@@ -220,7 +215,7 @@ class GPTBigCodeModel(nn.Module): ...@@ -220,7 +215,7 @@ class GPTBigCodeModel(nn.Module):
for i in range(len(self.h)): for i in range(len(self.h)):
layer = self.h[i] layer = self.h[i]
hidden_states = layer(hidden_states, kv_caches[i], input_metadata) hidden_states = layer(hidden_states, kv_caches[i], attn_metadata)
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
return hidden_states return hidden_states
...@@ -245,11 +240,11 @@ class GPTBigCodeForCausalLM(nn.Module): ...@@ -245,11 +240,11 @@ class GPTBigCodeForCausalLM(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[torch.Tensor],
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata) attn_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(self, hidden_states: torch.Tensor,
......
...@@ -16,15 +16,14 @@ ...@@ -16,15 +16,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only GPT-J model compatible with HuggingFace weights.""" """Inference-only GPT-J model compatible with HuggingFace weights."""
from typing import List, Optional, Tuple from typing import List, Optional
import torch import torch
from torch import nn from torch import nn
from transformers import GPTJConfig from transformers import GPTJConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase, LinearMethodBase,
QKVParallelLinear, QKVParallelLinear,
...@@ -41,8 +40,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader, ...@@ -41,8 +40,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator) hf_model_weights_iterator)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
class GPTJAttention(nn.Module): class GPTJAttention(nn.Module):
...@@ -93,14 +90,13 @@ class GPTJAttention(nn.Module): ...@@ -93,14 +90,13 @@ class GPTJAttention(nn.Module):
self, self,
position_ids: torch.Tensor, position_ids: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: torch.Tensor,
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k = self.rotary_emb(position_ids, q, k) q, k = self.rotary_emb(position_ids, q, k)
k_cache, v_cache = kv_cache attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
attn_output, _ = self.out_proj(attn_output) attn_output, _ = self.out_proj(attn_output)
return attn_output return attn_output
...@@ -154,8 +150,8 @@ class GPTJBlock(nn.Module): ...@@ -154,8 +150,8 @@ class GPTJBlock(nn.Module):
self, self,
position_ids: torch.Tensor, position_ids: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: torch.Tensor,
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
residual = hidden_states residual = hidden_states
hidden_states = self.ln_1(hidden_states) hidden_states = self.ln_1(hidden_states)
...@@ -163,7 +159,7 @@ class GPTJBlock(nn.Module): ...@@ -163,7 +159,7 @@ class GPTJBlock(nn.Module):
position_ids=position_ids, position_ids=position_ids,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache, kv_cache=kv_cache,
input_metadata=input_metadata, attn_metadata=attn_metadata,
) )
mlp_output = self.mlp(hidden_states) mlp_output = self.mlp(hidden_states)
hidden_states = attn_output + mlp_output + residual hidden_states = attn_output + mlp_output + residual
...@@ -192,8 +188,8 @@ class GPTJModel(nn.Module): ...@@ -192,8 +188,8 @@ class GPTJModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[torch.Tensor],
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.wte(input_ids) hidden_states = self.wte(input_ids)
for i in range(len(self.h)): for i in range(len(self.h)):
...@@ -202,7 +198,7 @@ class GPTJModel(nn.Module): ...@@ -202,7 +198,7 @@ class GPTJModel(nn.Module):
position_ids, position_ids,
hidden_states, hidden_states,
kv_caches[i], kv_caches[i],
input_metadata, attn_metadata,
) )
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
return hidden_states return hidden_states
...@@ -232,11 +228,11 @@ class GPTJForCausalLM(nn.Module): ...@@ -232,11 +228,11 @@ class GPTJForCausalLM(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[torch.Tensor],
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata) attn_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(self, hidden_states: torch.Tensor,
......
...@@ -16,15 +16,14 @@ ...@@ -16,15 +16,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only GPT-NeoX model compatible with HuggingFace weights.""" """Inference-only GPT-NeoX model compatible with HuggingFace weights."""
from typing import List, Optional, Tuple from typing import List, Optional
import torch import torch
from torch import nn from torch import nn
from transformers import GPTNeoXConfig from transformers import GPTNeoXConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase, LinearMethodBase,
QKVParallelLinear, QKVParallelLinear,
...@@ -41,8 +40,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader, ...@@ -41,8 +40,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator) hf_model_weights_iterator)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
class GPTNeoXAttention(nn.Module): class GPTNeoXAttention(nn.Module):
...@@ -94,14 +91,13 @@ class GPTNeoXAttention(nn.Module): ...@@ -94,14 +91,13 @@ class GPTNeoXAttention(nn.Module):
self, self,
position_ids: torch.Tensor, position_ids: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: torch.Tensor,
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.query_key_value(hidden_states) qkv, _ = self.query_key_value(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k = self.rotary_emb(position_ids, q, k) q, k = self.rotary_emb(position_ids, q, k)
k_cache, v_cache = kv_cache attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
output, _ = self.dense(attn_output) output, _ = self.dense(attn_output)
return output return output
...@@ -155,15 +151,15 @@ class GPTNeoXLayer(nn.Module): ...@@ -155,15 +151,15 @@ class GPTNeoXLayer(nn.Module):
self, self,
position_ids: torch.Tensor, position_ids: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: torch.Tensor,
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
attn_input = self.input_layernorm(hidden_states) attn_input = self.input_layernorm(hidden_states)
attn_output = self.attention( attn_output = self.attention(
position_ids=position_ids, position_ids=position_ids,
hidden_states=attn_input, hidden_states=attn_input,
kv_cache=kv_cache, kv_cache=kv_cache,
input_metadata=input_metadata, attn_metadata=attn_metadata,
) )
if self.use_parallel_residual: if self.use_parallel_residual:
...@@ -208,8 +204,8 @@ class GPTNeoXModel(nn.Module): ...@@ -208,8 +204,8 @@ class GPTNeoXModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[torch.Tensor],
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_in(input_ids) hidden_states = self.embed_in(input_ids)
for i in range(len(self.layers)): for i in range(len(self.layers)):
...@@ -218,7 +214,7 @@ class GPTNeoXModel(nn.Module): ...@@ -218,7 +214,7 @@ class GPTNeoXModel(nn.Module):
position_ids, position_ids,
hidden_states, hidden_states,
kv_caches[i], kv_caches[i],
input_metadata, attn_metadata,
) )
hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.final_layer_norm(hidden_states)
return hidden_states return hidden_states
...@@ -246,11 +242,11 @@ class GPTNeoXForCausalLM(nn.Module): ...@@ -246,11 +242,11 @@ class GPTNeoXForCausalLM(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[torch.Tensor],
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.gpt_neox(input_ids, positions, kv_caches, hidden_states = self.gpt_neox(input_ids, positions, kv_caches,
input_metadata) attn_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(self, hidden_states: torch.Tensor,
......
...@@ -5,9 +5,8 @@ import torch ...@@ -5,9 +5,8 @@ import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear, MergedColumnParallelLinear,
...@@ -25,8 +24,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader, ...@@ -25,8 +24,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator) hf_model_weights_iterator)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
class InternLM2MLP(nn.Module): class InternLM2MLP(nn.Module):
...@@ -124,14 +121,13 @@ class InternLM2Attention(nn.Module): ...@@ -124,14 +121,13 @@ class InternLM2Attention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: torch.Tensor,
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.wqkv(hidden_states) qkv, _ = self.wqkv(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
output, _ = self.wo(attn_output) output, _ = self.wo(attn_output)
return output return output
...@@ -172,8 +168,8 @@ class InternLMDecoderLayer(nn.Module): ...@@ -172,8 +168,8 @@ class InternLMDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: torch.Tensor,
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
...@@ -187,7 +183,7 @@ class InternLMDecoderLayer(nn.Module): ...@@ -187,7 +183,7 @@ class InternLMDecoderLayer(nn.Module):
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache, kv_cache=kv_cache,
input_metadata=input_metadata, attn_metadata=attn_metadata,
) )
# Fully Connected # Fully Connected
...@@ -221,8 +217,8 @@ class InternLM2Model(nn.Module): ...@@ -221,8 +217,8 @@ class InternLM2Model(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[torch.Tensor],
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.tok_embeddings(input_ids) hidden_states = self.tok_embeddings(input_ids)
residual = None residual = None
...@@ -232,7 +228,7 @@ class InternLM2Model(nn.Module): ...@@ -232,7 +228,7 @@ class InternLM2Model(nn.Module):
positions, positions,
hidden_states, hidden_states,
kv_caches[i], kv_caches[i],
input_metadata, attn_metadata,
residual, residual,
) )
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
...@@ -258,11 +254,11 @@ class InternLM2ForCausalLM(nn.Module): ...@@ -258,11 +254,11 @@ class InternLM2ForCausalLM(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[torch.Tensor],
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata) attn_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(self, hidden_states: torch.Tensor,
......
...@@ -20,14 +20,13 @@ ...@@ -20,14 +20,13 @@
"""Inference-only Jais model compatible with HuggingFace weights.""" """Inference-only Jais model compatible with HuggingFace weights."""
import math import math
from typing import List, Optional, Tuple from typing import List, Optional
import torch import torch
from torch import nn from torch import nn
from vllm.transformers_utils.configs import JAISConfig from vllm.transformers_utils.configs import JAISConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
ColumnParallelLinear, ColumnParallelLinear,
LinearMethodBase, LinearMethodBase,
...@@ -49,8 +48,6 @@ from vllm.model_executor.weight_utils import ( ...@@ -49,8 +48,6 @@ from vllm.model_executor.weight_utils import (
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
KVCache = Tuple[torch.Tensor, torch.Tensor]
class SwiGLUActivation(nn.Module): class SwiGLUActivation(nn.Module):
...@@ -122,14 +119,12 @@ class JAISAttention(nn.Module): ...@@ -122,14 +119,12 @@ class JAISAttention(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: torch.Tensor,
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.c_attn(hidden_states) qkv, _ = self.c_attn(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
key_cache, value_cache = kv_cache attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output = self.attn(q, k, v, key_cache, value_cache,
input_metadata)
attn_output, _ = self.c_proj(attn_output) attn_output, _ = self.c_proj(attn_output)
return attn_output return attn_output
...@@ -196,15 +191,15 @@ class JAISBlock(nn.Module): ...@@ -196,15 +191,15 @@ class JAISBlock(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: torch.Tensor,
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
residual = hidden_states residual = hidden_states
hidden_states = self.ln_1(hidden_states) hidden_states = self.ln_1(hidden_states)
attn_output = self.attn( attn_output = self.attn(
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache, kv_cache=kv_cache,
input_metadata=input_metadata, attn_metadata=attn_metadata,
) )
# residual connection # residual connection
hidden_states = attn_output + residual hidden_states = attn_output + residual
...@@ -248,8 +243,8 @@ class JAISModel(nn.Module): ...@@ -248,8 +243,8 @@ class JAISModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[torch.Tensor],
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.wte(input_ids) inputs_embeds = self.wte(input_ids)
if self.wpe is not None: if self.wpe is not None:
...@@ -262,7 +257,7 @@ class JAISModel(nn.Module): ...@@ -262,7 +257,7 @@ class JAISModel(nn.Module):
for i in range(len(self.h)): for i in range(len(self.h)):
layer = self.h[i] layer = self.h[i]
hidden_states = layer(hidden_states, kv_caches[i], input_metadata) hidden_states = layer(hidden_states, kv_caches[i], attn_metadata)
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
return hidden_states return hidden_states
...@@ -293,11 +288,11 @@ class JAISLMHeadModel(nn.Module): ...@@ -293,11 +288,11 @@ class JAISLMHeadModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[torch.Tensor],
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata) attn_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(self, hidden_states: torch.Tensor,
...@@ -348,4 +343,4 @@ class JAISLMHeadModel(nn.Module): ...@@ -348,4 +343,4 @@ class JAISLMHeadModel(nn.Module):
loaded_weight = loaded_weight.t() loaded_weight = loaded_weight.t()
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
\ No newline at end of file
...@@ -27,10 +27,9 @@ import torch ...@@ -27,10 +27,9 @@ import torch
from torch import nn from torch import nn
from transformers import LlamaConfig from transformers import LlamaConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import LoRAConfig from vllm.config import LoRAConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear, MergedColumnParallelLinear,
...@@ -48,8 +47,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader, ...@@ -48,8 +47,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator) hf_model_weights_iterator)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
class LlamaMLP(nn.Module): class LlamaMLP(nn.Module):
...@@ -150,14 +147,13 @@ class LlamaAttention(nn.Module): ...@@ -150,14 +147,13 @@ class LlamaAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: torch.Tensor,
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
...@@ -203,8 +199,8 @@ class LlamaDecoderLayer(nn.Module): ...@@ -203,8 +199,8 @@ class LlamaDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: torch.Tensor,
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
...@@ -218,7 +214,7 @@ class LlamaDecoderLayer(nn.Module): ...@@ -218,7 +214,7 @@ class LlamaDecoderLayer(nn.Module):
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache, kv_cache=kv_cache,
input_metadata=input_metadata, attn_metadata=attn_metadata,
) )
# Fully Connected # Fully Connected
...@@ -258,8 +254,8 @@ class LlamaModel(nn.Module): ...@@ -258,8 +254,8 @@ class LlamaModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[torch.Tensor],
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
residual = None residual = None
...@@ -269,7 +265,7 @@ class LlamaModel(nn.Module): ...@@ -269,7 +265,7 @@ class LlamaModel(nn.Module):
positions, positions,
hidden_states, hidden_states,
kv_caches[i], kv_caches[i],
input_metadata, attn_metadata,
residual, residual,
) )
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
...@@ -336,11 +332,11 @@ class LlamaForCausalLM(nn.Module): ...@@ -336,11 +332,11 @@ class LlamaForCausalLM(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[torch.Tensor],
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata) attn_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(self, hidden_states: torch.Tensor,
......
...@@ -21,15 +21,14 @@ ...@@ -21,15 +21,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only Mixtral model.""" """Inference-only Mixtral model."""
from typing import List, Optional, Tuple from typing import List, Optional
import torch import torch
from torch import nn from torch import nn
from transformers import MixtralConfig from transformers import MixtralConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import LoRAConfig from vllm.config import LoRAConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (LinearMethodBase,
...@@ -51,8 +50,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader, ...@@ -51,8 +50,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator) hf_model_weights_iterator)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
class MixtralMoE(nn.Module): class MixtralMoE(nn.Module):
"""A tensor-parallel MoE implementation for Mixtral that shards each expert """A tensor-parallel MoE implementation for Mixtral that shards each expert
...@@ -209,14 +206,13 @@ class MixtralAttention(nn.Module): ...@@ -209,14 +206,13 @@ class MixtralAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: torch.Tensor,
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
...@@ -254,8 +250,8 @@ class MixtralDecoderLayer(nn.Module): ...@@ -254,8 +250,8 @@ class MixtralDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: torch.Tensor,
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
# Self Attention # Self Attention
...@@ -269,7 +265,7 @@ class MixtralDecoderLayer(nn.Module): ...@@ -269,7 +265,7 @@ class MixtralDecoderLayer(nn.Module):
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache, kv_cache=kv_cache,
input_metadata=input_metadata, attn_metadata=attn_metadata,
) )
# Fully Connected # Fully Connected
...@@ -309,15 +305,15 @@ class MixtralModel(nn.Module): ...@@ -309,15 +305,15 @@ class MixtralModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[torch.Tensor],
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
residual = None residual = None
for i in range(len(self.layers)): for i in range(len(self.layers)):
layer = self.layers[i] layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states, hidden_states, residual = layer(positions, hidden_states,
kv_caches[i], input_metadata, kv_caches[i], attn_metadata,
residual) residual)
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
...@@ -377,11 +373,11 @@ class MixtralForCausalLM(nn.Module): ...@@ -377,11 +373,11 @@ class MixtralForCausalLM(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[torch.Tensor],
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata) attn_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(self, hidden_states: torch.Tensor,
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only Mixtral model.""" """Inference-only Mixtral model."""
from typing import List, Optional, Tuple from typing import List, Optional
import numpy as np import numpy as np
...@@ -31,8 +31,7 @@ import torch.nn.functional as F ...@@ -31,8 +31,7 @@ import torch.nn.functional as F
from torch import nn from torch import nn
from transformers import MixtralConfig from transformers import MixtralConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (LinearMethodBase,
ReplicatedLinear, ReplicatedLinear,
...@@ -52,8 +51,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader, ...@@ -52,8 +51,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator) hf_model_weights_iterator)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
class MixtralMLP(nn.Module): class MixtralMLP(nn.Module):
...@@ -227,14 +224,13 @@ class MixtralAttention(nn.Module): ...@@ -227,14 +224,13 @@ class MixtralAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: torch.Tensor,
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
...@@ -269,8 +265,8 @@ class MixtralDecoderLayer(nn.Module): ...@@ -269,8 +265,8 @@ class MixtralDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: torch.Tensor,
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
# Self Attention # Self Attention
...@@ -284,7 +280,7 @@ class MixtralDecoderLayer(nn.Module): ...@@ -284,7 +280,7 @@ class MixtralDecoderLayer(nn.Module):
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache, kv_cache=kv_cache,
input_metadata=input_metadata, attn_metadata=attn_metadata,
) )
# Fully Connected # Fully Connected
...@@ -319,15 +315,15 @@ class MixtralModel(nn.Module): ...@@ -319,15 +315,15 @@ class MixtralModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[torch.Tensor],
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
residual = None residual = None
for i in range(len(self.layers)): for i in range(len(self.layers)):
layer = self.layers[i] layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states, hidden_states, residual = layer(positions, hidden_states,
kv_caches[i], input_metadata, kv_caches[i], attn_metadata,
residual) residual)
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
...@@ -352,11 +348,11 @@ class MixtralForCausalLM(nn.Module): ...@@ -352,11 +348,11 @@ class MixtralForCausalLM(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[torch.Tensor],
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata) attn_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(self, hidden_states: torch.Tensor,
......
# coding=utf-8 # coding=utf-8
# Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main # Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main
import math import math
from typing import List, Optional, Tuple from typing import List, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.model_executor.input_metadata import InputMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase, LinearMethodBase,
QKVParallelLinear, QKVParallelLinear,
...@@ -25,8 +24,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader, ...@@ -25,8 +24,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs.mpt import MPTConfig from vllm.transformers_utils.configs.mpt import MPTConfig
KVCache = Tuple[torch.Tensor, torch.Tensor]
def _get_alibi_slopes( def _get_alibi_slopes(
total_num_heads: int, total_num_heads: int,
...@@ -116,8 +113,8 @@ class MPTAttention(nn.Module): ...@@ -116,8 +113,8 @@ class MPTAttention(nn.Module):
self, self,
position_ids: torch.Tensor, position_ids: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: torch.Tensor,
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
del position_ids # unused. del position_ids # unused.
qkv, _ = self.Wqkv(hidden_states) qkv, _ = self.Wqkv(hidden_states)
...@@ -127,8 +124,7 @@ class MPTAttention(nn.Module): ...@@ -127,8 +124,7 @@ class MPTAttention(nn.Module):
if self.qk_ln: if self.qk_ln:
q = self.q_ln(q) q = self.q_ln(q)
k = self.k_ln(k) k = self.k_ln(k)
k_cache, v_cache = kv_cache attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
output, _ = self.out_proj(attn_output) output, _ = self.out_proj(attn_output)
return output return output
...@@ -184,15 +180,15 @@ class MPTBlock(nn.Module): ...@@ -184,15 +180,15 @@ class MPTBlock(nn.Module):
self, self,
position_ids: torch.Tensor, position_ids: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: torch.Tensor,
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
x = self.norm_1(hidden_states) x = self.norm_1(hidden_states)
x = self.attn( x = self.attn(
position_ids=position_ids, position_ids=position_ids,
hidden_states=x, hidden_states=x,
kv_cache=kv_cache, kv_cache=kv_cache,
input_metadata=input_metadata, attn_metadata=attn_metadata,
) )
hidden_states = hidden_states + x hidden_states = hidden_states + x
x = self.norm_2(hidden_states) x = self.norm_2(hidden_states)
...@@ -230,8 +226,8 @@ class MPTModel(nn.Module): ...@@ -230,8 +226,8 @@ class MPTModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[torch.Tensor],
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.wte(input_ids) hidden_states = self.wte(input_ids)
for i in range(len(self.blocks)): for i in range(len(self.blocks)):
...@@ -240,7 +236,7 @@ class MPTModel(nn.Module): ...@@ -240,7 +236,7 @@ class MPTModel(nn.Module):
position_ids, position_ids,
hidden_states, hidden_states,
kv_caches[i], kv_caches[i],
input_metadata, attn_metadata,
) )
hidden_states = self.norm_f(hidden_states) hidden_states = self.norm_f(hidden_states)
return hidden_states return hidden_states
...@@ -267,11 +263,11 @@ class MPTForCausalLM(nn.Module): ...@@ -267,11 +263,11 @@ class MPTForCausalLM(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[torch.Tensor],
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata) attn_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(self, hidden_states: torch.Tensor,
......
...@@ -42,8 +42,7 @@ import torch ...@@ -42,8 +42,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from vllm.model_executor.input_metadata import InputMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
ColumnParallelLinear, ColumnParallelLinear,
LinearMethodBase, LinearMethodBase,
...@@ -67,8 +66,6 @@ from vllm.sequence import SamplerOutput ...@@ -67,8 +66,6 @@ from vllm.sequence import SamplerOutput
# this model must need this dependency # this model must need this dependency
from hf_olmo import OLMoConfig from hf_olmo import OLMoConfig
KVCache = Tuple[torch.Tensor, torch.Tensor]
class SwiGLU(nn.Module): class SwiGLU(nn.Module):
...@@ -146,16 +143,15 @@ class OlmoAttention(nn.Module): ...@@ -146,16 +143,15 @@ class OlmoAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: torch.Tensor,
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.attn_norm(hidden_states) hidden_states = self.attn_norm(hidden_states)
qkv, _ = self.att_proj(hidden_states) qkv, _ = self.att_proj(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
if self.config.rope: if self.config.rope:
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
output, _ = self.attn_out(attn_output) output, _ = self.attn_out(attn_output)
return output return output
...@@ -241,12 +237,12 @@ class OlmoBlock(nn.Module): ...@@ -241,12 +237,12 @@ class OlmoBlock(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: torch.Tensor,
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
# Attention block. # Attention block.
og_x = hidden_states og_x = hidden_states
x = self.attn(positions, hidden_states, kv_cache, input_metadata) x = self.attn(positions, hidden_states, kv_cache, attn_metadata)
x = x + og_x x = x + og_x
# MLP block. # MLP block.
...@@ -296,8 +292,8 @@ class OlmoModel(nn.Module): ...@@ -296,8 +292,8 @@ class OlmoModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[torch.Tensor],
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
:param input_ids: A tensor of shape `(batch_size, seq_len)`. :param input_ids: A tensor of shape `(batch_size, seq_len)`.
...@@ -313,7 +309,7 @@ class OlmoModel(nn.Module): ...@@ -313,7 +309,7 @@ class OlmoModel(nn.Module):
positions, positions,
x, x,
kv_caches[block_idx], kv_caches[block_idx],
input_metadata, attn_metadata,
) )
# Apply final layer norm. # Apply final layer norm.
...@@ -344,14 +340,14 @@ class OLMoForCausalLM(nn.Module): ...@@ -344,14 +340,14 @@ class OLMoForCausalLM(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[torch.Tensor],
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model( hidden_states = self.model(
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
kv_caches=kv_caches, kv_caches=kv_caches,
input_metadata=input_metadata, attn_metadata=attn_metadata,
) )
return hidden_states return hidden_states
......
...@@ -17,15 +17,14 @@ ...@@ -17,15 +17,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only OPT model compatible with HuggingFace weights.""" """Inference-only OPT model compatible with HuggingFace weights."""
from typing import List, Optional, Tuple from typing import List, Optional
import torch import torch
from torch import nn from torch import nn
from transformers import OPTConfig from transformers import OPTConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase, LinearMethodBase,
QKVParallelLinear, QKVParallelLinear,
...@@ -42,8 +41,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader, ...@@ -42,8 +41,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator) hf_model_weights_iterator)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
class OPTLearnedPositionalEmbedding(nn.Embedding): class OPTLearnedPositionalEmbedding(nn.Embedding):
...@@ -97,14 +94,12 @@ class OPTAttention(nn.Module): ...@@ -97,14 +94,12 @@ class OPTAttention(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: torch.Tensor,
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
key_cache, value_cache = kv_cache attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output = self.attn(q, k, v, key_cache, value_cache,
input_metadata)
output, _ = self.out_proj(attn_output) output, _ = self.out_proj(attn_output)
return output return output
...@@ -152,8 +147,8 @@ class OPTDecoderLayer(nn.Module): ...@@ -152,8 +147,8 @@ class OPTDecoderLayer(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: torch.Tensor,
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
# Self Attention # Self Attention
residual = hidden_states residual = hidden_states
...@@ -162,7 +157,7 @@ class OPTDecoderLayer(nn.Module): ...@@ -162,7 +157,7 @@ class OPTDecoderLayer(nn.Module):
hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states = self.self_attn(hidden_states=hidden_states, hidden_states = self.self_attn(hidden_states=hidden_states,
kv_cache=kv_cache, kv_cache=kv_cache,
input_metadata=input_metadata) attn_metadata=attn_metadata)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
# 350m applies layer norm AFTER attention # 350m applies layer norm AFTER attention
if not self.do_layer_norm_before: if not self.do_layer_norm_before:
...@@ -241,8 +236,8 @@ class OPTDecoder(nn.Module): ...@@ -241,8 +236,8 @@ class OPTDecoder(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[torch.Tensor],
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
pos_embeds = self.embed_positions(positions) pos_embeds = self.embed_positions(positions)
...@@ -252,7 +247,7 @@ class OPTDecoder(nn.Module): ...@@ -252,7 +247,7 @@ class OPTDecoder(nn.Module):
for i in range(len(self.layers)): for i in range(len(self.layers)):
layer = self.layers[i] layer = self.layers[i]
hidden_states = layer(hidden_states, kv_caches[i], input_metadata) hidden_states = layer(hidden_states, kv_caches[i], attn_metadata)
if self.final_layer_norm is not None: if self.final_layer_norm is not None:
hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.final_layer_norm(hidden_states)
...@@ -275,10 +270,10 @@ class OPTModel(nn.Module): ...@@ -275,10 +270,10 @@ class OPTModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[torch.Tensor],
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
return self.decoder(input_ids, positions, kv_caches, input_metadata) return self.decoder(input_ids, positions, kv_caches, attn_metadata)
class OPTForCausalLM(nn.Module): class OPTForCausalLM(nn.Module):
...@@ -300,11 +295,11 @@ class OPTForCausalLM(nn.Module): ...@@ -300,11 +295,11 @@ class OPTForCausalLM(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[torch.Tensor],
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata) attn_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(self, hidden_states: torch.Tensor,
......
...@@ -10,9 +10,8 @@ import torch ...@@ -10,9 +10,8 @@ import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear, MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
...@@ -29,8 +28,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader, ...@@ -29,8 +28,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator) hf_model_weights_iterator)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
class OrionMLP(nn.Module): class OrionMLP(nn.Module):
...@@ -128,14 +125,13 @@ class OrionAttention(nn.Module): ...@@ -128,14 +125,13 @@ class OrionAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: torch.Tensor,
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
...@@ -178,8 +174,8 @@ class OrionDecoderLayer(nn.Module): ...@@ -178,8 +174,8 @@ class OrionDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: torch.Tensor,
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
...@@ -189,7 +185,7 @@ class OrionDecoderLayer(nn.Module): ...@@ -189,7 +185,7 @@ class OrionDecoderLayer(nn.Module):
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache, kv_cache=kv_cache,
input_metadata=input_metadata, attn_metadata=attn_metadata,
) )
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
...@@ -227,8 +223,8 @@ class OrionModel(nn.Module): ...@@ -227,8 +223,8 @@ class OrionModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[torch.Tensor],
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
residual = None residual = None
...@@ -238,7 +234,7 @@ class OrionModel(nn.Module): ...@@ -238,7 +234,7 @@ class OrionModel(nn.Module):
positions, positions,
hidden_states, hidden_states,
kv_caches[i], kv_caches[i],
input_metadata, attn_metadata,
residual, residual,
) )
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
...@@ -264,11 +260,11 @@ class OrionForCausalLM(nn.Module): ...@@ -264,11 +260,11 @@ class OrionForCausalLM(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[torch.Tensor],
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata) attn_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(self, hidden_states: torch.Tensor,
......
...@@ -35,15 +35,14 @@ ...@@ -35,15 +35,14 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Inference-only Phi-1.5 model compatible with HuggingFace weights.""" """Inference-only Phi-1.5 model compatible with HuggingFace weights."""
from typing import List, Optional, Tuple from typing import List, Optional
import torch import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase, LinearMethodBase,
QKVParallelLinear, QKVParallelLinear,
...@@ -60,8 +59,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader, ...@@ -60,8 +59,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator) hf_model_weights_iterator)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
class PhiAttention(nn.Module): class PhiAttention(nn.Module):
...@@ -115,14 +112,13 @@ class PhiAttention(nn.Module): ...@@ -115,14 +112,13 @@ class PhiAttention(nn.Module):
self, self,
position_ids: torch.Tensor, position_ids: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: torch.Tensor,
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k = self.rotary_emb(position_ids, q, k) q, k = self.rotary_emb(position_ids, q, k)
k_cache, v_cache = kv_cache attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
output, _ = self.dense(attn_output) output, _ = self.dense(attn_output)
return output return output
...@@ -172,8 +168,8 @@ class PhiLayer(nn.Module): ...@@ -172,8 +168,8 @@ class PhiLayer(nn.Module):
self, self,
position_ids: torch.Tensor, position_ids: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: torch.Tensor,
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
residual = hidden_states residual = hidden_states
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
...@@ -181,7 +177,7 @@ class PhiLayer(nn.Module): ...@@ -181,7 +177,7 @@ class PhiLayer(nn.Module):
position_ids=position_ids, position_ids=position_ids,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache, kv_cache=kv_cache,
input_metadata=input_metadata, attn_metadata=attn_metadata,
) )
feed_forward_hidden_states = self.mlp(hidden_states) feed_forward_hidden_states = self.mlp(hidden_states)
hidden_states = attn_outputs + feed_forward_hidden_states + residual hidden_states = attn_outputs + feed_forward_hidden_states + residual
...@@ -209,8 +205,8 @@ class PhiModel(nn.Module): ...@@ -209,8 +205,8 @@ class PhiModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[torch.Tensor],
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
for i in range(self.config.num_hidden_layers): for i in range(self.config.num_hidden_layers):
...@@ -219,7 +215,7 @@ class PhiModel(nn.Module): ...@@ -219,7 +215,7 @@ class PhiModel(nn.Module):
positions, positions,
hidden_states, hidden_states,
kv_caches[i], kv_caches[i],
input_metadata, attn_metadata,
) )
hidden_states = self.final_layernorm(hidden_states) hidden_states = self.final_layernorm(hidden_states)
...@@ -248,11 +244,11 @@ class PhiForCausalLM(nn.Module): ...@@ -248,11 +244,11 @@ class PhiForCausalLM(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[torch.Tensor],
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata) attn_metadata)
return hidden_states return hidden_states
......
...@@ -10,9 +10,8 @@ import torch ...@@ -10,9 +10,8 @@ import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear, MergedColumnParallelLinear,
...@@ -30,8 +29,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader, ...@@ -30,8 +29,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator) hf_model_weights_iterator)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
class QWenMLP(nn.Module): class QWenMLP(nn.Module):
...@@ -111,15 +108,13 @@ class QWenAttention(nn.Module): ...@@ -111,15 +108,13 @@ class QWenAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: torch.Tensor,
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.c_attn(hidden_states) qkv, _ = self.c_attn(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
output, _ = self.c_proj(attn_output) output, _ = self.c_proj(attn_output)
return output return output
...@@ -153,8 +148,8 @@ class QWenBlock(nn.Module): ...@@ -153,8 +148,8 @@ class QWenBlock(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: torch.Tensor,
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
...@@ -167,7 +162,7 @@ class QWenBlock(nn.Module): ...@@ -167,7 +162,7 @@ class QWenBlock(nn.Module):
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache, kv_cache=kv_cache,
input_metadata=input_metadata, attn_metadata=attn_metadata,
) )
# Fully Connected # Fully Connected
...@@ -201,8 +196,8 @@ class QWenModel(nn.Module): ...@@ -201,8 +196,8 @@ class QWenModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[torch.Tensor],
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.wte(input_ids) hidden_states = self.wte(input_ids)
residual = None residual = None
...@@ -212,7 +207,7 @@ class QWenModel(nn.Module): ...@@ -212,7 +207,7 @@ class QWenModel(nn.Module):
positions, positions,
hidden_states, hidden_states,
kv_caches[i], kv_caches[i],
input_metadata, attn_metadata,
residual, residual,
) )
hidden_states, _ = self.ln_f(hidden_states, residual) hidden_states, _ = self.ln_f(hidden_states, residual)
...@@ -238,11 +233,11 @@ class QWenLMHeadModel(nn.Module): ...@@ -238,11 +233,11 @@ class QWenLMHeadModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[torch.Tensor],
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata) attn_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(self, hidden_states: torch.Tensor,
......
...@@ -28,9 +28,8 @@ import torch ...@@ -28,9 +28,8 @@ import torch
from torch import nn from torch import nn
from transformers import Qwen2Config from transformers import Qwen2Config
from vllm.model_executor.input_metadata import InputMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear, MergedColumnParallelLinear,
...@@ -49,8 +48,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader, ...@@ -49,8 +48,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
from vllm.config import LoRAConfig from vllm.config import LoRAConfig
KVCache = Tuple[torch.Tensor, torch.Tensor]
class Qwen2MLP(nn.Module): class Qwen2MLP(nn.Module):
...@@ -147,14 +144,13 @@ class Qwen2Attention(nn.Module): ...@@ -147,14 +144,13 @@ class Qwen2Attention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: torch.Tensor,
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
...@@ -197,8 +193,8 @@ class Qwen2DecoderLayer(nn.Module): ...@@ -197,8 +193,8 @@ class Qwen2DecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: torch.Tensor,
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
...@@ -212,7 +208,7 @@ class Qwen2DecoderLayer(nn.Module): ...@@ -212,7 +208,7 @@ class Qwen2DecoderLayer(nn.Module):
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache, kv_cache=kv_cache,
input_metadata=input_metadata, attn_metadata=attn_metadata,
) )
# Fully Connected # Fully Connected
...@@ -248,8 +244,8 @@ class Qwen2Model(nn.Module): ...@@ -248,8 +244,8 @@ class Qwen2Model(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[torch.Tensor],
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
residual = None residual = None
...@@ -259,7 +255,7 @@ class Qwen2Model(nn.Module): ...@@ -259,7 +255,7 @@ class Qwen2Model(nn.Module):
positions, positions,
hidden_states, hidden_states,
kv_caches[i], kv_caches[i],
input_metadata, attn_metadata,
residual, residual,
) )
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
...@@ -315,11 +311,11 @@ class Qwen2ForCausalLM(nn.Module): ...@@ -315,11 +311,11 @@ class Qwen2ForCausalLM(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[torch.Tensor],
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata) attn_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(self, hidden_states: torch.Tensor,
......
...@@ -25,9 +25,8 @@ import torch ...@@ -25,9 +25,8 @@ import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear, MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
...@@ -44,8 +43,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader, ...@@ -44,8 +43,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator) hf_model_weights_iterator)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
class StablelmMLP(nn.Module): class StablelmMLP(nn.Module):
...@@ -134,14 +131,13 @@ class StablelmAttention(nn.Module): ...@@ -134,14 +131,13 @@ class StablelmAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: torch.Tensor,
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
...@@ -166,8 +162,8 @@ class StablelmDecoderLayer(nn.Module): ...@@ -166,8 +162,8 @@ class StablelmDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: torch.Tensor,
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
residual = hidden_states residual = hidden_states
...@@ -176,7 +172,7 @@ class StablelmDecoderLayer(nn.Module): ...@@ -176,7 +172,7 @@ class StablelmDecoderLayer(nn.Module):
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache, kv_cache=kv_cache,
input_metadata=input_metadata, attn_metadata=attn_metadata,
) )
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
...@@ -211,8 +207,8 @@ class StableLMEpochModel(nn.Module): ...@@ -211,8 +207,8 @@ class StableLMEpochModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[torch.Tensor],
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
for i in range(len(self.layers)): for i in range(len(self.layers)):
...@@ -221,7 +217,7 @@ class StableLMEpochModel(nn.Module): ...@@ -221,7 +217,7 @@ class StableLMEpochModel(nn.Module):
positions, positions,
hidden_states, hidden_states,
kv_caches[i], kv_caches[i],
input_metadata, attn_metadata,
) )
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
return hidden_states return hidden_states
...@@ -246,11 +242,11 @@ class StablelmForCausalLM(nn.Module): ...@@ -246,11 +242,11 @@ class StablelmForCausalLM(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[torch.Tensor],
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata) attn_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(self, hidden_states: torch.Tensor,
......
...@@ -18,15 +18,14 @@ ...@@ -18,15 +18,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" PyTorch Starcoder2 model.""" """ PyTorch Starcoder2 model."""
from typing import List, Optional, Tuple from typing import List, Optional
import torch import torch
from torch import nn from torch import nn
from transformers import Starcoder2Config from transformers import Starcoder2Config
from vllm.model_executor.input_metadata import InputMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...@@ -43,8 +42,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader, ...@@ -43,8 +42,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator) hf_model_weights_iterator)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
class Starcoder2Attention(nn.Module): class Starcoder2Attention(nn.Module):
...@@ -111,14 +108,13 @@ class Starcoder2Attention(nn.Module): ...@@ -111,14 +108,13 @@ class Starcoder2Attention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: torch.Tensor,
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
...@@ -171,8 +167,8 @@ class Starcoder2DecoderLayer(nn.Module): ...@@ -171,8 +167,8 @@ class Starcoder2DecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: torch.Tensor,
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
# Self Attention # Self Attention
residual = hidden_states residual = hidden_states
...@@ -181,7 +177,7 @@ class Starcoder2DecoderLayer(nn.Module): ...@@ -181,7 +177,7 @@ class Starcoder2DecoderLayer(nn.Module):
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache, kv_cache=kv_cache,
input_metadata=input_metadata, attn_metadata=attn_metadata,
) )
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
...@@ -217,14 +213,14 @@ class Starcoder2Model(nn.Module): ...@@ -217,14 +213,14 @@ class Starcoder2Model(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[torch.Tensor],
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
for i in range(len(self.layers)): for i in range(len(self.layers)):
layer = self.layers[i] layer = self.layers[i]
hidden_states = layer(positions, hidden_states, kv_caches[i], hidden_states = layer(positions, hidden_states, kv_caches[i],
input_metadata) attn_metadata)
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
return hidden_states return hidden_states
...@@ -258,11 +254,11 @@ class Starcoder2ForCausalLM(nn.Module): ...@@ -258,11 +254,11 @@ class Starcoder2ForCausalLM(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[torch.Tensor],
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata) attn_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(self, hidden_states: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment