Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
cdc1fa12
Unverified
Commit
cdc1fa12
authored
Feb 25, 2025
by
Harry Mellor
Committed by
GitHub
Feb 24, 2025
Browse files
Remove unused kwargs from model definitions (#13555)
parent
f61528d4
Changes
104
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
103 additions
and
360 deletions
+103
-360
docs/source/contributing/model/basic.md
docs/source/contributing/model/basic.md
+0
-2
docs/source/contributing/model/multimodal.md
docs/source/contributing/model/multimodal.md
+0
-2
tests/kernels/test_encoder_decoder_attn.py
tests/kernels/test_encoder_decoder_attn.py
+3
-11
vllm/attention/layer.py
vllm/attention/layer.py
+7
-12
vllm/model_executor/layers/mamba/mamba_mixer.py
vllm/model_executor/layers/mamba/mamba_mixer.py
+3
-2
vllm/model_executor/layers/mamba/mamba_mixer2.py
vllm/model_executor/layers/mamba/mamba_mixer2.py
+2
-2
vllm/model_executor/models/adapters.py
vllm/model_executor/models/adapters.py
+1
-5
vllm/model_executor/models/arctic.py
vllm/model_executor/models/arctic.py
+5
-19
vllm/model_executor/models/aria.py
vllm/model_executor/models/aria.py
+0
-5
vllm/model_executor/models/baichuan.py
vllm/model_executor/models/baichuan.py
+5
-19
vllm/model_executor/models/bamba.py
vllm/model_executor/models/bamba.py
+7
-22
vllm/model_executor/models/bart.py
vllm/model_executor/models/bart.py
+20
-73
vllm/model_executor/models/bert.py
vllm/model_executor/models/bert.py
+12
-32
vllm/model_executor/models/blip2.py
vllm/model_executor/models/blip2.py
+1
-6
vllm/model_executor/models/bloom.py
vllm/model_executor/models/bloom.py
+7
-24
vllm/model_executor/models/chameleon.py
vllm/model_executor/models/chameleon.py
+4
-23
vllm/model_executor/models/chatglm.py
vllm/model_executor/models/chatglm.py
+8
-34
vllm/model_executor/models/commandr.py
vllm/model_executor/models/commandr.py
+5
-19
vllm/model_executor/models/dbrx.py
vllm/model_executor/models/dbrx.py
+7
-28
vllm/model_executor/models/deepseek.py
vllm/model_executor/models/deepseek.py
+6
-20
No files found.
docs/source/contributing/model/basic.md
View file @
cdc1fa12
...
...
@@ -74,8 +74,6 @@ def forward(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
...
```
...
...
docs/source/contributing/model/multimodal.md
View file @
cdc1fa12
...
...
@@ -16,8 +16,6 @@ Further update the model as follows:
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
+ pixel_values: torch.Tensor,
) -> SamplerOutput:
```
...
...
tests/kernels/test_encoder_decoder_attn.py
View file @
cdc1fa12
...
...
@@ -644,11 +644,7 @@ def _run_encoder_attention_test(
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
reshaped_query
=
packed_qkv
.
query
.
view
(
-
1
,
test_pt
.
num_heads
*
test_pt
.
head_size
)
return
attn
.
forward
(
reshaped_query
,
packed_qkv
.
key
,
packed_qkv
.
value
,
torch
.
tensor
([],
dtype
=
torch
.
float32
,
device
=
packed_qkv
.
query
.
device
),
attn_metadata
)
return
attn
.
forward
(
reshaped_query
,
packed_qkv
.
key
,
packed_qkv
.
value
)
def
_run_decoder_self_attention_test
(
...
...
@@ -682,7 +678,6 @@ def _run_decoder_self_attention_test(
& attn_metadata
'''
attn
=
test_rsrcs
.
attn
kv_cache
=
test_rsrcs
.
kv_cache
packed_qkv
=
decoder_test_params
.
packed_qkvo
.
packed_qkv
assert
packed_qkv
is
not
None
with
set_forward_context
(
attn_metadata
,
vllm_config
):
...
...
@@ -695,8 +690,7 @@ def _run_decoder_self_attention_test(
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
reshaped_query
=
packed_qkv
.
query
.
view
(
-
1
,
test_pt
.
num_heads
*
test_pt
.
head_size
)
return
attn
.
forward
(
reshaped_query
,
packed_qkv
.
key
,
packed_qkv
.
value
,
kv_cache
,
attn_metadata
)
return
attn
.
forward
(
reshaped_query
,
packed_qkv
.
key
,
packed_qkv
.
value
)
def
_run_encoder_decoder_cross_attention_test
(
...
...
@@ -744,7 +738,6 @@ def _run_encoder_decoder_cross_attention_test(
assert
decoder_test_params
.
packed_qkvo
.
packed_qkv
is
not
None
attn
=
test_rsrcs
.
attn
kv_cache
=
test_rsrcs
.
kv_cache
if
cross_test_params
is
None
:
key
=
None
value
=
None
...
...
@@ -762,8 +755,7 @@ def _run_encoder_decoder_cross_attention_test(
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
reshaped_query
=
decoder_test_params
.
packed_qkvo
.
packed_qkv
.
query
.
view
(
-
1
,
test_pt
.
num_heads
*
test_pt
.
head_size
)
return
attn
.
forward
(
reshaped_query
,
key
,
value
,
kv_cache
,
attn_metadata
)
return
attn
.
forward
(
reshaped_query
,
key
,
value
)
@
pytest
.
fixture
(
autouse
=
True
)
...
...
vllm/attention/layer.py
View file @
cdc1fa12
...
...
@@ -7,7 +7,7 @@ import torch.nn as nn
import
torch.nn.functional
as
F
import
vllm.envs
as
envs
from
vllm.attention
import
AttentionMetadata
,
AttentionType
from
vllm.attention
import
AttentionType
from
vllm.attention.selector
import
backend_name_to_enum
,
get_attn_backend
from
vllm.config
import
CacheConfig
,
get_current_vllm_config
from
vllm.forward_context
import
ForwardContext
,
get_forward_context
...
...
@@ -153,15 +153,10 @@ class Attention(nn.Module):
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
# NOTE: please avoid accessing `kv_cache` and `attn_metadata` arguments
# directly, use `self.kv_cache` and
# `get_forward_context().attn_metadata` instead.
if
self
.
calculate_kv_scales
:
ctx_
attn_metadata
=
get_forward_context
().
attn_metadata
if
ctx_
attn_metadata
.
enable_kv_scales_calculation
:
attn_metadata
=
get_forward_context
().
attn_metadata
if
attn_metadata
.
enable_kv_scales_calculation
:
self
.
calc_kv_scales
(
key
,
value
)
if
self
.
use_output
:
output
=
torch
.
empty_like
(
query
)
...
...
@@ -177,14 +172,14 @@ class Attention(nn.Module):
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
if
self
.
use_direct_call
:
forward_context
:
ForwardContext
=
get_forward_context
()
ctx_
attn_metadata
=
forward_context
.
attn_metadata
attn_metadata
=
forward_context
.
attn_metadata
self_kv_cache
=
self
.
kv_cache
[
forward_context
.
virtual_engine
]
self
.
impl
.
forward
(
self
,
query
,
key
,
value
,
self_kv_cache
,
ctx_
attn_metadata
,
attn_metadata
,
output
=
output
)
else
:
torch
.
ops
.
vllm
.
unified_attention_with_output
(
...
...
@@ -193,10 +188,10 @@ class Attention(nn.Module):
else
:
if
self
.
use_direct_call
:
forward_context
=
get_forward_context
()
ctx_
attn_metadata
=
forward_context
.
attn_metadata
attn_metadata
=
forward_context
.
attn_metadata
self_kv_cache
=
self
.
kv_cache
[
forward_context
.
virtual_engine
]
return
self
.
impl
.
forward
(
self
,
query
,
key
,
value
,
self_kv_cache
,
ctx_
attn_metadata
)
self_kv_cache
,
attn_metadata
)
else
:
return
torch
.
ops
.
vllm
.
unified_attention
(
query
,
key
,
value
,
self
.
layer_name
)
...
...
vllm/model_executor/layers/mamba/mamba_mixer.py
View file @
cdc1fa12
...
...
@@ -7,6 +7,7 @@ from torch.nn.parameter import Parameter
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.distributed.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.forward_context
import
get_forward_context
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
...
...
@@ -130,14 +131,14 @@ class MambaMixer(CustomOp):
)
if
use_rms_norm
else
None
def
forward_native
(
self
,
hidden_states
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
conv_state
:
torch
.
Tensor
,
ssm_state
:
torch
.
Tensor
):
pass
def
forward_cuda
(
self
,
hidden_states
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
mamba_cache_params
:
MambaCacheParams
):
attn_metadata
:
AttentionMetadata
=
get_forward_context
().
attn_metadata
# 1. Gated MLP's linear projection
projected_states
=
self
.
in_proj
(
hidden_states
)[
0
].
transpose
(
-
2
,
-
1
)
hidden_states
,
gate
=
projected_states
.
chunk
(
2
,
dim
=-
2
)
...
...
vllm/model_executor/layers/mamba/mamba_mixer2.py
View file @
cdc1fa12
...
...
@@ -14,6 +14,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_gather
,
tensor_model_parallel_all_reduce
)
from
vllm.forward_context
import
get_forward_context
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
RowParallelLinear
)
...
...
@@ -376,17 +377,16 @@ class MambaMixer2(CustomOp):
eps
=
rms_norm_eps
)
def
forward_native
(
self
,
hidden_states
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
conv_state
:
torch
.
Tensor
,
ssm_state
:
torch
.
Tensor
):
pass
def
forward_cuda
(
self
,
hidden_states
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
mamba_cache_params
:
MambaCacheParams
,
sequence_idx
:
Optional
[
torch
.
Tensor
]
=
None
,
):
attn_metadata
:
AttentionMetadata
=
get_forward_context
().
attn_metadata
seq_len
,
_
=
hidden_states
.
shape
groups_time_state_size
=
self
.
n_groups
*
self
.
ssm_state_size
...
...
vllm/model_executor/models/adapters.py
View file @
cdc1fa12
...
...
@@ -160,7 +160,6 @@ def as_classification_model(cls: _T) -> _T:
return
cls
# Lazy import
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.model_executor.layers.linear
import
RowParallelLinear
from
vllm.model_executor.layers.pooler
import
PoolingType
...
...
@@ -201,13 +200,10 @@ def as_classification_model(cls: _T) -> _T:
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
list
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
hidden_states
=
super
().
forward
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
hidden_states
=
super
().
forward
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
)
logits
,
_
=
self
.
score
(
hidden_states
)
...
...
vllm/model_executor/models/arctic.py
View file @
cdc1fa12
...
...
@@ -5,7 +5,7 @@ from typing import Iterable, List, Optional, Set, Tuple, Union
import
torch
from
torch
import
nn
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
...
...
@@ -283,13 +283,11 @@ class ArcticAttention(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
...
...
@@ -336,16 +334,12 @@ class ArcticDecoderLayer(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
residual_input
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
hidden_states
=
residual_input
+
hidden_states
...
...
@@ -400,8 +394,6 @@ class ArcticModel(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
...
...
@@ -413,11 +405,8 @@ class ArcticModel(nn.Module):
else
:
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
layers
[
i
]
hidden_states
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
)
for
layer
in
self
.
layers
[
self
.
start_layer
:
self
.
end_layer
]:
hidden_states
=
layer
(
positions
,
hidden_states
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
})
hidden_states
=
self
.
norm
(
hidden_states
)
...
...
@@ -458,13 +447,10 @@ class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
...
...
vllm/model_executor/models/aria.py
View file @
cdc1fa12
...
...
@@ -9,7 +9,6 @@ from transformers import AriaConfig, AriaTextConfig, BatchFeature
from
transformers.models.aria.modeling_aria
import
AriaCrossAttention
from
transformers.models.aria.processing_aria
import
AriaProcessor
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
CacheConfig
,
QuantizationConfig
,
VllmConfig
from
vllm.distributed
import
get_tensor_model_parallel_rank
from
vllm.model_executor.layers.activation
import
get_act_fn
...
...
@@ -626,8 +625,6 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
:
object
,
...
...
@@ -643,8 +640,6 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
hidden_states
=
self
.
language_model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
,
)
...
...
vllm/model_executor/models/baichuan.py
View file @
cdc1fa12
...
...
@@ -20,13 +20,13 @@
# limitations under the License.
"""Inference-only BaiChuan model compatible with HuggingFace weights."""
import
math
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
Iterable
,
Optional
,
Set
,
Tuple
,
Union
import
torch
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
...
...
@@ -182,14 +182,12 @@ class BaiChuanAttention(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
W_pack
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
if
self
.
postion_embedding
!=
"ALIBI"
:
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
...
...
@@ -232,8 +230,6 @@ class BaiChuanDecoderLayer(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
...
...
@@ -246,8 +242,6 @@ class BaiChuanDecoderLayer(nn.Module):
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
# Fully Connected
...
...
@@ -301,8 +295,6 @@ class BaiChuanModel(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
...
...
@@ -316,13 +308,10 @@ class BaiChuanModel(nn.Module):
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
layers
[
i
]
for
layer
in
self
.
layers
[
self
.
start_layer
:
self
.
end_layer
]:
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
residual
,
)
if
not
get_pp_group
().
is_last_rank
:
...
...
@@ -379,13 +368,10 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
...
...
vllm/model_executor/models/bamba.py
View file @
cdc1fa12
# SPDX-License-Identifier: Apache-2.0
"""Inference-only Bamba model."""
# Added by the IBM Team, 2024
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
from
typing
import
Iterable
,
Optional
,
Set
,
Tuple
import
torch
from
torch
import
nn
from
transformers
import
BambaConfig
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.attention.layer
import
Attention
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
divide
,
get_tensor_model_parallel_world_size
from
vllm.distributed.parallel_state
import
get_pp_group
from
vllm.forward_context
import
get_forward_context
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
...
...
@@ -107,7 +107,6 @@ class BambaMixerDecoderLayer(nn.Module):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
mamba_cache_params
:
MambaCacheParams
,
sequence_idx
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -120,8 +119,8 @@ class BambaMixerDecoderLayer(nn.Module):
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
mamba
(
hidden_states
,
attn_metadata
,
mamba_cache_params
,
sequence_idx
)
hidden_states
=
self
.
mamba
(
hidden_states
,
mamba_cache_params
,
sequence_idx
)
# Fully Connected
hidden_states
,
residual
=
self
.
pre_ff_layernorm
(
hidden_states
,
residual
)
...
...
@@ -215,15 +214,13 @@ class BambaAttentionDecoderLayer(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
**
kwargs
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
...
...
@@ -231,8 +228,6 @@ class BambaAttentionDecoderLayer(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
**
kwargs
,
):
...
...
@@ -246,8 +241,6 @@ class BambaAttentionDecoderLayer(nn.Module):
hidden_states
=
self
.
self_attention
(
positions
=
positions
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
# Fully Connected
hidden_states
,
residual
=
self
.
pre_ff_layernorm
(
...
...
@@ -312,8 +305,6 @@ class BambaModel(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
mamba_cache_params
:
MambaCacheParams
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -323,6 +314,7 @@ class BambaModel(nn.Module):
# proper continuous batching computation including
# chunked prefill
seq_idx
=
None
attn_metadata
=
get_forward_context
().
attn_metadata
if
attn_metadata
.
num_prefills
>
0
:
seq_idx
=
torch
.
zeros_like
(
input_ids
,
dtype
=
torch
.
int32
)
for
i
,
(
srt
,
end
)
in
enumerate
(
...
...
@@ -348,9 +340,7 @@ class BambaModel(nn.Module):
num_attn
=
0
for
i
in
range
(
len
(
self
.
layers
)):
layer
=
self
.
layers
[
i
]
kv_cache
=
None
if
isinstance
(
layer
,
BambaAttentionDecoderLayer
):
kv_cache
=
kv_caches
[
num_attn
]
num_attn
+=
1
layer_mamba_cache_params
=
None
...
...
@@ -361,8 +351,6 @@ class BambaModel(nn.Module):
hidden_states
,
residual
=
layer
(
positions
=
positions
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
residual
=
residual
,
mamba_cache_params
=
layer_mamba_cache_params
,
sequence_idx
=
seq_idx
,
...
...
@@ -440,8 +428,6 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
):
...
...
@@ -454,8 +440,7 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
self
.
vllm_config
,
self
.
lm_head
.
weight
.
dtype
,
num_mamba_layers
,
*
self
.
_get_mamba_cache_shape
())
mamba_cache_params
=
self
.
mamba_cache
.
current_run_tensors
(
**
kwargs
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
mamba_cache_params
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
mamba_cache_params
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
...
...
vllm/model_executor/models/bart.py
View file @
cdc1fa12
...
...
@@ -19,14 +19,14 @@
# limitations under the License.
"""PyTorch BART model."""
import
math
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
from
typing
import
Iterable
,
Optional
,
Tuple
import
torch
from
torch
import
nn
from
transformers
import
BartConfig
from
transformers.utils
import
logging
from
vllm.attention
import
Attention
,
AttentionMetadata
,
AttentionType
from
vllm.attention
import
Attention
,
AttentionType
from
vllm.config
import
CacheConfig
,
LoRAConfig
,
VllmConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
get_act_fn
...
...
@@ -181,14 +181,13 @@ class BartEncoderAttention(nn.Module):
prefix
=
f
"
{
prefix
}
.attn"
,
attn_type
=
AttentionType
.
ENCODER
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
)
->
torch
.
Tensor
:
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Input shape: Batch x Time x Channel"""
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
out_proj
(
attn_output
)
return
output
...
...
@@ -261,14 +260,13 @@ class BartDecoderSelfAttention(nn.Module):
prefix
=
f
"
{
prefix
}
.attn"
,
attn_type
=
AttentionType
.
DECODER
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
)
->
torch
.
Tensor
:
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Input shape: Batch x Time x Channel"""
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
out_proj
(
attn_output
)
return
output
...
...
@@ -344,8 +342,6 @@ class BartCrossAttention(nn.Module):
def
forward
(
self
,
decoder_hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
encoder_hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""Input shape: Batch x Time x Channel"""
...
...
@@ -363,7 +359,7 @@ class BartCrossAttention(nn.Module):
_
,
k
,
v
=
qkv_enc
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
out_proj
(
attn_output
)
return
output
...
...
@@ -411,23 +407,16 @@ class BartEncoderLayer(nn.Module):
self
.
final_layer_norm
=
nn
.
LayerNorm
(
self
.
embed_dim
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
)
->
torch
.
Tensor
:
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
r
"""
Args:
hidden_states
torch.Tensor of *encoder* input embeddings.
kv_cache:
Layer-wise list of KV cache tensors
attn_metadata:
vLLM Attention metadata structure
Returns:
Encoder layer output torch.Tensor
"""
residual
=
hidden_states
hidden_states
=
self
.
self_attn
(
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
)
hidden_states
=
self
.
self_attn
(
hidden_states
=
hidden_states
)
hidden_states
=
residual
+
hidden_states
hidden_states
=
self
.
self_attn_layer_norm
(
hidden_states
)
...
...
@@ -509,18 +498,12 @@ class BartDecoderLayer(nn.Module):
def
forward
(
self
,
decoder_hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
encoder_hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
r
"""
Args:
decoder_hidden_states
torch.Tensor of *decoder* input embeddings.
kv_cache:
KV cache tensor
attn_metadata:
vLLM Attention metadata structure
encoder_hidden_states
torch.Tensor of *encoder* input embeddings.
Returns:
...
...
@@ -529,9 +512,7 @@ class BartDecoderLayer(nn.Module):
residual
=
decoder_hidden_states
# Self Attention
hidden_states
=
self
.
self_attn
(
hidden_states
=
decoder_hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
)
hidden_states
=
self
.
self_attn
(
hidden_states
=
decoder_hidden_states
)
hidden_states
=
residual
+
hidden_states
hidden_states
=
self
.
self_attn_layer_norm
(
hidden_states
)
...
...
@@ -542,8 +523,6 @@ class BartDecoderLayer(nn.Module):
hidden_states
=
self
.
encoder_attn
(
decoder_hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
encoder_hidden_states
=
encoder_hidden_states
,
)
...
...
@@ -609,9 +588,8 @@ class BartEncoder(nn.Module):
self
.
layernorm_embedding
=
nn
.
LayerNorm
(
embed_dim
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
)
->
torch
.
Tensor
:
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
)
->
torch
.
Tensor
:
r
"""
Args:
input_ids
...
...
@@ -620,10 +598,6 @@ class BartEncoder(nn.Module):
provide it.
positions
Positions of *encoder* input sequence tokens.
kv_caches:
Layer-wise list of KV cache tensors
attn_metadata:
vLLM Attention metadata structure
Returns:
Decoder output torch.Tensor
"""
...
...
@@ -636,12 +610,8 @@ class BartEncoder(nn.Module):
hidden_states
=
inputs_embeds
+
embed_pos
hidden_states
=
self
.
layernorm_embedding
(
hidden_states
)
for
idx
,
encoder_layer
in
enumerate
(
self
.
layers
):
hidden_states
=
encoder_layer
(
hidden_states
=
hidden_states
,
kv_cache
=
kv_caches
[
idx
],
attn_metadata
=
attn_metadata
,
)
for
encoder_layer
in
self
.
layers
:
hidden_states
=
encoder_layer
(
hidden_states
=
hidden_states
)
return
hidden_states
...
...
@@ -693,9 +663,7 @@ class BartDecoder(nn.Module):
def
forward
(
self
,
decoder_input_ids
:
torch
.
Tensor
,
decoder_positions
:
torch
.
Tensor
,
encoder_hidden_states
:
Optional
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
)
->
torch
.
Tensor
:
encoder_hidden_states
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
r
"""
Args:
decoder_input_ids
...
...
@@ -706,10 +674,6 @@ class BartDecoder(nn.Module):
Positions of *decoder* input sequence tokens.
encoder_hidden_states:
Tensor of encoder output embeddings
kv_caches:
Layer-wise list of KV cache tensors
attn_metadata:
vLLM Attention metadata structure
Returns:
Decoder output torch.Tensor
"""
...
...
@@ -725,11 +689,9 @@ class BartDecoder(nn.Module):
# decoder layers
for
idx
,
decoder_layer
in
enumerate
(
self
.
layers
)
:
for
decoder_layer
in
self
.
layers
:
hidden_states
=
decoder_layer
(
decoder_hidden_states
=
hidden_states
,
kv_cache
=
kv_caches
[
idx
],
attn_metadata
=
attn_metadata
,
encoder_hidden_states
=
encoder_hidden_states
,
)
...
...
@@ -768,8 +730,7 @@ class BartModel(nn.Module):
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
encoder_input_ids
:
torch
.
Tensor
,
encoder_positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
)
->
torch
.
Tensor
:
encoder_positions
:
torch
.
Tensor
)
->
torch
.
Tensor
:
r
"""
Args:
input_ids
...
...
@@ -782,10 +743,6 @@ class BartModel(nn.Module):
Indices of *encoder* input sequence tokens in the vocabulary.
encoder_positions:
Positions of *encoder* input sequence tokens.
kv_caches:
Layer-wise list of KV cache tensors
attn_metadata:
vLLM Attention metadata structure
Returns:
Model output torch.Tensor
"""
...
...
@@ -796,18 +753,14 @@ class BartModel(nn.Module):
# Run encoder attention if a non-zero number of encoder tokens
# are provided as input
encoder_hidden_states
=
self
.
encoder
(
input_ids
=
encoder_input_ids
,
positions
=
encoder_positions
,
kv_caches
=
kv_caches
,
attn_metadata
=
attn_metadata
)
positions
=
encoder_positions
)
# decoder outputs consists of
# (dec_features, past_key_value, dec_hidden, dec_attn)
decoder_outputs
=
self
.
decoder
(
decoder_input_ids
=
input_ids
,
decoder_positions
=
positions
,
encoder_hidden_states
=
encoder_hidden_states
,
kv_caches
=
kv_caches
,
attn_metadata
=
attn_metadata
)
encoder_hidden_states
=
encoder_hidden_states
)
return
decoder_outputs
...
...
@@ -845,8 +798,6 @@ class BartForConditionalGeneration(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
*
,
encoder_input_ids
:
torch
.
Tensor
,
...
...
@@ -863,15 +814,11 @@ class BartForConditionalGeneration(nn.Module):
torch.Tensor of *encoder* input token ids.
encoder_positions
torch.Tensor of *encoder* position indices
kv_caches:
Layer-wise list of KV cache tensors
attn_metadata:
vLLM Attention metadata structure
Returns:
Output torch.Tensor
"""
return
self
.
model
(
input_ids
,
positions
,
encoder_input_ids
,
encoder_positions
,
kv_caches
,
attn_metadata
)
encoder_positions
)
def
compute_logits
(
self
,
...
...
vllm/model_executor/models/bert.py
View file @
cdc1fa12
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
from
typing
import
Iterable
,
Optional
,
Set
,
Tuple
import
torch
from
torch
import
nn
from
transformers
import
BertConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
,
AttentionType
from
vllm.attention
import
Attention
,
AttentionType
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
PoolerConfig
,
VllmConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.forward_context
import
get_forward_context
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
QKVParallelLinear
,
...
...
@@ -113,12 +114,9 @@ class BertEncoder(nn.Module):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
for
i
in
range
(
len
(
self
.
layer
)):
layer
=
self
.
layer
[
i
]
hidden_states
=
layer
(
hidden_states
,
kv_caches
[
i
],
attn_metadata
)
for
layer
in
self
.
layer
:
hidden_states
=
layer
(
hidden_states
)
return
hidden_states
...
...
@@ -152,13 +150,8 @@ class BertLayer(nn.Module):
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.output"
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
Optional
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
):
attn_output
=
self
.
attention
(
hidden_states
,
kv_cache
,
attn_metadata
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
):
attn_output
=
self
.
attention
(
hidden_states
)
intermediate_output
=
self
.
intermediate
(
attn_output
)
output
=
self
.
output
(
intermediate_output
,
attn_output
)
return
output
...
...
@@ -191,10 +184,8 @@ class BertAttention(nn.Module):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
self_output
=
self
.
self
(
hidden_states
,
kv_cache
,
attn_metadata
)
self_output
=
self
.
self
(
hidden_states
)
return
self
.
output
(
self_output
,
hidden_states
)
...
...
@@ -246,12 +237,10 @@ class BertSelfAttention(nn.Module):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
output
=
self
.
attn
(
q
,
k
,
v
)
return
output
...
...
@@ -343,8 +332,6 @@ class BertModel(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
token_type_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -352,13 +339,14 @@ class BertModel(nn.Module):
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
else
:
attn_metadata
=
get_forward_context
().
attn_metadata
assert
hasattr
(
attn_metadata
,
"seq_lens_tensor"
)
hidden_states
=
self
.
embeddings
(
input_ids
=
input_ids
,
seq_lens
=
attn_metadata
.
seq_lens_tensor
,
position_ids
=
position_ids
,
token_type_ids
=
token_type_ids
)
return
self
.
encoder
(
hidden_states
,
kv_caches
,
attn_metadata
)
return
self
.
encoder
(
hidden_states
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
...
...
@@ -420,17 +408,13 @@ class BertEmbeddingModel(nn.Module):
self
,
input_ids
:
Optional
[
torch
.
Tensor
],
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
return
self
.
model
(
input_ids
=
input_ids
,
position_ids
=
positions
,
kv_caches
=
kv_caches
,
inputs_embeds
=
inputs_embeds
,
intermediate_tensors
=
intermediate_tensors
,
attn_metadata
=
attn_metadata
)
intermediate_tensors
=
intermediate_tensors
)
def
pooler
(
self
,
...
...
@@ -519,16 +503,12 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding):
self
,
input_ids
:
Optional
[
torch
.
Tensor
],
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
token_type_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
return
self
.
bert
(
input_ids
=
input_ids
,
position_ids
=
positions
,
kv_caches
=
kv_caches
,
inputs_embeds
=
inputs_embeds
,
intermediate_tensors
=
intermediate_tensors
,
attn_metadata
=
attn_metadata
,
token_type_ids
=
token_type_ids
)
vllm/model_executor/models/blip2.py
View file @
cdc1fa12
# SPDX-License-Identifier: Apache-2.0
from
functools
import
cached_property
from
typing
import
(
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Set
,
Tuple
,
from
typing
import
(
Iterable
,
Literal
,
Mapping
,
Optional
,
Set
,
Tuple
,
TypedDict
,
Union
)
import
torch
...
...
@@ -9,7 +9,6 @@ import torch.nn as nn
from
transformers
import
(
BatchFeature
,
Blip2Config
,
Blip2QFormerConfig
,
apply_chunking_to_forward
)
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
...
...
@@ -658,8 +657,6 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
:
object
,
...
...
@@ -708,8 +705,6 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
hidden_states
=
self
.
language_model
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
)
...
...
vllm/model_executor/models/bloom.py
View file @
cdc1fa12
...
...
@@ -18,13 +18,13 @@
# limitations under the License.
"""Inference-only BLOOM model compatible with HuggingFace weights."""
import
math
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
Iterable
,
Optional
,
Set
,
Tuple
,
Union
import
torch
from
torch
import
nn
from
transformers
import
BloomConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
...
...
@@ -126,13 +126,11 @@ class BloomAttention(nn.Module):
self
,
position_ids
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
del
position_ids
# Unused.
qkv
,
_
=
self
.
query_key_value
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
dense
(
attn_output
)
return
output
...
...
@@ -193,8 +191,6 @@ class BloomBlock(nn.Module):
self
,
position_ids
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
# Layer norm at the beginning of the transformer layer.
layernorm_output
=
self
.
input_layernorm
(
hidden_states
)
...
...
@@ -209,8 +205,6 @@ class BloomBlock(nn.Module):
attention_output
=
self
.
self_attention
(
position_ids
=
position_ids
,
hidden_states
=
layernorm_output
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
attention_output
=
attention_output
+
residual
layernorm_output
=
self
.
post_attention_layernorm
(
attention_output
)
...
...
@@ -266,8 +260,6 @@ class BloomModel(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
...
...
@@ -279,14 +271,8 @@ class BloomModel(nn.Module):
else
:
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
h
[
i
]
hidden_states
=
layer
(
position_ids
,
hidden_states
,
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
)
for
layer
in
self
.
h
[
self
.
start_layer
:
self
.
end_layer
]:
hidden_states
=
layer
(
position_ids
,
hidden_states
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
})
hidden_states
=
self
.
ln_f
(
hidden_states
)
...
...
@@ -322,14 +308,11 @@ class BloomForCausalLM(nn.Module, SupportsPP):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
def
compute_logits
(
...
...
vllm/model_executor/models/chameleon.py
View file @
cdc1fa12
# SPDX-License-Identifier: Apache-2.0
from
functools
import
cached_property
from
typing
import
(
Any
,
Dict
,
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Set
,
from
typing
import
(
Any
,
Dict
,
Iterable
,
Literal
,
Mapping
,
Optional
,
Set
,
Tuple
,
TypedDict
,
Union
)
import
torch
...
...
@@ -10,7 +10,7 @@ import torch.nn.functional as F
from
transformers
import
(
BatchFeature
,
ChameleonConfig
,
ChameleonProcessor
,
ChameleonVQVAEConfig
)
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.logger
import
init_logger
...
...
@@ -310,15 +310,13 @@ class ChameleonAttention(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
_apply_qk_norm
(
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
...
...
@@ -372,8 +370,6 @@ class ChameleonDecoderLayer(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
...
...
@@ -386,8 +382,6 @@ class ChameleonDecoderLayer(nn.Module):
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
# Fully Connected
...
...
@@ -447,8 +441,6 @@ class ChameleonSwinDecoderLayer(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
...
...
@@ -456,8 +448,6 @@ class ChameleonSwinDecoderLayer(nn.Module):
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
...
...
@@ -906,8 +896,6 @@ class ChameleonModel(nn.Module):
self
,
input_ids
:
Optional
[
torch
.
Tensor
],
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
...
...
@@ -921,13 +909,10 @@ class ChameleonModel(nn.Module):
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
layers
[
i
]
for
layer
in
self
.
layers
[
self
.
start_layer
:
self
.
end_layer
]:
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
residual
,
)
if
not
get_pp_group
().
is_last_rank
:
...
...
@@ -1028,8 +1013,6 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
,
...
...
@@ -1048,8 +1031,6 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
)
return
hidden_states
...
...
vllm/model_executor/models/chatglm.py
View file @
cdc1fa12
...
...
@@ -2,13 +2,13 @@
# Adapted from
# https://github.com/THUDM/ChatGLM2-6B
"""Inference-only ChatGLM model compatible with THUDM weights."""
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
Iterable
,
Optional
,
Set
,
Tuple
,
Union
import
torch
from
torch
import
nn
from
torch.nn
import
LayerNorm
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
SiluAndMul
...
...
@@ -108,19 +108,11 @@ class GLMAttention(nn.Module):
self
,
hidden_states
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
query_key_value
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
position_ids
,
q
,
k
)
context_layer
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
,
)
context_layer
=
self
.
attn
(
q
,
k
,
v
)
attn_output
,
_
=
self
.
dense
(
context_layer
)
return
attn_output
...
...
@@ -215,8 +207,6 @@ class GLMBlock(nn.Module):
self
,
hidden_states
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
# hidden_states: [num_tokens, h]
# Layer norm at the beginning of the transformer layer.
...
...
@@ -225,8 +215,6 @@ class GLMBlock(nn.Module):
attention_output
=
self
.
self_attention
(
hidden_states
=
layernorm_output
,
position_ids
=
position_ids
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
# Residual connection.
...
...
@@ -289,17 +277,10 @@ class GLMTransformer(nn.Module):
self
,
hidden_states
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
layers
[
i
]
hidden_states
=
layer
(
hidden_states
=
hidden_states
,
position_ids
=
position_ids
,
kv_cache
=
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
=
attn_metadata
,
)
for
layer
in
self
.
layers
[
self
.
start_layer
:
self
.
end_layer
]:
hidden_states
=
layer
(
hidden_states
=
hidden_states
,
position_ids
=
position_ids
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
})
...
...
@@ -350,8 +331,6 @@ class ChatGLMModel(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
:
object
,
...
...
@@ -369,8 +348,6 @@ class ChatGLMModel(nn.Module):
hidden_states
=
self
.
encoder
(
hidden_states
=
hidden_states
,
position_ids
=
positions
,
kv_caches
=
kv_caches
,
attn_metadata
=
attn_metadata
,
)
return
hidden_states
...
...
@@ -494,12 +471,9 @@ class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
vllm/model_executor/models/commandr.py
View file @
cdc1fa12
...
...
@@ -21,14 +21,14 @@
# This file is based on the LLama model definition file in transformers
"""PyTorch Cohere model."""
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
Iterable
,
Optional
,
Set
,
Tuple
,
Union
import
torch
import
torch.utils.checkpoint
from
torch
import
nn
from
transformers
import
CohereConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
...
...
@@ -218,8 +218,6 @@ class CohereAttention(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
...
...
@@ -227,7 +225,7 @@ class CohereAttention(nn.Module):
q
,
k
=
self
.
_apply_qk_norm
(
q
,
k
)
if
self
.
v1
or
self
.
sliding_window
:
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
...
...
@@ -255,8 +253,6 @@ class CohereDecoderLayer(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
...
...
@@ -265,8 +261,6 @@ class CohereDecoderLayer(nn.Module):
hidden_states_attention
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
hidden_states_mlp
=
self
.
mlp
(
hidden_states
)
# Add everything together
...
...
@@ -311,8 +305,6 @@ class CohereModel(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
...
...
@@ -326,13 +318,10 @@ class CohereModel(nn.Module):
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
layers
[
i
]
for
layer
in
self
.
layers
[
self
.
start_layer
:
self
.
end_layer
]:
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
residual
,
)
if
not
get_pp_group
().
is_last_rank
:
...
...
@@ -389,13 +378,10 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
...
...
vllm/model_executor/models/dbrx.py
View file @
cdc1fa12
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
Iterable
,
Optional
,
Set
,
Tuple
,
Union
import
torch
import
torch.nn
as
nn
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
...
...
@@ -230,15 +230,13 @@ class DbrxAttention(nn.Module):
self
,
position_ids
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
Wqkv
(
hidden_states
)
if
self
.
clip_qkv
is
not
None
:
qkv
.
clamp_
(
min
=-
self
.
clip_qkv
,
max
=
self
.
clip_qkv
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
position_ids
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
hidden_states
,
_
=
self
.
out_proj
(
attn_output
)
return
hidden_states
...
...
@@ -265,16 +263,12 @@ class DbrxFusedNormAttention(nn.Module):
self
,
position_ids
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
residual
=
hidden_states
hidden_states
=
self
.
norm_1
(
hidden_states
)
x
=
self
.
attn
(
position_ids
=
position_ids
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
hidden_states
=
residual
+
x
residual
=
hidden_states
...
...
@@ -303,14 +297,10 @@ class DbrxBlock(nn.Module):
self
,
position_ids
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
hidden_states
,
residual
=
self
.
norm_attn_norm
(
position_ids
=
position_ids
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
hidden_states
=
self
.
ffn
(
hidden_states
)
hidden_states
=
hidden_states
+
residual
...
...
@@ -353,8 +343,6 @@ class DbrxModel(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
...
...
@@ -366,14 +354,8 @@ class DbrxModel(nn.Module):
else
:
assert
intermediate_tensors
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
block
=
self
.
blocks
[
i
]
hidden_states
=
block
(
position_ids
,
hidden_states
,
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
)
for
block
in
self
.
blocks
[
self
.
start_layer
:
self
.
end_layer
]:
hidden_states
=
block
(
position_ids
,
hidden_states
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
})
hidden_states
=
self
.
norm_f
(
hidden_states
)
...
...
@@ -415,14 +397,11 @@ class DbrxForCausalLM(nn.Module, SupportsPP):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
def
compute_logits
(
...
...
vllm/model_executor/models/deepseek.py
View file @
cdc1fa12
...
...
@@ -22,13 +22,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Deepseek model."""
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Set
,
Tuple
,
Union
import
torch
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
...
...
@@ -248,13 +248,11 @@ class DeepseekAttention(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
...
...
@@ -309,8 +307,6 @@ class DeepseekDecoderLayer(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
# Self Attention
...
...
@@ -323,8 +319,6 @@ class DeepseekDecoderLayer(nn.Module):
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
# Fully Connected
...
...
@@ -370,8 +364,6 @@ class DeepseekModel(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
...
...
@@ -384,11 +376,8 @@ class DeepseekModel(nn.Module):
else
:
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
residual
)
for
layer
in
self
.
layers
[
self
.
start_layer
:
self
.
end_layer
]:
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
residual
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
,
...
...
@@ -425,13 +414,10 @@ class DeepseekForCausalLM(nn.Module, SupportsPP):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment