Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
cdc1fa12
Unverified
Commit
cdc1fa12
authored
Feb 25, 2025
by
Harry Mellor
Committed by
GitHub
Feb 24, 2025
Browse files
Remove unused kwargs from model definitions (#13555)
parent
f61528d4
Changes
104
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
103 additions
and
360 deletions
+103
-360
docs/source/contributing/model/basic.md
docs/source/contributing/model/basic.md
+0
-2
docs/source/contributing/model/multimodal.md
docs/source/contributing/model/multimodal.md
+0
-2
tests/kernels/test_encoder_decoder_attn.py
tests/kernels/test_encoder_decoder_attn.py
+3
-11
vllm/attention/layer.py
vllm/attention/layer.py
+7
-12
vllm/model_executor/layers/mamba/mamba_mixer.py
vllm/model_executor/layers/mamba/mamba_mixer.py
+3
-2
vllm/model_executor/layers/mamba/mamba_mixer2.py
vllm/model_executor/layers/mamba/mamba_mixer2.py
+2
-2
vllm/model_executor/models/adapters.py
vllm/model_executor/models/adapters.py
+1
-5
vllm/model_executor/models/arctic.py
vllm/model_executor/models/arctic.py
+5
-19
vllm/model_executor/models/aria.py
vllm/model_executor/models/aria.py
+0
-5
vllm/model_executor/models/baichuan.py
vllm/model_executor/models/baichuan.py
+5
-19
vllm/model_executor/models/bamba.py
vllm/model_executor/models/bamba.py
+7
-22
vllm/model_executor/models/bart.py
vllm/model_executor/models/bart.py
+20
-73
vllm/model_executor/models/bert.py
vllm/model_executor/models/bert.py
+12
-32
vllm/model_executor/models/blip2.py
vllm/model_executor/models/blip2.py
+1
-6
vllm/model_executor/models/bloom.py
vllm/model_executor/models/bloom.py
+7
-24
vllm/model_executor/models/chameleon.py
vllm/model_executor/models/chameleon.py
+4
-23
vllm/model_executor/models/chatglm.py
vllm/model_executor/models/chatglm.py
+8
-34
vllm/model_executor/models/commandr.py
vllm/model_executor/models/commandr.py
+5
-19
vllm/model_executor/models/dbrx.py
vllm/model_executor/models/dbrx.py
+7
-28
vllm/model_executor/models/deepseek.py
vllm/model_executor/models/deepseek.py
+6
-20
No files found.
docs/source/contributing/model/basic.md
View file @
cdc1fa12
...
@@ -74,8 +74,6 @@ def forward(
...
@@ -74,8 +74,6 @@ def forward(
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
...
```
```
...
...
docs/source/contributing/model/multimodal.md
View file @
cdc1fa12
...
@@ -16,8 +16,6 @@ Further update the model as follows:
...
@@ -16,8 +16,6 @@ Further update the model as follows:
self,
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor,
positions: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
+ pixel_values: torch.Tensor,
+ pixel_values: torch.Tensor,
) -> SamplerOutput:
) -> SamplerOutput:
```
```
...
...
tests/kernels/test_encoder_decoder_attn.py
View file @
cdc1fa12
...
@@ -644,11 +644,7 @@ def _run_encoder_attention_test(
...
@@ -644,11 +644,7 @@ def _run_encoder_attention_test(
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
reshaped_query
=
packed_qkv
.
query
.
view
(
reshaped_query
=
packed_qkv
.
query
.
view
(
-
1
,
test_pt
.
num_heads
*
test_pt
.
head_size
)
-
1
,
test_pt
.
num_heads
*
test_pt
.
head_size
)
return
attn
.
forward
(
return
attn
.
forward
(
reshaped_query
,
packed_qkv
.
key
,
packed_qkv
.
value
)
reshaped_query
,
packed_qkv
.
key
,
packed_qkv
.
value
,
torch
.
tensor
([],
dtype
=
torch
.
float32
,
device
=
packed_qkv
.
query
.
device
),
attn_metadata
)
def
_run_decoder_self_attention_test
(
def
_run_decoder_self_attention_test
(
...
@@ -682,7 +678,6 @@ def _run_decoder_self_attention_test(
...
@@ -682,7 +678,6 @@ def _run_decoder_self_attention_test(
& attn_metadata
& attn_metadata
'''
'''
attn
=
test_rsrcs
.
attn
attn
=
test_rsrcs
.
attn
kv_cache
=
test_rsrcs
.
kv_cache
packed_qkv
=
decoder_test_params
.
packed_qkvo
.
packed_qkv
packed_qkv
=
decoder_test_params
.
packed_qkvo
.
packed_qkv
assert
packed_qkv
is
not
None
assert
packed_qkv
is
not
None
with
set_forward_context
(
attn_metadata
,
vllm_config
):
with
set_forward_context
(
attn_metadata
,
vllm_config
):
...
@@ -695,8 +690,7 @@ def _run_decoder_self_attention_test(
...
@@ -695,8 +690,7 @@ def _run_decoder_self_attention_test(
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
reshaped_query
=
packed_qkv
.
query
.
view
(
reshaped_query
=
packed_qkv
.
query
.
view
(
-
1
,
test_pt
.
num_heads
*
test_pt
.
head_size
)
-
1
,
test_pt
.
num_heads
*
test_pt
.
head_size
)
return
attn
.
forward
(
reshaped_query
,
packed_qkv
.
key
,
packed_qkv
.
value
,
return
attn
.
forward
(
reshaped_query
,
packed_qkv
.
key
,
packed_qkv
.
value
)
kv_cache
,
attn_metadata
)
def
_run_encoder_decoder_cross_attention_test
(
def
_run_encoder_decoder_cross_attention_test
(
...
@@ -744,7 +738,6 @@ def _run_encoder_decoder_cross_attention_test(
...
@@ -744,7 +738,6 @@ def _run_encoder_decoder_cross_attention_test(
assert
decoder_test_params
.
packed_qkvo
.
packed_qkv
is
not
None
assert
decoder_test_params
.
packed_qkvo
.
packed_qkv
is
not
None
attn
=
test_rsrcs
.
attn
attn
=
test_rsrcs
.
attn
kv_cache
=
test_rsrcs
.
kv_cache
if
cross_test_params
is
None
:
if
cross_test_params
is
None
:
key
=
None
key
=
None
value
=
None
value
=
None
...
@@ -762,8 +755,7 @@ def _run_encoder_decoder_cross_attention_test(
...
@@ -762,8 +755,7 @@ def _run_encoder_decoder_cross_attention_test(
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
reshaped_query
=
decoder_test_params
.
packed_qkvo
.
packed_qkv
.
query
.
view
(
reshaped_query
=
decoder_test_params
.
packed_qkvo
.
packed_qkv
.
query
.
view
(
-
1
,
test_pt
.
num_heads
*
test_pt
.
head_size
)
-
1
,
test_pt
.
num_heads
*
test_pt
.
head_size
)
return
attn
.
forward
(
reshaped_query
,
key
,
value
,
kv_cache
,
return
attn
.
forward
(
reshaped_query
,
key
,
value
)
attn_metadata
)
@
pytest
.
fixture
(
autouse
=
True
)
@
pytest
.
fixture
(
autouse
=
True
)
...
...
vllm/attention/layer.py
View file @
cdc1fa12
...
@@ -7,7 +7,7 @@ import torch.nn as nn
...
@@ -7,7 +7,7 @@ import torch.nn as nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.attention
import
AttentionMetadata
,
AttentionType
from
vllm.attention
import
AttentionType
from
vllm.attention.selector
import
backend_name_to_enum
,
get_attn_backend
from
vllm.attention.selector
import
backend_name_to_enum
,
get_attn_backend
from
vllm.config
import
CacheConfig
,
get_current_vllm_config
from
vllm.config
import
CacheConfig
,
get_current_vllm_config
from
vllm.forward_context
import
ForwardContext
,
get_forward_context
from
vllm.forward_context
import
ForwardContext
,
get_forward_context
...
@@ -153,15 +153,10 @@ class Attention(nn.Module):
...
@@ -153,15 +153,10 @@ class Attention(nn.Module):
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# NOTE: please avoid accessing `kv_cache` and `attn_metadata` arguments
# directly, use `self.kv_cache` and
# `get_forward_context().attn_metadata` instead.
if
self
.
calculate_kv_scales
:
if
self
.
calculate_kv_scales
:
ctx_
attn_metadata
=
get_forward_context
().
attn_metadata
attn_metadata
=
get_forward_context
().
attn_metadata
if
ctx_
attn_metadata
.
enable_kv_scales_calculation
:
if
attn_metadata
.
enable_kv_scales_calculation
:
self
.
calc_kv_scales
(
key
,
value
)
self
.
calc_kv_scales
(
key
,
value
)
if
self
.
use_output
:
if
self
.
use_output
:
output
=
torch
.
empty_like
(
query
)
output
=
torch
.
empty_like
(
query
)
...
@@ -177,14 +172,14 @@ class Attention(nn.Module):
...
@@ -177,14 +172,14 @@ class Attention(nn.Module):
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
if
self
.
use_direct_call
:
if
self
.
use_direct_call
:
forward_context
:
ForwardContext
=
get_forward_context
()
forward_context
:
ForwardContext
=
get_forward_context
()
ctx_
attn_metadata
=
forward_context
.
attn_metadata
attn_metadata
=
forward_context
.
attn_metadata
self_kv_cache
=
self
.
kv_cache
[
forward_context
.
virtual_engine
]
self_kv_cache
=
self
.
kv_cache
[
forward_context
.
virtual_engine
]
self
.
impl
.
forward
(
self
,
self
.
impl
.
forward
(
self
,
query
,
query
,
key
,
key
,
value
,
value
,
self_kv_cache
,
self_kv_cache
,
ctx_
attn_metadata
,
attn_metadata
,
output
=
output
)
output
=
output
)
else
:
else
:
torch
.
ops
.
vllm
.
unified_attention_with_output
(
torch
.
ops
.
vllm
.
unified_attention_with_output
(
...
@@ -193,10 +188,10 @@ class Attention(nn.Module):
...
@@ -193,10 +188,10 @@ class Attention(nn.Module):
else
:
else
:
if
self
.
use_direct_call
:
if
self
.
use_direct_call
:
forward_context
=
get_forward_context
()
forward_context
=
get_forward_context
()
ctx_
attn_metadata
=
forward_context
.
attn_metadata
attn_metadata
=
forward_context
.
attn_metadata
self_kv_cache
=
self
.
kv_cache
[
forward_context
.
virtual_engine
]
self_kv_cache
=
self
.
kv_cache
[
forward_context
.
virtual_engine
]
return
self
.
impl
.
forward
(
self
,
query
,
key
,
value
,
return
self
.
impl
.
forward
(
self
,
query
,
key
,
value
,
self_kv_cache
,
ctx_
attn_metadata
)
self_kv_cache
,
attn_metadata
)
else
:
else
:
return
torch
.
ops
.
vllm
.
unified_attention
(
return
torch
.
ops
.
vllm
.
unified_attention
(
query
,
key
,
value
,
self
.
layer_name
)
query
,
key
,
value
,
self
.
layer_name
)
...
...
vllm/model_executor/layers/mamba/mamba_mixer.py
View file @
cdc1fa12
...
@@ -7,6 +7,7 @@ from torch.nn.parameter import Parameter
...
@@ -7,6 +7,7 @@ from torch.nn.parameter import Parameter
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.distributed.parallel_state
import
(
from
vllm.distributed.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.forward_context
import
get_forward_context
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
...
@@ -130,14 +131,14 @@ class MambaMixer(CustomOp):
...
@@ -130,14 +131,14 @@ class MambaMixer(CustomOp):
)
if
use_rms_norm
else
None
)
if
use_rms_norm
else
None
def
forward_native
(
self
,
hidden_states
:
torch
.
Tensor
,
def
forward_native
(
self
,
hidden_states
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
conv_state
:
torch
.
Tensor
,
ssm_state
:
torch
.
Tensor
):
conv_state
:
torch
.
Tensor
,
ssm_state
:
torch
.
Tensor
):
pass
pass
def
forward_cuda
(
self
,
hidden_states
:
torch
.
Tensor
,
def
forward_cuda
(
self
,
hidden_states
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
mamba_cache_params
:
MambaCacheParams
):
mamba_cache_params
:
MambaCacheParams
):
attn_metadata
:
AttentionMetadata
=
get_forward_context
().
attn_metadata
# 1. Gated MLP's linear projection
# 1. Gated MLP's linear projection
projected_states
=
self
.
in_proj
(
hidden_states
)[
0
].
transpose
(
-
2
,
-
1
)
projected_states
=
self
.
in_proj
(
hidden_states
)[
0
].
transpose
(
-
2
,
-
1
)
hidden_states
,
gate
=
projected_states
.
chunk
(
2
,
dim
=-
2
)
hidden_states
,
gate
=
projected_states
.
chunk
(
2
,
dim
=-
2
)
...
...
vllm/model_executor/layers/mamba/mamba_mixer2.py
View file @
cdc1fa12
...
@@ -14,6 +14,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
...
@@ -14,6 +14,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_gather
,
tensor_model_parallel_all_gather
,
tensor_model_parallel_all_reduce
)
tensor_model_parallel_all_reduce
)
from
vllm.forward_context
import
get_forward_context
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
...
@@ -376,17 +377,16 @@ class MambaMixer2(CustomOp):
...
@@ -376,17 +377,16 @@ class MambaMixer2(CustomOp):
eps
=
rms_norm_eps
)
eps
=
rms_norm_eps
)
def
forward_native
(
self
,
hidden_states
:
torch
.
Tensor
,
def
forward_native
(
self
,
hidden_states
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
conv_state
:
torch
.
Tensor
,
ssm_state
:
torch
.
Tensor
):
conv_state
:
torch
.
Tensor
,
ssm_state
:
torch
.
Tensor
):
pass
pass
def
forward_cuda
(
def
forward_cuda
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
mamba_cache_params
:
MambaCacheParams
,
mamba_cache_params
:
MambaCacheParams
,
sequence_idx
:
Optional
[
torch
.
Tensor
]
=
None
,
sequence_idx
:
Optional
[
torch
.
Tensor
]
=
None
,
):
):
attn_metadata
:
AttentionMetadata
=
get_forward_context
().
attn_metadata
seq_len
,
_
=
hidden_states
.
shape
seq_len
,
_
=
hidden_states
.
shape
groups_time_state_size
=
self
.
n_groups
*
self
.
ssm_state_size
groups_time_state_size
=
self
.
n_groups
*
self
.
ssm_state_size
...
...
vllm/model_executor/models/adapters.py
View file @
cdc1fa12
...
@@ -160,7 +160,6 @@ def as_classification_model(cls: _T) -> _T:
...
@@ -160,7 +160,6 @@ def as_classification_model(cls: _T) -> _T:
return
cls
return
cls
# Lazy import
# Lazy import
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.model_executor.layers.linear
import
RowParallelLinear
from
vllm.model_executor.layers.linear
import
RowParallelLinear
from
vllm.model_executor.layers.pooler
import
PoolingType
from
vllm.model_executor.layers.pooler
import
PoolingType
...
@@ -201,13 +200,10 @@ def as_classification_model(cls: _T) -> _T:
...
@@ -201,13 +200,10 @@ def as_classification_model(cls: _T) -> _T:
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
list
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
super
().
forward
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
super
().
forward
(
input_ids
,
positions
,
attn_metadata
,
intermediate_tensors
,
intermediate_tensors
,
inputs_embeds
)
inputs_embeds
)
logits
,
_
=
self
.
score
(
hidden_states
)
logits
,
_
=
self
.
score
(
hidden_states
)
...
...
vllm/model_executor/models/arctic.py
View file @
cdc1fa12
...
@@ -5,7 +5,7 @@ from typing import Iterable, List, Optional, Set, Tuple, Union
...
@@ -5,7 +5,7 @@ from typing import Iterable, List, Optional, Set, Tuple, Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
...
@@ -283,13 +283,11 @@ class ArcticAttention(nn.Module):
...
@@ -283,13 +283,11 @@ class ArcticAttention(nn.Module):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
o_proj
(
attn_output
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
return
output
...
@@ -336,16 +334,12 @@ class ArcticDecoderLayer(nn.Module):
...
@@ -336,16 +334,12 @@ class ArcticDecoderLayer(nn.Module):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
residual_input
=
hidden_states
residual_input
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
hidden_states
=
self
.
self_attn
(
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
positions
=
positions
,
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
)
hidden_states
=
residual_input
+
hidden_states
hidden_states
=
residual_input
+
hidden_states
...
@@ -400,8 +394,6 @@ class ArcticModel(nn.Module):
...
@@ -400,8 +394,6 @@ class ArcticModel(nn.Module):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
...
@@ -413,11 +405,8 @@ class ArcticModel(nn.Module):
...
@@ -413,11 +405,8 @@ class ArcticModel(nn.Module):
else
:
else
:
assert
intermediate_tensors
is
not
None
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
for
layer
in
self
.
layers
[
self
.
start_layer
:
self
.
end_layer
]:
layer
=
self
.
layers
[
i
]
hidden_states
=
layer
(
positions
,
hidden_states
)
hidden_states
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
)
if
not
get_pp_group
().
is_last_rank
:
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
})
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
})
hidden_states
=
self
.
norm
(
hidden_states
)
hidden_states
=
self
.
norm
(
hidden_states
)
...
@@ -458,13 +447,10 @@ class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant):
...
@@ -458,13 +447,10 @@ class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
inputs_embeds
)
return
hidden_states
return
hidden_states
...
...
vllm/model_executor/models/aria.py
View file @
cdc1fa12
...
@@ -9,7 +9,6 @@ from transformers import AriaConfig, AriaTextConfig, BatchFeature
...
@@ -9,7 +9,6 @@ from transformers import AriaConfig, AriaTextConfig, BatchFeature
from
transformers.models.aria.modeling_aria
import
AriaCrossAttention
from
transformers.models.aria.modeling_aria
import
AriaCrossAttention
from
transformers.models.aria.processing_aria
import
AriaProcessor
from
transformers.models.aria.processing_aria
import
AriaProcessor
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
CacheConfig
,
QuantizationConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
QuantizationConfig
,
VllmConfig
from
vllm.distributed
import
get_tensor_model_parallel_rank
from
vllm.distributed
import
get_tensor_model_parallel_rank
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.activation
import
get_act_fn
...
@@ -626,8 +625,6 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
...
@@ -626,8 +625,6 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
:
object
,
**
kwargs
:
object
,
...
@@ -643,8 +640,6 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
...
@@ -643,8 +640,6 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
hidden_states
=
self
.
language_model
(
hidden_states
=
self
.
language_model
(
input_ids
,
input_ids
,
positions
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
,
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
,
inputs_embeds
=
inputs_embeds
,
)
)
...
...
vllm/model_executor/models/baichuan.py
View file @
cdc1fa12
...
@@ -20,13 +20,13 @@
...
@@ -20,13 +20,13 @@
# limitations under the License.
# limitations under the License.
"""Inference-only BaiChuan model compatible with HuggingFace weights."""
"""Inference-only BaiChuan model compatible with HuggingFace weights."""
import
math
import
math
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
Iterable
,
Optional
,
Set
,
Tuple
,
Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
...
@@ -182,14 +182,12 @@ class BaiChuanAttention(nn.Module):
...
@@ -182,14 +182,12 @@ class BaiChuanAttention(nn.Module):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
W_pack
(
hidden_states
)
qkv
,
_
=
self
.
W_pack
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
if
self
.
postion_embedding
!=
"ALIBI"
:
if
self
.
postion_embedding
!=
"ALIBI"
:
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
o_proj
(
attn_output
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
return
output
...
@@ -232,8 +230,6 @@ class BaiChuanDecoderLayer(nn.Module):
...
@@ -232,8 +230,6 @@ class BaiChuanDecoderLayer(nn.Module):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
# Self Attention
...
@@ -246,8 +242,6 @@ class BaiChuanDecoderLayer(nn.Module):
...
@@ -246,8 +242,6 @@ class BaiChuanDecoderLayer(nn.Module):
hidden_states
=
self
.
self_attn
(
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
positions
=
positions
,
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
)
# Fully Connected
# Fully Connected
...
@@ -301,8 +295,6 @@ class BaiChuanModel(nn.Module):
...
@@ -301,8 +295,6 @@ class BaiChuanModel(nn.Module):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
...
@@ -316,13 +308,10 @@ class BaiChuanModel(nn.Module):
...
@@ -316,13 +308,10 @@ class BaiChuanModel(nn.Module):
assert
intermediate_tensors
is
not
None
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
residual
=
intermediate_tensors
[
"residual"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
for
layer
in
self
.
layers
[
self
.
start_layer
:
self
.
end_layer
]:
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
hidden_states
,
residual
=
layer
(
positions
,
positions
,
hidden_states
,
hidden_states
,
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
residual
,
residual
,
)
)
if
not
get_pp_group
().
is_last_rank
:
if
not
get_pp_group
().
is_last_rank
:
...
@@ -379,13 +368,10 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
...
@@ -379,13 +368,10 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
inputs_embeds
)
return
hidden_states
return
hidden_states
...
...
vllm/model_executor/models/bamba.py
View file @
cdc1fa12
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
"""Inference-only Bamba model."""
"""Inference-only Bamba model."""
# Added by the IBM Team, 2024
# Added by the IBM Team, 2024
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
from
typing
import
Iterable
,
Optional
,
Set
,
Tuple
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
BambaConfig
from
transformers
import
BambaConfig
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.attention.layer
import
Attention
from
vllm.attention.layer
import
Attention
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
divide
,
get_tensor_model_parallel_world_size
from
vllm.distributed
import
divide
,
get_tensor_model_parallel_world_size
from
vllm.distributed.parallel_state
import
get_pp_group
from
vllm.distributed.parallel_state
import
get_pp_group
from
vllm.forward_context
import
get_forward_context
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
...
@@ -107,7 +107,6 @@ class BambaMixerDecoderLayer(nn.Module):
...
@@ -107,7 +107,6 @@ class BambaMixerDecoderLayer(nn.Module):
def
forward
(
def
forward
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
residual
:
Optional
[
torch
.
Tensor
],
mamba_cache_params
:
MambaCacheParams
,
mamba_cache_params
:
MambaCacheParams
,
sequence_idx
:
Optional
[
torch
.
Tensor
]
=
None
,
sequence_idx
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -120,8 +119,8 @@ class BambaMixerDecoderLayer(nn.Module):
...
@@ -120,8 +119,8 @@ class BambaMixerDecoderLayer(nn.Module):
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
,
residual
)
hidden_states
=
self
.
mamba
(
hidden_states
,
attn_metadata
,
hidden_states
=
self
.
mamba
(
hidden_states
,
mamba_cache_params
,
mamba_cache_params
,
sequence_idx
)
sequence_idx
)
# Fully Connected
# Fully Connected
hidden_states
,
residual
=
self
.
pre_ff_layernorm
(
hidden_states
,
residual
=
self
.
pre_ff_layernorm
(
hidden_states
,
residual
)
hidden_states
,
residual
)
...
@@ -215,15 +214,13 @@ class BambaAttentionDecoderLayer(nn.Module):
...
@@ -215,15 +214,13 @@ class BambaAttentionDecoderLayer(nn.Module):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
**
kwargs
,
**
kwargs
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
o_proj
(
attn_output
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
return
output
...
@@ -231,8 +228,6 @@ class BambaAttentionDecoderLayer(nn.Module):
...
@@ -231,8 +228,6 @@ class BambaAttentionDecoderLayer(nn.Module):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
residual
:
Optional
[
torch
.
Tensor
],
**
kwargs
,
**
kwargs
,
):
):
...
@@ -246,8 +241,6 @@ class BambaAttentionDecoderLayer(nn.Module):
...
@@ -246,8 +241,6 @@ class BambaAttentionDecoderLayer(nn.Module):
hidden_states
=
self
.
self_attention
(
hidden_states
=
self
.
self_attention
(
positions
=
positions
,
positions
=
positions
,
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
)
# Fully Connected
# Fully Connected
hidden_states
,
residual
=
self
.
pre_ff_layernorm
(
hidden_states
,
residual
=
self
.
pre_ff_layernorm
(
...
@@ -312,8 +305,6 @@ class BambaModel(nn.Module):
...
@@ -312,8 +305,6 @@ class BambaModel(nn.Module):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
mamba_cache_params
:
MambaCacheParams
,
mamba_cache_params
:
MambaCacheParams
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -323,6 +314,7 @@ class BambaModel(nn.Module):
...
@@ -323,6 +314,7 @@ class BambaModel(nn.Module):
# proper continuous batching computation including
# proper continuous batching computation including
# chunked prefill
# chunked prefill
seq_idx
=
None
seq_idx
=
None
attn_metadata
=
get_forward_context
().
attn_metadata
if
attn_metadata
.
num_prefills
>
0
:
if
attn_metadata
.
num_prefills
>
0
:
seq_idx
=
torch
.
zeros_like
(
input_ids
,
dtype
=
torch
.
int32
)
seq_idx
=
torch
.
zeros_like
(
input_ids
,
dtype
=
torch
.
int32
)
for
i
,
(
srt
,
end
)
in
enumerate
(
for
i
,
(
srt
,
end
)
in
enumerate
(
...
@@ -348,9 +340,7 @@ class BambaModel(nn.Module):
...
@@ -348,9 +340,7 @@ class BambaModel(nn.Module):
num_attn
=
0
num_attn
=
0
for
i
in
range
(
len
(
self
.
layers
)):
for
i
in
range
(
len
(
self
.
layers
)):
layer
=
self
.
layers
[
i
]
layer
=
self
.
layers
[
i
]
kv_cache
=
None
if
isinstance
(
layer
,
BambaAttentionDecoderLayer
):
if
isinstance
(
layer
,
BambaAttentionDecoderLayer
):
kv_cache
=
kv_caches
[
num_attn
]
num_attn
+=
1
num_attn
+=
1
layer_mamba_cache_params
=
None
layer_mamba_cache_params
=
None
...
@@ -361,8 +351,6 @@ class BambaModel(nn.Module):
...
@@ -361,8 +351,6 @@ class BambaModel(nn.Module):
hidden_states
,
residual
=
layer
(
hidden_states
,
residual
=
layer
(
positions
=
positions
,
positions
=
positions
,
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
residual
=
residual
,
residual
=
residual
,
mamba_cache_params
=
layer_mamba_cache_params
,
mamba_cache_params
=
layer_mamba_cache_params
,
sequence_idx
=
seq_idx
,
sequence_idx
=
seq_idx
,
...
@@ -440,8 +428,6 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
...
@@ -440,8 +428,6 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
def
forward
(
self
,
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
):
**
kwargs
):
...
@@ -454,8 +440,7 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
...
@@ -454,8 +440,7 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
self
.
vllm_config
,
self
.
lm_head
.
weight
.
dtype
,
num_mamba_layers
,
self
.
vllm_config
,
self
.
lm_head
.
weight
.
dtype
,
num_mamba_layers
,
*
self
.
_get_mamba_cache_shape
())
*
self
.
_get_mamba_cache_shape
())
mamba_cache_params
=
self
.
mamba_cache
.
current_run_tensors
(
**
kwargs
)
mamba_cache_params
=
self
.
mamba_cache
.
current_run_tensors
(
**
kwargs
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
mamba_cache_params
,
attn_metadata
,
mamba_cache_params
,
intermediate_tensors
,
inputs_embeds
)
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
return
hidden_states
...
...
vllm/model_executor/models/bart.py
View file @
cdc1fa12
...
@@ -19,14 +19,14 @@
...
@@ -19,14 +19,14 @@
# limitations under the License.
# limitations under the License.
"""PyTorch BART model."""
"""PyTorch BART model."""
import
math
import
math
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
from
typing
import
Iterable
,
Optional
,
Tuple
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
BartConfig
from
transformers
import
BartConfig
from
transformers.utils
import
logging
from
transformers.utils
import
logging
from
vllm.attention
import
Attention
,
AttentionMetadata
,
AttentionType
from
vllm.attention
import
Attention
,
AttentionType
from
vllm.config
import
CacheConfig
,
LoRAConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
LoRAConfig
,
VllmConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.activation
import
get_act_fn
...
@@ -181,14 +181,13 @@ class BartEncoderAttention(nn.Module):
...
@@ -181,14 +181,13 @@ class BartEncoderAttention(nn.Module):
prefix
=
f
"
{
prefix
}
.attn"
,
prefix
=
f
"
{
prefix
}
.attn"
,
attn_type
=
AttentionType
.
ENCODER
)
attn_type
=
AttentionType
.
ENCODER
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
attn_metadata
:
AttentionMetadata
)
->
torch
.
Tensor
:
"""Input shape: Batch x Time x Channel"""
"""Input shape: Batch x Time x Channel"""
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
out_proj
(
attn_output
)
output
,
_
=
self
.
out_proj
(
attn_output
)
return
output
return
output
...
@@ -261,14 +260,13 @@ class BartDecoderSelfAttention(nn.Module):
...
@@ -261,14 +260,13 @@ class BartDecoderSelfAttention(nn.Module):
prefix
=
f
"
{
prefix
}
.attn"
,
prefix
=
f
"
{
prefix
}
.attn"
,
attn_type
=
AttentionType
.
DECODER
)
attn_type
=
AttentionType
.
DECODER
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
attn_metadata
:
AttentionMetadata
)
->
torch
.
Tensor
:
"""Input shape: Batch x Time x Channel"""
"""Input shape: Batch x Time x Channel"""
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
out_proj
(
attn_output
)
output
,
_
=
self
.
out_proj
(
attn_output
)
return
output
return
output
...
@@ -344,8 +342,6 @@ class BartCrossAttention(nn.Module):
...
@@ -344,8 +342,6 @@ class BartCrossAttention(nn.Module):
def
forward
(
def
forward
(
self
,
self
,
decoder_hidden_states
:
torch
.
Tensor
,
decoder_hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
encoder_hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
,
encoder_hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Input shape: Batch x Time x Channel"""
"""Input shape: Batch x Time x Channel"""
...
@@ -363,7 +359,7 @@ class BartCrossAttention(nn.Module):
...
@@ -363,7 +359,7 @@ class BartCrossAttention(nn.Module):
_
,
k
,
v
=
qkv_enc
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
_
,
k
,
v
=
qkv_enc
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
dim
=-
1
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
out_proj
(
attn_output
)
output
,
_
=
self
.
out_proj
(
attn_output
)
return
output
return
output
...
@@ -411,23 +407,16 @@ class BartEncoderLayer(nn.Module):
...
@@ -411,23 +407,16 @@ class BartEncoderLayer(nn.Module):
self
.
final_layer_norm
=
nn
.
LayerNorm
(
self
.
embed_dim
)
self
.
final_layer_norm
=
nn
.
LayerNorm
(
self
.
embed_dim
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
attn_metadata
:
AttentionMetadata
)
->
torch
.
Tensor
:
r
"""
r
"""
Args:
Args:
hidden_states
hidden_states
torch.Tensor of *encoder* input embeddings.
torch.Tensor of *encoder* input embeddings.
kv_cache:
Layer-wise list of KV cache tensors
attn_metadata:
vLLM Attention metadata structure
Returns:
Returns:
Encoder layer output torch.Tensor
Encoder layer output torch.Tensor
"""
"""
residual
=
hidden_states
residual
=
hidden_states
hidden_states
=
self
.
self_attn
(
hidden_states
=
hidden_states
,
hidden_states
=
self
.
self_attn
(
hidden_states
=
hidden_states
)
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
)
hidden_states
=
residual
+
hidden_states
hidden_states
=
residual
+
hidden_states
hidden_states
=
self
.
self_attn_layer_norm
(
hidden_states
)
hidden_states
=
self
.
self_attn_layer_norm
(
hidden_states
)
...
@@ -509,18 +498,12 @@ class BartDecoderLayer(nn.Module):
...
@@ -509,18 +498,12 @@ class BartDecoderLayer(nn.Module):
def
forward
(
def
forward
(
self
,
self
,
decoder_hidden_states
:
torch
.
Tensor
,
decoder_hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
encoder_hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
,
encoder_hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
r
"""
r
"""
Args:
Args:
decoder_hidden_states
decoder_hidden_states
torch.Tensor of *decoder* input embeddings.
torch.Tensor of *decoder* input embeddings.
kv_cache:
KV cache tensor
attn_metadata:
vLLM Attention metadata structure
encoder_hidden_states
encoder_hidden_states
torch.Tensor of *encoder* input embeddings.
torch.Tensor of *encoder* input embeddings.
Returns:
Returns:
...
@@ -529,9 +512,7 @@ class BartDecoderLayer(nn.Module):
...
@@ -529,9 +512,7 @@ class BartDecoderLayer(nn.Module):
residual
=
decoder_hidden_states
residual
=
decoder_hidden_states
# Self Attention
# Self Attention
hidden_states
=
self
.
self_attn
(
hidden_states
=
decoder_hidden_states
,
hidden_states
=
self
.
self_attn
(
hidden_states
=
decoder_hidden_states
)
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
)
hidden_states
=
residual
+
hidden_states
hidden_states
=
residual
+
hidden_states
hidden_states
=
self
.
self_attn_layer_norm
(
hidden_states
)
hidden_states
=
self
.
self_attn_layer_norm
(
hidden_states
)
...
@@ -542,8 +523,6 @@ class BartDecoderLayer(nn.Module):
...
@@ -542,8 +523,6 @@ class BartDecoderLayer(nn.Module):
hidden_states
=
self
.
encoder_attn
(
hidden_states
=
self
.
encoder_attn
(
decoder_hidden_states
=
hidden_states
,
decoder_hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
encoder_hidden_states
=
encoder_hidden_states
,
encoder_hidden_states
=
encoder_hidden_states
,
)
)
...
@@ -609,9 +588,8 @@ class BartEncoder(nn.Module):
...
@@ -609,9 +588,8 @@ class BartEncoder(nn.Module):
self
.
layernorm_embedding
=
nn
.
LayerNorm
(
embed_dim
)
self
.
layernorm_embedding
=
nn
.
LayerNorm
(
embed_dim
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
positions
:
torch
.
Tensor
)
->
torch
.
Tensor
:
attn_metadata
:
AttentionMetadata
)
->
torch
.
Tensor
:
r
"""
r
"""
Args:
Args:
input_ids
input_ids
...
@@ -620,10 +598,6 @@ class BartEncoder(nn.Module):
...
@@ -620,10 +598,6 @@ class BartEncoder(nn.Module):
provide it.
provide it.
positions
positions
Positions of *encoder* input sequence tokens.
Positions of *encoder* input sequence tokens.
kv_caches:
Layer-wise list of KV cache tensors
attn_metadata:
vLLM Attention metadata structure
Returns:
Returns:
Decoder output torch.Tensor
Decoder output torch.Tensor
"""
"""
...
@@ -636,12 +610,8 @@ class BartEncoder(nn.Module):
...
@@ -636,12 +610,8 @@ class BartEncoder(nn.Module):
hidden_states
=
inputs_embeds
+
embed_pos
hidden_states
=
inputs_embeds
+
embed_pos
hidden_states
=
self
.
layernorm_embedding
(
hidden_states
)
hidden_states
=
self
.
layernorm_embedding
(
hidden_states
)
for
idx
,
encoder_layer
in
enumerate
(
self
.
layers
):
for
encoder_layer
in
self
.
layers
:
hidden_states
=
encoder_layer
(
hidden_states
=
encoder_layer
(
hidden_states
=
hidden_states
)
hidden_states
=
hidden_states
,
kv_cache
=
kv_caches
[
idx
],
attn_metadata
=
attn_metadata
,
)
return
hidden_states
return
hidden_states
...
@@ -693,9 +663,7 @@ class BartDecoder(nn.Module):
...
@@ -693,9 +663,7 @@ class BartDecoder(nn.Module):
def
forward
(
self
,
decoder_input_ids
:
torch
.
Tensor
,
def
forward
(
self
,
decoder_input_ids
:
torch
.
Tensor
,
decoder_positions
:
torch
.
Tensor
,
decoder_positions
:
torch
.
Tensor
,
encoder_hidden_states
:
Optional
[
torch
.
Tensor
],
encoder_hidden_states
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
)
->
torch
.
Tensor
:
r
"""
r
"""
Args:
Args:
decoder_input_ids
decoder_input_ids
...
@@ -706,10 +674,6 @@ class BartDecoder(nn.Module):
...
@@ -706,10 +674,6 @@ class BartDecoder(nn.Module):
Positions of *decoder* input sequence tokens.
Positions of *decoder* input sequence tokens.
encoder_hidden_states:
encoder_hidden_states:
Tensor of encoder output embeddings
Tensor of encoder output embeddings
kv_caches:
Layer-wise list of KV cache tensors
attn_metadata:
vLLM Attention metadata structure
Returns:
Returns:
Decoder output torch.Tensor
Decoder output torch.Tensor
"""
"""
...
@@ -725,11 +689,9 @@ class BartDecoder(nn.Module):
...
@@ -725,11 +689,9 @@ class BartDecoder(nn.Module):
# decoder layers
# decoder layers
for
idx
,
decoder_layer
in
enumerate
(
self
.
layers
)
:
for
decoder_layer
in
self
.
layers
:
hidden_states
=
decoder_layer
(
hidden_states
=
decoder_layer
(
decoder_hidden_states
=
hidden_states
,
decoder_hidden_states
=
hidden_states
,
kv_cache
=
kv_caches
[
idx
],
attn_metadata
=
attn_metadata
,
encoder_hidden_states
=
encoder_hidden_states
,
encoder_hidden_states
=
encoder_hidden_states
,
)
)
...
@@ -768,8 +730,7 @@ class BartModel(nn.Module):
...
@@ -768,8 +730,7 @@ class BartModel(nn.Module):
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
encoder_input_ids
:
torch
.
Tensor
,
encoder_input_ids
:
torch
.
Tensor
,
encoder_positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
encoder_positions
:
torch
.
Tensor
)
->
torch
.
Tensor
:
attn_metadata
:
AttentionMetadata
)
->
torch
.
Tensor
:
r
"""
r
"""
Args:
Args:
input_ids
input_ids
...
@@ -782,10 +743,6 @@ class BartModel(nn.Module):
...
@@ -782,10 +743,6 @@ class BartModel(nn.Module):
Indices of *encoder* input sequence tokens in the vocabulary.
Indices of *encoder* input sequence tokens in the vocabulary.
encoder_positions:
encoder_positions:
Positions of *encoder* input sequence tokens.
Positions of *encoder* input sequence tokens.
kv_caches:
Layer-wise list of KV cache tensors
attn_metadata:
vLLM Attention metadata structure
Returns:
Returns:
Model output torch.Tensor
Model output torch.Tensor
"""
"""
...
@@ -796,18 +753,14 @@ class BartModel(nn.Module):
...
@@ -796,18 +753,14 @@ class BartModel(nn.Module):
# Run encoder attention if a non-zero number of encoder tokens
# Run encoder attention if a non-zero number of encoder tokens
# are provided as input
# are provided as input
encoder_hidden_states
=
self
.
encoder
(
input_ids
=
encoder_input_ids
,
encoder_hidden_states
=
self
.
encoder
(
input_ids
=
encoder_input_ids
,
positions
=
encoder_positions
,
positions
=
encoder_positions
)
kv_caches
=
kv_caches
,
attn_metadata
=
attn_metadata
)
# decoder outputs consists of
# decoder outputs consists of
# (dec_features, past_key_value, dec_hidden, dec_attn)
# (dec_features, past_key_value, dec_hidden, dec_attn)
decoder_outputs
=
self
.
decoder
(
decoder_outputs
=
self
.
decoder
(
decoder_input_ids
=
input_ids
,
decoder_input_ids
=
input_ids
,
decoder_positions
=
positions
,
decoder_positions
=
positions
,
encoder_hidden_states
=
encoder_hidden_states
,
encoder_hidden_states
=
encoder_hidden_states
)
kv_caches
=
kv_caches
,
attn_metadata
=
attn_metadata
)
return
decoder_outputs
return
decoder_outputs
...
@@ -845,8 +798,6 @@ class BartForConditionalGeneration(nn.Module):
...
@@ -845,8 +798,6 @@ class BartForConditionalGeneration(nn.Module):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
*
,
*
,
encoder_input_ids
:
torch
.
Tensor
,
encoder_input_ids
:
torch
.
Tensor
,
...
@@ -863,15 +814,11 @@ class BartForConditionalGeneration(nn.Module):
...
@@ -863,15 +814,11 @@ class BartForConditionalGeneration(nn.Module):
torch.Tensor of *encoder* input token ids.
torch.Tensor of *encoder* input token ids.
encoder_positions
encoder_positions
torch.Tensor of *encoder* position indices
torch.Tensor of *encoder* position indices
kv_caches:
Layer-wise list of KV cache tensors
attn_metadata:
vLLM Attention metadata structure
Returns:
Returns:
Output torch.Tensor
Output torch.Tensor
"""
"""
return
self
.
model
(
input_ids
,
positions
,
encoder_input_ids
,
return
self
.
model
(
input_ids
,
positions
,
encoder_input_ids
,
encoder_positions
,
kv_caches
,
attn_metadata
)
encoder_positions
)
def
compute_logits
(
def
compute_logits
(
self
,
self
,
...
...
vllm/model_executor/models/bert.py
View file @
cdc1fa12
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
from
typing
import
Iterable
,
Optional
,
Set
,
Tuple
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
BertConfig
from
transformers
import
BertConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
,
AttentionType
from
vllm.attention
import
Attention
,
AttentionType
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
PoolerConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
PoolerConfig
,
VllmConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.forward_context
import
get_forward_context
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
QKVParallelLinear
,
QKVParallelLinear
,
...
@@ -113,12 +114,9 @@ class BertEncoder(nn.Module):
...
@@ -113,12 +114,9 @@ class BertEncoder(nn.Module):
def
forward
(
def
forward
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
for
i
in
range
(
len
(
self
.
layer
)):
for
layer
in
self
.
layer
:
layer
=
self
.
layer
[
i
]
hidden_states
=
layer
(
hidden_states
)
hidden_states
=
layer
(
hidden_states
,
kv_caches
[
i
],
attn_metadata
)
return
hidden_states
return
hidden_states
...
@@ -152,13 +150,8 @@ class BertLayer(nn.Module):
...
@@ -152,13 +150,8 @@ class BertLayer(nn.Module):
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.output"
)
prefix
=
f
"
{
prefix
}
.output"
)
def
forward
(
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
):
self
,
attn_output
=
self
.
attention
(
hidden_states
)
hidden_states
:
torch
.
Tensor
,
kv_cache
:
Optional
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
):
attn_output
=
self
.
attention
(
hidden_states
,
kv_cache
,
attn_metadata
)
intermediate_output
=
self
.
intermediate
(
attn_output
)
intermediate_output
=
self
.
intermediate
(
attn_output
)
output
=
self
.
output
(
intermediate_output
,
attn_output
)
output
=
self
.
output
(
intermediate_output
,
attn_output
)
return
output
return
output
...
@@ -191,10 +184,8 @@ class BertAttention(nn.Module):
...
@@ -191,10 +184,8 @@ class BertAttention(nn.Module):
def
forward
(
def
forward
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
self_output
=
self
.
self
(
hidden_states
,
kv_cache
,
attn_metadata
)
self_output
=
self
.
self
(
hidden_states
)
return
self
.
output
(
self_output
,
hidden_states
)
return
self
.
output
(
self_output
,
hidden_states
)
...
@@ -246,12 +237,10 @@ class BertSelfAttention(nn.Module):
...
@@ -246,12 +237,10 @@ class BertSelfAttention(nn.Module):
def
forward
(
def
forward
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
output
=
self
.
attn
(
q
,
k
,
v
)
return
output
return
output
...
@@ -343,8 +332,6 @@ class BertModel(nn.Module):
...
@@ -343,8 +332,6 @@ class BertModel(nn.Module):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
token_type_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
token_type_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -352,13 +339,14 @@ class BertModel(nn.Module):
...
@@ -352,13 +339,14 @@ class BertModel(nn.Module):
if
inputs_embeds
is
not
None
:
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
hidden_states
=
inputs_embeds
else
:
else
:
attn_metadata
=
get_forward_context
().
attn_metadata
assert
hasattr
(
attn_metadata
,
"seq_lens_tensor"
)
assert
hasattr
(
attn_metadata
,
"seq_lens_tensor"
)
hidden_states
=
self
.
embeddings
(
hidden_states
=
self
.
embeddings
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
seq_lens
=
attn_metadata
.
seq_lens_tensor
,
seq_lens
=
attn_metadata
.
seq_lens_tensor
,
position_ids
=
position_ids
,
position_ids
=
position_ids
,
token_type_ids
=
token_type_ids
)
token_type_ids
=
token_type_ids
)
return
self
.
encoder
(
hidden_states
,
kv_caches
,
attn_metadata
)
return
self
.
encoder
(
hidden_states
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
torch
.
Tensor
]])
->
Set
[
str
]:
...
@@ -420,17 +408,13 @@ class BertEmbeddingModel(nn.Module):
...
@@ -420,17 +408,13 @@ class BertEmbeddingModel(nn.Module):
self
,
self
,
input_ids
:
Optional
[
torch
.
Tensor
],
input_ids
:
Optional
[
torch
.
Tensor
],
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
self
.
model
(
input_ids
=
input_ids
,
return
self
.
model
(
input_ids
=
input_ids
,
position_ids
=
positions
,
position_ids
=
positions
,
kv_caches
=
kv_caches
,
inputs_embeds
=
inputs_embeds
,
inputs_embeds
=
inputs_embeds
,
intermediate_tensors
=
intermediate_tensors
,
intermediate_tensors
=
intermediate_tensors
)
attn_metadata
=
attn_metadata
)
def
pooler
(
def
pooler
(
self
,
self
,
...
@@ -519,16 +503,12 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding):
...
@@ -519,16 +503,12 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding):
self
,
self
,
input_ids
:
Optional
[
torch
.
Tensor
],
input_ids
:
Optional
[
torch
.
Tensor
],
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
token_type_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
token_type_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
self
.
bert
(
input_ids
=
input_ids
,
return
self
.
bert
(
input_ids
=
input_ids
,
position_ids
=
positions
,
position_ids
=
positions
,
kv_caches
=
kv_caches
,
inputs_embeds
=
inputs_embeds
,
inputs_embeds
=
inputs_embeds
,
intermediate_tensors
=
intermediate_tensors
,
intermediate_tensors
=
intermediate_tensors
,
attn_metadata
=
attn_metadata
,
token_type_ids
=
token_type_ids
)
token_type_ids
=
token_type_ids
)
vllm/model_executor/models/blip2.py
View file @
cdc1fa12
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
functools
import
cached_property
from
functools
import
cached_property
from
typing
import
(
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Set
,
Tuple
,
from
typing
import
(
Iterable
,
Literal
,
Mapping
,
Optional
,
Set
,
Tuple
,
TypedDict
,
Union
)
TypedDict
,
Union
)
import
torch
import
torch
...
@@ -9,7 +9,6 @@ import torch.nn as nn
...
@@ -9,7 +9,6 @@ import torch.nn as nn
from
transformers
import
(
BatchFeature
,
Blip2Config
,
Blip2QFormerConfig
,
from
transformers
import
(
BatchFeature
,
Blip2Config
,
Blip2QFormerConfig
,
apply_chunking_to_forward
)
apply_chunking_to_forward
)
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
...
@@ -658,8 +657,6 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -658,8 +657,6 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
:
object
,
**
kwargs
:
object
,
...
@@ -708,8 +705,6 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -708,8 +705,6 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
hidden_states
=
self
.
language_model
.
model
(
input_ids
,
hidden_states
=
self
.
language_model
.
model
(
input_ids
,
positions
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
,
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
)
inputs_embeds
=
inputs_embeds
)
...
...
vllm/model_executor/models/bloom.py
View file @
cdc1fa12
...
@@ -18,13 +18,13 @@
...
@@ -18,13 +18,13 @@
# limitations under the License.
# limitations under the License.
"""Inference-only BLOOM model compatible with HuggingFace weights."""
"""Inference-only BLOOM model compatible with HuggingFace weights."""
import
math
import
math
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
Iterable
,
Optional
,
Set
,
Tuple
,
Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
BloomConfig
from
transformers
import
BloomConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
...
@@ -126,13 +126,11 @@ class BloomAttention(nn.Module):
...
@@ -126,13 +126,11 @@ class BloomAttention(nn.Module):
self
,
self
,
position_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
del
position_ids
# Unused.
del
position_ids
# Unused.
qkv
,
_
=
self
.
query_key_value
(
hidden_states
)
qkv
,
_
=
self
.
query_key_value
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
dense
(
attn_output
)
output
,
_
=
self
.
dense
(
attn_output
)
return
output
return
output
...
@@ -193,8 +191,6 @@ class BloomBlock(nn.Module):
...
@@ -193,8 +191,6 @@ class BloomBlock(nn.Module):
self
,
self
,
position_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# Layer norm at the beginning of the transformer layer.
# Layer norm at the beginning of the transformer layer.
layernorm_output
=
self
.
input_layernorm
(
hidden_states
)
layernorm_output
=
self
.
input_layernorm
(
hidden_states
)
...
@@ -209,8 +205,6 @@ class BloomBlock(nn.Module):
...
@@ -209,8 +205,6 @@ class BloomBlock(nn.Module):
attention_output
=
self
.
self_attention
(
attention_output
=
self
.
self_attention
(
position_ids
=
position_ids
,
position_ids
=
position_ids
,
hidden_states
=
layernorm_output
,
hidden_states
=
layernorm_output
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
)
attention_output
=
attention_output
+
residual
attention_output
=
attention_output
+
residual
layernorm_output
=
self
.
post_attention_layernorm
(
attention_output
)
layernorm_output
=
self
.
post_attention_layernorm
(
attention_output
)
...
@@ -266,8 +260,6 @@ class BloomModel(nn.Module):
...
@@ -266,8 +260,6 @@ class BloomModel(nn.Module):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
...
@@ -279,14 +271,8 @@ class BloomModel(nn.Module):
...
@@ -279,14 +271,8 @@ class BloomModel(nn.Module):
else
:
else
:
assert
intermediate_tensors
is
not
None
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
for
layer
in
self
.
h
[
self
.
start_layer
:
self
.
end_layer
]:
layer
=
self
.
h
[
i
]
hidden_states
=
layer
(
position_ids
,
hidden_states
)
hidden_states
=
layer
(
position_ids
,
hidden_states
,
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
)
if
not
get_pp_group
().
is_last_rank
:
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
})
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
})
hidden_states
=
self
.
ln_f
(
hidden_states
)
hidden_states
=
self
.
ln_f
(
hidden_states
)
...
@@ -322,14 +308,11 @@ class BloomForCausalLM(nn.Module, SupportsPP):
...
@@ -322,14 +308,11 @@ class BloomForCausalLM(nn.Module, SupportsPP):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
attn_metadata
,
intermediate_tensors
,
intermediate_tensors
,
inputs_embeds
)
inputs_embeds
)
return
hidden_states
return
hidden_states
def
compute_logits
(
def
compute_logits
(
...
...
vllm/model_executor/models/chameleon.py
View file @
cdc1fa12
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
functools
import
cached_property
from
functools
import
cached_property
from
typing
import
(
Any
,
Dict
,
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Set
,
from
typing
import
(
Any
,
Dict
,
Iterable
,
Literal
,
Mapping
,
Optional
,
Set
,
Tuple
,
TypedDict
,
Union
)
Tuple
,
TypedDict
,
Union
)
import
torch
import
torch
...
@@ -10,7 +10,7 @@ import torch.nn.functional as F
...
@@ -10,7 +10,7 @@ import torch.nn.functional as F
from
transformers
import
(
BatchFeature
,
ChameleonConfig
,
ChameleonProcessor
,
from
transformers
import
(
BatchFeature
,
ChameleonConfig
,
ChameleonProcessor
,
ChameleonVQVAEConfig
)
ChameleonVQVAEConfig
)
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -310,15 +310,13 @@ class ChameleonAttention(nn.Module):
...
@@ -310,15 +310,13 @@ class ChameleonAttention(nn.Module):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
_apply_qk_norm
(
q
,
k
)
q
,
k
=
self
.
_apply_qk_norm
(
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
o_proj
(
attn_output
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
return
output
...
@@ -372,8 +370,6 @@ class ChameleonDecoderLayer(nn.Module):
...
@@ -372,8 +370,6 @@ class ChameleonDecoderLayer(nn.Module):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
...
@@ -386,8 +382,6 @@ class ChameleonDecoderLayer(nn.Module):
...
@@ -386,8 +382,6 @@ class ChameleonDecoderLayer(nn.Module):
hidden_states
=
self
.
self_attn
(
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
positions
=
positions
,
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
)
# Fully Connected
# Fully Connected
...
@@ -447,8 +441,6 @@ class ChameleonSwinDecoderLayer(nn.Module):
...
@@ -447,8 +441,6 @@ class ChameleonSwinDecoderLayer(nn.Module):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
...
@@ -456,8 +448,6 @@ class ChameleonSwinDecoderLayer(nn.Module):
...
@@ -456,8 +448,6 @@ class ChameleonSwinDecoderLayer(nn.Module):
hidden_states
=
self
.
self_attn
(
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
positions
=
positions
,
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
)
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
...
@@ -906,8 +896,6 @@ class ChameleonModel(nn.Module):
...
@@ -906,8 +896,6 @@ class ChameleonModel(nn.Module):
self
,
self
,
input_ids
:
Optional
[
torch
.
Tensor
],
input_ids
:
Optional
[
torch
.
Tensor
],
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
...
@@ -921,13 +909,10 @@ class ChameleonModel(nn.Module):
...
@@ -921,13 +909,10 @@ class ChameleonModel(nn.Module):
assert
intermediate_tensors
is
not
None
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
residual
=
intermediate_tensors
[
"residual"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
for
layer
in
self
.
layers
[
self
.
start_layer
:
self
.
end_layer
]:
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
hidden_states
,
residual
=
layer
(
positions
,
positions
,
hidden_states
,
hidden_states
,
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
residual
,
residual
,
)
)
if
not
get_pp_group
().
is_last_rank
:
if
not
get_pp_group
().
is_last_rank
:
...
@@ -1028,8 +1013,6 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -1028,8 +1013,6 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
,
**
kwargs
,
...
@@ -1048,8 +1031,6 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -1048,8 +1031,6 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
hidden_states
=
self
.
model
(
input_ids
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
,
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
)
inputs_embeds
=
inputs_embeds
)
return
hidden_states
return
hidden_states
...
...
vllm/model_executor/models/chatglm.py
View file @
cdc1fa12
...
@@ -2,13 +2,13 @@
...
@@ -2,13 +2,13 @@
# Adapted from
# Adapted from
# https://github.com/THUDM/ChatGLM2-6B
# https://github.com/THUDM/ChatGLM2-6B
"""Inference-only ChatGLM model compatible with THUDM weights."""
"""Inference-only ChatGLM model compatible with THUDM weights."""
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
Iterable
,
Optional
,
Set
,
Tuple
,
Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
torch.nn
import
LayerNorm
from
torch.nn
import
LayerNorm
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
...
@@ -108,19 +108,11 @@ class GLMAttention(nn.Module):
...
@@ -108,19 +108,11 @@ class GLMAttention(nn.Module):
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
query_key_value
(
hidden_states
)
qkv
,
_
=
self
.
query_key_value
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
position_ids
,
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
position_ids
,
q
,
k
)
context_layer
=
self
.
attn
(
context_layer
=
self
.
attn
(
q
,
k
,
v
)
q
,
k
,
v
,
kv_cache
,
attn_metadata
,
)
attn_output
,
_
=
self
.
dense
(
context_layer
)
attn_output
,
_
=
self
.
dense
(
context_layer
)
return
attn_output
return
attn_output
...
@@ -215,8 +207,6 @@ class GLMBlock(nn.Module):
...
@@ -215,8 +207,6 @@ class GLMBlock(nn.Module):
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# hidden_states: [num_tokens, h]
# hidden_states: [num_tokens, h]
# Layer norm at the beginning of the transformer layer.
# Layer norm at the beginning of the transformer layer.
...
@@ -225,8 +215,6 @@ class GLMBlock(nn.Module):
...
@@ -225,8 +215,6 @@ class GLMBlock(nn.Module):
attention_output
=
self
.
self_attention
(
attention_output
=
self
.
self_attention
(
hidden_states
=
layernorm_output
,
hidden_states
=
layernorm_output
,
position_ids
=
position_ids
,
position_ids
=
position_ids
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
)
# Residual connection.
# Residual connection.
...
@@ -289,17 +277,10 @@ class GLMTransformer(nn.Module):
...
@@ -289,17 +277,10 @@ class GLMTransformer(nn.Module):
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
for
layer
in
self
.
layers
[
self
.
start_layer
:
self
.
end_layer
]:
layer
=
self
.
layers
[
i
]
hidden_states
=
layer
(
hidden_states
=
hidden_states
,
hidden_states
=
layer
(
position_ids
=
position_ids
)
hidden_states
=
hidden_states
,
position_ids
=
position_ids
,
kv_cache
=
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
=
attn_metadata
,
)
if
not
get_pp_group
().
is_last_rank
:
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
})
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
})
...
@@ -350,8 +331,6 @@ class ChatGLMModel(nn.Module):
...
@@ -350,8 +331,6 @@ class ChatGLMModel(nn.Module):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
:
object
,
**
kwargs
:
object
,
...
@@ -369,8 +348,6 @@ class ChatGLMModel(nn.Module):
...
@@ -369,8 +348,6 @@ class ChatGLMModel(nn.Module):
hidden_states
=
self
.
encoder
(
hidden_states
=
self
.
encoder
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
position_ids
=
positions
,
position_ids
=
positions
,
kv_caches
=
kv_caches
,
attn_metadata
=
attn_metadata
,
)
)
return
hidden_states
return
hidden_states
...
@@ -494,12 +471,9 @@ class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP):
...
@@ -494,12 +471,9 @@ class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
attn_metadata
,
intermediate_tensors
,
intermediate_tensors
,
inputs_embeds
)
inputs_embeds
)
return
hidden_states
return
hidden_states
vllm/model_executor/models/commandr.py
View file @
cdc1fa12
...
@@ -21,14 +21,14 @@
...
@@ -21,14 +21,14 @@
# This file is based on the LLama model definition file in transformers
# This file is based on the LLama model definition file in transformers
"""PyTorch Cohere model."""
"""PyTorch Cohere model."""
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
Iterable
,
Optional
,
Set
,
Tuple
,
Union
import
torch
import
torch
import
torch.utils.checkpoint
import
torch.utils.checkpoint
from
torch
import
nn
from
torch
import
nn
from
transformers
import
CohereConfig
from
transformers
import
CohereConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
...
@@ -218,8 +218,6 @@ class CohereAttention(nn.Module):
...
@@ -218,8 +218,6 @@ class CohereAttention(nn.Module):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
...
@@ -227,7 +225,7 @@ class CohereAttention(nn.Module):
...
@@ -227,7 +225,7 @@ class CohereAttention(nn.Module):
q
,
k
=
self
.
_apply_qk_norm
(
q
,
k
)
q
,
k
=
self
.
_apply_qk_norm
(
q
,
k
)
if
self
.
v1
or
self
.
sliding_window
:
if
self
.
v1
or
self
.
sliding_window
:
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
o_proj
(
attn_output
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
return
output
...
@@ -255,8 +253,6 @@ class CohereDecoderLayer(nn.Module):
...
@@ -255,8 +253,6 @@ class CohereDecoderLayer(nn.Module):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
# Self Attention
...
@@ -265,8 +261,6 @@ class CohereDecoderLayer(nn.Module):
...
@@ -265,8 +261,6 @@ class CohereDecoderLayer(nn.Module):
hidden_states_attention
=
self
.
self_attn
(
hidden_states_attention
=
self
.
self_attn
(
positions
=
positions
,
positions
=
positions
,
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
)
hidden_states_mlp
=
self
.
mlp
(
hidden_states
)
hidden_states_mlp
=
self
.
mlp
(
hidden_states
)
# Add everything together
# Add everything together
...
@@ -311,8 +305,6 @@ class CohereModel(nn.Module):
...
@@ -311,8 +305,6 @@ class CohereModel(nn.Module):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
...
@@ -326,13 +318,10 @@ class CohereModel(nn.Module):
...
@@ -326,13 +318,10 @@ class CohereModel(nn.Module):
assert
intermediate_tensors
is
not
None
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
residual
=
intermediate_tensors
[
"residual"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
for
layer
in
self
.
layers
[
self
.
start_layer
:
self
.
end_layer
]:
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
hidden_states
,
residual
=
layer
(
positions
,
positions
,
hidden_states
,
hidden_states
,
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
residual
,
residual
,
)
)
if
not
get_pp_group
().
is_last_rank
:
if
not
get_pp_group
().
is_last_rank
:
...
@@ -389,13 +378,10 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -389,13 +378,10 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
inputs_embeds
)
return
hidden_states
return
hidden_states
...
...
vllm/model_executor/models/dbrx.py
View file @
cdc1fa12
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
Iterable
,
Optional
,
Set
,
Tuple
,
Union
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_world_size
)
...
@@ -230,15 +230,13 @@ class DbrxAttention(nn.Module):
...
@@ -230,15 +230,13 @@ class DbrxAttention(nn.Module):
self
,
self
,
position_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
Wqkv
(
hidden_states
)
qkv
,
_
=
self
.
Wqkv
(
hidden_states
)
if
self
.
clip_qkv
is
not
None
:
if
self
.
clip_qkv
is
not
None
:
qkv
.
clamp_
(
min
=-
self
.
clip_qkv
,
max
=
self
.
clip_qkv
)
qkv
.
clamp_
(
min
=-
self
.
clip_qkv
,
max
=
self
.
clip_qkv
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
position_ids
,
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
position_ids
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
hidden_states
,
_
=
self
.
out_proj
(
attn_output
)
hidden_states
,
_
=
self
.
out_proj
(
attn_output
)
return
hidden_states
return
hidden_states
...
@@ -265,16 +263,12 @@ class DbrxFusedNormAttention(nn.Module):
...
@@ -265,16 +263,12 @@ class DbrxFusedNormAttention(nn.Module):
self
,
self
,
position_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
residual
=
hidden_states
residual
=
hidden_states
hidden_states
=
self
.
norm_1
(
hidden_states
)
hidden_states
=
self
.
norm_1
(
hidden_states
)
x
=
self
.
attn
(
x
=
self
.
attn
(
position_ids
=
position_ids
,
position_ids
=
position_ids
,
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
)
hidden_states
=
residual
+
x
hidden_states
=
residual
+
x
residual
=
hidden_states
residual
=
hidden_states
...
@@ -303,14 +297,10 @@ class DbrxBlock(nn.Module):
...
@@ -303,14 +297,10 @@ class DbrxBlock(nn.Module):
self
,
self
,
position_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
,
residual
=
self
.
norm_attn_norm
(
hidden_states
,
residual
=
self
.
norm_attn_norm
(
position_ids
=
position_ids
,
position_ids
=
position_ids
,
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
)
hidden_states
=
self
.
ffn
(
hidden_states
)
hidden_states
=
self
.
ffn
(
hidden_states
)
hidden_states
=
hidden_states
+
residual
hidden_states
=
hidden_states
+
residual
...
@@ -353,8 +343,6 @@ class DbrxModel(nn.Module):
...
@@ -353,8 +343,6 @@ class DbrxModel(nn.Module):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
...
@@ -366,14 +354,8 @@ class DbrxModel(nn.Module):
...
@@ -366,14 +354,8 @@ class DbrxModel(nn.Module):
else
:
else
:
assert
intermediate_tensors
assert
intermediate_tensors
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
for
block
in
self
.
blocks
[
self
.
start_layer
:
self
.
end_layer
]:
block
=
self
.
blocks
[
i
]
hidden_states
=
block
(
position_ids
,
hidden_states
)
hidden_states
=
block
(
position_ids
,
hidden_states
,
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
)
if
not
get_pp_group
().
is_last_rank
:
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
})
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
})
hidden_states
=
self
.
norm_f
(
hidden_states
)
hidden_states
=
self
.
norm_f
(
hidden_states
)
...
@@ -415,14 +397,11 @@ class DbrxForCausalLM(nn.Module, SupportsPP):
...
@@ -415,14 +397,11 @@ class DbrxForCausalLM(nn.Module, SupportsPP):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
attn_metadata
,
intermediate_tensors
,
intermediate_tensors
,
inputs_embeds
)
inputs_embeds
)
return
hidden_states
return
hidden_states
def
compute_logits
(
def
compute_logits
(
...
...
vllm/model_executor/models/deepseek.py
View file @
cdc1fa12
...
@@ -22,13 +22,13 @@
...
@@ -22,13 +22,13 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Inference-only Deepseek model."""
"""Inference-only Deepseek model."""
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Set
,
Tuple
,
Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
...
@@ -248,13 +248,11 @@ class DeepseekAttention(nn.Module):
...
@@ -248,13 +248,11 @@ class DeepseekAttention(nn.Module):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
o_proj
(
attn_output
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
return
output
...
@@ -309,8 +307,6 @@ class DeepseekDecoderLayer(nn.Module):
...
@@ -309,8 +307,6 @@ class DeepseekDecoderLayer(nn.Module):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
residual
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# Self Attention
# Self Attention
...
@@ -323,8 +319,6 @@ class DeepseekDecoderLayer(nn.Module):
...
@@ -323,8 +319,6 @@ class DeepseekDecoderLayer(nn.Module):
hidden_states
=
self
.
self_attn
(
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
positions
=
positions
,
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
)
# Fully Connected
# Fully Connected
...
@@ -370,8 +364,6 @@ class DeepseekModel(nn.Module):
...
@@ -370,8 +364,6 @@ class DeepseekModel(nn.Module):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
...
@@ -384,11 +376,8 @@ class DeepseekModel(nn.Module):
...
@@ -384,11 +376,8 @@ class DeepseekModel(nn.Module):
else
:
else
:
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
residual
=
intermediate_tensors
[
"residual"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
for
layer
in
self
.
layers
[
self
.
start_layer
:
self
.
end_layer
]:
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
residual
)
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
residual
)
if
not
get_pp_group
().
is_last_rank
:
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
,
"hidden_states"
:
hidden_states
,
...
@@ -425,13 +414,10 @@ class DeepseekForCausalLM(nn.Module, SupportsPP):
...
@@ -425,13 +414,10 @@ class DeepseekForCausalLM(nn.Module, SupportsPP):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
inputs_embeds
)
return
hidden_states
return
hidden_states
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment